# Imports

In [1]:
!pip install torch torchvision segment-anything pycocotools pillow tqdm

Collecting pycocotools
  Downloading pycocotools-2.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Downloading pycocotools-2.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (427 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m427.8/427.8 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pycocotools
Successfully installed pycocotools-2.0.8


In [2]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from segment_anything import sam_model_registry
from pycocotools.coco import COCO
from PIL import Image
import numpy as np
import torchvision.transforms as T
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import logging
from pathlib import Path
from tqdm import tqdm

# Config

In [3]:
CHECKPOINT_PATH = "/kaggle/input/sam/other/default/1/sam_vit_l_0b3195.pth"
TRAIN_ROOT = '/kaggle/input/underwaterimageinstancesegmentation/UIIS/UDW/train'
TRAIN_ANN = '/kaggle/input/underwaterimageinstancesegmentation/UIIS/UDW/annotations/train.json'
VAL_ROOT = '/kaggle/input/underwaterimageinstancesegmentation/UIIS/UDW/val'
VAL_ANN = '/kaggle/input/underwaterimageinstancesegmentation/UIIS/UDW/annotations/val.json'

# Dataset Setup

In [4]:
class UnderwaterInstanceDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None, image_size=(1024, 1024)):
        self.root = root_dir
        self.coco = COCO(ann_file)
        self.ids = list(self.coco.imgs.keys())
        self.transform = transform
        self.image_size = image_size
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        img_id = self.ids[idx]
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        
        # Load image
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')
        
        # Resize image
        image = image.resize(self.image_size, Image.Resampling.BILINEAR)
        
        # Create mask
        masks = []
        labels = []
        for ann in anns:
            mask = self.coco.annToMask(ann)
            # Resize mask
            mask = Image.fromarray(mask)
            mask = mask.resize(self.image_size, Image.Resampling.NEAREST)
            mask = np.array(mask)
            masks.append(mask)
            labels.append(ann['category_id'])
            
        # Stack masks along new dimension
        if len(masks) > 0:
            masks = np.stack(masks, axis=0)
        else:
            masks = np.zeros((1, *self.image_size), dtype=np.uint8)
            labels = [0]  # Add dummy label
            
        # Convert to tensors
        if self.transform:
            image = self.transform(image)
            masks = torch.from_numpy(masks).float()
            
        return image, masks, torch.tensor(labels)

In [5]:
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    
    # Find max number of masks across batch
    max_masks = max(item[1].shape[0] for item in batch)
    
    # Pad masks and labels to max size
    padded_masks = []
    padded_labels = []
    
    for _, masks, labels in batch:
        num_masks = masks.shape[0]
        
        # Create padding for masks
        if num_masks < max_masks:
            padding = torch.zeros((max_masks - num_masks, masks.shape[1], masks.shape[2]))
            padded_mask = torch.cat([masks, padding], dim=0)
            
            # Create padding for labels 
            label_padding = torch.zeros(max_masks - num_masks)
            padded_label = torch.cat([labels, label_padding])
        else:
            padded_mask = masks
            padded_label = labels
            
        padded_masks.append(padded_mask)
        padded_labels.append(padded_label)
    
    # Stack padded masks and labels
    masks = torch.stack(padded_masks)
    labels = torch.stack(padded_labels)
    
    return images, masks, labels

# Model Training

In [6]:
import gc
import torch.cuda.amp as amp
from torch.utils.checkpoint import checkpoint

