In [13]:
import pickle
import os
import numpy as np
import torch
import torch.optim as optim
import datasets
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
from datasets import Forest, ToTensor
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from torchvision.models import ResNet50_Weights
from torch.autograd import Variable
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
transform = transforms.Compose([
    ToTensor(),
])

train_dataset = Forest(root_dir="/home/k45848/multispectral-imagery-segmentation/data/clean_data/train", label=True, transform=transform)
eval_dataset = Forest(root_dir="/home/k45848/multispectral-imagery-segmentation/data/clean_data/eval", label=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=16)

for X, Y in train_loader:
    print(X.shape)
    print(Y.shape)
    break

torch.Size([16, 11, 64, 64])
torch.Size([16, 64, 64])


In [15]:
num_classes = 6
weights = DeepLabV3_ResNet50_Weights
weights_backbone = ResNet50_Weights.DEFAULT
model = deeplabv3_resnet50(weights=None, weights_backbone=weights_backbone, num_classes=num_classes)


In [16]:
b = model.backbone

In [17]:
# Modify the model for 10 channels input and 3 classes output
#### input channels = 10, output classes = 3
input_channels = 11
model.backbone.conv1 = torch.nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.to(device)

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(11, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): S

In [18]:
criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(model.parameters(), lr=0.001)
num_epochs = 40


In [19]:
def compute_iou(preds, labels, num_classes):
    iou = 0
    for cls in range(num_classes):
        intersection = ((preds== cls) & (labels==cls)).sum()
        union = ((preds== cls) | (labels==cls)).sum()
        if union > 0:
            iou += float(intersection) / float(union)
    return iou / num_classes

Training

In [20]:
# Lists to store training and evaluation metrics
train_losses = []
train_ious = []
eval_losses = []
eval_ious = []

# Training loop 
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0
    running_iou = 0
    
    # Wrap the train_loader with tqdm to add a progress bar
    with tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as t:
        for images, labels in t:
            # Move data to device (e.g., GPU if available)
            images = images.to(device)
            labels = labels.to(device)
            
            # Zero the parameter gradients
            optimiser.zero_grad()
            
            # Forward pass
            outputs = model(images)['out']
            
            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimiser.step()
            
            # Calculate IoU
            preds = torch.argmax(outputs, dim=1)
            batch_iou = compute_iou(preds, labels, num_classes)
            running_iou += batch_iou
             
            # Print statistics
            running_loss += loss.item() * images.size(0)
            t.set_postfix(loss=running_loss / ((t.n + 1) * train_loader.batch_size))
            t.update()
    
    epoch_loss = running_loss / len(train_dataset)
    epoch_iou = running_iou / len(train_loader)

    # Append training metrics to lists
    train_losses.append(epoch_loss)
    train_ious.append(epoch_iou)

    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}, IoU: {epoch_iou:.4f}")
    


    model.eval()
    eval_loss = 0
    eval_iou = 0

    with torch.no_grad():
        for images, labels in eval_loader:
           
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)['out']

            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Calculate IoU
            preds = torch.argmax(outputs, dim=1)
            batch_iou = compute_iou(preds, labels, num_classes)
            eval_iou += batch_iou

            # Accumulate loss
            eval_loss += loss.item() * images.size(0)

    # Average loss and IoU over the evaluation dataset
    eval_loss /= len(eval_dataset)
    eval_iou /= len(eval_loader)

    # Append evaluation metrics to lists
    eval_losses.append(eval_loss)
    eval_ious.append(eval_iou)

    print(f"Evaluation Loss: {eval_loss:.4f}, Evaluation IoU: {eval_iou:.4f}")



print('Training finished.')

# Plotting
plt.figure(figsize=(10, 5))

# Plot training and validation loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(eval_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Plot training and validation IoU
plt.subplot(1, 2, 2)
plt.plot(train_ious, label='Training IoU')
plt.plot(eval_ious, label='Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('Training and Validation IoU')
plt.legend()

plt.tight_layout()
plt.show()

Epoch 1/40:   0%|          | 1/738 [00:00<02:38,  4.66batch/s, loss=1.71]

In [None]:
epoch_iou = running_iou / len(train_loader)