In [None]:
import numpy as np

### fix old numpy code in LVISEval (np.float is deprecated)
if not hasattr(np, "float"):
    np.float = float
import torch
import torchvision
import requests
import sys
import os
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from typing import List, Tuple, Literal

In [None]:
### add this if on colab:
# !pip install lvis 
### remove if on colab:
from constants import *
import lvis

## Dataset

In [None]:
from torch.utils.data import Dataset
from torchvision.io import decode_image
from torchvision import tv_tensors
from torchvision.ops import box_convert
from PIL import Image
from io import BytesIO


### with help of https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html


class LVISDataset(Dataset):
    """ 
    Dataset class to load data from the disk or URLs.
    Transform annotations into the format required by pytorch.

    img_dirs : List of paths where images are stored.
    lvis_gt : LVIS object containing annotations (ground truth).
    transforms (optional) : torchvision.transforms composition.
    cat_ids (optional) : category ids to filter.
    """
    def __init__(self, img_dirs: List[str], lvis_gt: lvis.LVIS, transforms=None, cat_ids=None) -> None:
        self.img_dirs = img_dirs
        self.lvis_gt = lvis_gt
        self.transforms = transforms
        self._create_index(cat_ids)

    def _create_index(self, cat_ids):
        """
        Maps the indices of the dataset with the LVIS indices.
        """
        # Filter images based on the provided ids or get all images.
        self.img_ids = self.lvis_gt.get_img_ids() if cat_ids is None else self._get_img_ids(cat_ids)
        self.cat_ids = self.lvis_gt.get_cat_ids() if cat_ids is None else cat_ids

        # Mapping id to label in both ways.
        # i+1 because id = 0 is for background.
        self.cat_id_to_label = {cat_id: i + 1 for i, cat_id in enumerate(self.cat_ids)}
        self.label_to_cat_id = {i + 1: cat_id for i, cat_id in enumerate(self.cat_ids)}

        # Loads from local files or from URL.
        if all(isinstance(dir, str) for dir in self.img_dirs) and all(os.path.isdir(dir) for dir in self.img_dirs):
            self._get_image = self._get_image_from_file
            print("will load images from files")
        else:
            self._get_image = self._get_image_from_url
            print("will load images from urls")

    def _get_img_ids(self, cat_ids):
        """
        Filter images based on the provided ids or get all images.
        """
        return list({
            iid for cat_id in cat_ids
            for iid in self.lvis_gt.cat_img_map[cat_id]
        })

    def _get_image_from_file(self, id):
        """
        Loads an image from the disk and returns a torch.Tensor
        """
        image_paths = [os.path.join(images_dir, f'{str(id).zfill(12)}.jpg') for images_dir in self.img_dirs]
        for image_path in image_paths:
            if os.path.isfile(image_path):
                return decode_image(image_path)
        print(f"image not found: {image_paths}")
        sys.exit(1)

    def _get_image_from_url(self, id):
        """
        Dowloads an image from the COCO website
        """
        url = self.lvis_gt.imgs[id]['coco_url']
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        return img

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx: int):
        """
        Returns the image (as a Tensor) and targets (as a Dict) for the specified index idx.
        """
        ### get image
        img_id = self.img_ids[idx]
        img = self._get_image(img_id)
        img = tv_tensors.Image(img)
        _, h, w = img.shape

        ### get annotationss
        annot_ids = self.lvis_gt.get_ann_ids(img_ids=[img_id])
        annots = self.lvis_gt.load_anns(annot_ids)
        annots = [annot for annot in annots if annot['category_id'] in self.cat_ids]
        # labels
        labels = torch.tensor([self.cat_id_to_label[annot['category_id']] for annot in annots])
        # area
        areas = torch.tensor([annot['area'] for annot in annots])
        # boxes
        boxes = torch.tensor([annot['bbox'] for annot in annots], dtype=torch.float32)
        boxes_xyxy = box_convert(boxes, in_fmt='xywh', out_fmt='xyxy')
        boxe_tv = tv_tensors.BoundingBoxes(boxes_xyxy, format='XYXY', canvas_size=(h, w))  # type: ignore
        # masks
        masks = [torch.from_numpy(self.lvis_gt.ann_to_mask(ann)) for ann in annots]  # shape: (N, H, W)
        mask_tv = tv_tensors.Mask(torch.stack(masks))

        target = {}
        target['image_id'] = img_id
        target['labels'] = labels
        target['area'] = areas
        target['boxes'] = boxe_tv
        target['masks'] = mask_tv

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target

