[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rslab-ntua/MSc_GBDA/blob/master/2022/Lab5.ipynb)

In [None]:
# Download data, unzip
!gdown --fuzzy https://drive.google.com/file/d/1pgTcsgGwogtc4EPy1I2mOtMiB-Hl13Iy/view?usp=sharing
!tar -xf dataset.tar.gz --directory ./

In [None]:
!pip install pytorch_lightning

# Data preparation and feeding pipeline

In [None]:
from torch.utils.data import Dataset, DataLoader, default_collate
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms.functional import convert_image_dtype
from torchvision.utils import draw_bounding_boxes
from matplotlib import pyplot as plt
import torch
import numpy as np
import pandas as pd
import os

DATA_ROOT = "./dataset"

# Define a class to handle data loading from disk
class ODDataset(Dataset):
    def __init__(self, data_root, mode="train"):
        '''
        data_root: path to the root directory of the dataset
        mode: ["train"(default)/"val"/"test"] available modes/splits of the dataset
        '''
        super().__init__()
        
        assert mode in ["train", "val", "test"]

        self.mode = mode
        self.root = data_root
        
        # Build dataset
        self._build_db()
        
        # Define dataset nomenclanture
        self.categories = {
            "battery": 0,
            "dice": 1,
            "toycar": 2,
            "candle": 3,
            "highlighter": 4,
            "spoon": 5
        }
        
    @property
    def reverse_categories(self) -> dict:
        '''
        Returns a category id to category name mapping as a dict
        '''
        return {v:k for k, v in self.categories.items()}
        
        
    def _build_db(self) -> None:
        '''
        Collect a database of all available samples
        '''
        self.db = pd.read_csv(
            os.path.join(self.root, f"{self.mode}_test.txt"),
            header=None
        ).values.flatten().tolist()
        
    def _parse_sample(self, sample: str) -> tuple:
        '''
        Read image and object detection label data 
        '''
        
        im_file = os.path.join(self.root, self.mode, "images", sample)
        assert os.path.exists(im_file)
        
        label_file = os.path.join(self.root, self.mode, "labels", sample.replace(".jpg", ".txt"))
        assert os.path.exists(label_file)
        
        # read image from disk and convert to [0,1] float32
        im = convert_image_dtype(read_image(im_file, mode=ImageReadMode.RGB), dtype=torch.float32)
        
        # read label in KITTI format and parse labels and bboxes as list of dict {"labels", "boxes"}
        # boxes in xyxy format (left, top, right, bottom)
        labels_kitti = pd.read_csv(label_file, sep=" ", header=None).values
        labels = default_collate([{
            "boxes": torch.from_numpy(row_record[4:8].astype(np.int64)),
            "labels": self.categories[row_record[0]]
        } for row_record in labels_kitti ])
        
        return im, labels
    
    def __getitem__(self, index) -> tuple:
        '''
        Retrieve a specific sample from the dataset
        '''
        sample = self.db[index]
        return self._parse_sample(sample)
    
    def __len__(self):
        '''
        Returns the total number of samples in the dataset
        '''
        return len(self.db)

# Initialize a dataset instance for testing purposes
dset = ODDataset(DATA_ROOT, mode="train")

# Visualize sample #0
sample = dset[0]

# Draw bounding boxes on image
drawn_image = draw_bounding_boxes(
    image=convert_image_dtype(sample[0], torch.uint8),
    boxes=sample[1]["boxes"],
    labels=[dset.reverse_categories[int(cat_id)] for cat_id in sample[1]["labels"]],
)
plt.imshow(drawn_image.permute(1,2,0).numpy())

In [None]:
from typing import List, Tuple



train_dset = ODDataset(DATA_ROOT, mode="train")
val_dset = ODDataset(DATA_ROOT,mode="val")
test_dset = ODDataset(DATA_ROOT, mode="test")

# Define an appropriate DataLoader

# Specify a custom collate function to perform manual batching
def custom_collate(samples: List[Tuple[torch.Tensor, dict]]) -> Tuple[List[torch.Tensor], List[dict]]:
    '''
    samples: a list of tuples. Each tuple has a RGB image as a torch.Tensor and a dict with the following:
                    -- "boxes" : Nx4 boxes found in each image
                    -- "labels": N, categories of each object found
    
    Returns (tuple):
        --  List of images (variable sizes allowed) of type torch.Tensor
        --  List of dict with the following data:
                -- "boxes" : Nx4 boxes found in each image
                -- "labels": N, categories of each object found
    '''
    images = [s[0] for s in samples]
    labels = [s[1] for s in samples]
    
    return images, labels

# Define dataloaders for train/val/test splits
train_dloader = DataLoader(train_dset, batch_size=8, shuffle=True, collate_fn=custom_collate, num_workers=8, prefetch_factor=2)
val_dloader = DataLoader(val_dset, batch_size=8, shuffle=False, collate_fn=custom_collate, num_workers=8, prefetch_factor=2)
test_dloader = DataLoader(test_dset, batch_size=8, shuffle=False, collate_fn=custom_collate, num_workers=8, prefetch_factor=2)

# Define an OD model

In [None]:
import pytorch_lightning as pl
from typing import Optional
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn

