In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchvision.transforms import Resize, InterpolationMode, ToPILImage, RandomCrop
import torchmetrics
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
import segmentation_models_pytorch as smp

from src.evaluation.evaluate_result import evaluate_result
from src.datasets.UAVidSemanticSegmentationDataset import (
    UAVidSemanticSegmentationDataset,
)

  from .autonotebook import tqdm as notebook_tqdm


## Prepare environment

In [2]:
torch.cuda.is_available()

True

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again

# DEBUG
# device = torch.device("cpu")
print(device)

cuda


In [4]:
VAL_SIZE = 0.2
BATCH_SIZE = 8
SEED = 42
UAVID_DATASET_PATH = "data/UAVidSemanticSegmentationDataset"

# IMAGE_SIZE = 576

IMAGE_WIDTH = 1024
IMAGE_HEIGHT = 576

In [25]:
train_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="train",
    transforms=[
        # Resize(IMAGE_SIZE),
        # Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=InterpolationMode.NEAREST_EXACT),
        RandomCrop((IMAGE_HEIGHT, IMAGE_WIDTH)),
        # Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(train_dataset))

200


In [6]:
val_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="valid",
    transforms=[
        # Resize(IMAGE_SIZE),
        Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=InterpolationMode.NEAREST_EXACT),
        # Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(val_dataset))

70


In [7]:
test_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="test",
    transforms=[
        # Resize(IMAGE_SIZE),
        Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=InterpolationMode.NEAREST_EXACT),
        # Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(test_dataset))

10


## Sanity check data

In [8]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

In [9]:
for images, masks in train_loader:
    print(images.shape)
    print(masks.shape)
    break

torch.Size([8, 3, 576, 1024])
torch.Size([8, 1, 576, 1024])


In [10]:
to_pil_transform = ToPILImage()

In [11]:
if BATCH_SIZE == 1:
    img = to_pil_transform(images.squeeze())

In [12]:
# img.show()

In [13]:
if BATCH_SIZE == 1:
    msk = to_pil_transform(masks.squeeze())

In [14]:
# msk.show()

## Basic training loop

In [15]:
import gc

gc.collect()

20

In [16]:
torch.cuda.empty_cache()

In [35]:
import torch
import torch.optim as optim
import segmentation_models_pytorch as smp
from tqdm import tqdm
import json
from time import perf_counter

# Define model, loss function, and optimizer
model = smp.Unet(encoder_name="resnet18", encoder_weights="imagenet", in_channels=3, classes=8).to(device)
# jaccard_loss = torchmetrics.JaccardIndex(task="multiclass", num_classes=8).to(device)
jaccard_loss = smp.losses.JaccardLoss(mode="multiclass", from_logits=False).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training and validation loop
num_epochs = 10
torch.cuda.empty_cache()
gc.collect()

for epoch in tqdm(range(num_epochs)):
    model.train()
    print("Starting epoch", epoch+1)
    t0 = perf_counter()
    train_loss = 0.0

    print("Training...")
    for images, masks in tqdm(train_loader):
        images = images.to(device)
        masks = masks.to(device).squeeze(1)

        optimizer.zero_grad()
        torch.cuda.empty_cache()
        # outputs = model(images).argmax(dim=1)
        outputs = model(images)
        # print("outputs")
        # print(outputs.shape)
        # print(outputs.dtype)
        # print("masks")
        # print(masks.shape)
        # print(masks.dtype)
        loss = jaccard_loss(y_pred=outputs.long(), y_true=masks.long())
        loss.requires_grad = True
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

    train_loss /= len(train_loader)
    
    print("Validating...")
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in tqdm(val_loader):
            images = images.to(device)
            masks = masks.to(device).squeeze(1)
            
            # outputs = model(images).argmax(dim=1)
            outputs = model(images)
            # print(outputs.shape, masks.shape)
            loss = jaccard_loss(y_pred=outputs.long(), y_true=masks.long())

            val_loss += loss.item()

    val_loss /= len(val_loader)
    
    # Save metrics to JSON file
    metrics = {
        'epoch': epoch+1,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_iou': float(JaccardIndex(num_classes=8, task="multiclass").compute()),
        'val_iou': float(JaccardIndex(num_classes=8, task="multiclass").compute()),
        'train_precision': float(Precision(num_classes=8, task="multiclass").compute()),
        'val_precision': float(Precision(num_classes=8, task="multiclass").compute()),
        'train_recall': float(Recall(num_classes=8, task="multiclass").compute()),
        'val_recall': float(Recall(num_classes=8, task="multiclass").compute()),
        'train_f1': float(F1Score(num_classes=8, task="multiclass").compute()),
        'val_f1': float(F1Score(num_classes=8, task="multiclass").compute())
    }

    with open('metrics.json', 'a') as f:
        json.dump(metrics, f)
        f.write('\n')

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"Epoch finished in {perf_counter() - t0:.2f} seconds")
    print()
    model_path = f"model_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)



  0%|          | 0/10 [00:00<?, ?it/s]

Starting epoch 1
Training...


100%|██████████| 25/25 [02:35<00:00,  6.23s/it]