## Model

    Args maskRCNNN_resnet50(:
        backbone (nn.Module): the network used to compute the features for the model.
            It should contain an out_channels attribute, which indicates the number of output
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or and OrderedDict[Tensor].
        num_classes (int): number of output classes of the model (including the background).
            If box_predictor is specified, num_classes should be None.
        min_size (int): Images are rescaled before feeding them to the backbone:
            we attempt to preserve the aspect ratio and scale the shorter edge
            to ``min_size``. If the resulting longer edge exceeds ``max_size``,
            then downscale so that the longer edge does not exceed ``max_size``.
            This may result in the shorter edge beeing lower than ``min_size``.
        max_size (int): See ``min_size``.
        image_mean (Tuple[float, float, float]): mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
        image_std (Tuple[float, float, float]): std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
        rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
            the locations indicated by the bounding boxes
        box_head (nn.Module): module that takes the cropped feature maps as input
        box_predictor (nn.Module): module that takes the output of box_head and returns the
            classification logits and box regression deltas.
        box_score_thresh (float): during inference, only return proposals with a classification score
            greater than box_score_thresh
        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
        box_detections_per_img (int): maximum number of detections per image, for all classes.
        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
            considered as positive during training of the classification head
        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
            considered as negative during training of the classification head
        box_batch_size_per_image (int): number of proposals that are sampled during training of the
            classification head
        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
            of the classification head
        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
            bounding boxes
        mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
             the locations indicated by the bounding boxes, which will be used for the mask head.
        mask_head (nn.Module): module that takes the cropped feature maps as input
        mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
            segmentation mask logits

In [None]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


### with help of https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html

#TODO : DANS LE modèle rajouter comme variable box_score_thresh
def get_model_instance_segmentation(num_classes, box_score_thresh=0.05, min_size=800, max_size=1333):
    # load pretrained maskrcnn
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT", 
                                                               box_score_thresh=box_score_thresh, 
                                                               min_size=min_size,
                                                               max_size= max_size)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features  # type: ignore
    # replace the pre-trained head
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # get number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels  # type: ignore
    hidden_layer = 256
    # replace the mask predictor
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model


