In [1]:
import cv2 as cv 
import numpy as np
from pycocotools.coco import COCO
import os
import json
import seaborn as sns
import random
import matplotlib.pyplot as plt
import torch 
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import albumentations as album 
from sklearn.model_selection import train_test_split
from typing import Callable
from albumentations.pytorch import ToTensorV2
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

  check_for_updates()


In [2]:
DATASET_PATH = "data"
coco = COCO(f"{DATASET_PATH}/annotations.json")
df=pd.read_csv("data/metadata_splits.csv")
train_ids,valid_ids,test_ids = [],[],[]
for index, row in df.iterrows():
    if row['split_open'] == 'train':
        train_ids.append(row['id'])
    elif row['split_open'] == 'valid':
        valid_ids.append(row['id'])
    elif row['split_open'] == 'test':
        test_ids.append(row['id'])
print(len(train_ids), len(valid_ids), len(test_ids))

loading annotations into memory...
Done (t=1.50s)
creating index...
index created!
5303 1118 2308


In [3]:
def get_augmentation():
    transform = [
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
        album.RandomRotate90(p=0.5),
    ]
    return album.Compose(transform)

In [4]:
class TurtlesDataset(Dataset):
    def __init__(self, coco, image_ids,augmentation=None):
        self.coco = coco
        self.image_ids = image_ids
        self.augmentation = augmentation
        self.catIds = coco.getCatIds()
        self.label_priority = {
            'head': 3,  
            'flipper': 2,
            'turtle': 1 
        }

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image = self.coco.loadImgs([image_id])[0]
        image = cv.imread(f"{DATASET_PATH}/{image['file_name']}")
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.int32)
        priority_map = np.zeros_like(mask)  
        
        ann_ids = coco.getAnnIds(imgIds=image_id, catIds=self.catIds, iscrowd=None)
        anns = coco.loadAnns(ann_ids)
        
        for ann in anns:
            category_id = ann['category_id']
            category_name = coco.loadCats([category_id])[0]['name']
            ann_mask = coco.annToMask(ann)
            
            current_priority = self.label_priority.get(category_name, 0)
    
            mask = np.where((ann_mask == 1) & (priority_map == 0), category_id, mask)
            
            mask = np.where((ann_mask == 1) & (current_priority > priority_map), category_id, mask)
            
            priority_map = np.where(ann_mask == 1, np.maximum(priority_map, current_priority), priority_map)
        image = cv.resize(image, (512, 512))
        mask = cv.resize(mask, (512, 512), interpolation=cv.INTER_NEAREST)
        
        image = image.astype(np.float32)
        mask = mask.astype(np.float32)
        if self.augmentation:
            transformed = self.augmentation(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        transforms = [ToTensorV2()]     
        composed = album.Compose(transforms)
        transformed = composed(image=image, mask=mask)
        image, mask=transformed['image'], transformed['mask']        
        return image, mask

train_dataset = TurtlesDataset(coco, train_ids, 
                         augmentation=get_augmentation())
val_dataset = TurtlesDataset(coco, valid_ids,
                            augmentation=get_augmentation())

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiClassDiceFocalLoss(nn.Module):
    def __init__(self, smooth=1e-6, alpha=0.25, gamma=2.0, dice_weight=0.5, reduction='mean', apply_softmax=True):
        super(MultiClassDiceFocalLoss, self).__init__()
        self.smooth = smooth
        self.alpha = alpha
        self.gamma = gamma
        self.dice_weight = dice_weight
        self.reduction = reduction
        self.apply_softmax = apply_softmax

    def forward(self, inputs, targets):
        device = inputs.device
        num_classes = inputs.size(1)

        # Apply softmax if specified
        if self.apply_softmax:
            inputs = torch.softmax(inputs, dim=1)

        # Convert targets to one-hot encoding and ensure it's on the same device
        targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float().to(device)

        # Ensure shapes match for inputs and targets_one_hot
        if inputs.shape != targets_one_hot.shape:
            raise ValueError(f"Shape mismatch: inputs.shape = {inputs.shape}, targets_one_hot.shape = {targets_one_hot.shape}")

        # Dice Loss calculation
        dims = (0, 2, 3)
        intersection = torch.sum(inputs * targets_one_hot, dims)
        cardinality = torch.sum(inputs + targets_one_hot, dims)
        dice_loss = 1 - (2. * intersection + self.smooth) / (cardinality + self.smooth)

        # Calculate Dice Loss based on reduction type
        if self.reduction == 'mean':
            dice_loss = dice_loss.mean()
        elif self.reduction == 'sum':
            dice_loss = dice_loss.sum()
        
        # Focal Loss calculation
        ce_loss = -targets_one_hot * torch.log(inputs + 1e-8)  # Use a smaller epsilon for numerical stability
        focal_loss = self.alpha * (1 - inputs) ** self.gamma * ce_loss
        focal_loss = focal_loss.sum(dim=1)  # Sum over classes

        # Apply reduction to focal loss
        if self.reduction == 'mean':
            focal_loss = focal_loss.mean()
        elif self.reduction == 'sum':
            focal_loss = focal_loss.sum()

        # Combine Dice and Focal Loss
        combined_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * focal_loss
        return combined_loss


In [6]:
def calculate_iou(pred, target, num_classes):

    iou_per_class = []
    for cls in range(1,num_classes):
        pred_inds = (pred == cls)
        target_inds = (target == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        
        if union == 0:
            iou_per_class.append(float('nan')) 
        else:
            iou_per_class.append(intersection / union)
    
    return iou_per_class


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork

class SwinSegmentation(nn.Module):
    def __init__(self, num_classes):
        super(SwinSegmentation, self).__init__()
        
        # Create a Swin Transformer backbone model using timm
        self.backbone = timm.create_model(
            'swin_small_patch4_window7_224', 
            pretrained=True,
            features_only=True, 
            img_size=512
        )

        # Create an FPN (Feature Pyramid Network) for multi-scale feature fusion
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=[96, 192, 384, 768],  # Input channels from the backbone features
            out_channels=64
        )

        # Convolution layer for classification (outputs num_classes channels)
        self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        features = self.backbone(x) 
        # Convert features to (batch, channels, height, width) format
        features = [f.permute(0, 3, 1, 2) for f in features]
        feature_maps = {i: features[i] for i in range(len(features))}
        
        # Fuse multi-scale features using FPN
        fpn_out = self.fpn(feature_maps)
        out = self.classifier(fpn_out[0])  
        # Upsample to the input image size
        out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        
        return out


In [8]:
device='mps'
num_classes = 4 
model = SwinSegmentation(num_classes)
model.to(device)
criterion = MultiClassDiceFocalLoss(smooth=1e-6, alpha=0.25, gamma=2.0, dice_weight=0)
weights = [1.0, 1.3, 1.3, 1.3]
#criterion=WeightedCrossEntropyLoss(weights)
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-5,betas=(0.9, 0.999))

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
from tqdm import tqdm

