In [1]:
import csv
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torchvision.transforms import Resize, InterpolationMode, ToTensor, ToPILImage, TenCrop, Compose, Lambda
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
import segmentation_models_pytorch as smp

from src.models.BaselineModel import BaselineModel
from src.evaluation.evaluate_result import evaluate_result
from src.callbacks.SaveRandomImagesCallback import SaveRandomImagesCallback
from src.callbacks.SaveTestPredsMulticlass import SaveTestPredsMulticlass
from src.datasets.utils.Squeeze5DimIfNeeded import Squeeze5DimIfNeeded
from src.datasets.DubaiSemanticSegmentationDataset import (
    DubaiSemanticSegmentationDataset,
)
from src.datasets.utils.ResizeToDivisibleBy32 import ResizeToDivisibleBy32

  from .autonotebook import tqdm as notebook_tqdm


## Prepare environment

In [2]:
torch.cuda.is_available()

True

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again
print(device)

cuda


In [4]:
VAL_SIZE = 0.2
BATCH_SIZE = 1
SEED = 42
SAVE_VAL_DIR = "outputs/Dubai/val"
SAVE_TEST_DIR = "outputs/Dubai/test"
DUBAI_DATASET_PATH = "data/DubaiSemanticSegmentationDataset"
IMAGE_SIZE = 576
CROP_IMAGE_SIZE = 224

In [5]:
labeled_dataset = DubaiSemanticSegmentationDataset(
    DUBAI_DATASET_PATH, 
    # transforms=[ResizeToDivisibleBy32()]
    # transforms=[TenCrop(CROP_IMAGE_SIZE, vertical_flip=True)]
    # transforms=Compose([
    #     TenCrop(size=CROP_IMAGE_SIZE, vertical_flip=True),
    #     Lambda(lambda crops: torch.stack([crop for crop in crops])),
    #     Squeeze5DimIfNeeded(),
    # ])
    # transforms=[
    #     TenCrop(CROP_IMAGE_SIZE, vertical_flip=True),
    #     Lambda(lambda crops: torch.stack([crop for crop in crops]))
    # ]
)
print(len(labeled_dataset))

72


## Data preparation

### Sanity check data

In [6]:
sanity_check_loader = DataLoader(labeled_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Train, validation and test split

In [7]:
train_size = int(0.7 * len(labeled_dataset))
val_size = int(0.2 * len(labeled_dataset))
test_size = len(labeled_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
    labeled_dataset, [train_size, val_size, test_size]
)
print(len(train_dataset), len(val_dataset), len(test_dataset))

50 14 8


In [8]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
for images, masks in train_loader:
    print(images.shape)
    print(images.squeeze(0).shape)
    print()
    print(masks.shape)
    print(masks.squeeze(0).shape)
    break

# with TenCrop use:
# for batch in train_loader:
#     images, masks = batch
#     print(images.shape)
#     print(masks.shape)
#     break

Transforming
torch.Size([3, 1479, 2149])
torch.Size([1, 1479, 2149])
torch.Size([1, 10, 3, 224, 224])
torch.Size([10, 3, 224, 224])

torch.Size([1, 10, 1, 224, 224])
torch.Size([10, 1, 224, 224])


In [52]:
to_pil_transform = ToPILImage()

In [53]:
if BATCH_SIZE == 1:
    img = to_pil_transform(images.squeeze())

ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

In [55]:
# img.show()

In [None]:
if BATCH_SIZE == 1:
    msk = to_pil_transform(masks.squeeze())

In [None]:
# msk.show()

# Training module

In [56]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super(SegmentationModel, self).__init__()

        self.model = model
        self.learning_rate = learning_rate

        self.criterion = smp.losses.JaccardLoss(mode="multiclass", from_logits=False)
        self.train_iou = JaccardIndex(num_classes=6, task="multiclass")
        self.val_iou = JaccardIndex(num_classes=6, task="multiclass")

    def forward(self, x):
        output = self.model(x.to(device))
        return output

    def training_step(self, batch, batch_idx):
        images, masks = batch
        preds = self(images)
        loss = self.criterion(preds, masks.squeeze(1).long())

        self.log("train_loss", loss, on_epoch=True, on_step=True)
        self.log(
            "train_iou",
            self.train_iou(preds, masks.squeeze(1)),
            on_epoch=True,
            on_step=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        preds = self(images)
        loss = self.criterion(preds, masks.squeeze(1).long())

        self.log("val_loss", loss, on_epoch=True, on_step=True)
        self.log(
            "val_iou",
            self.val_iou(preds, masks.squeeze(1)),
            on_epoch=True,
            on_step=True,
        )

        return loss

    def test_step(self, batch, batch_idx):
        # just here to activate the test_epoch_end
        # callback SaveTestPreds starts on_test_epoch_end
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

### Model selection

In [57]:
model = smp.Unet(
    encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=6,  # model output channels (number of classes in your dataset)
    activation="softmax",
).to(device)

In [58]:
# model = smp.UnetPlusPlus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,  # model output channels (number of classes in your dataset)
# ).to(device)

In [59]:
# model = smp.DeepLabV3(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [60]:
# model = smp.DeepLabV3Plus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [61]:
segmentation_model = SegmentationModel(model)

### Configuration

In [62]:
model_checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", save_top_k=-1, mode="min"
)

In [63]:
logger = CSVLogger("lightning_logs", name="dubai_segmentation_model")

In [64]:
save_callback = SaveRandomImagesCallback(save_dir=SAVE_VAL_DIR)

# Dubai test dataset is also labeled
# so we can use different logic for test set evaluation
save_test_preds_callback = SaveTestPredsMulticlass(save_dir=SAVE_TEST_DIR)

early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", patience=5)

In [65]:
trainer = pl.Trainer(
    max_epochs=50,
    log_every_n_steps=10,
    callbacks=[
        model_checkpoint_callback,
        save_callback,
        save_test_preds_callback,
        early_stopping_callback,
    ],
    logger=logger,
)

# DEBUG
# trainer = pl.Trainer(
#     max_epochs=1,
#     callbacks=[model_checkpoint_callback, save_callback],
#     logger=logger,
#     accelerator="cpu"
#     )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### Run

In [66]:
trainer.fit(
    segmentation_model, train_dataloaders=train_loader, val_dataloaders=val_loader
)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | model     | Unet                   | 14.3 M
1 | criterion | JaccardLoss            | 0     
2 | train_iou | MulticlassJaccardIndex | 0     
3 | val_iou   | MulticlassJaccardIndex | 0     
-----------------------------------------------------
14.3 M    Trainable params
0         Non-trainable params
14.3 M    Total params
57.316    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

d:\__repos\aerial_segmentation\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 10, 3, 224, 224]

In [None]:
trainer.test(ckpt_path="best", dataloaders=test_loader)