In [None]:
!pip install segmentation_models_pytorch

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from torchmetrics.classification import MulticlassJaccardIndex, MulticlassAccuracy, MulticlassConfusionMatrix
import timm
from PIL import Image
from tqdm import tqdm
import random
import glob
from einops import rearrange
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Mixed precision setup
use_amp = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
print(f"Using mixed precision training (F16+F32): {use_amp}")

# Constants
BATCH_SIZE = 32
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 1024
NUM_CLASSES = 19  # Cityscapes has 19 evaluation classes
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
BASE_PATH = '/kaggle/input'
CITYSCAPES_IMAGES = f'{BASE_PATH}/cityscapes-leftimg8bit-trainvaltest/leftImg8bit'
CITYSCAPES_MASKS = f'{BASE_PATH}/yrealdataset/gtFine'

# Define Cityscapes color mapping to class indices
# This is the standard Cityscapes color mapping
cityscapes_classes = [
    (0, 0, 0),         # background
    (128, 64, 128),    # road
    (244, 35, 232),    # sidewalk
    (70, 70, 70),      # building
    (102, 102, 156),   # wall
    (190, 153, 153),   # fence
    (153, 153, 153),   # pole
    (250, 170, 30),    # traffic light
    (220, 220, 0),     # traffic sign
    (107, 142, 35),    # vegetation
    (152, 251, 152),   # terrain
    (70, 130, 180),    # sky
    (220, 20, 60),     # person
    (255, 0, 0),       # rider
    (0, 0, 142),       # car
    (0, 0, 70),        # truck
    (0, 60, 100),      # bus
    (0, 80, 100),      # train
    (0, 0, 230),       # motorcycle
    (119, 11, 32),     # bicycle
]

# Create a mapping from RGB to class index
def create_label_mapping():
    color_to_class = {}
    for i, color in enumerate(cityscapes_classes):
        color_to_class[color] = i
    return color_to_class

color_to_class = create_label_mapping()

# Custom Dataset
class CityscapesDataset(Dataset):
    def __init__(self, split, transform=None):
        self.split = split
        self.transform = transform
        self.color_to_class = color_to_class
        
        # Get all image paths
        self.images = []
        self.masks = []
        
        image_pattern = os.path.join(CITYSCAPES_IMAGES, split, '*', '*_leftImg8bit.png')
        print(f"Looking for images in: {image_pattern}")
        self.images = sorted(glob.glob(image_pattern))
        
        # For each image, find the corresponding mask
        for img_path in self.images:
            # Extract city and filename
            parts = img_path.split('/')
            city = parts[-2]
            filename = parts[-1].replace('_leftImg8bit.png', '')
            
            # Construct the mask path
            mask_path = os.path.join(CITYSCAPES_MASKS, split, city, f"{filename}_gtFine_color.png")
            if os.path.exists(mask_path):
                self.masks.append(mask_path)
            else:
                print(f"Warning: Mask not found for {img_path}")
                # Remove the image if mask doesn't exist
                self.images.remove(img_path)
        
        print(f"Found {len(self.images)} image-mask pairs for {split} set")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        mask_path = self.masks[idx]
        
        # Load image and mask
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("RGB"))
        
        # Apply transformations if any
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # Convert RGB mask to class indices
        mask_indices = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)
        
        # Iterate through each pixel and find the closest class color
        for i in range(len(cityscapes_classes)):
            class_color = np.array(cityscapes_classes[i])
            mask_indices[(np.abs(mask - class_color.reshape(1, 1, 3))).sum(axis=2) < 30] = i
            
        return image, mask_indices

