[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rslab-ntua/MSc_GBDA/2022/Lab3.ipynb)

In [None]:
!wget -P data/oxford-iiit-pet "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
!wget -P data/oxford-iiit-pet "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"

!tar -xf data/oxford-iiit-pet/annotations.tar.gz --directory data/oxford-iiit-pet
!tar -xf data/oxford-iiit-pet/images.tar.gz --directory data/oxford-iiit-pet

In [None]:
!pip install albumentations
!pip install pytorch-lightning

%load_ext tensorboard

In [None]:
from typing import Optional, List
import torch
import matplotlib.pyplot as plt

# Utility function for image grid plots
def display_image_grid(images: List[torch.Tensor], masks: List[torch.Tensor], predicted_masks: Optional[List[torch.Tensor]] = None):
    cols = 3 if predicted_masks is not None else 2
    rows = len(images)
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 24))
    for i in range(rows):
        ax[i, 0].imshow(images[i].cpu().numpy().transpose(1,2,0))
        ax[i, 1].imshow(masks[i].cpu().numpy(), interpolation="nearest")

        ax[i, 0].set_title("Image")
        ax[i, 1].set_title("Ground truth mask")

        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()

        if predicted_masks is not None:
            predicted_mask = predicted_masks[i]
            ax[i, 2].imshow(predicted_mask.cpu().numpy(), interpolation="nearest")
            ax[i, 2].set_title("Predicted mask")
            ax[i, 2].set_axis_off()
    plt.tight_layout()
    plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import glob
import os
import copy
import numpy as np
import cv2

DATA_ROOT = "data/oxford-iiit-pet"

class OxfordPets(Dataset):
    def __init__(self, data_root, transform=None, indices=None):
        super().__init__()
        self.indices = indices
        self._build_db(data_root)
        
        self.transform = transform
        
    
    def _build_db(self, data_root):
        self.db = []
        
        im_files = sorted(glob.glob(os.path.join(data_root, "images", "*.jpg")))
        
        def append_file(im_file):
            
            if cv2.imread(im_file) is None:
                return
            
            im_name = os.path.splitext(os.path.basename(im_file))[0]
            mask_file = os.path.join(data_root, "annotations", "trimaps", im_name + ".png")
            
            sample = {
                "im_file": im_file,
                "mask_file": mask_file
            }
            self.db.append(sample)
        
        if self.indices is not None:
            for idx in self.indices:
                append_file(im_files[idx])
        else:
            for im_file in im_files:
                append_file(im_file)
            
    
    def preprocess_mask(self, mask):
        mask = mask.astype(np.float32)
        mask[mask == 2.0] = 0.0
        mask[(mask == 1.0) | (mask == 3.0)] = 1.0
        return mask

    def _load_data(self, sample):
        s = copy.copy(sample)
        s.update({
            "im": cv2.cvtColor(cv2.imread(s["im_file"]), cv2.COLOR_BGR2RGB),
            "mask": self.preprocess_mask(cv2.imread(s["mask_file"], cv2.IMREAD_UNCHANGED))
        })
        
        return s
    
    def __getitem__(self, index):
        sample = self._load_data(self.db[index])
        
        if self.transform is not None:
            transformed = self.transform(image=sample["im"], mask=sample["mask"])
            sample["im"] = transformed["image"]
            sample["mask"] = transformed["mask"]
        
        return sample
    
    def __len__(self):
        return len(self.db)

transform = A.Compose(
    [
        # A.Resize(256, 256),
        ToTensorV2(),
    ]
)

dset = OxfordPets(DATA_ROOT, transform=transform)

ims, masks = [], []
for i in range(3):
    ims.append(dset[i]["im"])
    masks.append(dset[i]["mask"])

display_image_grid(ims, masks)

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

dset = OxfordPets(DATA_ROOT, transform=transform)
all_indices = np.arange(len(dset))

train_indices, val_indices = train_test_split(all_indices, test_size=0.3)


train_transform = A.Compose(
    [
        A.Resize(256, 256),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)
train_dset = OxfordPets(DATA_ROOT, transform=train_transform, indices=train_indices)

val_transform = A.Compose(
    [A.Resize(256, 256), 
     A.Normalize(mean=(0.485, 0.456, 0.406),
                 std=(0.229, 0.224, 0.225)),
     ToTensorV2()]
)
val_dset = OxfordPets(DATA_ROOT, transform=val_transform, indices=val_indices)


ims, masks = [], []
for i in range(2):
    ims.append(val_dset[i]["im"])
    masks.append(val_dset[i]["mask"])

print(len(train_dset), len(val_dset))
display_image_grid(ims, masks)

# Configure DataLoaders
train_dloader = DataLoader(train_dset, batch_size=4, shuffle=True, num_workers=2)
val_dloader = DataLoader(val_dset, batch_size=4, shuffle=False, num_workers=2)

In [None]:
from torchvision.models.vgg import vgg11


vgg = vgg11(pretrained=True, progress=False)
print(vgg)

In [None]:
import pytorch_lightning as pl
from torchsummary import summary
from torch import nn
from torchvision.models.vgg import vgg11
from torch.nn import functional as F
from torchmetrics import Accuracy


class FCN(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        
        vgg = vgg11(pretrained=True)
        self.encoder = nn.Sequential(*list(vgg.features.children())[:15])
        
        self.decoder = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Flatten(0, 1)
        )
        
        self.lr = lr
        self.pixel_accuracy = Accuracy()
        self.save_hyperparameters()
        
    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)
    
    def training_step(self, batch, batch_idx):
        im = batch["im"]
        mask = batch["mask"]
        
        preds = self(im)
        
        loss = F.binary_cross_entropy_with_logits(input=preds, target=mask)
        
        self.log("loss/train", loss, on_step=False, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        im = batch["im"]
        mask = batch["mask"]
        
        preds = self(im)
        
        loss = F.binary_cross_entropy_with_logits(input=preds, target=mask)
        
        self.log("loss/val", loss, on_step=False, on_epoch=True)

        self.pixel_accuracy(preds, mask.type(torch.long))
        self.log("px_acc/val", self.pixel_accuracy, on_step=False, on_epoch=True)
        
            
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
print(summary(FCN(), input_size=(3,256,256), device="cpu"))

In [None]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

callbacks = [
    EarlyStopping(monitor="px_acc/val", mode="max", patience=3),
    ModelCheckpoint(monitor="px_acc/val", mode="max", save_last=True)
]

model = FCN()
trainer = pl.Trainer(
    accelerator="gpu", 
    devices=1,
    max_epochs=20,
    callbacks=callbacks,
    default_root_dir="fcn"
)

trainer.fit(model, train_dataloaders=train_dloader, val_dataloaders=val_dloader)


In [None]:
%tensorboard --logdir fcn/lightning_logs

In [None]:
# Inference with "best" model

best_model = FCN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
best_model.eval()

s = next(iter(val_dloader))

with torch.no_grad():
    display_image_grid(s["im"], s["mask"], best_model(s["im"]) > 0)