#TODO : enlever ça je pense c'est impossible à utilise at au vu de nos résultats pas besoin
class MaskRCNNWrapper(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = get_model_instance_segmentation(num_classes)

    def forward(self, images, targets=None):
        return self.model(images, targets)

    def freeze(self):
        """
        Freezes backbone layer learning.
        To prevent unlearning learned features.
        """
        for param in self.model.backbone.parameters():
            param.requires_grad = False

    def unfreeze(self):
        """
        Unfreezes backbone layer learning.
        """
        for param in self.model.backbone.parameters():
            param.requires_grad = True


## Utils

#### Data

In [None]:
from torchvision.transforms import v2 as T


def get_transform(training=True):
    """
    Returns a composition of all desired torchvision.transforms to apply during data preparation.
    training (optional, True as default) : differentiates the transforms to apply to the data
    meant for training and those meant for validation phase.
    """
    transforms = []
    if training:
        transforms.append(T.RandomHorizontalFlip())
    transforms.append(T.ToDtype(torch.float32, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)


def custom_collate_fn(batch):
    """
    Custom collate function to prevent stacking of images with different shapes
    """
    return tuple(zip(*batch))


def get_filtered_cat_ids(cats, names):
    """
    Get a list of ids of given list of categories
    """
    cat_ids = []
    cat_names = []
    for id, cat in cats.items():
        if cat['name'] in names:
            cat_ids.append(id)
            cat_names.append(cat['name'])
    print(f'category found for {[name for name in names if name in cat_names]}')
    print(f'category NOT found for {[name for name in names if name not in cat_names]}\n')
    return cat_ids

#### IoU and Dice

Inspired by https://docs.pytorch.org/vision/main/_modules/torchvision/ops/boxes.html#box_iou

In [None]:
def compute_iou_matrix(p, t):
    """
    Returns a Tensor representing the correlation between all predictions and all targets.
    """
    # pred masks : torch.Size([79, 1, 453, 640])
    pred_masks = p['masks'].squeeze(1)
    target_masks = t['masks']

    if pred_masks.size(0) == 0 or target_masks.size(0) == 0:
        print(f"compute_iou_matrix : pas de predictions ou pas de targets")
        return None

    # We need to binarize to perform scalar product
    pred_masks_bin = (pred_masks > 0.5).float()
    target_masks = target_masks.float()

    # We reshape preds and targets into matrices to perform scalar product
    N = pred_masks_bin.size(0)
    M = target_masks.size(0)

    # (N,H,W) -> (N,H*W)
    pred_matrix = pred_masks_bin.reshape(N, -1)
    # (M,H,W) -> (M, H*W)
    target_matrix = target_masks.reshape(M, -1)

    # Since binary, the intersection is the scalar product between the two matrices 
    # We take the transpose of target_matrix. (N,H*W)*(H*W,M) = (N,M)
    intersection = torch.mm(pred_matrix, target_matrix.t())

    # we count the number of pixels (area) for each predicted instances and targets.
    pred_area = pred_matrix.sum(axis=1, keepdims=True)  # (N,1)
    target_area = target_matrix.sum(axis=1, keepdims=True)  # (M,1)
    # 
    broadcast_matrix = pred_area + target_area.t()

    # Union =  A + B - intersection
    union = broadcast_matrix - intersection

    iou_matrix = intersection / (union + 1e-6)

    # We sort the matrix by lines (predictions) based on 'score' 
    # so the calculations on tp,fp,fn are more pertinent
    sorted_indices = torch.argsort(p['scores'], descending=True)
    iou_matrix_sorted = iou_matrix[sorted_indices]

    return iou_matrix_sorted


def count_tp_fp_fn(p,t, iou_threshold=0.5):
    """
    Calculate from the number of true positives, false positives and false negatives.
    Based on the calculation of a iou matrix see compute_iou_matrix(p,t) function
    """
    iou_matrix = compute_iou_matrix(p,t)
    N, M = iou_matrix.shape
    if iou_matrix is None:
        return 0, N, M  # if None no tp

    # We keep the already detected matches, for us if a target is detected
    # twice (two separate prediction masks) then the second detection is a false positive.
    detected_matches = [] 
    # True positives and false positves count
    tp = 0
    fp = 0

    # We search for each predicted mask if it corresponds to a target
    for id_predMask in range(N):
        # if no targets then only false positives
        if M == 0:
            fp += 1
            continue

        # Using the iou_matrix we find the best possible candidate to be the target for our predicted mask
        best_target_candidate = torch.argmax(iou_matrix[id_predMask]).item()
        best_iou = iou_matrix[id_predMask][best_target_candidate]
        # If the iou between our predicted mask and its best corresponding target is higher
        # than a threshold we consider it as a true positive.
        if best_iou > iou_threshold:
            if best_target_candidate not in detected_matches:
                tp += 1
                detected_matches.append(best_target_candidate)
            else:
                # already detected once
                fp += 1
        else:
            fp += 1
    # False negatives count
    fn = M - tp
    return tp, fp, fn

#### Plot

In [None]:
def add_img_and_mask(ax, mask, name, color):
    ys, xs = np.nonzero(mask)
    y, x = ys.mean(), xs.mean()
    overlay = np.zeros((*mask.shape, 4))
    overlay[..., :3] = color[:3]  # apply colors except alphas
    overlay[..., 3] = mask * 0.5  # alpha is =0.5 where mask exists
    ax.imshow(overlay)
    ax.text(x, y, name, color="white", ha="center", va="center")


def plot_images_with_anns(img, target, pred, label_to_name=None, score_thresh=0.3):
    img_np = img.permute(1, 2, 0).cpu().numpy()
    pmasks = pred['masks'].cpu().numpy()[:, 0, :, :]
    pscores = pred["scores"].cpu().numpy()
    pmasks = (pmasks > 0.5).astype(np.uint8)  # to binary mask
    plabels = pred['labels'].cpu().numpy()
    tmasks = target['masks'].cpu().numpy()
    tlabels = target['labels'].cpu().numpy()

    _, axes = plt.subplots(1, 2, figsize=(8, 8))
    ax_l, ax_r = axes
    ax_l.imshow(img_np.astype(np.float32))
    ax_r.imshow(img_np.astype(np.float32))
    for tlabel, tmask in zip(tlabels, tmasks):
        if not np.any(tmask):
            continue  # skip empty mask
        label_name = label_to_name[tlabel] if label_to_name is not None and tlabel in label_to_name else str(tlabel)
        color = plt .get_cmap("tab20")(tlabel % 20)
        add_img_and_mask(ax_l, tmask, label_name, color)
    for plabel, pmask, pscore in zip(plabels, pmasks, pscores):
        if not np.any(pmask) or pscore < score_thresh:
            continue  # skip empty mask or low score
        label_name = label_to_name[plabel] if label_to_name is not None and plabel in label_to_name else str(plabel)
        color = plt.get_cmap("tab20")(plabel % 20)
        add_img_and_mask(ax_r, pmask, label_name, color)
    ax_l.set_axis_off()
    ax_r.set_axis_off()
    plt.tight_layout()
    plt.show()


def plot_losses(train_losses, val_losses,best_model_epoch=None,early_stop=None, filepath=None):
    _, ax = plt.subplots(figsize=((12, 6)))
    epochs = np.arange(1, len(train_losses) + 1, 1)
    ax.plot(epochs, train_losses, 'r', label='Training Loss')
    ax.plot(epochs, val_losses, 'g', label='Validation Loss')
    if early_stop is not None:
        plt.scatter(epochs[early_stop], val_losses[early_stop], marker='x', c='g', label='start of overfitting')
    if best_model_epoch is not None:
        plt.scatter(epochs[best_model_epoch], val_losses[best_model_epoch], marker='x', c='b', label='Saved Model Epoch')
    ax.set_title('Loss Plots')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Loss')
    ax.legend(loc="upper right")
    if filepath is not None:
        plt.savefig(filepath)
    plt.show()
        
def plot_iou(val_iou, best_model_epoch=None,early_stop=None, filepath=None):
    _, ax = plt.subplots(figsize=((12, 6)))
    epochs = np.arange(1, len(val_iou) + 1, 1)
    ax.plot(epochs, val_iou, 'r', label='val iou')

    if early_stop is not None:
        plt.scatter(epochs[early_stop], val_iou[early_stop], marker='x', c='g', label='Saved Model Epoch')
    if best_model_epoch is not None:
        plt.scatter(epochs[best_model_epoch], val_iou[best_model_epoch], marker='x', c='b', label='Saved Model Epoch')
    ax.set_title('iou Plot')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('iou')
    ax.legend(loc="upper right")
    if filepath is not None:
        plt.savefig(filepath)
    plt.show()

#### Train

pred = dict_keys(['boxes', 'labels', 'scores', 'masks'])
pred masks torch.Size([79, 1, 453, 640])

target : {'image_id': 38438, 'labels': tensor([1]), 'area': tensor([51464.3008]), 'boxes': tensor([[ 16.2200,  16.1600, 316.9400, 345.2500]]), 'masks': tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
...
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8)}

In [None]:
import math


def train_one_epoch(model, optimizer, data_loader, epoch, device, scaler=None,
                    warmup=True,
                    print_freq: None | int = 10):
    """
    Executes a training loop for a single epoch.
    """
    model.train()

    lr_warmup = None
    if epoch==0 and warmup:
        warmup_iters = min(1000, len(data_loader) - 1)
        warmup_factor = 1.0 / warmup_iters
        lr_warmup = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    total_loss = 0
    for i, (images, targets) in enumerate(tqdm(data_loader, desc="TRAIN EPOCH (/batches)")):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]

        # We use automatic mixed precision for better performence.
        with torch.amp.autocast(device_type=device.type, enabled=scaler is not None):
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()  # type: ignore
        total_loss += loss_value

        optimizer.zero_grad()

        # Since we use amp we need a scaler to work in float16 without risking losing information.
        if scaler is not None:
            scaler.scale(losses).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            losses.backward()  # type: ignore
            optimizer.step()

        if print_freq is not None and i % print_freq == 0:
            tqdm.write(f"[batch {i + 1}/{len(data_loader)}] loss: {loss_value}")
        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training \nLoss dict:\n{loss_dict}")
            sys.exit(1)
        if lr_warmup is not None:
            lr_warmup.step()

    return total_loss / len(data_loader)


