# TorchVision Object Detection Finetuning for IMPTOX Particles

## Pytorch Lightning Adaptation

### Imports


In [6]:
import os
import torch
import torch.nn as nn
import numpy as np
import torchmetrics
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from PIL import Image
import matplotlib.pyplot as plt


from pycocotools.coco import COCO
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule
from pytorch_lightning.loggers import TensorBoardLogger


import torchvision
from torchvision.transforms import v2 as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.functional import to_tensor 
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator



#os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py")
#os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py")
#os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py")
#os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py")
#os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py")


import utils


DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#DEVICE = [0,1]
#DEVICE = torch.device('cpu')

torch.cuda.empty_cache() 


NUM_WORKERS = 4


### Dataset

In [7]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None, resize=(256, 256)):
        """
        Args:
            data_dir (string): Directory with all the images and masks.
            transform (callable, optional): Optional transform to be applied on a sample.
            target_transform (callable, optional): Optional transform to be applied on the target (mask).
            class_mode (string): 'file' for different mask files per class, 'color' for different colors in a single mask.
            color_mapping (dict): Mapping from color to class if class_mode is 'color'.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.resize = resize
        self.image_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            

        img_name = os.path.join(self.data_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        
        image = image.resize(self.resize, Image.BILINEAR)
        
        image = to_tensor(image).to(torch.float32)

        # Get all masks for the current image (normally only one, but ready if multiple classes are present)
        masks = []
        for f in os.listdir(self.data_dir):
            #print(f"> current file: {f}")
            if f.startswith(self.image_files[idx].replace('.jpg', '')) and f.endswith('_mask.png'):
                
                mask_path = os.path.join(self.data_dir, f)
                mask = Image.open(mask_path).convert('L')
                
                # Resize mask to a fixed size 
                mask = mask.resize((256, 256), Image.NEAREST)
                
                masks.append(np.array(mask))
                
        combined_mask = np.maximum.reduce(masks)
        
        #-----
        from scipy.ndimage import label
        # Use connected component labeling to find individual objects
        labeled_array, num_features = label(combined_mask)
        labeled_mask = torch.tensor(labeled_array)
        
        
        #mask = torch.tensor(combined_mask, dtype=torch.uint8)
        mask = torch.tensor(np.array(labeled_mask, dtype=np.uint8))
        #print(f"Mask: {mask.shape}")

        
            
        # Add the required information: boxes, labels, image_id, area, iscrowd and masks
        
        # instances are encoded as different colors
        obj_ids = torch.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        num_objs = len(obj_ids)

        
        # split the color-encoded mask into a set
        # of binary masks
        masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)
        #print(f"Masks: {masks}")

        # get bounding box coordinates for each mask. Clamp to avoid negative (invalid) values. 
        boxes = torch.clamp(masks_to_boxes(masks), min=0)
        #print(f"Boxes type: {type(boxes)}")
        
        # We need to filter out boxes that have zero width or height
        valid_boxes = []
        for box in boxes:
            x_min, y_min, x_max, y_max = box
            if (x_max > x_min) and (y_max > y_min):
                valid_boxes.append(box)

        # Convert back to tensor
        boxes = torch.stack(valid_boxes) if valid_boxes else torch.empty((0, 4))
        
        
        #print(f"Boxes dimentions: {boxes.shape}")

        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)

        image_id = idx
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Wrap sample and targets into torchvision tv_tensors:
        img = tv_tensors.Image(image)

        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
        target["masks"] = tv_tensors.Mask(masks).to(DEVICE)
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transform:
            image, target = self.transform(image, target)
        
        img = img.to(DEVICE)
        
        return img, target

### DataModule

In [8]:
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, train_val_test_split=(0.7, 0.15, 0.15)):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_val_test_split = train_val_test_split

    def setup(self, stage=None):
        
       # Define transforms
        transform = T.Compose([
            T.Resize((256, 256)),  # Resize images to 256x256
            T.ToTensor(),
            #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

        #target_transform = T.Compose([
        #    T.Resize((256, 256)),  # Resize images to 256x256
        #    T.ToTensor()
        #    ])

        
        # Load dataset
        self.dataset = CustomDataset(self.data_dir, transform=transform)#, target_transform=target_transform)
        
        # Split dataset into train, val, and test
        # Calculate split sizes
        train_size = int(self.train_val_test_split[0] * len(self.dataset))
        val_size = int(self.train_val_test_split[1] * len(self.dataset))
        test_size = len(self.dataset) - train_size - val_size
        self.train_data, self.val_data, self.test_data = random_split(self.dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(self.train_data, num_workers = NUM_WORKERS, pin_memory=True, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, num_workers = NUM_WORKERS, pin_memory=True, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_data, num_workers = NUM_WORKERS, pin_memory=True, batch_size=self.batch_size)



# ----------------------------------
# Instantiate the CustomDataModule
# ----------------------------------

data_dir = '/mnt/remote/workspaces/thibault.schowing/0_DATA/IMPTOX/00_Dataset/uFTIR_CurSquareSemantic.v1i.png-mask-semantic/train'
batch_size = 4
train_val_test_split = (0.6, 0.2, 0.2)



dm = CustomDataModule(data_dir, batch_size = batch_size, train_val_test_split=train_val_test_split)
dm.setup()

### Create the Model

In [None]:
class FasterRCNNModule(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-3):
        super(FasterRCNNModule, self).__init__()
        self.save_hyperparameters()
        
        # load a model pre-trained on COCO
        self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

        # replace the classifier with a new one, that has
        # num_classes which is user-defined
        num_classes = 2  # 1 class (person) + background
       
        
        
        # load an instance segmentation model pre-trained on COCO
        self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

        # get number of input features for the classifier
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        # now get the number of input features for the mask classifier
        in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        # and replace the mask predictor with a new one
        self.model.roi_heads.mask_predictor = MaskRCNNPredictor(
            in_features_mask,
            hidden_layer,
            num_classes
        )
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def forward(self, x, target=None):
        return self.model(x, target)