class SAMFinetune(nn.Module):
    def __init__(self, checkpoint_path, model_type="vit_l"):
        super().__init__()
        self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        
        # Enable gradient checkpointing
        self.sam.image_encoder.use_gradient_checkpointing = True
        self.sam.mask_decoder.transformer.use_gradient_checkpointing = True
        
        # Freeze image encoder
        for param in self.sam.image_encoder.parameters():
            param.requires_grad = False
            
        # Freeze prompt encoder
        for param in self.sam.prompt_encoder.parameters():
            param.requires_grad = False
            
        # Only train mask decoder
        for param in self.sam.mask_decoder.parameters():
            param.requires_grad = True

    def forward(self, image, masks=None):
        # Clear cache before forward pass
        torch.cuda.empty_cache()
        gc.collect()
        
        batch_size = image.shape[0]
        num_instances = masks.shape[1] if masks is not None else 1
        
        # Get image embeddings with gradient checkpointing
        with torch.cuda.amp.autocast():
            image_embeddings = checkpoint(self.sam.image_encoder, image)
        
        final_masks = []
        
        for b in range(batch_size):
            curr_masks = masks[b] if masks is not None else None
            curr_embeddings = image_embeddings[b:b+1]
            
            instance_masks = []
            
            for i in range(num_instances):
                if curr_masks is not None:
                    mask = curr_masks[i]
                    if mask.sum() > 0:
                        y_indices, x_indices = torch.where(mask > 0)
                        center_y = y_indices.float().mean()
                        center_x = x_indices.float().mean()
                    else:
                        center_y = torch.rand(1, device=image.device) * image.shape[2]
                        center_x = torch.rand(1, device=image.device) * image.shape[3]
                    
                    point_coords = torch.tensor([[center_x, center_y]], device=image.device)
                    point_labels = torch.ones(1, device=image.device)
                    
                    with torch.cuda.amp.autocast():
                        sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
                            points=(point_coords.unsqueeze(0), point_labels.unsqueeze(0)),
                            boxes=None,
                            masks=None
                        )
                        
                        pos_encoding = self.sam.prompt_encoder.get_dense_pe()
                        
                        low_res_masks, _ = self.sam.mask_decoder(
                            image_embeddings=curr_embeddings,
                            image_pe=pos_encoding,
                            sparse_prompt_embeddings=sparse_embeddings,
                            dense_prompt_embeddings=dense_embeddings,
                            multimask_output=False,
                        )
                        
                        curr_mask = self.sam.postprocess_masks(
                            low_res_masks,
                            input_size=image.shape[-2:],
                            original_size=image.shape[-2:]
                        )
                    
                    instance_masks.append(curr_mask.squeeze())
                    
                    # Clear cache after each instance
                    torch.cuda.empty_cache()
                    gc.collect()
            
            if instance_masks:
                image_masks = torch.stack(instance_masks, dim=0)
            else:
                image_masks = torch.zeros((num_instances, *image.shape[-2:]), device=image.device)
            
            final_masks.append(image_masks)
            
        output_masks = torch.stack(final_masks, dim=0)
        return output_masks

In [7]:
# def train_sam():
#     # Setup
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     torch.backends.cudnn.benchmark = True
    
#     # Use smaller image size
#     image_size = (512, 512)  # Reduced from 1024x1024
    
#     # Transform
#     transform = T.Compose([
#         T.ToTensor(),
#         T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])

#     # Datasets
#     train_dataset = UnderwaterInstanceDataset(
#         root_dir=TRAIN_ROOT,
#         ann_file=TRAIN_ANN,
#         transform=transform,
#         image_size=image_size
#     )
    
#     val_dataset = UnderwaterInstanceDataset(
#         root_dir=VAL_ROOT,
#         ann_file=VAL_ANN,
#         transform=transform,
#         image_size=image_size
#     )

#     # Smaller batch size
#     train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, collate_fn=collate_fn)
#     val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn)

#     model = SAMFinetune(CHECKPOINT_PATH).to(device)
    
#     # Initialize mixed precision training
#     scaler = amp.GradScaler()
    
#     optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
#     scheduler = CosineAnnealingLR(optimizer, T_max=50)
#     dice_loss = nn.BCEWithLogitsLoss()
#     iou_loss = nn.BCEWithLogitsLoss()

#     # Training loop
#     best_val_loss = float('inf')
#     num_epochs = 50
    
#     for epoch in range(num_epochs):
#         model.train()
#         total_loss = 0
        
#         for batch_idx, (images, masks, labels) in enumerate(tqdm(train_loader)):
#             images = images.to(device)
#             masks = masks.to(device)
            
#             # Clear cache before each batch
#             torch.cuda.empty_cache()
#             gc.collect()
            
#             optimizer.zero_grad()
            
#             # Use mixed precision training
#             with amp.autocast():
#                 pred_masks = model(images, masks)
#                 loss = dice_loss(pred_masks, masks) + iou_loss(pred_masks, masks)
            
#             scaler.scale(loss).backward()
#             scaler.step(optimizer)
#             scaler.update()
            
#             total_loss += loss.item()
            
#             # Free up memory
#             del pred_masks, loss
#             torch.cuda.empty_cache()
#             gc.collect()

#         # Validation with reduced memory usage
#         model.eval()
#         val_loss = 0
        