@torch.no_grad()
def evaluate(model, data_loader, device, iou_thresh=0.5):
    """
    Evaluates the model.
    Contains two loops over the data. First to calculate the loss and second to calculate the other metrics such as IoU.
    """
    model.train()   # in train to output validation loss
    total_loss = 0
    for images, targets in tqdm(data_loader, desc="VALIDATION loss (/batches)"):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
        # get loss
        loss_dict = model(images, targets)
        total_loss += sum(loss.item() for loss in loss_dict.values())

    model.eval()    # in eval to output predictions
    # true positives, false positives, false negatives
    total_tp, total_fp, total_fn = 0, 0, 0
    for images, targets in tqdm(data_loader, desc="VALIDATION preds (/batches)"):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
        preds = model(images)
        # Count the tp, fp, fn for each image masks predicted based on associated target
        for p, t in zip(preds, targets):
            tp, fp, fn = count_tp_fp_fn(p, t, iou_thresh)
            total_tp += tp
            total_fp += fp
            total_fn += fn

    eps = 1e-6  # epsilon to avoid division with zero
    iou = total_tp / (total_tp + total_fp + total_fn + eps)
    precision = total_tp / (total_tp + total_fp + eps)
    recall = total_tp / (total_tp + total_fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    val_loss = total_loss / len(data_loader)
    return val_loss, iou, precision, recall, f1


def train(
        model, optimizer, lr_scheduler,
        train_loader, val_loader, lvis_gt_val,
        epochs, patience, warmup,
        save_model_path, device, scaler=None,
        print_freq: None | int = 10
):
    """
    Main training loop.
    """
    train_losses = []
    val_losses = []
    val_iou = []

    # Tracking
    early_stoping_epoch = 0
    best_model_epoch = 0
    epochs_no_improvement = 0
    best_iou = 0

    # Example images to display during training
    images_vis, targets_vis = next(iter(val_loader))
    images_vis = list(image.to(device) for image in images_vis)
    targets_vis = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets_vis]
    label_to_name = {k: lvis_gt_val.load_cats([v])[0]['name'] for k, v in
                                 val_loader.dataset.dataset.label_to_cat_id.items()}

    for epoch in tqdm(range(epochs), desc='TRAINING LOOP (/epochs)'):

        loss_train = train_one_epoch(model, optimizer, train_loader, epoch, device, scaler, warmup, print_freq)
        train_losses.append(loss_train)

        val_loss, iou, precision, recall, f1 = evaluate(model, val_loader, device)
        val_losses.append(val_loss)
        val_iou.append(iou)
        tqdm.write(f"[epoch {str(epoch + 1).zfill(2)}/{str(epochs).zfill(2)}]: train loss = {loss_train:.4f} | val loss = {val_loss:.4f}")
        tqdm.write(f"               iou = {iou:.4f} | precision = {precision:.4f} | recall = {recall:.4f} | f1 = {f1:.4f}")

        #To display an example
        #TODO : Change %1 (test)
        if (epoch + 1) % 1 == 0:
            with torch.no_grad():
                model.eval()
                preds_vis = model(images_vis)  
            plot_images_with_anns(images_vis[0], targets_vis[0], preds_vis[0], label_to_name, 0.5)
        
        if iou >= best_iou: 
            best_iou = iou
            best_model_epoch = epoch
            epochs_no_improvement = 0
            torch.save(model.state_dict(), save_model_path)
        else:
            epochs_no_improvement += 1
            tqdm.write(f"NO improvement [{epochs_no_improvement}/{patience}]")
            if epochs_no_improvement >= patience:
                print("Patience reached, stopping training")
                break
        if lr_scheduler is not None:
            lr_scheduler.step()
    return train_losses, val_losses,val_iou, best_model_epoch, early_stoping_epoch

