In [2]:
# Semantic Segmentation with ResNet on Pascal VOC Dataset
# Simple notebook for training a ResNet-based segmentation model

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================================================
# 1. Dataset Loading and Visualization
# =============================================================================

# Define transforms for training and validation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

target_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.NEAREST),
    transforms.ToTensor()
])

# Load Pascal VOC dataset
print("Loading Pascal VOC dataset...")
train_dataset = torchvision.datasets.VOCSegmentation(
    root='./data',
    year='2012',
    image_set='train',
    download=True,
    transform=train_transform,
    target_transform=target_transform
)

val_dataset = torchvision.datasets.VOCSegmentation(
    root='./data',
    year='2012',
    image_set='val',
    download=True,
    transform=train_transform,
    target_transform=target_transform
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create data loaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Pascal VOC class names (21 classes including background)
class_names = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]

# Color map for visualization
def create_colormap():
    colormap = np.zeros((21, 3), dtype=np.uint8)
    colors = [
        [0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
        [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0],
        [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128],
        [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]
    ]
    for i, color in enumerate(colors):
        colormap[i] = color
    return colormap

colormap = create_colormap()

def visualize_sample(dataset, idx=0):
    """Visualize a sample from the dataset"""
    image, target = dataset[idx]

    # Denormalize image for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image_denorm = image * std + mean
    image_denorm = torch.clamp(image_denorm, 0, 1)

    # Convert target to color map
    target_np = (target.squeeze(0) * 255).numpy().astype(np.uint8)
    # Handle the ignore value (255) by setting it to 0 (background) for visualization
    target_np = np.where(target_np == 255, 0, target_np)
    target_colored = colormap[target_np]

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(image_denorm.permute(1, 2, 0))
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(target_colored)
    plt.title('Segmentation Mask')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

# Visualize some samples
print("\nVisualizing sample images and their segmentation masks:")
for i in range(3):
    visualize_sample(train_dataset, i)

# =============================================================================
# 2. Model Definition
# =============================================================================

class ResNetSegmentation(nn.Module):
    def __init__(self, num_classes=21):
        super(ResNetSegmentation, self).__init__()

        # Load pretrained ResNet18 (small ResNet)
        resnet = torchvision.models.resnet18(pretrained=True)

        # Remove the last fully connected layer and average pooling
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

        # Add upsampling layers for segmentation
        # ResNet18 backbone reduces spatial size by 32x (224->7), so we need to upsample by 32x
        self.upsample = nn.Sequential(
            # 7x7 -> 14x14
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 14x14 -> 28x28
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # 28x28 -> 56x56
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # 56x56 -> 112x112
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            # 112x112 -> 224x224
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            # Final classification layer
            nn.Conv2d(32, num_classes, kernel_size=1),
        )

    def forward(self, x):
        features = self.backbone(x)
        output = self.upsample(features)
        return output

# Create model
model = ResNetSegmentation(num_classes=21).to(device)
print(f"\nModel created and moved to {device}")

# Print model summary
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model):,}")

# =============================================================================
# 3. Training Setup
# =============================================================================

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=255)  # 255 is typically used for void/ignore pixels
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Training function
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    for i, (images, targets) in enumerate(dataloader):
        images = images.to(device)
        targets = (targets.squeeze(1) * 255).long().to(device)  # Convert to class indices

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 50 == 0:
            print(f'Batch {i}/{len(dataloader)}, Loss: {loss.item():.4f}')

    return running_loss / len(dataloader)

# Validation function
def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0

    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)
            targets = (targets.squeeze(1) * 255).long().to(device)

            outputs = model(images)
            loss = criterion(outputs, targets)
            running_loss += loss.item()

            # Calculate pixel accuracy
            predictions = torch.argmax(outputs, dim=1)
            mask = targets != 255  # Ignore void pixels
            correct_pixels += (predictions[mask] == targets[mask]).sum().item()
            total_pixels += mask.sum().item()

    accuracy = correct_pixels / total_pixels if total_pixels > 0 else 0
    return running_loss / len(dataloader), accuracy

# =============================================================================
# 4. Training Loop
# =============================================================================

num_epochs = 100  # Small number for demonstration
train_losses = []
val_losses = []
val_accuracies = []

print(f"\nStarting training for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    print('-' * 50)

    # Training
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)

    # Validation
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)

    # Update learning rate
    scheduler.step()

    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f'Train Loss: {train_loss:.4f}')
    print(f'Val Loss: {val_loss:.4f}')
    print(f'Val Accuracy: {val_acc:.4f}')

print("\nTraining completed!")

# =============================================================================
# 5. Results Visualization
# =============================================================================

# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(range(1, num_epochs+1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, num_epochs+1), val_losses, 'r-', label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(range(1, num_epochs+1), val_accuracies, 'g-', label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.bar(['Train Loss', 'Val Loss', 'Val Accuracy'],
        [train_losses[-1], val_losses[-1], val_accuracies[-1]],
        color=['blue', 'red', 'green'])
plt.title('Final Metrics')
plt.ylabel('Value')

plt.tight_layout()
plt.show()

# =============================================================================
# 6. Prediction Visualization
# =============================================================================

def visualize_predictions(model, dataset, device, num_samples=3):
    """Visualize model predictions"""
    model.eval()

    with torch.no_grad():
        for i in range(num_samples):
            image, target = dataset[i]
            image_batch = image.unsqueeze(0).to(device)

            # Get prediction
            output = model(image_batch)
            prediction = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()

            # Denormalize image for visualization
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            image_denorm = image * std + mean
            image_denorm = torch.clamp(image_denorm, 0, 1)

            # Convert masks to color
            target_np = (target.squeeze(0) * 255).numpy().astype(np.uint8)
            # Handle the ignore value (255) by setting it to 0 (background) for visualization
            target_np = np.where(target_np == 255, 0, target_np)
            target_colored = colormap[target_np]
            prediction_colored = colormap[prediction]

            plt.figure(figsize=(15, 5))

            plt.subplot(1, 3, 1)
            plt.imshow(image_denorm.permute(1, 2, 0))
            plt.title('Original Image')
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.imshow(target_colored)
            plt.title('Ground Truth')
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.imshow(prediction_colored)
            plt.title('Prediction')
            plt.axis('off')

            plt.tight_layout()
            plt.show()

print("\nVisualizing predictions on validation samples:")
visualize_predictions(model, val_dataset, device, num_samples=3)

# =============================================================================
# 7. Model Saving
# =============================================================================

# Save the trained model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'val_accuracies': val_accuracies,
}, 'resnet_segmentation_model.pth')

print("\nModel saved as 'resnet_segmentation_model.pth'")

# Print final results
print(f"\n{'='*50}")
print("TRAINING SUMMARY")
print(f"{'='*50}")
print(f"Final Training Loss: {train_losses[-1]:.4f}")
print(f"Final Validation Loss: {val_losses[-1]:.4f}")
print(f"Final Validation Accuracy: {val_accuracies[-1]:.4f}")
print(f"Best Validation Accuracy: {max(val_accuracies):.4f}")
print(f"{'='*50}")

Using device: cuda
Loading Pascal VOC dataset...


URLError: <urlopen error [Errno 110] Connection timed out>