#         with torch.no_grad():
#             for images, masks, labels in val_loader:
#                 images = images.to(device)
#                 masks = masks.to(device)
                
#                 with amp.autocast():
#                     pred_masks = model(images)
#                     loss = dice_loss(pred_masks, masks) + iou_loss(pred_masks, masks)
#                 val_loss += loss.item()
                
#                 del pred_masks, loss
#                 torch.cuda.empty_cache()
#                 gc.collect()

#         avg_loss = total_loss / len(train_loader)
#         avg_val_loss = val_loss / len(val_loader)
        
#         print(f'Epoch {epoch+1}/{num_epochs}:')
#         print(f'Training Loss: {avg_loss:.4f}')
#         print(f'Validation Loss: {avg_val_loss:.4f}')

#         if avg_val_loss < best_val_loss:
#             best_val_loss = avg_val_loss
#             torch.save(model.state_dict(), 'best_model.pth')

#         scheduler.step()

# Training Model

# Evaluation

In [8]:
# def evaluate_sam(model):
#     model.eval()
    
#     # Load validation dataset
#     val_dataset = UnderwaterDataset(
#         coco_annotation='/kaggle/input/underwaterimageinstancesegmentation/UIIS/UDW/annotations/val.json',
#         img_dir='/kaggle/input/underwaterimageinstancesegmentation/UIIS/UDW/val',
#         transform=transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.485, 0.456, 0.406], 
#                                std=[0.229, 0.224, 0.225])
#         ])
#     )
    
#     val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
#     # Metrics
#     iou_scores = []
    
#     with torch.no_grad():
#         for images, masks in tqdm(val_loader):
#             images = images.to(DEVICE)
#             masks = masks.to(DEVICE)
            
#             # Get predictions
#             mask_predictions, _ = model(images, None)
#             pred_masks = (torch.sigmoid(mask_predictions) > 0.5).float()
            
#             # Calculate IoU
#             intersection = (pred_masks * masks).sum()
#             union = pred_masks.sum() + masks.sum() - intersection
#             iou = (intersection + 1e-6) / (union + 1e-6)
#             iou_scores.append(iou.item())
    
#     mean_iou = np.mean(iou_scores)
#     print(f"Mean IoU: {mean_iou:.4f}")
    
#     return mean_iou

# Driver

In [9]:
# trained_model = train_sam()
# torch.save(trained_model.state_dict(), 'underwater_sam.pth')
# mean_iou = evaluate_sam(trained_model)

In [10]:
def evaluate_sam(model):
    model.eval()
    
    # Load validation dataset
    val_dataset = UnderwaterInstanceDataset(
        root_dir=VAL_ROOT,
        ann_file=VAL_ANN,
        transform=T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
        ]),
        image_size=(1024, 1024)
    )
    
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, collate_fn=collate_fn)
    
    # Metrics
    iou_scores = []
    
    with torch.no_grad():
        for images, masks, _ in tqdm(val_loader):
            images = images.to(device)
            masks = masks.to(device)
            
            # Get predictions
            pred_masks = model(images, masks)
            pred_masks = (torch.sigmoid(pred_masks) > 0.5).float()
            
            # Calculate IoU for each instance
            for i in range(masks.shape[1]):
                if masks[:, i].sum() > 0:  # Only evaluate non-empty masks
                    intersection = (pred_masks[:, i] * masks[:, i]).sum()
                    union = pred_masks[:, i].sum() + masks[:, i].sum() - intersection
                    iou = (intersection + 1e-6) / (union + 1e-6)
                    iou_scores.append(iou.item())
    
    mean_iou = np.mean(iou_scores)
    print(f"Mean IoU: {mean_iou:.4f}")
    print(f"Number of instances evaluated: {len(iou_scores)}")
    
    return mean_iou

In [11]:
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create and load model
    model = SAMFinetune(CHECKPOINT_PATH).to(device)
    
    # Evaluate model
    print("Evaluating pretrained SAM model...")
    evaluate_sam(model)

  state_dict = torch.load(f)


Evaluating pretrained SAM model...
loading annotations into memory...
Done (t=0.49s)
creating index...
index created!


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
  with torch.cuda.amp.autocast():
100%|██████████| 691/691 [14:03<00:00,  1.22s/it]

Mean IoU: 0.6363
Number of instances evaluated: 3784





0.6362520139296935