num_epochs = 25
best_val_loss = float('inf')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
train_losses = []
val_losses = []
iou_per_epoch = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Training", leave=False)
    
    for images, masks in train_loader_tqdm:
        images = images.to(device, dtype=torch.float32)
        masks = masks.to(device, dtype=torch.long)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
        train_loader_tqdm.set_postfix(loss=loss.item())
    
    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_loss:.4f}')
    
    model.eval()
    val_loss = 0.0
    iou_scores = []
    val_loader_tqdm = tqdm(val_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Validation", leave=False)
    with torch.no_grad():
        for images, masks in val_loader_tqdm:
            images = images.to(device, dtype=torch.float32)
            masks = masks.to(device, dtype=torch.long)
            
            outputs = model(images)
            preds = outputs.argmax(dim=1)  
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            for i in range(preds.size(0)): 
                iou_per_class = calculate_iou(preds[i].cpu(), masks[i].cpu(), num_classes)
                iou_scores.append(iou_per_class)
            

            val_loader_tqdm.set_postfix(val_loss=loss.item())
            
    avg_val_loss = val_loss / len(val_loader)
    mean_iou = torch.tensor(iou_scores).nanmean(dim=0).tolist() 
    mean_iou_overall = torch.tensor(iou_scores).nanmean().item() 
    scheduler.step(avg_val_loss) 
    
    val_losses.append(avg_val_loss)  
    iou_per_epoch.append(mean_iou_overall)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_model.pth')
  
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}')
    print(f'Class-wise IOU: {mean_iou}')
    print(f'Mean IOU (Overall): {mean_iou_overall:.4f}')
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f'NestedUnet_epoch_{epoch+1}.pth')
        print(f"Model saved at epoch {epoch+1}")


Epoch [1/25] - Training:   1%|    | 39/5303 [00:14<30:30,  2.88it/s, loss=0.327]

In [None]:
num_epochs=[i+1 for i in range(num_epochs)]
plt.figure(figsize=(20, 12))
plt.plot(num_epochs, train_losses, label='Training Loss')
plt.plot(num_epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(20, 12))
plt.plot(num_epochs, iou_per_epoch, label='Mean IoU', color='tab:blue')
plt.xlabel('Epoch')
plt.ylabel('Mean IoU')
plt.title('Mean IoU over Epochs')
plt.legend()
plt.grid(True)
plt.show()