In [45]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from torchvision.transforms import Resize, InterpolationMode, ToPILImage
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
import segmentation_models_pytorch as smp

from src.evaluation.evaluate_result import evaluate_result
from src.callbacks.SaveRandomImagesCallback import SaveRandomImagesCallback
from src.callbacks.SaveTestPreds import SaveTestPreds
from src.datasets.INRIAAerialImageLabellingDatasetPatches import (
    INRIAAerialImageLabellingDatasetPatches,
)
from src.utils import view_and_save_images_shapes
from src.datasets.utils.ResizeToDivisibleBy32 import ResizeToDivisibleBy32

## Prepare environment

In [46]:
torch.cuda.is_available()

True

In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again
print(device)

cuda


In [48]:
VAL_SIZE = 0.2
BATCH_SIZE = 16
SEED = 42
IMAGE_SIZE = 576
SAVE_VAL_DIR = "outputs/INRIA-patches/val"
SAVE_TEST_DIR = "outputs/INRIA-patches/test"
INRIA_PATCHES_DATASET_PATH = "data/INRIAAerialImageLabellingDatasetPatches"  # home PC

In [49]:
labeled_dataset = INRIAAerialImageLabellingDatasetPatches(
    INRIA_PATCHES_DATASET_PATH,
    split="train",
    transforms=[ResizeToDivisibleBy32()],
)
print(len(labeled_dataset))

18000


In [50]:
test_dataset = INRIAAerialImageLabellingDatasetPatches(
    INRIA_PATCHES_DATASET_PATH,
    split="test",
    transforms=[ResizeToDivisibleBy32()],
)
print(len(test_dataset))

14400


## Prepare data

In [51]:
sanity_check_dataset = INRIAAerialImageLabellingDatasetPatches(
    INRIA_PATCHES_DATASET_PATH,
    split="train",
    transforms=[ResizeToDivisibleBy32()],
)
print(len(sanity_check_dataset))

18000


In [52]:
sanity_check_dataloader = DataLoader(
    sanity_check_dataset, batch_size=BATCH_SIZE, shuffle=True
)

In [53]:
for images, masks in sanity_check_dataloader:
    print("images")
    print(images.shape)
    unique, counts = np.unique(images, return_counts=True)
    print(dict(zip(unique, counts)))
    print()
    print("masks")
    print(masks.shape)
    unique, counts = np.unique(masks, return_counts=True)
    print(dict(zip(unique, counts)))
    break

images
torch.Size([16, 3, 512, 512])
{0.0: 10306, 0.003921569: 16872, 0.007843138: 8924, 0.011764706: 3758, 0.015686275: 2456, 0.019607844: 2610, 0.023529412: 2772, 0.02745098: 2495, 0.03137255: 2383, 0.03529412: 2151, 0.039215688: 2190, 0.043137256: 2658, 0.047058824: 3352, 0.050980393: 3261, 0.05490196: 3454, 0.05882353: 3791, 0.0627451: 4469, 0.06666667: 5606, 0.07058824: 8516, 0.07450981: 13283, 0.078431375: 17363, 0.08235294: 18285, 0.08627451: 18593, 0.09019608: 18903, 0.09411765: 21447, 0.09803922: 23591, 0.101960786: 23067, 0.105882354: 22800, 0.10980392: 22157, 0.11372549: 25089, 0.11764706: 26111, 0.12156863: 27646, 0.1254902: 31428, 0.12941177: 34024, 0.13333334: 39943, 0.13725491: 44055, 0.14117648: 45839, 0.14509805: 47214, 0.14901961: 50342, 0.15294118: 52682, 0.15686275: 53640, 0.16078432: 55762, 0.16470589: 56878, 0.16862746: 58460, 0.17254902: 61373, 0.1764706: 63883, 0.18039216: 65593, 0.18431373: 68307, 0.1882353: 71636, 0.19215687: 72752, 0.19607843: 76906, 0.2: 804

In [54]:
# view_and_save_images_shapes(
#     sanity_check_dataloader, "inria_patches_images_shapes", verbose=True
# )

### Train, validation and test data

In [55]:
train_size = int(0.8 * len(labeled_dataset))
val_size = len(labeled_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    labeled_dataset, [train_size, val_size]
)
print(len(train_dataset), len(val_dataset), len(test_dataset))

14400 3600 14400


In [56]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    persistent_workers=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    persistent_workers=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    persistent_workers=True,
)