#### Evaluation

In [None]:
from lvis import LVISResults, LVISEval
import pycocotools.mask as maskUtils


# TODO masks do not seem to be in the right format
@torch.inference_mode()
def get_predictions(model, data_loader, device, score_thresh=0.05):
    model.eval()
    predictions = []
    for images, targets in tqdm(data_loader, desc="BATCHES"):
        images = list(image.to(device) for image in images)
        outputs = model(images)
        for o, t in zip(outputs, targets):
            image_id = t["image_id"]
            boxes = o["boxes"].cpu().numpy()
            scores = o["scores"].cpu().numpy()
            labels = o["labels"].cpu().numpy()
            masks = o["masks"].cpu().numpy()[:, 0, :, :]
            for box, score, label, mask in zip(boxes, scores, labels, masks):
                if score < score_thresh:
                    continue
                # Convert mask to binary on the 0.5 threshold
                mask_bin = (mask > 0.5).astype(np.uint8)
                # LVIS requires Run Lenght Encoding for masks
                rle = maskUtils.encode(np.asfortranarray(mask_bin))
                rle['counts'] = rle['counts'].decode('utf-8')
                # LVIS needs [x,y,width,height] format for boxes
                x1, y1, x2, y2 = box
                predictions.append({
                    "image_id": image_id,
                    "category_id": data_loader.dataset.dataset.label_to_cat_id[int(label)],
                    # dict inside dataset in subset in loader
                    "bbox": [float(x1), float(y1), float(x2 - x1), float(y2 - y1)],
                    "segmentation": rle,
                    "score": float(score)
                })
    return predictions


