In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from torchmetrics.classification import Accuracy, Dice
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import numpy as np

if torch.cuda.is_available():
    device = torch.device("cuda")  # Use the GPU
    print("GPU is available")
else:
    device = torch.device("cpu")   # Use the CPU
    print("GPU is not available")


  from .autonotebook import tqdm as notebook_tqdm


GPU is available


In [None]:
from torchvision.datasets import VOCSegmentation

# Define the transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the Pascal VOC dataset
dataset = VOCSegmentation(root='data', year='2012', image_set='train', download=True, transform=transform)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders for training and validation sets
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [7]:
num_classes = 21

def choose_model(model_name):
    if model_name == 'deeplab':
        # Load the pre-trained DeepLabv3 model
        model = models.segmentation.deeplabv3_resnet101(weights_backbone="ResNet101_Weights.DEFAULT", num_classes=num_classes)
        model.to(device)
        return model
    elif model_name == 'segformer':
        # Load the pre-trained SegFormer feature extractor
        # feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

        # Load the pre-trained SegFormer model with the specified number of classes
        model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512", num_labels=num_classes, ignore_mismatched_sizes=True)
        model.to(device)
        return model
    elif model_name == 'lraspp':
        # Load the pre-trained LR-ASPP model
        model = models.segmentation.lraspp_mobilenet_v3_large(weights_backbone="MobileNet_V3_Large_Weights.DEFAULT", num_classes=num_classes)
        model.to(device)
        return model
    else:
        raise ValueError(f'Unknown model name: {model_name}')

In [8]:
model_options = ["deeplab", "segformer", "lraspp"]
model = choose_model(model_options[0])
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
accuracy = Accuracy(task="multiclass", num_classes=num_classes)
mDice = Dice(average='macro', num_classes=num_classes)

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [9]:
# 3. Training and Evaluation

# Train and test functions
def train(model, loader):
    model.train()
    correct, total, running_loss = 0, 0, 0.0
    all_preds, all_labels = [], []

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)['out']
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Store predictions and labels as tensors
        all_preds.append(predicted.cpu())
        all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    # Calculate metrics
    train_accuracy = accuracy(all_preds, all_labels)
    train_mDice = mDice(all_preds, all_labels)
    
    return running_loss / len(loader), train_accuracy, train_mDice

def test(model, loader):
    model.eval()
    correct, total, running_loss = 0, 0, 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Store predictions and labels as tensors
            all_preds.append(predicted.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    # Calculate metrics
    test_accuracy = accuracy(all_preds, all_labels)
    test_mDice = mDice(all_preds, all_labels)
    
    return running_loss / len(loader), test_accuracy, test_mDice

In [None]:
# Tracking metrics
train_losses, test_losses = [], []
train_accuracies, test_accuracies = [], []
train_mdice, test_mdice = [], []

# Main training Loop
for epoch in range(2):
    train_loss, train_accuracy, train_mDice = train(model, train_loader)
    test_loss, test_accuracy, test_mDice = test(model, val_loader)

    # Store metrics
    train_losses.append(train_loss)
    test_losses.append(test_loss)

    train_accuracies.append(train_accuracy)
    test_accuracies.append(test_accuracy)

    train_mdice.append(train_mDice)
    test_mdice.append(test_mDice)

    print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Train mDice: {train_mDice:.4f} | Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.4f}, Test mDice: {test_mDice:.4f}")

In [None]:
# 4. Plot Results
plt.figure(figsize=(12, 4))

# Loss Plot
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.title('Loss')
plt.legend()

# Accuracy Plot
plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Accuracy')
plt.legend()

# mDice Plot
plt.subplot(1, 3, 3)
plt.plot(train_mdice, label='Train mDice')
plt.plot(test_mdice, label='Test mDice')
plt.title('mDice Score')
plt.legend()

plt.show()