In [57]:
for images, masks in train_loader:
    print(images.shape)
    print(masks.shape)
    break

torch.Size([16, 3, 512, 512])
torch.Size([16, 1, 512, 512])


In [58]:
to_pil_transform = ToPILImage()

In [59]:
if BATCH_SIZE == 1:
    img = to_pil_transform(images.squeeze())

In [60]:
# img.show()

In [61]:
if BATCH_SIZE == 1:
    msk = to_pil_transform(masks.squeeze()).convert("L")

In [62]:
# msk.show()

# Training module

In [63]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super(SegmentationModel, self).__init__()

        self.model = model
        self.learning_rate = learning_rate

        self.criterion = smp.losses.MCCLoss()
        self.train_iou = JaccardIndex(num_classes=2, task="binary")
        self.val_iou = JaccardIndex(num_classes=2, task="binary")

    def forward(self, x):
        output = self.model(x.to(device))
        return output

    def training_step(self, batch, batch_idx):
        images, masks = batch
        masks = torch.div(masks, 255).float()
        preds = self(images)
        loss = self.criterion(preds, masks)

        self.log("train_loss", loss, on_epoch=True, on_step=True)
        self.log("train_iou", self.train_iou(preds, masks), on_epoch=True, on_step=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        masks = torch.div(masks, 255).float()
        preds = self(images)
        loss = self.criterion(preds, masks)

        self.log("val_loss", loss, on_epoch=True, on_step=True)
        self.log("val_iou", self.val_iou(preds, masks), on_epoch=True, on_step=True)

        return loss

    def test_step(self, batch, batch_idx):
        # just here to activate the test_epoch_end
        # callback SaveTestPreds starts on_test_epoch_end
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [79]:
model = smp.Unet(
    encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,  # model output channels (number of classes in your dataset)
    activation="sigmoid",
).to(device)

In [65]:
# model = smp.UnetPlusPlus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,  # model output channels (number of classes in your dataset)
# ).to(device)

In [66]:
# model = smp.DeepLabV3(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
#     activation="sigmoid"
# ).to(device)

In [67]:
# model = smp.DeepLabV3Plus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [80]:
# segmentation_model = SegmentationModel(model)
segmentation_model = SegmentationModel.load_from_checkpoint(
    r"lightning_logs\inria_patches_segmentation_model\version_1\checkpoints\epoch=29-step=27000.ckpt",
    model=model
    )

In [81]:
model_checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", save_top_k=-1, mode="min"
)

In [82]:
logger = CSVLogger("lightning_logs", name="inria_patches_segmentation_model")

In [83]:
save_callback = SaveRandomImagesCallback(save_dir=SAVE_VAL_DIR)
save_test_preds_callback = SaveTestPreds(save_dir=SAVE_TEST_DIR)
early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", patience=5)

In [84]:
trainer = pl.Trainer(
    max_epochs=50,
    log_every_n_steps=10,
    callbacks=[
        model_checkpoint_callback,
        save_callback,
        save_test_preds_callback,
        early_stopping_callback
    ],
    logger=logger,
)

# DEBUG
# trainer = pl.Trainer(
#     max_epochs=1,
#     callbacks=[model_checkpoint_callback, save_callback],
#     logger=logger,
#     accelerator="cpu"
#     )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [85]:
trainer.fit(
    segmentation_model, train_dataloaders=train_loader, val_dataloaders=val_loader
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Unet               | 14.3 M
1 | criterion | MCCLoss            | 0     
2 | train_iou | BinaryJaccardIndex | 0     
3 | val_iou   | BinaryJaccardIndex | 0     
-------------------------------------------------
14.3 M    Trainable params
0         Non-trainable params
14.3 M    Total params
57.313    Total estimated model params size (MB)


Epoch 9: 100%|██████████| 900/900 [07:57<00:00,  1.89it/s, v_num=3]        


In [86]:
trainer.test(ckpt_path="best", dataloaders=test_loader)

Restoring states from the checkpoint path at lightning_logs\inria_patches_segmentation_model\version_3\checkpoints\epoch=4-step=4500.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs\inria_patches_segmentation_model\version_3\checkpoints\epoch=4-step=4500.ckpt


Testing DataLoader 0: 100%|██████████| 900/900 [00:58<00:00, 15.47it/s]


[{}]