@torch.inference_mode()
def run_lvis_eval(model, data_loader, lvis_gt, cat_ids, device, iou_type: Literal["bbox", "segm"] = "segm"):
    score_thresh = 0.0001   # very low threshold for lviseval
    predictions = get_predictions(model, data_loader, device, score_thresh)
    if len(predictions) == 0:
        print("No detections — skipping LVIS evaluation.")
    else:
        lvis_dt = LVISResults(lvis_gt, predictions)
        lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type)
        lvis_eval.params.cat_ids = cat_ids
        lvis_eval.run()
        lvis_eval.print_results()

## Pipeline

#### Arguments

In [None]:
### Device
NO_AMP: bool = False
PIN_MEMORY: bool = True
NUM_WORKERS: int = 4  # int (0: main process)

### Data
MAX_IMAGES: int = 10000  # 80/20 split (train: 0.8*max_img | val: 0.2*max_img)
CATEGORIES: List[str] = ['cat', 'dog', 'cow', 'bull', 'pigeon', 'giraffe', 'bear', 'elephant', 'rabbit', 'horse']

### Model
MIN_IMG_SIZE: int = 400  # 800 default
MAX_IMG_SIZE: int = 800 # 1333 default
BOX_SCORE_THRESH: float = 0.1   # 0.05 default

### Learning
LR: float = 1e-3
MOMENTUM: float = 0.9
EPOCHS: int = 25
BATCH_SIZE: int = 10
PATIENCE: int = 5
WARMUP: bool = True      # whether the learning rate should start small during the first epoch

### Others
BATCH_PRINT_FREQ: int|None = None  # None: no print inside epoch
OUTPUT_DIR: str = 'output' # '/kaggle/working/'

#### Data Preparation

In [None]:
### load annotations
from lvis import LVIS

lvis_gt_train = LVIS(TRAIN_ANNOT_PATH)
lvis_gt_val = LVIS(VAL_ANNOT_PATH)

In [None]:
### create datasets/dataloaders
from torch.utils.data import DataLoader, Subset

cat_ids = get_filtered_cat_ids(lvis_gt_train.cats, CATEGORIES)
num_classes = len(cat_ids) + 1  # +1 for background

