In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import timm
import json
from sklearn.metrics import jaccard_score
import albumentations as A
from albumentations.pytorch import ToTensorV2
import scipy.io
import torch.utils.checkpoint as checkpoint
import matplotlib.pyplot as plt
import random
from tqdm import tqdm

  check_for_updates()


In [2]:
# Set environment variable to help with memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Define paths
LVIS_BASE_PATH = '/kaggle/input/lvis-v1'
BERKELEY_BASE_PATH = '/kaggle/input/berkeley-segmentation-dataset-500-bsds500'
OUTPUT_DIR = '/kaggle/working/output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [3]:
# Dataset class for LVIS (training only)
class LVISDataset(Dataset):
    def __init__(self, image_dir, ann_file, transform=None):
        self.image_dir = image_dir
        with open(ann_file, 'r') as f:
            self.annotations = json.load(f)['annotations']
        self.transform = transform

        # Filter image_ids to only include images that exist in the directory
        all_image_ids = list(set(ann['image_id'] for ann in self.annotations))
        self.image_ids = []
        for img_id in all_image_ids:
            img_path = os.path.join(self.image_dir, f'{img_id:012d}.jpg')
            if os.path.exists(img_path):
                self.image_ids.append(img_id)
            else:
                print(f"Warning: Image {img_path} not found, skipping.")

        # Use only 1/20th of the dataset
        self.image_ids = self.image_ids[:len(self.image_ids) // 20]
        print(f"Total images available for training (1/4th): {len(self.image_ids)}")

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f'{img_id:012d}.jpg')
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found at {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Create mask (binary: foreground=1, background=0)
        mask = np.zeros(image.shape[:2], dtype=np.uint8)
        for ann in self.annotations:
            if ann['image_id'] == img_id:
                seg = ann['segmentation']
                for poly in seg:
                    poly = np.array(poly).reshape(-1, 2).astype(np.int32)
                    cv2.fillPoly(mask, [poly], 1)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented['image'], augmented['mask']

        return image, mask

# Dataset class for Berkeley (test only)
class BerkeleyDataset(Dataset):
    def __init__(self, base_path, split='test', transform=None):
        self.split = split
        self.image_dir = os.path.join(base_path, 'images', split)
        self.mask_dir = os.path.join(base_path, 'ground_truth', split)
        self.transform = transform
        self.images = [f for f in os.listdir(self.image_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_name = img_name.replace('.jpg', '.mat')
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        # Load image
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found at {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load mask from .mat file
        mat = scipy.io.loadmat(mask_path)
        mask = mat['groundTruth'][0, 0]['Segmentation'][0, 0]
        mask = (mask > 0).astype(np.uint8)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented['image'], augmented['mask']

        return image, mask, img_name

In [4]:
# Data transforms (224x224 to match ViT-B expectation)
train_transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

test_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# Probability map modulation
def modulate_probability_map(prob_map, gamma=2.0, threshold=0.7):
    modulated = prob_map.clone()
    fg_mask = prob_map > threshold
    modulated[fg_mask] = torch.pow(modulated[fg_mask], 1/gamma)
    bg_mask = prob_map < (1 - threshold)
    modulated[bg_mask] = torch.pow(modulated[bg_mask], gamma)
    return modulated

# MFP Network (ViT-B only)
class MFPNet(nn.Module):
    def __init__(self):
        super(MFPNet, self).__init__()
        # ViT-B backbone
        self.encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.feature_channels = 768  # ViT-B has 768 channels
        self.patch_size = 16

        # Probability feature extractor
        self.prob_conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        # Fusion and segmentation head
        self.fusion = nn.Sequential(
            nn.Conv2d(self.feature_channels + 64, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.seg_head = nn.Conv2d(256, 1, kernel_size=1)

    def forward(self, image, prob_map=None):
        # Backbone features (ViT-B)
        x = self.encoder.forward_features(image)  # [batch_size, 197, 768]
        x = x[:, 1:]  # Remove CLS token: [batch_size, 196, 768]
        # Reshape to feature map: 196 patches = 14x14 grid for 224x224
        grid_size = int((x.shape[1]) ** 0.5)  # 14 for 224x224
        x = x.transpose(1, 2).reshape(-1, self.feature_channels, grid_size, grid_size)  # [batch_size, 768, 14, 14]

        # If no prob_map provided, initialize with zeros
        if prob_map is None:
            prob_map = torch.zeros((image.shape[0], 1, image.shape[2], image.shape[3]), device=image.device)

        # Modulate probability map
        modulated_prob = modulate_probability_map(prob_map)

        # Extract probability features
        prob_features = checkpoint.checkpoint_sequential(self.prob_conv, segments=2, input=modulated_prob)  # [batch_size, 64, 224, 224]

        # Upsample backbone features to match prob_features size (224x224)
        x = F.interpolate(x, size=prob_features.shape[2:], mode='bilinear', align_corners=False)

        # Late fusion
        fused = torch.cat([x, prob_features], dim=1)
        fused = checkpoint.checkpoint_sequential(self.fusion, segments=2, input=fused)

        # Segmentation head
        logits = self.seg_head(fused)
        prob = torch.sigmoid(logits)
        return prob

# Dice loss
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, pred, target, smooth=1):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)
        intersection = (pred * target).sum()
        return 1 - ((2. * intersection + smooth) / (pred.sum() + target.sum() + smooth))

# Training function (with accuracy tracking and tqdm progress bar)
def train_model(model, train_loader, num_epochs=5, device='cuda'):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion_ce = nn.BCELoss()
    criterion_dice = DiceLoss()

    train_accuracies = []
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_acc = 0
        # Wrap the train_loader with tqdm for a progress bar
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
        for images, masks in progress_bar:
            images, masks = images.to(device), masks.to(device).float()
            masks = masks.unsqueeze(1)

            prob = model(images)
            prob = model(images, prob)

            loss = 0.5 * criterion_ce(prob, masks) + 0.5 * criterion_dice(prob, masks)
            train_loss += loss.item()

            # Compute accuracy
            pred = (prob > 0.5).float()
            acc = (pred == masks).float().mean().item()
            train_acc += acc

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update progress bar with current loss and accuracy
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{acc:.4f}'})

        epoch_acc = train_acc / len(train_loader)
        train_accuracies.append(epoch_acc)
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {epoch_acc:.4f}')

        # Save model
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f'model_vit_b_epoch{epoch+1}.pth'))

    return train_accuracies

# Evaluation and visualization function
def evaluate_and_visualize(model, test_loader, device='cuda'):
    model.eval()
    miou = 0
    acc = 0
    # Select a random image for visualization
    test_dataset = test_loader.dataset
    random_idx = random.randint(0, len(test_dataset) - 1)
    image, mask, img_name = test_dataset[random_idx]
    image = image.unsqueeze(0).to(device)  # Add batch dimension
    mask = mask.unsqueeze(0).to(device).float()  # Add batch dimension

    with torch.no_grad():
        # Compute metrics on the entire test set
        for images, masks, _ in test_loader:
            images, masks = images.to(device), masks.to(device).float()
            masks = masks.unsqueeze(1)

            prob = model(images)
            pred = (prob > 0.5).float()
            miou += jaccard_score(masks.cpu().numpy().flatten(), pred.cpu().numpy().flatten(), average='binary')
            acc += (pred == masks).float().mean().item()

        # Predict on the random image
        prob = model(image)
        pred = (prob > 0.5).float()

    test_miou = miou / len(test_loader)
    test_acc = acc / len(test_loader)

    # Visualize the results
    # Denormalize the image for display
    image = image.squeeze(0).cpu().numpy().transpose(1, 2, 0)  # [3, 224, 224] -> [224, 224, 3]
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = (image * std + mean) * 255
    image = image.astype(np.uint8)

    mask = mask.squeeze().cpu().numpy()  # [1, 224, 224] -> [224, 224]
    pred = pred.squeeze().cpu().numpy()  # [1, 224, 224] -> [224, 224]

    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(image)
    axes[0].set_title(f"Input Image: {img_name}")
    axes[0].axis('off')

    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title("Ground Truth Mask")
    axes[1].axis('off')

    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title("Predicted Mask (ViT-B)")
    axes[2].axis('off')

    plt.savefig(os.path.join(OUTPUT_DIR, f'segmentation_vit_b_{img_name}.png'))
    plt.close()

    return test_miou, test_acc

In [5]:
# Main execution
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Clear GPU memory cache
    if device.type == 'cuda':
        torch.cuda.empty_cache()

    # Load datasets
    train_dataset = LVISDataset(
        image_dir=os.path.join(LVIS_BASE_PATH, 'train2017', 'train2017'),
        ann_file=os.path.join(LVIS_BASE_PATH, '/kaggle/input/lvis-v1/lvis_v1_train.json/lvis_v1_train.json'),
        transform=train_transform
    )
    test_dataset = BerkeleyDataset(
        base_path=BERKELEY_BASE_PATH,
        split='test',
        transform=test_transform
    )

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

    # Train and evaluate with ViT-B
    print('\nTraining with ViT-B backbone')
    model_vit = MFPNet()
    vit_train_accuracies = train_model(model_vit, train_loader, num_epochs=10, device=device)

    print('\nEvaluating with ViT-B backbone')
    vit_test_miou, vit_test_acc = evaluate_and_visualize(model_vit, test_loader, device=device)
    print(f'ViT-B Test mIoU: {vit_test_miou:.4f}, Test Accuracy: {vit_test_acc:.4f}')

    # Print training accuracies
    print("\nViT-B Training Accuracies per Epoch:")
    for epoch, acc in enumerate(vit_train_accuracies, 1):
        print(f"Epoch {epoch}: {acc:.4f}")

if __name__ == '__main__':
    main()

Total images available for training (1/4th): 4969

Training with ViT-B backbone


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

                                                                                        

Epoch 1/10, Train Accuracy: 0.7689


                                                                                        

Epoch 2/10, Train Accuracy: 0.7601


                                                                                        

Epoch 3/10, Train Accuracy: 0.7640


                                                                                        

Epoch 4/10, Train Accuracy: 0.7676


                                                                                        

Epoch 5/10, Train Accuracy: 0.7707


                                                                                        

Epoch 6/10, Train Accuracy: 0.7738


                                                                                        

Epoch 7/10, Train Accuracy: 0.7752


                                                                                        

Epoch 8/10, Train Accuracy: 0.7769


                                                                                        

Epoch 9/10, Train Accuracy: 0.7782


                                                                                         

Epoch 10/10, Train Accuracy: 0.7801

Evaluating with ViT-B backbone




ViT-B Test mIoU: 0.2326, Test Accuracy: 0.2326

ViT-B Training Accuracies per Epoch:
Epoch 1: 0.7689
Epoch 2: 0.7601
Epoch 3: 0.7640
Epoch 4: 0.7676
Epoch 5: 0.7707
Epoch 6: 0.7738
Epoch 7: 0.7752
Epoch 8: 0.7769
Epoch 9: 0.7782
Epoch 10: 0.7801
