In [1]:
import os
# import argparse
import numpy as np
import PIL.Image as Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
from torchsummary import summary

from dataset.uav_segmentation import UAVSegmentation
from model.vanilla_unet import VanillaUNetDoubleConv as VanillaUNet
from trainer import train_one_epoch
from eval import evaluate
from utils import EarlyStopper
from utils import saveModel

In [2]:
DATASET_PATH_TRAIN = '/mnt/hdd/dataset/uav_dataset/train/'
DATASET_PATH_VAL = '/mnt/hdd/dataset/uav_dataset/val/'
DATASET_PATH_TEST = '/mnt/hdd/dataset/uav_dataset/test/'

BATCH_SIZE = 16
NUM_EPOCHS = 3
LEARNING_RATE = 0.0001
GPU_ID = 0

NUM_CLASSES = 2
NUM_WORKER = 30

SCH_FACTOR = 0.15
SCH_PATIENCE = 15
SCH_COOLDOWN = 5

ES_PATIENCE = 30
ES_MIN_DELTA = 0.001
ES_MODE = "min"

BEST_TRAIN_LOSS = float('inf')
BEST_VAL_LOSS = float('inf')

SEL_CRITERION = "CrossEntropyLoss"
SEL_OPTIMIZER = "AdamW"
SEL_SCHEDULER = "ReduceLROnPlateau"

In [None]:
transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

In [None]:
train_dataset = UAVSegmentation(DATASET_PATH_TRAIN, NUM_CLASSES, transforms=transform)
val_dataset = UAVSegmentation(DATASET_PATH_VAL, NUM_CLASSES, transforms=transform)
test_dataset = UAVSegmentation(DATASET_PATH_TEST, NUM_CLASSES, transforms=transform)


print('Train dataset size:', len(train_dataset))
print('Val dataset size:', len(val_dataset))
print('Test dataset size:', len(test_dataset))



In [None]:
# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKER)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKER)

# Iterate over the train dataloader
for images, masks in train_dataloader:
    print('Train batch size:', images.size())
    break

# Iterate over the val dataloader
for images, masks in val_dataloader:
    print('Val batch size:', images.size())
    break

for idx, (image, mask) in enumerate(train_dataset):
    print('Image shape:', image.shape)
    print('Mask shape:', mask.shape)
    break

In [3]:
# Define the model
model = VanillaUNet(in_channels=1, out_channels=NUM_CLASSES)

if SEL_CRITERION == 'CrossEntropyLoss':
    criterion = torch.nn.CrossEntropyLoss()
elif SEL_CRITERION == 'BCEWithLogitsLoss':
    criterion = torch.nn.BCEWithLogitsLoss()

if SEL_OPTIMIZER == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
elif SEL_OPTIMIZER == 'AdamW':
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
elif SEL_OPTIMIZER == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

if SEL_SCHEDULER == 'ReduceLROnPlateau':
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=SCH_FACTOR, patience=SCH_PATIENCE, cooldown=SCH_COOLDOWN)

device = torch.device(f'cuda:{GPU_ID}' if torch.cuda.is_available() else 'cpu')
early_stopper = EarlyStopper(patience = int(ES_PATIENCE), 
                            min_delta = float(ES_MIN_DELTA))

model.to(device)