Validating...


100%|██████████| 9/9 [00:52<00:00,  5.83s/it]
 10%|█         | 1/10 [03:28<31:16, 208.48s/it]

Epoch 1/10, Train Loss: 12474949885.4400, Val Loss: 4207777745.7778
Epoch finished in 208.39 seconds

Starting epoch 2
Training...


100%|██████████| 25/25 [02:40<00:00,  6.43s/it]


Validating...


100%|██████████| 9/9 [01:16<00:00,  8.49s/it]
 20%|██        | 2/10 [07:25<30:02, 225.37s/it]

Epoch 2/10, Train Loss: 11101049800.9600, Val Loss: 12206110979.5556
Epoch finished in 237.09 seconds

Starting epoch 3
Training...


100%|██████████| 25/25 [02:36<00:00,  6.28s/it]


Validating...


100%|██████████| 9/9 [01:01<00:00,  6.78s/it]
 30%|███       | 3/10 [11:03<25:54, 222.04s/it]

Epoch 3/10, Train Loss: 11748999756.8000, Val Loss: 17981250074.6667
Epoch finished in 218.00 seconds

Starting epoch 4
Training...


100%|██████████| 25/25 [02:42<00:00,  6.52s/it]


Validating...


100%|██████████| 9/9 [00:52<00:00,  5.85s/it]
 40%|████      | 4/10 [14:39<21:57, 219.53s/it]

Epoch 4/10, Train Loss: 13218399637.7600, Val Loss: 18498471960.8889
Epoch finished in 215.60 seconds

Starting epoch 5
Training...


100%|██████████| 25/25 [02:34<00:00,  6.16s/it]


Validating...


100%|██████████| 9/9 [01:29<00:00,  9.91s/it]
 50%|█████     | 5/10 [18:42<19:00, 228.12s/it]

Epoch 5/10, Train Loss: 15492700225.2800, Val Loss: 18392638298.6667
Epoch finished in 243.25 seconds

Starting epoch 6
Training...


100%|██████████| 25/25 [02:27<00:00,  5.90s/it]


Validating...


100%|██████████| 9/9 [01:16<00:00,  8.52s/it]
 60%|██████    | 6/10 [22:26<15:07, 226.79s/it]

Epoch 6/10, Train Loss: 15274400046.0800, Val Loss: 18713194142.2222
Epoch finished in 224.14 seconds

Starting epoch 7
Training...


100%|██████████| 25/25 [02:29<00:00,  5.97s/it]


Validating...


100%|██████████| 9/9 [01:17<00:00,  8.62s/it]
 70%|███████   | 7/10 [26:13<11:20, 226.82s/it]

Epoch 7/10, Train Loss: 8628150008.3200, Val Loss: 18680833623.1111
Epoch finished in 226.74 seconds

Starting epoch 8
Training...


100%|██████████| 25/25 [02:38<00:00,  6.35s/it]


Validating...


100%|██████████| 9/9 [01:27<00:00,  9.77s/it]
 80%|████████  | 8/10 [30:20<07:46, 233.22s/it]

Epoch 8/10, Train Loss: 11480449827.8400, Val Loss: 18853194531.5556
Epoch finished in 246.83 seconds

Starting epoch 9
Training...


100%|██████████| 25/25 [02:38<00:00,  6.35s/it]


Validating...


100%|██████████| 9/9 [01:20<00:00,  8.90s/it]
 90%|█████████ | 9/10 [34:19<03:55, 235.04s/it]

Epoch 9/10, Train Loss: 14186699907.8400, Val Loss: 18462777233.7778
Epoch finished in 238.90 seconds

Starting epoch 10
Training...


100%|██████████| 25/25 [02:38<00:00,  6.36s/it]


Validating...


100%|██████████| 9/9 [01:21<00:00,  9.07s/it]
100%|██████████| 10/10 [38:20<00:00, 230.05s/it]

Epoch 10/10, Train Loss: 9482599861.7600, Val Loss: 18797777674.6667
Epoch finished in 240.61 seconds






In [18]:
# compare random model with trained model (model_10.pth) on test set
random_model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=8).to(device)
random_model.eval()

# Load the trained model

# trained_model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=8).to(device)

model_path = "model_10.pth"
trained_model = smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=8).to(device)
trained_model.load_state_dict(torch.load(model_path))
trained_model.eval()

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [24]:
import os

# Evaluate the random model on the test set
random_model_loss = 0.0
with torch.no_grad():
    for images in test_loader:
        images = images.to(device)
        random_outputs = random_model(images).argmax(dim=1)
        print(random_outputs.shape)
        # Save images as PNG files
        output_dir = "random_model_output"
        os.makedirs(output_dir, exist_ok=True)

        for i, image in enumerate(random_outputs):
            image_path = os.path.join(output_dir, f"image_{i}.png")
            # print(image.shape)
            # print(image.dtype)
            # unique, counts = np.unique(image.cpu().numpy(), return_counts=True)
            # print(dict(zip(unique, counts)))
            # plt.matshow(image.cpu().numpy())
            plt.imsave(image_path, image.cpu().numpy())
            # image = to_pil_transform(image)
            # image.save(image_path)

        print("random_outputs saved successfully.")