# Data Augmentation
train_transform = A.Compose([
    A.RandomScale(scale_limit=0.1),
    A.RandomCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# Initialize datasets
print("Initializing datasets...")
train_dataset = CityscapesDataset(split='train', transform=train_transform)
val_dataset = CityscapesDataset(split='val', transform=val_transform)
test_dataset = CityscapesDataset(split='test', transform=val_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

# CNN-ViT Hybrid Model
class CNNViTHybrid(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(CNNViTHybrid, self).__init__()
        
        # CNN Backbone (ResNet50 without classification head)
        self.cnn_backbone = models.resnet50(pretrained=True)
        self.cnn_features = nn.Sequential(*list(self.cnn_backbone.children())[:-2])
        
        # ViT Encoder using timm
        self.vit = timm.create_model('vit_base_patch16_384', pretrained=True, num_classes=0)
        vit_embed_dim = self.vit.embed_dim  # Typically 768 for base model
        
        # Adapter to match CNN feature dimensions to ViT input
        self.adapter = nn.Conv2d(2048, 768, kernel_size=1)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768, 512, kernel_size=2, stride=2),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, num_classes, kernel_size=2, stride=2),
        )
    
    def forward(self, x):
        # Get CNN features
        cnn_features = self.cnn_features(x)  # [B, 2048, H/32, W/32]
        print(f"CNN features shape: {cnn_features.shape}") if random.random() < 0.01 else None
        
        # Adapt CNN features for ViT
        adapted_features = self.adapter(cnn_features)  # [B, 768, H/32, W/32]
        
        # Reshape for ViT
        B, C, H, W = adapted_features.shape
        adapted_features = rearrange(adapted_features, 'b c h w -> b (h w) c')
        
        # Pass through ViT
        vit_features = self.vit.forward_features(adapted_features)  # [B, (H/32)*(W/32), 768]
        print(f"ViT features shape: {vit_features.shape}") if random.random() < 0.01 else None
        
        # Reshape back to spatial dimensions
        vit_features = rearrange(vit_features, 'b (h w) c -> b c h w', h=H, w=W)
        
        # Decode to segmentation map
        output = self.decoder(vit_features)  # [B, num_classes, H, W]
        print(f"Output shape: {output.shape}") if random.random() < 0.01 else None
        
        return output

# Focal Tversky Loss
class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, gamma=2.0, smooth=1e-5):
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, inputs, targets):
        # Flatten inputs and targets
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        # True Positives, False Positives & False Negatives
        TP = (inputs * targets).sum()    
        FP = ((1-targets) * inputs).sum()
        FN = (targets * (1-inputs)).sum()
        
        # Tversky index
        tversky = (TP + self.smooth) / (TP + self.alpha*FP + self.beta*FN + self.smooth)  
        
        # Focal Tversky loss
        focal_tversky = (1 - tversky) ** self.gamma
        
        return focal_tversky

# Initialize model, loss, and optimizer
print("Initializing model...")
model = CNNViTHybrid(num_classes=NUM_CLASSES).to(device)
criterion = FocalTverskyLoss(alpha=0.7, beta=0.3, gamma=2.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Initialize metrics
iou_metric = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to(device)
accuracy_metric = MulticlassAccuracy(num_classes=NUM_CLASSES).to(device)
confusion_matrix = MulticlassConfusionMatrix(num_classes=NUM_CLASSES).to(device)

# Training and validation functions
def train_one_epoch(model, loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    running_iou = 0.0
    running_acc = 0.0
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
    
    for batch_idx, (images, masks) in enumerate(progress_bar):
        images = images.to(device)
        masks = masks.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(images)
            loss = criterion(F.softmax(outputs, dim=1), F.one_hot(masks, NUM_CLASSES).permute(0, 3, 1, 2).float())
        
        # Backward and optimize with scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Calculate metrics
        preds = outputs.argmax(dim=1)
        batch_iou = iou_metric(preds, masks)
        batch_acc = accuracy_metric(preds, masks)
        
        # Update running metrics
        running_loss += loss.item()
        running_iou += batch_iou.item()
        running_acc += batch_acc.item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'iou': running_iou / (batch_idx + 1),
            'acc': running_acc / (batch_idx + 1)
        })
        
        # Print occasional sample metrics
        if batch_idx % 50 == 0:
            print(f"\nBatch {batch_idx}: Loss: {loss.item():.4f}, IoU: {batch_iou.item():.4f}, Acc: {batch_acc.item():.4f}")
    
    # Calculate epoch metrics
    epoch_loss = running_loss / len(loader)
    epoch_iou = running_iou / len(loader)
    epoch_acc = running_acc / len(loader)
    
    return epoch_loss, epoch_iou, epoch_acc

