In [205]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchvision.transforms import Resize, InterpolationMode, ToPILImage
import torchmetrics
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
import segmentation_models_pytorch as smp

from src.evaluation.evaluate_result import evaluate_result
from src.datasets.UAVidSemanticSegmentationDataset import (
    UAVidSemanticSegmentationDataset,
)

## Prepare environment

In [173]:
torch.cuda.is_available()

True

In [174]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again

# DEBUG
# device = torch.device("cpu")
print(device)

cuda


In [175]:
VAL_SIZE = 0.2
BATCH_SIZE = 8
SEED = 42
UAVID_DATASET_PATH = "data/UAVidSemanticSegmentationDataset"

# IMAGE_SIZE = 576

IMAGE_WIDTH = 1024
IMAGE_HEIGHT = 576

In [176]:
train_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="train",
    transforms=[
        # Resize(IMAGE_SIZE),
        Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=InterpolationMode.NEAREST_EXACT),
        # Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(train_dataset))

200


In [177]:
val_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="valid",
    transforms=[
        # Resize(IMAGE_SIZE),
        Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=InterpolationMode.NEAREST_EXACT),
        # Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(val_dataset))

70


In [178]:
test_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="test",
    transforms=[
        # Resize(IMAGE_SIZE),
        Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=InterpolationMode.NEAREST_EXACT),
        # Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(test_dataset))

10


## Sanity check data

In [179]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

In [180]:
for images, masks in train_loader:
    print(images.shape)
    print(masks.shape)
    break

torch.Size([8, 3, 576, 1024])
torch.Size([8, 1, 576, 1024])


In [181]:
to_pil_transform = ToPILImage()

In [182]:
if BATCH_SIZE == 1:
    img = to_pil_transform(images.squeeze())

In [183]:
# img.show()

In [184]:
if BATCH_SIZE == 1:
    msk = to_pil_transform(masks.squeeze())

In [185]:
# msk.show()

## Basic training loop

In [186]:
import gc

gc.collect()

2177

In [187]:
torch.cuda.empty_cache()

In [208]:

def jaccard_loss(preds, targets, smooth=1e-6, num_classes=8):
    print(targets.shape)
    targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2)
    intersection = (preds * targets_one_hot).sum(dim=(1, 2, 3))
    union = preds.sum(dim=(1, 2, 3)) + targets_one_hot.sum(dim=(1, 2, 3)) - intersection
    jaccard = (intersection + smooth) / (union + smooth)
    return 1 - jaccard.mean()

In [210]:
batch_size = 8
height = 256
width = 256
num_classes = 8  # Adjust this to your actual number of classes

In [212]:
targets = torch.randint(0, num_classes, (batch_size, height, width))
print(targets.shape)

torch.Size([8, 256, 256])


In [209]:
preds = torch.tensor([[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]])
targets = torch.tensor([[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]])

loss = jaccard_loss(preds, targets)
print(loss.item())  # Output: 0.0

torch.Size([1, 2, 2, 2])


RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 5 is not equal to len(dims) = 4

In [204]:
preds = torch.tensor([[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]])
targets = torch.tensor([[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]])

loss = jaccard_loss(preds, targets)
print(loss.item())  # Output: 1.0

NameError: name 'F' is not defined

In [198]:
preds = torch.tensor([[[[0, 0], [1, 1]], [[2, 2], [1, 0]]]])
targets = torch.tensor([[[[0, 2], [1, 1]], [[1, 1], [1, 2]]]])

loss = jaccard_loss(preds, targets)
print(loss.item())

0.2222222089767456


In [214]:
import torch
import torch.optim as optim
import segmentation_models_pytorch as smp
from tqdm import tqdm
import json
from time import perf_counter

# Define model, loss function, and optimizer
model = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=8).to(device)
# jaccard_loss = torchmetrics.JaccardIndex(task="multiclass", num_classes=8).to(device)
# jaccard_loss = smp.losses.JaccardLoss(mode="multiclass", from_logits=False).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training and validation loop
num_epochs = 10
torch.cuda.empty_cache()
gc.collect()

for epoch in tqdm(range(num_epochs)):
    model.train()
    print("Starting epoch", epoch+1)
    t0 = perf_counter()
    train_loss = 0.0

    print("Training...")
    for images, masks in tqdm(train_loader):
        images = images.to(device)
        masks = masks.to(device).squeeze(1)

        optimizer.zero_grad()
        torch.cuda.empty_cache()
        # outputs = model(images).argmax(dim=1)
        outputs = model(images)
        print("outputs")
        print(outputs.shape)
        print(outputs.dtype)
        print("masks")
        print(masks.shape)
        print(masks.dtype)
        loss = jaccard_loss(preds=outputs.long(), targets=masks.long())
        loss.requires_grad = True
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

    train_loss /= len(train_loader)
    
    print("Validating...")
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in tqdm(val_loader):
            images = images.to(device)
            masks = masks.to(device)
            
            # outputs = model(images).argmax(dim=1)
            outputs = model(images)
            print(outputs.shape, masks.shape)
            loss = jaccard_loss(preds=outputs.long(), targets=masks.long())

            val_loss += loss.item()

    val_loss /= len(val_loader)
    
    # Save metrics to JSON file
    metrics = {
        'epoch': epoch+1,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_iou': float(JaccardIndex(num_classes=8, task="multiclass").compute()),
        'val_iou': float(JaccardIndex(num_classes=8, task="multiclass").compute()),
        'train_precision': float(Precision(num_classes=8, task="multiclass").compute()),
        'val_precision': float(Precision(num_classes=8, task="multiclass").compute()),
        'train_recall': float(Recall(num_classes=8, task="multiclass").compute()),
        'val_recall': float(Recall(num_classes=8, task="multiclass").compute()),
        'train_f1': float(F1Score(num_classes=8, task="multiclass").compute()),
        'val_f1': float(F1Score(num_classes=8, task="multiclass").compute())
    }

    with open('metrics.json', 'a') as f:
        json.dump(metrics, f)
        f.write('\n')

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"Epoch finished in {perf_counter() - t0:.2f} seconds")
    print()
    model_path = f"model_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)