# Evaluate the trained model on the test set
trained_model_loss = 0.0
with torch.no_grad():
    for images in test_loader:
        images = images.to(device)
        trained_outputs = trained_model(images).argmax(dim=1)
        print(trained_outputs.shape)
        output_dir = "trained_output_images"
        os.makedirs(output_dir, exist_ok=True)

        for i, image in enumerate(trained_outputs):
            image_path = os.path.join(output_dir, f"image_{i}.png")
            # print(image.shape)
            # print(image.dtype)
            # unique, counts = np.unique(image.cpu().numpy(), return_counts=True)
            # print(dict(zip(unique, counts)))
            # plt.matshow(image.cpu().numpy())
            plt.imsave(image_path, image.cpu().numpy())
            # image = to_pil_transform(image)
            # image.save(image_path)

        print("trained_outputs saved successfully.")


torch.Size([8, 576, 1024])
random_outputs saved successfully.
torch.Size([2, 576, 1024])
random_outputs saved successfully.
torch.Size([8, 576, 1024])
trained_outputs saved successfully.
torch.Size([2, 576, 1024])
trained_outputs saved successfully.


# Training module

In [None]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super(SegmentationModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        # self.criterion = torch.nn.CrossEntropyLoss()
        self.criterion = smp.losses.JaccardLoss(mode="multiclass", from_logits=False)

        # Metrics
        self.train_iou = JaccardIndex(num_classes=8, task="multiclass")
        self.val_iou = JaccardIndex(num_classes=8, task="multiclass")
        self.train_precision = Precision(num_classes=8, task="multiclass")
        self.val_precision = Precision(num_classes=8, task="multiclass")
        self.train_recall = Recall(num_classes=8, task="multiclass")
        self.val_recall = Recall(num_classes=8, task="multiclass")
        self.train_f1 = F1Score(num_classes=8, task="multiclass")
        self.val_f1 = F1Score(num_classes=8, task="multiclass")

    def forward(self, x):
        output = self.model(x.to(device))
        # FIXME: not really x > 0.5
        # this is degenerated case for binary segmentation
        # output = torch.argmax(output, dim=1)
        return output

    def training_step(self, batch, batch_idx):
        
        images, masks = batch
        print("masks before")
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        # masks = masks.to(torch.int16).squeeze()
        masks = masks.squeeze()
        preds = self(images).squeeze()
        # preds = self(images).to(torch.int16).squeeze()
        print("masks")
        print(masks.shape)
        print(masks.dtype)
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        print("preds")
        print(preds.shape)
        print(preds.dtype)
        # unique, counts = np.unique(preds.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        loss = self.criterion(preds, masks)
        loss.requires_grad = True
        
        self.log('train_loss', loss, on_epoch=True)
        self.log('train_iou', self.train_iou(preds, masks), on_epoch=True)
        self.log('train_precision', self.train_precision(preds, masks), on_epoch=True)
        self.log('train_recall', self.train_recall(preds, masks), on_epoch=True)
        self.log('train_f1', self.train_f1(preds, masks), on_epoch=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        print("masks before")
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        # masks = masks.to(torch.int16).squeeze()
        masks = masks.squeeze()
        preds = self(images).squeeze()
        # preds = self(images).to(torch.int16).squeeze()
        print("masks")
        print(masks.shape)
        print(masks.dtype)
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        print("preds")
        print(preds.shape)
        print(preds.dtype)
        # unique, counts = np.unique(preds.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        loss = self.criterion(preds, masks)
        loss.requires_grad = True
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_iou', self.val_iou(preds, masks), on_epoch=True)
        self.log('val_precision', self.val_precision(preds, masks), on_epoch=True)
        self.log('val_recall', self.val_recall(preds, masks), on_epoch=True)
        self.log('val_f1', self.val_f1(preds, masks), on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
model = smp.Unet(
    encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=8,  # model output channels (number of classes in your dataset)
    activation="softmax",
).to(device)

In [None]:
# model = smp.UnetPlusPlus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,  # model output channels (number of classes in your dataset)
# ).to(device)

In [None]:
# model = smp.DeepLabV3(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [None]:
# model = smp.DeepLabV3Plus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [None]:
segmentation_model = SegmentationModel(model)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss_epoch', save_top_k=-1, mode='min'
    )

In [None]:
logger = CSVLogger("logs", name="segmentation_model")

In [None]:
trainer = pl.Trainer(max_epochs=10, callbacks=[checkpoint_callback], logger=logger, accelerator=str(device))
# trainer = pl.Trainer(max_epochs=10, accelerator="cpu")

In [None]:
trainer.fit(segmentation_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# UNET

In [None]:
# TODO: load best model from checkpoint

In [None]:
model.eval()
with torch.no_grad():
    for images, masks in train_loader:
        print(images.shape)
        print(masks.shape)
        break
    output = model(images.to(device))
    output = torch.argmax(output, dim=1)