def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    running_iou = 0.0
    running_acc = 0.0
    
    progress_bar = tqdm(loader, desc="Validating")
    
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(progress_bar):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass with mixed precision
            with torch.cuda.amp.autocast(enabled=use_amp):
                outputs = model(images)
                loss = criterion(F.softmax(outputs, dim=1), F.one_hot(masks, NUM_CLASSES).permute(0, 3, 1, 2).float())
            
            # Calculate metrics
            preds = outputs.argmax(dim=1)
            batch_iou = iou_metric(preds, masks)
            batch_acc = accuracy_metric(preds, masks)
            
            # Update running metrics
            running_loss += loss.item()
            running_iou += batch_iou.item()
            running_acc += batch_acc.item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'val_loss': running_loss / (batch_idx + 1),
                'val_iou': running_iou / (batch_idx + 1),
                'val_acc': running_acc / (batch_idx + 1)
            })
    
    # Calculate epoch metrics
    epoch_loss = running_loss / len(loader)
    epoch_iou = running_iou / len(loader)
    epoch_acc = running_acc / len(loader)
    
    return epoch_loss, epoch_iou, epoch_acc

# Function to calculate Dice coefficient
def dice_coefficient(pred, target):
    smooth = 1e-5
    pred = pred.view(-1)
    target = target.view(-1)
    
    intersection = (pred * target).sum()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
    
    return dice.item()

# Training loop
print("Starting training...")
history = {
    'train_loss': [],
    'train_iou': [],
    'train_acc': [],
    'val_loss': [],
    'val_iou': [],
    'val_acc': [],
    'dice': []
}

best_val_iou = 0.0

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss, train_iou, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, epoch)
    
    # Validate
    val_loss, val_iou, val_acc = validate(model, val_loader, criterion)
    
    # Calculate Dice on validation
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Calculating Dice"):
            images = images.to(device)
            masks = masks.to(device)
            
            with torch.cuda.amp.autocast(enabled=use_amp):
                outputs = model(images)
                
            preds = outputs.argmax(dim=1)
            
            # Convert to one-hot for dice calculation
            one_hot_preds = F.one_hot(preds, NUM_CLASSES).permute(0, 3, 1, 2).float()
            one_hot_masks = F.one_hot(masks, NUM_CLASSES).permute(0, 3, 1, 2).float()
            
            # Calculate dice for each class and average
            batch_dice = 0
            for cls in range(NUM_CLASSES):
                batch_dice += dice_coefficient(one_hot_preds[:, cls], one_hot_masks[:, cls])
            batch_dice /= NUM_CLASSES
            
            dice_scores.append(batch_dice)
    
    val_dice = sum(dice_scores) / len(dice_scores)
    
    # Update learning rate
    scheduler.step()
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_iou'].append(train_iou)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_iou'].append(val_iou)
    history['val_acc'].append(val_acc)
    history['dice'].append(val_dice)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}:")
    print(f"Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, Acc: {train_acc:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, Acc: {val_acc:.4f}, Dice: {val_dice:.4f}")
    
    # Save best model
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"New best model saved with IoU: {best_val_iou:.4f}")

print("Training complete!")

# Plot metrics
plt.figure(figsize=(20, 15))

