In [245]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchvision import transforms
from torchvision.transforms import Resize, InterpolationMode
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
import segmentation_models_pytorch as smp

from src.models.BaselineModel import BaselineModel
from src.evaluation.evaluate_result import evaluate_result
from src.datasets.INRIAAerialImageLabellingDataset import (
    INRIAAerialImageLabellingDataset,
)

from src.datasets.utils.ResizeToDivisibleBy32 import ResizeToDivisibleBy32

## Prepare environment

In [246]:
torch.cuda.is_available()

True

In [247]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again
print(device)

cuda


In [248]:
VAL_SIZE = 0.2
BATCH_SIZE = 1
SEED = 42
IMAGE_SIZE = 576
INRIA_DATASET_PATH = "data/INRIAAerialImageLabellingDataset"  # home PC
# INRIA_DATASET_PATH = "data/TestSubsets/INRIAAerialImageLabellingDataset"  # laptop

In [249]:
labeled_dataset = INRIAAerialImageLabellingDataset(
    INRIA_DATASET_PATH,
    split="train",
    transforms=[
        # Resize(IMAGE_SIZE),
        Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(labeled_dataset))

data/INRIAAerialImageLabellingDataset\train
180


In [250]:
test_dataset = INRIAAerialImageLabellingDataset(
    INRIA_DATASET_PATH,
    split="test",
    transforms=[
        # Resize(IMAGE_SIZE),
        Resize(IMAGE_SIZE, interpolation=InterpolationMode.NEAREST_EXACT),
        # ResizeToDivisibleBy32()
    ],
)
print(len(test_dataset))

data/INRIAAerialImageLabellingDataset\test
144


## Sanity check data

In [251]:
train_size = int(0.8 * len(labeled_dataset))
val_size = len(labeled_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(labeled_dataset, [train_size, val_size])

In [252]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [253]:
for images, masks in train_loader:
    print(images.shape)
    print(masks.shape)
    break

torch.Size([1, 3, 576, 576])
torch.Size([1, 1, 576, 576])


In [254]:
to_pil_transform = transforms.ToPILImage()

In [255]:
img = to_pil_transform(images.squeeze())

In [256]:
# img.show()

In [257]:
msk = to_pil_transform(masks.squeeze())

In [258]:
# msk.show()

# Training module

In [259]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super(SegmentationModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.criterion = torch.nn.BCELoss()

        # Metrics
        self.train_iou = JaccardIndex(num_classes=2, task="binary")
        self.val_iou = JaccardIndex(num_classes=2, task="binary")
        # self.train_precision = Precision(num_classes=2, task="binary")
        # self.val_precision = Precision(num_classes=2, task="binary")
        # self.train_recall = Recall(num_classes=2, task="binary")
        # self.val_recall = Recall(num_classes=2, task="binary")
        # self.train_f1 = F1Score(num_classes=2, task="binary")
        # self.val_f1 = F1Score(num_classes=2, task="binary")

    def forward(self, x):
        x = self.model(x.to(device))
        output = (x > 0.5).float()
        print("forward")
        print(output.shape)
        print(type(output))
        print(output.dtype)
        print(output.max())
        print(output.min())
        return output

    def training_step(self, batch, batch_idx):
        
        images, masks = batch
        masks = torch.div(masks, 255).float()
        preds = self(images)
        print("training_step")
        
        print("preds")
        print(preds.shape)
        print(type(preds))
        print(preds.dtype)
        print(preds.max())
        print(preds.min())

        print("masks")
        print(masks.shape)
        print(type(masks))
        print(masks.dtype)
        print(masks.max())
        print(masks.min())
        print(masks.unique())
        loss = self.criterion(preds, masks)
        loss.requires_grad = True
        
        # preds = torch.sigmoid(y_hat).round()
        
        self.log('train_loss', loss, on_epoch=True)
        self.log('train_iou', self.train_iou(preds, masks), on_epoch=True)
        # self.log('train_precision', self.train_precision(preds, y), on_epoch=True)
        # self.log('train_recall', self.train_recall(preds, y), on_epoch=True)
        # self.log('train_f1', self.train_f1(preds, y), on_epoch=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        masks = torch.div(masks, 255).float()
        preds = self(images)
        print("validation_step")

        print("preds")
        print(preds.shape)
        print(type(preds))
        print(preds.dtype)
        print(preds.max())
        print(preds.min())
    
        print("masks")
        print(masks.shape)
        print(type(masks))
        print(masks.dtype)
        print(masks.max())
        print(masks.min())
        print(masks.unique())
        loss = self.criterion(preds, masks)
        loss.requires_grad = True
        
        # preds = torch.sigmoid(masks_hat).round()
        
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_iou', self.val_iou(preds, masks), on_epoch=True)
        # self.log('val_precision', self.val_precision(preds, y), on_epoch=True)
        # self.log('val_recall', self.val_recall(preds, y), on_epoch=True)
        # self.log('val_f1', self.val_f1(preds, y), on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [260]:
model = smp.Unet(
    encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,  # model output channels (number of classes in your dataset)
).to(device)

In [261]:
# model = smp.UnetPlusPlus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,  # model output channels (number of classes in your dataset)
# ).to(device)

In [262]:
# model = smp.DeepLabV3(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [263]:
# model = smp.DeepLabV3Plus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [264]:
segmentation_model = SegmentationModel(model)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss_epoch', save_top_k=-1, mode='min'
    )

In [None]:
logger = CSVLogger("logs", name="segmentation_model")

In [267]:
trainer = pl.Trainer(max_epochs=10, callbacks=[checkpoint_callback], logger=logger, gpus=1)
# trainer = pl.Trainer(max_epochs=10, accelerator="cpu")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [268]:
trainer.fit(segmentation_model, train_dataloaders=train_loader, val_dataloaders=val_loader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Unet               | 14.3 M
1 | criterion | BCELoss            | 0     
2 | train_iou | BinaryJaccardIndex | 0     
3 | val_iou   | BinaryJaccardIndex | 0     
-------------------------------------------------
14.3 M    Trainable params
0         Non-trainable params
14.3 M    Total params
57.313    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]forward
torch.Size([1, 1, 576, 576])
<class 'torch.Tensor'>
torch.float32
tensor(1., device='cuda:0')
tensor(0., device='cuda:0')
validation_step
preds
torch.Size([1, 1, 576, 576])
<class 'torch.Tensor'>
torch.float32
tensor(1., device='cuda:0')
tensor(0., device='cuda:0')
masks
torch.Size([1, 1, 576, 576])
<class 'torch.Tensor'>
torch.float32
tensor(1., device='cuda:0')
tensor(0., device='cuda:0')
tensor([0., 1.], device='cuda:0')
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:02<00:02,  0.47it/s]forward
torch.Size([1, 1, 576, 576])
<class 'torch.Tensor'>
torch.float32
tensor(1., device='cuda:0')
tensor(0., device='cuda:0')
validation_step
preds
torch.Size([1, 1, 576, 576])
<class 'torch.Tensor'>
torch.float32
tensor(1., device='cuda:0')
tensor(0., device='cuda:0')
masks
torch.Size([1, 1, 576, 576])
<class 'torch.Tensor'>
torch.float32
tensor(1., device='cuda:0')
tensor(0., device='cuda:0')
tensor([0., 1.], dev

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 144/144 [01:00<00:00,  2.39it/s, v_num=25]


# UNET

In [None]:
# TODO: load best model from checkpoint

In [None]:
model.eval()
with torch.no_grad():
    for images, masks in train_loader:
        print(images.shape)
        print(masks.shape)
        break
    output = model(images.to(device))
    output = (output > 0.5).float()