# train
dataset_train = LVISDataset([COCO2017_TRAIN_PATH], lvis_gt_train, get_transform(training=True), cat_ids)
subset_train = Subset(dataset_train, (torch.randperm(len(dataset_train))[:int(MAX_IMAGES * 0.8)]).tolist())
train_loader = DataLoader(subset_train,
                          batch_size=BATCH_SIZE,
                          collate_fn=custom_collate_fn,
                          shuffle=True,
                          num_workers=NUM_WORKERS,
                          pin_memory=PIN_MEMORY)
print(f"Size of train dataset: {len(dataset_train)}")
print(f"Size of train subset: {len(subset_train)}\n")

# val
dataset_val = LVISDataset([COCO2017_VAL_PATH, COCO2017_TRAIN_PATH], lvis_gt_val,
                          get_transform(training=False), cat_ids)
subset_val = Subset(dataset_val, (torch.randperm(len(dataset_val))[:int(MAX_IMAGES * 0.2)]).tolist())
val_loader = DataLoader(subset_val,
                        batch_size=BATCH_SIZE,
                        collate_fn=custom_collate_fn,
                        num_workers=NUM_WORKERS,
                        pin_memory=PIN_MEMORY)
print(f"Size of validation dataset: {len(dataset_val)}")
print(f"Size of validation subset: {len(subset_val)}")

# for faster evaluation (when instancing LVISResults):
lvis_gt_val.cats = {k: v for k, v in lvis_gt_val.cats.items() if k in cat_ids}
lvis_gt_train.cats = {k: v for k, v in lvis_gt_train.cats.items() if k in cat_ids}

#### Training

In [None]:
### Initialization
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' and not NO_AMP else None

model = get_model_instance_segmentation(num_classes, box_score_thresh=BOX_SCORE_THRESH, min_size=MIN_IMG_SIZE, max_size=MAX_IMG_SIZE)
model.to(device)
print(f'Device used: {device.type}')

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=0.0005
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=1,
    gamma=0.8
)

In [None]:
### Training Loop
train_losses, val_losses,val_iou, best_epoch, early_stop_epoch = train(model,
                                             optimizer, lr_scheduler,
                                             train_loader, val_loader, lvis_gt_val,
                                             EPOCHS, PATIENCE, WARMUP, 
                                             os.path.join(OUTPUT_DIR, "best_model.pt"),
                                             device, scaler, BATCH_PRINT_FREQ)
plot_losses(train_losses, val_losses, best_model_epoch=best_epoch,early_stop=early_stop_epoch, filepath=os.path.join(OUTPUT_DIR,"losses.jpg"))
plot_iou(val_iou, early_stop=best_epoch)

#### Evaluation & Visualisation

In [None]:
### load best model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
best_model = get_model_instance_segmentation(num_classes)
best_model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, "best_model.pt"), map_location=device))
best_model.to(device)
print('')

In [None]:
### run full LVIS evaluation
run_lvis_eval(best_model, val_loader, lvis_gt_val, cat_ids, device, iou_type="segm")

In [None]:
### evaluation metrics
iou_threshs = [0.5, 0.75]
for iou_thresh in iou_threshs:
    val_loss, iou, precision, recall, f1 = evaluate(model, val_loader, device, iou_thresh)
    print(f'For IoU = {iou_thresh}')
    print(f'    Val loss  = {val_loss:.4f}')
    print(f'    IoU       = {iou:.4f}')
    print(f'    precision = {precision:.4f}')
    print(f'    recall    = {recall:.4f}')
    print(f'    f1        = {f1:.4f}')
    print()

In [None]:
### visualization
IMAGES_TO_SHOW = 5
SCORE_THRESH = 0.3  # higher score thresh for visualisation

label_to_name = {k: lvis_gt_val.load_cats([v])[0]['name'] for k, v in
                 val_loader.dataset.dataset.label_to_cat_id.items()}

images, targets = next(iter(val_loader))
images = [img.to(device) for img in images]
best_model.eval()
with torch.no_grad():
    preds = best_model(images)
for i in range(min(IMAGES_TO_SHOW, len(images))):
    plot_images_with_anns(images[i], targets[i], preds[i], label_to_name, SCORE_THRESH)