# Plot Loss
plt.subplot(2, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss Evolution')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot IoU
plt.subplot(2, 2, 2)
plt.plot(history['train_iou'], label='Train IoU')
plt.plot(history['val_iou'], label='Val IoU')
plt.title('IoU Evolution')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.legend()
plt.grid(True)

# Plot Accuracy
plt.subplot(2, 2, 3)
plt.plot(history['train_acc'], label='Train Accuracy')
plt.plot(history['val_acc'], label='Val Accuracy')
plt.title('Accuracy Evolution')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# Plot Dice
plt.subplot(2, 2, 4)
plt.plot(history['dice'], label='Validation Dice')
plt.title('Dice Coefficient Evolution')
plt.xlabel('Epoch')
plt.ylabel('Dice')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_metrics.png')
plt.show()

# Load best model for testing
print("Loading best model for testing...")
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Test the model and create confusion matrix
print("Testing model and creating confusion matrix...")
test_iou = 0.0
test_acc = 0.0
test_dice = 0.0
all_preds = []
all_masks = []

with torch.no_grad():
    for images, masks in tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        masks = masks.to(device)
        
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(images)
        
        preds = outputs.argmax(dim=1)
        
        # Update metrics
        test_iou += iou_metric(preds, masks).item()
        test_acc += accuracy_metric(preds, masks).item()
        
        # Calculate dice
        one_hot_preds = F.one_hot(preds, NUM_CLASSES).permute(0, 3, 1, 2).float()
        one_hot_masks = F.one_hot(masks, NUM_CLASSES).permute(0, 3, 1, 2).float()
        
        batch_dice = 0
        for cls in range(NUM_CLASSES):
            batch_dice += dice_coefficient(one_hot_preds[:, cls], one_hot_masks[:, cls])
        batch_dice /= NUM_CLASSES
        
        test_dice += batch_dice
        
        # Collect predictions and masks for confusion matrix
        all_preds.append(preds.flatten())
        all_masks.append(masks.flatten())
        
        # Update confusion matrix for each batch
        confusion_matrix.update(preds, masks)

# Calculate final test metrics
test_iou /= len(test_loader)
test_acc /= len(test_loader)
test_dice /= len(test_loader)

print(f"\nTest Results - IoU: {test_iou:.4f}, Acc: {test_acc:.4f}, Dice: {test_dice:.4f}")

# Get the confusion matrix
conf_matrix = confusion_matrix.compute().cpu().numpy()

# Plot confusion matrix
plt.figure(figsize=(12, 10))
plt.imshow(conf_matrix, cmap='Blues')
plt.colorbar()
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')

# Add text annotations to the confusion matrix
thresh = conf_matrix.max() / 2
for i in range(conf_matrix.shape[0]):
    for j in range(conf_matrix.shape[1]):
        plt.text(j, i, f'{int(conf_matrix[i, j])}',
                 ha="center", va="center",
                 color="white" if conf_matrix[i, j] > thresh else "black")

plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.show()

# Function to visualize predictions
def visualize_prediction(model, dataset, idx):
    image, mask = dataset[idx]
    image_tensor = image.unsqueeze(0).to(device)
    mask = mask.numpy()
    
    # Get prediction
    model.eval()
    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=use_amp):
            output = model(image_tensor)
        pred = output.argmax(dim=1).squeeze().cpu().numpy()
    
    # Convert image back from normalized tensor for visualization
    image = image.permute(1, 2, 0).cpu().numpy()
    image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)
    
    # Create color masks
    mask_colored = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    pred_colored = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
    
    for class_idx, color in enumerate(cityscapes_classes):
        mask_colored[mask == class_idx] = color
        pred_colored[pred == class_idx] = color
    
    # Plot
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Input Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(mask_colored)
    plt.title('Ground Truth')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(pred_colored)
    plt.title('Prediction')
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(f'prediction_example_{idx}.png')
    plt.show()

# Visualize some predictions
for i in [0, 10, 20]:
    visualize_prediction(model, test_dataset, i)

print("All done! Model has been trained, evaluated, and visualizations have been created.")