VanillaUNetDoubleConv(
  (stage1): Sequential(
    (0): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [4]:
summary(model, (1, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]             576
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
            Conv2d-4         [-1, 64, 224, 224]          36,864
       BatchNorm2d-5         [-1, 64, 224, 224]             128
              ReLU-6         [-1, 64, 224, 224]               0
        DoubleConv-7         [-1, 64, 224, 224]               0
            Conv2d-8        [-1, 128, 224, 224]          73,728
       BatchNorm2d-9        [-1, 128, 224, 224]             256
             ReLU-10        [-1, 128, 224, 224]               0
           Conv2d-11        [-1, 128, 224, 224]         147,456
      BatchNorm2d-12        [-1, 128, 224, 224]             256
             ReLU-13        [-1, 128, 224, 224]               0
       DoubleConv-14        [-1, 128, 2

In [None]:
for epoch in range(NUM_EPOCHS):
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS}')
    train_loss, train_dice_loss, train_dice_metrics = train_one_epoch(model, train_dataloader, criterion, optimizer, device)
    print(f'Train Loss: {train_loss} | Dice loss: {train_dice_loss} | Dice metrics: {train_dice_metrics}')
    val_loss, val_dice_loss, val_dice_metrics = evaluate(model, val_dataloader, criterion, device)
    print(f'Val Loss: {val_loss} | Dice loss: {val_dice_loss} | Dice metrics: {val_dice_metrics}')
    
    lr_scheduler.step(val_loss)

    # Save the model
    if train_loss < BEST_TRAIN_LOSS:
        saveModel(model, optimizer, lr_scheduler, epoch, train_loss, 'checkpoints/best_train.pth')
    if val_loss < BEST_VAL_LOSS:
        saveModel(model, optimizer, lr_scheduler, epoch, val_loss, 'checkpoints/best_val.pth')

    if early_stopper.early_stop(val_loss):
        print('Early stopping')
        break


In [None]:
def compute_confusion_matrix(pred, target, num_classes):
    # print(pred.shape, target.shape)
    pred = pred.flatten()
    target = target.flatten()
    mask = (target >= 0) & (target < num_classes)
    # print(pred.shape, mask.shape, target.shape)
    return np.bincount(
        num_classes * target[mask].astype(int) + pred[mask].astype(int),
        minlength=num_classes**2,
    ).reshape(num_classes, num_classes)

def calculate_metrics(confusion_matrix):
    tp = np.diag(confusion_matrix)
    sum_rows = confusion_matrix.sum(axis=1)
    sum_cols = confusion_matrix.sum(axis=0)
    total_pixels = confusion_matrix.sum()

    pixel_accuracy = tp.sum() / total_pixels
    mean_pixel_accuracy = np.mean(tp / np.maximum(sum_rows, 1))
    iou = tp / np.maximum(sum_rows + sum_cols - tp, 1)
    mean_iou = np.mean(iou)

    return pixel_accuracy, mean_pixel_accuracy, iou, mean_iou

In [None]:
def evaluate_model(model, val_loader, device, num_classes):
    model.eval()
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            # preds = outputs
            preds = torch.argmax(outputs, dim=1)
            masks = torch.argmax(masks, dim=1)

            for pred, mask in zip(preds, masks):
                cm = compute_confusion_matrix(
                    pred.cpu().numpy(),
                    mask.cpu().numpy(),
                    num_classes=num_classes
                )
                confusion_matrix += cm

    pixel_accuracy, mean_pixel_accuracy, iou, mean_iou = calculate_metrics(confusion_matrix)
    print(f"Pixel Accuracy: {pixel_accuracy:.4f}")
    print(f"Mean Pixel Accuracy: {mean_pixel_accuracy:.4f}")
    print(f"Mean IoU: {mean_iou:.4f}")
    print(f"IoU per Class: {iou}")

    return pixel_accuracy, mean_pixel_accuracy, iou, mean_iou

In [None]:
evaluate_model(model, val_dataloader, device=device, num_classes=NUM_CLASSES)

In [None]:
model_trained = VanillaUNet(in_channels=3, out_channels=NUM_CLASSES)
model_trained.load_state_dict(torch.load('checkpoints/best_val.pth')['model_state_dict'])
model_trained.to(device)

In [None]:
with torch.no_grad():
    for images, masks in val_dataloader:
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model_trained(images)
        break

In [None]:
outputs.shape

In [None]:
image = images[0].permute(1, 2, 0).cpu()
mask = masks[0].argmax(dim=0).cpu().numpy()
predicted_mask = outputs[0].argmax(dim=0).cpu().numpy()

In [None]:
plt.figure(figsize=(12, 12))
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.axis('off')
plt.title('Image')

In [None]:
plt.subplot(1, 3, 2)
plt.imshow(mask)
plt.axis('off')
plt.title('Mask')

In [None]:
plt.subplot(1, 3, 3)
plt.imshow(predicted_mask)
plt.axis('off')
plt.title('Predicted Mask')