# Define a Faster RCNN model
class FasterRCNNModel(pl.LightningModule):
    def __init__(self, num_classes):
        '''
        num_classes: Number of target classes in the dataset
        '''
        super().__init__()
        
        # Use a pretrained backbone with "new" ROI head
        self.model = fasterrcnn_resnet50_fpn(
            pretrained=False, 
            progress=False, 
            num_classes=num_classes, 
            pretrained_backbone=True, 
            trainable_backbone_layers=True)
        self.training_phase: Optional[int] = None

        self.save_hyperparameters()
        
        
    def forward(self, x: List[torch.Tensor]) -> List[dict]:
        '''
        x: List of images (any size)
        
        Returns:
        - List of dicts with keys (one dict per input image):
            -- "boxes" : Nx4 boxes found in each image
            -- "labels" : N, categories of each object found
            -- "scores" : N, prediction scores for each object found
        '''
        assert not self.training, "Use forward only for inference!"
        
        return self.model(x)
    
        
    def training_step(self, batch, batch_idx):
        '''
        Training logic
        '''
        images, labels = batch
        
        # When in train mode FRCNN returns a dict of losses
        #    -- loss_objectness (RPN)
        #    -- loss_rpn_box_reg (RPN)
        #    -- loss_classifier (ROI Heads)
        #    -- loss_box_reg (ROI Heads)
        losses = self.model(images, labels) 
    
        # Reduce by sum and return the appropriate composite loss function according to the current training phase
        assert self.training_phase is not None
        if self.training_phase % 2 == 0:
            # Phase 0 or 2
            self.log("loss/objectness", losses["loss_objectness"], batch_size=len(images), on_epoch=True, on_step=False)
            self.log("loss/rpn_box_reg", losses["loss_rpn_box_reg"], batch_size=len(images), on_epoch=True, on_step=False)
            return sum([losses["loss_objectness"], losses["loss_rpn_box_reg"]])
        else:
            # Phase 1 or 3 
            self.log("loss/classifier", losses["loss_classifier"], batch_size=len(images), on_epoch=True, on_step=False)
            self.log("loss/box_reg", losses["loss_box_reg"], batch_size=len(images), on_epoch=True, on_step=False)
            return sum([losses["loss_classifier"], losses["loss_box_reg"]])
        
    
    def configure_optimizers(self):
        '''
        Define optimizer and trainable parameters according to current training phase
        '''
        print(f"Configuring optimizers for training phase : {self.training_phase} / 3")
        assert self.training_phase is not None, "Set training phase first!"
        if self.training_phase == 0:
            # Train only RPN + backbone
            return torch.optim.Adam(
                params=[
                    {"params": self.model.backbone.parameters()},
                    {"params":self.model.rpn.parameters()}
                ],
                lr = 1e-4
            )
        elif self.training_phase == 1:
            # Train only RoiHeads + backbone
            return torch.optim.Adam(
                params=[
                    {"params": self.model.backbone.parameters()},
                    {"params":self.model.roi_heads.parameters()}
                ],
                lr = 1e-4
            )
        elif self.training_phase == 2:
            # Train only RPN 
            return torch.optim.Adam(
                params=self.model.rpn.parameters(), #type: ignore
                lr = 1e-4
            )
        elif self.training_phase == 3:
            # Train only RoiHeads 
            return torch.optim.Adam(
                params=self.model.roi_heads.parameters(), #type: ignore
                lr = 1e-4
            )
        raise AssertionError("Invalid training phase")

# Fine tune on new data!

In [None]:
%tensorboard --logdir=.
# Reload manually to update results (at first hit reload after 1st epoch is done)

In [None]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from copy import deepcopy

# Initialize a Faster RCNN model
model = FasterRCNNModel(num_classes=6)

# Train in 4 phases as described in the original paper

# Phase 0
model.training_phase = 0

# Save initial weights of the backbone CNN
p0_backbone_state_dict = deepcopy(model.model.backbone.state_dict())

trainer = pl.Trainer(
    accelerator="gpu", 
    devices=1,
    max_epochs=10,
    default_root_dir="frcnn_p0"
)
trainer.fit(model, train_dataloaders=train_dloader)


# Phase 1

# Restore initial weights of the backbone CNN
model.model.backbone.load_state_dict(p0_backbone_state_dict)

model.training_phase = 1
trainer = pl.Trainer(
    accelerator="gpu", 
    devices=1,
    max_epochs=10,
    default_root_dir="frcnn_p1",
    check_val_every_n_epoch=5
)
trainer.fit(model, train_dataloaders=train_dloader, val_dataloaders=val_dloader)

# Phase 2
model.training_phase = 2
trainer = pl.Trainer(
    accelerator="gpu", 
    devices=1,
    max_epochs=10,
    default_root_dir="frcnn_p2"
)
trainer.fit(model, train_dataloaders=train_dloader)

# Phase 3

model.training_phase = 3
trainer = pl.Trainer(
    accelerator="gpu", 
    devices=1,
    min_epochs=10,
    max_epochs=100,
    default_root_dir="frcnn_p3",
    check_val_every_n_epoch=2,
)
trainer.fit(model, train_dataloaders=train_dloader, val_dataloaders=val_dloader)