Starting epoch 1
Training...



[A

outputs
torch.Size([8, 8, 576, 1024])
torch.float32
masks
torch.Size([8, 576, 1024])
torch.uint8
torch.Size([8, 576, 1024])


  0%|          | 0/25 [00:52<?, ?it/s]
  0%|          | 0/10 [00:52<?, ?it/s]


RuntimeError: one_hot is only applicable to index tensor.

# Training module

In [None]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super(SegmentationModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        # self.criterion = torch.nn.CrossEntropyLoss()
        self.criterion = smp.losses.JaccardLoss(mode="multiclass", from_logits=False)

        # Metrics
        self.train_iou = JaccardIndex(num_classes=8, task="multiclass")
        self.val_iou = JaccardIndex(num_classes=8, task="multiclass")
        self.train_precision = Precision(num_classes=8, task="multiclass")
        self.val_precision = Precision(num_classes=8, task="multiclass")
        self.train_recall = Recall(num_classes=8, task="multiclass")
        self.val_recall = Recall(num_classes=8, task="multiclass")
        self.train_f1 = F1Score(num_classes=8, task="multiclass")
        self.val_f1 = F1Score(num_classes=8, task="multiclass")

    def forward(self, x):
        output = self.model(x.to(device))
        # FIXME: not really x > 0.5
        # this is degenerated case for binary segmentation
        # output = torch.argmax(output, dim=1)
        return output

    def training_step(self, batch, batch_idx):
        
        images, masks = batch
        print("masks before")
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        # masks = masks.to(torch.int16).squeeze()
        masks = masks.squeeze()
        preds = self(images).squeeze()
        # preds = self(images).to(torch.int16).squeeze()
        print("masks")
        print(masks.shape)
        print(masks.dtype)
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        print("preds")
        print(preds.shape)
        print(preds.dtype)
        # unique, counts = np.unique(preds.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        loss = self.criterion(preds, masks)
        loss.requires_grad = True
        
        self.log('train_loss', loss, on_epoch=True)
        self.log('train_iou', self.train_iou(preds, masks), on_epoch=True)
        self.log('train_precision', self.train_precision(preds, masks), on_epoch=True)
        self.log('train_recall', self.train_recall(preds, masks), on_epoch=True)
        self.log('train_f1', self.train_f1(preds, masks), on_epoch=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        print("masks before")
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        # masks = masks.to(torch.int16).squeeze()
        masks = masks.squeeze()
        preds = self(images).squeeze()
        # preds = self(images).to(torch.int16).squeeze()
        print("masks")
        print(masks.shape)
        print(masks.dtype)
        # unique, counts = np.unique(masks.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        print("preds")
        print(preds.shape)
        print(preds.dtype)
        # unique, counts = np.unique(preds.cpu(), return_counts=True)
        # print(dict(zip(unique, counts)))
        loss = self.criterion(preds, masks)
        loss.requires_grad = True
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_iou', self.val_iou(preds, masks), on_epoch=True)
        self.log('val_precision', self.val_precision(preds, masks), on_epoch=True)
        self.log('val_recall', self.val_recall(preds, masks), on_epoch=True)
        self.log('val_f1', self.val_f1(preds, masks), on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
model = smp.Unet(
    encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=8,  # model output channels (number of classes in your dataset)
    activation="softmax",
).to(device)

In [None]:
# model = smp.UnetPlusPlus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,  # model output channels (number of classes in your dataset)
# ).to(device)

In [None]:
# model = smp.DeepLabV3(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [None]:
# model = smp.DeepLabV3Plus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [None]:
segmentation_model = SegmentationModel(model)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss_epoch', save_top_k=-1, mode='min'
    )

In [None]:
logger = CSVLogger("logs", name="segmentation_model")

In [None]:
trainer = pl.Trainer(max_epochs=10, callbacks=[checkpoint_callback], logger=logger, accelerator=str(device))
# trainer = pl.Trainer(max_epochs=10, accelerator="cpu")

In [None]:
trainer.fit(segmentation_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# UNET

In [None]:
# TODO: load best model from checkpoint

In [None]:
model.eval()
with torch.no_grad():
    for images, masks in train_loader:
        print(images.shape)
        print(masks.shape)
        break
    output = model(images.to(device))
    output = torch.argmax(output, dim=1)