# Deep Learning Project

In [None]:
%pip install ftfy regex tqdm ultralytics wandb albumentations
%pip install git+https://github.com/openai/CLIP.git

In [None]:
import gc
import gzip
import html
import json
import math
import os
import pickle
import random
import re
from dataclasses import dataclass
from enum import Enum
from pprint import pprint
from typing import (Any, Callable, Dict, List, Optional, OrderedDict, Tuple,
                    Iterable, Iterable, Union)

import albumentations as A
import clip
import cv2
import ftfy
import gdown
import matplotlib.pyplot as plt
import numpy as np
import pkg_resources as p
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from albumentations.pytorch import ToTensorV2
from clip import tokenize
from clip.model import CLIP, ModifiedResNet
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from PIL import Image
from torch import Tensor, device, tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.ops import box_convert, box_iou, generalized_box_iou_loss
from tqdm.notebook import tqdm

import wandb

In [None]:
%cd /content/
# Download dataset and save under data/raw/ only if not already downloaded
url = "https://drive.google.com/uc?id=1xijq32XfEm6FPhUb7RsZYWHc2UuwVkiq"
if not os.path.exists("data/raw/refcocog.tar.gz"):
    print("Downloading dataset...")
    gdown.download(url=url, output="data/raw/", quiet=False, resume=True)
if not os.path.exists("data/raw/refcocog/"):
    print("Extracting dataset...")
    !tar -xf data/raw/refcocog.tar.gz -C data/raw

The preprocessed samples can be downloaded from Google Drive by executing the following cell. Otherwise, the preprocessing wil be done saving the file only temporarly in the Colab environment. 

In [None]:
# Download preprocessed dataset
url = "https://drive.google.com/drive/folders/1jaJV40dneOckZn7WHMQyd2jBh7A8534N"
gdown.download_folder(url=url, output="data/", quiet=False)

In [None]:
!wget https://raw.githubusercontent.com/ManuelaCorte/DLProject/develop/config.json
!wget https://raw.githubusercontent.com/ManuelaCorte/DLProject/develop/sweep_config.json

In [None]:
%mkdir src
%cd src

In [None]:
class Sample:
    def __init__(self, image_path: str, caption: str, bounding_box: Tensor) -> None:
        self.image_path = image_path
        self.caption = caption
        self.bounding_box = bounding_box

    def as_dict(self) -> dict[str, Any]:
        return {
            "image_path": self.image_path,
            "caption": self.caption,
            "bounding_box": self.bounding_box.tolist(),
        }

    @staticmethod
    def fromJSON(json: dict[str, Any]) -> Any:
        return Sample(json["image_path"], json["caption"], Tensor(json["bounding_box"]))


class BatchSample:
    def __init__(self, image: Tensor, caption: Tensor) -> None:
        self.image: Tensor = image
        self.caption: Tensor = caption

    def to(self, device: device | str) -> Any:
        return self.__class__(self.image.to(device), self.caption.to(device))

    def __str__(self) -> str:
        return f"BatchSample(image={self.image.shape}, caption={self.caption.shape})"


@dataclass(frozen=True)
class Split(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"


@dataclass(frozen=True)
class Result:
    bounding_box: Tensor
    score: Tensor


# XYXY: top left and bottom right corners
# XYWH: top left corner, width and height
# CXCWH: center coordinates, width and height
@dataclass(frozen=True)
class BboxType(Enum):
    XYXY = "xyxy"
    XYWH = "xywh"
    CXCWH = "cxcwh"

    def __str__(self) -> str:
        return super().__str__()


@dataclass
class Model:
    clip_embed_dim: int
    clip_ctx_length: int
    embed_dim: int
    mlp_hidden_dim: int
    img_size: int
    proj_img_size: int
    decoder_layers: int
    decoder_heads: int
    decoder_dim_feedforward: int


@dataclass
class Train:
    batch_size: int
    lr: float
    l1: float
    l2: float
    sweep: bool


@dataclass
class Logging:
    path: str
    save: bool
    resume: bool
    wandb: bool


class Config:
    def __init__(self) -> None:
        cfg: Dict[str, Any] = json.load(open("../config.json", "r"))
        self.dataset_path: str = cfg["dataset_path"]
        self.epochs: int = cfg["epochs"]
        self.model = Model(**cfg["model"])
        self.train = Train(**cfg["train"])
        self.logging = Logging(**cfg["logging"])

    def as_dict(self) -> Dict[str, Any]:
        return {
            "dataset_path": self.dataset_path,
            "epochs": self.epochs,
            "model": self.model.__dict__,
            "train": self.train.__dict__,
        }

    # if in other dict there are keys equal to the keys in self, update them
    def update(self, other: Dict[str, Any]):
        for k, v in other.items():
            if k in self.__dict__:
                self.__dict__[k] = v
            if k in self.model.__dict__:
                self.model.__dict__[k] = v
            if k in self.train.__dict__:
                self.train.__dict__[k] = v
            if k in self.logging.__dict__:
                self.logging.__dict__[k] = v

In [None]:
@dataclass(frozen=True)
class Metric(Enum):
    LOSS = "loss"
    ACCURACY_50 = "accuracy"  # IoU > 0.5 -> 1 else 0
    ACCURACY_75 = "accuracy75"  # IoU > 0.75 -> 1 else 0
    ACCURACY_90 = "accuracy90"  # IoU > 0.9 -> 1 else 0
    IOU = "iou"
    COSINE_SIMILARITY = "cosine_similarity"


@dataclass(frozen=True)
class Reduction(Enum):
    MEAN = "mean"
    SUM = "sum"
    NONE = "none"


class MetricsLogger:
    def __init__(self, metrics: Dict[str, List[float]] | None = None) -> None:
        self.metrics: Dict[str, List[float]] = {}
        if metrics is None:
            for metric in Metric:
                self.metrics[metric.value] = []
        else:
            self.metrics = metrics

    def update_metric(self, metrics: Dict[str, float]) -> None:
        for metric, value in metrics.items():
            self.metrics[metric].append(value)

    def get_metric(
        self, metric: Metric, red: Reduction = Reduction.NONE
    ) -> float | List[float]:
        values: List[float] = self.metrics[metric.value]
        match red.name:
            case Reduction.MEAN.name:
                return sum(values) / len(values)
            case Reduction.SUM.name:
                return sum(values)
            case Reduction.NONE.name:
                return values
            case _:
                raise ValueError(f"Reduction {red.name} doesn't exists")

    def __str__(self) -> str:
        res = "Metrics:\n"
        for metric, values in self.metrics.items():
            res += f"{metric}: {sum(values) / len(values)}\n"
        return res

In [None]:
# The Dataset contains samples with an image with a bounding box and a caption associated with the bounding box.


class VGDataset(Dataset[Tuple[BatchSample, Tensor]]):
    def __init__(
        self,
        dir_path: str,
        split: Split,
        output_bbox_type: BboxType,
        augment: bool,
        transform: bool = True,
        preprocessed: bool = False,
        preprocessed_path: str = "../data/processed/",
    ) -> None:
        super().__init__()
        self.dir_path: str = dir_path
        self.split: Split = split
        self.output_bbox_type: BboxType = output_bbox_type
        self.augment: bool = augment
        self.transform: bool = transform
        self.device: device = torch.device(
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
        if preprocessed:
            preprocess(dir_path, preprocessed_path, output_bbox_type)
            with open(
                preprocessed_path + f"{self.split.value}_samples.json", "rb"
            ) as samples:
                self.samples: List[Sample] = json.load(
                    samples, object_hook=Sample.fromJSON
                )
        else:
            self.samples: List[Sample] = self.get_samples()  # type: ignore

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, ref_id: int) -> Tuple[BatchSample, Tensor]:
        # extended_caption: str = f"find the region that corresponds to the description {self.samples[ref_id].caption}"
        caption: Tensor = tokenize(self.samples[ref_id].caption, truncate=True)  # type: ignore
        if self.transform:
            image, bbox = transform_sample(
                Image.open(self.samples[ref_id].image_path),
                self.samples[ref_id].bounding_box,
                self.augment,
            )
        else:
            image = read_image(self.samples[ref_id].image_path)
            bbox = torch.tensor([self.samples[ref_id].bounding_box])
        return BatchSample(image, caption), bbox

    def get_samples(self) -> List[Sample]:
        with open(self.dir_path + "annotations/instances.json", "r") as inst, open(
            self.dir_path + "annotations/refs(umd).p", "rb"
        ) as refs:
            instances = json.load(inst)
            references = pickle.load(refs)
        samples: List[Sample] = []
        for ref in references:
            if self.split.value == ref["split"]:
                image_path = self.get_image_path(ref["image_id"], instances)
                caption = self.get_caption(ref["sentences"])
                bbox = self.get_bounding_box(ref["ann_id"], instances)
                samples.append(Sample(image_path, caption, bbox))
        return samples

    def get_image_path(self, img_id: int, instances: Dict[str, Any]) -> str:
        image_name = next(
            image["file_name"] for image in instances["images"] if image["id"] == img_id
        )
        path = self.dir_path + "images/" + image_name
        return path

    def get_caption(self, captions: List[Dict[str, Any]]) -> str:
        longest_caption = captions[0]
        for caption in captions:
            if len(caption["sent"]) > len(longest_caption["sent"]):
                longest_caption = caption
        return f"find the region that corresponds to the description {longest_caption['sent']}"

    # Bounding boxed converted to format compatible with yolo or torchvision
    def get_bounding_box(self, ann_id: int, instances: Dict[str, Any]) -> Tensor:
        bbox = next(
            ann["bbox"] for ann in instances["annotations"] if ann["id"] == ann_id
        )
        bounding_box: Tensor = tensor([])
        match self.output_bbox_type.name:
            case BboxType.XYXY.name:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYXY.value
                )
            case BboxType.XYWH.name:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYWH.value
                )
            case BboxType.CXCWH.name:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.CXCWH.value
                )

        return bounding_box

In [None]:
def get_samples(
    dir_path: str, bbox_type: BboxType
) -> Tuple[List[Sample], List[Sample], List[Sample]]:
    with open(dir_path + "annotations/instances.json", "r") as inst, open(
        dir_path + "annotations/refs(umd).p", "rb"
    ) as refs:
        instances = json.load(inst)
        references = pickle.load(refs)
    train_samples: List[Sample] = []
    val_samples: List[Sample] = []
    test_samples: List[Sample] = []
    for ref in tqdm(references, desc=f"Processing dataset"):
        image_path = get_image_path(dir_path, ref["image_id"], instances)
        caption = get_caption(ref["sentences"])
        bbox = get_bounding_box(ref["ann_id"], instances, bbox_type)
        split = ref["split"]
        # print(split)
        match split:
            case Split.TRAIN.value:
                train_samples.append(Sample(image_path, caption, bbox))
            case Split.VAL.value:
                val_samples.append(Sample(image_path, caption, bbox))
            case Split.TEST.value:
                test_samples.append(Sample(image_path, caption, bbox))
            case _:
                raise ValueError(f"Invalid split: {split}")
    return train_samples, val_samples, test_samples


def get_image_path(dir_path: str, img_id: int, instances: Dict[str, Any]) -> str:
    image_name = next(
        image["file_name"] for image in instances["images"] if image["id"] == img_id
    )
    path = dir_path + "images/" + image_name
    return path


def get_caption(captions: List[Dict[str, Any]]) -> str:
    longest_caption = captions[0]
    for caption in captions:
        if len(caption["sent"]) > len(longest_caption["sent"]):
            longest_caption = caption
    return longest_caption["sent"]


# Bounding boxed converted to format compatible with yolo or torchvision
def get_bounding_box(
    ann_id: int, instances: Dict[str, Any], bbox_type: BboxType
) -> Tensor:
    bbox = next(ann["bbox"] for ann in instances["annotations"] if ann["id"] == ann_id)
    bounding_box: Tensor = tensor([])
    match bbox_type.name:
        case BboxType.XYXY.name:
            bounding_box = box_convert(
                tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYXY.value
            )
        case BboxType.XYWH.name:
            bounding_box = box_convert(
                tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYWH.value
            )
        case BboxType.CXCWH.name:
            bounding_box = box_convert(
                tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.CXCWH.value
            )

    return bounding_box


# If the files already exist, don't preprocess again
def preprocess(in_path: str, out_path: str, bbox_type: BboxType) -> None:
    if (
        os.path.exists(f"{out_path}train_samples.json")
        and os.path.exists(f"{out_path}val_samples.json")
        and os.path.exists(f"{out_path}test_samples.json")
    ):
        return
    train_samples, val_samples, test_samples = get_samples(in_path, bbox_type)

    json.dump(
        train_samples,
        open(f"{out_path}train_samples.json", "w"),
        default=Sample.as_dict,
    )

    json.dump(
        val_samples,
        open(f"{out_path}val_samples.json", "w"),
        default=Sample.as_dict,
    )

    json.dump(
        test_samples,
        open(f"{out_path}test_samples.json", "w"),
        default=Sample.as_dict,
    )


if __name__ == "__main__":
    preprocess("../data/raw/refcocog/", "../data/processed/", BboxType.XYWH)

    train: List[Sample] = json.load(
        open("../data/processed/train_samples.json", "r"), object_hook=Sample.fromJSON
    )
    val: List[Sample] = json.load(
        open("../data/processed/val_samples.json", "r"), object_hook=Sample.fromJSON
    )
    test: List[Sample] = json.load(
        open("../data/processed/test_samples.json", "r"), object_hook=Sample.fromJSON
    )
    print(len(train), len(val), len(test))
    print(train[0].image_path, train[0].caption, train[0].bounding_box.shape)

In [None]:
def custom_collate(
    batch: List[Tuple[BatchSample, torch.Tensor]]
) -> Tuple[List[BatchSample], torch.Tensor]:
    bboxes: List[torch.Tensor] = []
    samples: List[BatchSample] = []
    for sample, bbox in batch:
        samples.append(BatchSample(sample.image, sample.caption))
        bboxes.append(bbox)
    return samples, torch.stack(bboxes)


# Transform image according to CLIP preprocess function
# Normalize bounding box coordinates to be independent of image size
def transform_sample(
    image: Image.Image,
    box: Tensor,
    augment: bool,
    target_size: int = 224,
) -> Tuple[Tensor, Tensor]:
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Same transformation as in the CLIP preprocess function
    if augment:
        trans = A.Compose(
            transforms=[
                A.Resize(target_size, target_size, interpolation=cv2.INTER_CUBIC, p=1),
                A.CenterCrop(
                    target_size,
                    target_size,
                    always_apply=True,
                ),
                A.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                    max_pixel_value=255.0,
                    always_apply=True,
                ),
                A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0),
                A.GaussianBlur(p=1),
                A.PixelDropout(dropout_prob=0.02),
                A.Rotate(limit=20),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format="coco", label_fields=[]),
        )
    else:
        trans = A.Compose(
            transforms=[
                A.Resize(target_size, target_size, interpolation=cv2.INTER_CUBIC, p=1),
                A.CenterCrop(
                    target_size,
                    target_size,
                    always_apply=True,
                ),
                A.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                    max_pixel_value=255.0,
                ),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format="coco", label_fields=[]),
        )

    transformed_sample: Dict[str, Any] = trans(
        image=np.array(image), bboxes=box.tolist()
    )

    bbox_tensor: Tensor = torch.tensor(transformed_sample["bboxes"][0]) / target_size
    # print(bbox_tensor)
    return transformed_sample["image"], bbox_tensor.to(torch.float32)


def init_torch(seed: int = 41) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)

In [None]:
# CLIP transformer encoder
class TextEncoder(nn.Module):
    def __init__(self, batch_size: int, clip_ctx_length, embed_dim) -> None:
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pretrained_model: CLIP = clip.load("RN50", device=self.device)[0]
        self.pretrained_model.float()

        # Freeze the backbone
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        self.pretrained_model.transformer.register_forward_hook(self.hook_fn())
        self.transformer_output: Tensor = torch.empty(
            (batch_size, clip_ctx_length, embed_dim),
            requires_grad=True,
            device=self.device,
        )

    # @torch.no_grad()
    def forward(self, tokenized_caption: Tensor) -> Tuple[Tensor, Tensor]:
        out: Tensor = self.pretrained_model.encode_text(tokenized_caption).to(
            self.device
        )
        # .unsqueeze(1)
        return (
            self.transformer_output,
            out,
        )

    def hook_fn(self) -> Callable[[nn.Module, Tensor, Tensor], None]:
        def hook(module: nn.Module, input: Tensor, output: Tensor) -> None:
            self.transformer_output = output.permute(1, 0, 2)  # L B D -> B L D

        return hook

In [None]:
# Class that gets output for all layes of the backbone
# CLIP backbone is a modified ResNet with an attention layer for global pooling
class VisualEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pretrained_model: ModifiedResNet = clip.load("RN50", device=self.device)[
            0
        ].visual  # type: ignore
        self.pretrained_model.float()
        assert isinstance(self.pretrained_model, ModifiedResNet)

        # Freeze the backbone
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        # Register hooks to get the output of all layers
        self.layers_outputs: OrderedDict[str, Tensor] = OrderedDict()
        self.pretrained_model.layer1.register_forward_hook(self.hook_fn("layer1"))  # type: ignore
        self.pretrained_model.layer2.register_forward_hook(self.hook_fn("layer2"))  # type: ignore
        self.pretrained_model.layer3.register_forward_hook(self.hook_fn("layer3"))  # type: ignore
        self.pretrained_model.layer4.register_forward_hook(self.hook_fn("layer4"))  # type: ignore

    # @torch.no_grad()
    def forward(self, batch: Tensor) -> OrderedDict[str, Tensor]:
        out: Tensor = self.pretrained_model(batch)
        return self.layers_outputs

    def hook_fn(self, layer: str) -> Callable[[nn.Module, Tensor, Tensor], None]:
        def hook(module: nn.Module, input: Tensor, output: Tensor) -> None:
            # print(f"Module: {[module for  module in module.modules()]}")
            self.layers_outputs[layer] = output

        return hook


In [None]:
class FusionModule(nn.Module):
    def __init__(self, emb_dim: int, clip_emb_dim: int, proj_img_size: int) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.text_projection = nn.Sequential(
            nn.Linear(in_features=clip_emb_dim, out_features=clip_emb_dim),
            nn.BatchNorm1d(clip_emb_dim, device=self.device),
        ).to(self.device)

        self.vis_l4_projection = _conv_layer(
            input_dim=clip_emb_dim * 2,
            output_dim=clip_emb_dim,
            kernel_size=3,
            padding=1,
            device=self.device,
        )[
            :2
        ]  # Remove ReLU
        self.norm_layer = nn.Sequential(
            nn.BatchNorm2d(
                clip_emb_dim,
                device=self.device,
            ),
            nn.ReLU(),
        ).to(self.device)
        self.vis_l3_projection = _conv_layer(
            input_dim=clip_emb_dim + clip_emb_dim,
            output_dim=emb_dim,
            kernel_size=3,
            padding=1,
            device=self.device,
        )
        self.vis_l2_projection = _conv_layer(
            input_dim=emb_dim + emb_dim,
            output_dim=emb_dim,
            kernel_size=3,
            padding=1,
            device=self.device,
        )
        self.aggregation = _conv_layer(
            input_dim=clip_emb_dim + emb_dim + emb_dim,
            output_dim=emb_dim,
            kernel_size=1,
            padding=0,
            device=self.device,
        )

        self.coord_conv = nn.Sequential(
            CoordConv(emb_dim + 2, emb_dim),
            _conv_layer(
                input_dim=emb_dim,
                output_dim=emb_dim,
                kernel_size=3,
                padding=1,
                device=self.device,
            ),
        )

    def forward(
        self, visual_features: OrderedDict[str, Tensor], text_features: Tensor
    ) -> Tensor:
        visual_l2_features, visual_l3_features, visual_l4_features = (
            visual_features["layer2"],
            visual_features["layer3"],
            visual_features["layer4"],
        )
        # Visual and text features projection
        text_features_proj: Tensor = (
            self.text_projection(text_features).unsqueeze(-1).unsqueeze(-1)
        )  # B 1024 1 1
        visual_l4_features_proj: Tensor = self.vis_l4_projection(
            visual_l4_features
        )  # B 1024 7 7

        # First fusion l4 (B 1024 7 7) and text (B 1024)
        fused_l4: Tensor = self.norm_layer(
            visual_l4_features_proj * text_features_proj
        )  # B 1024 7 7

        # Second fusion l3 (B 512 14 14) and l4 (B 1024 7 7)
        fused_l4_upsample: Tensor = nn.Upsample(scale_factor=2, mode="nearest")(
            fused_l4
        )  # B 1024 14 14
        cat_features: Tensor = torch.cat([visual_l3_features, fused_l4_upsample], dim=1)
        fused_l3: Tensor = self.vis_l3_projection(cat_features)  # B 512 14 14

        # Third fusion l2 (B 512 28 28) and l3 (B 512 14 14)
        visual_l2_pooling = nn.MaxPool2d(kernel_size=2, stride=2)(
            visual_l2_features
        )  # B 512 14 14
        fused_l2: Tensor = self.vis_l2_projection(
            torch.cat([fused_l3, visual_l2_pooling], dim=1)
        )  # B 512 14 14

        # Aggregate features
        cat_visual_features: Tensor = torch.cat(
            [fused_l2, fused_l3, fused_l4_upsample], dim=1
        )  # B 2048 14 14
        aggregated_features: Tensor = self.aggregation(
            cat_visual_features
        )  # B 512 14 14

        # Add coordinate features
        final_features: Tensor = self.coord_conv(aggregated_features)  # B 512 14 14
        return final_features


class CoordConv(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.conv: nn.Sequential = _conv_layer(
            input_dim=in_channels,
            output_dim=out_channels,
            kernel_size=3,
            padding=1,
            device=self.device,
        )

    def add_coord(self, input: Tensor) -> Tensor:
        b, _, h, w = input.size()
        x_range = torch.linspace(-1, 1, w, device=self.device)
        y_range = torch.linspace(-1, 1, h, device=self.device)

        y, x = torch.meshgrid(y_range, x_range)
        y = y.expand([b, 1, -1, -1])
        x = x.expand([b, 1, -1, -1])
        coord_feat = torch.cat([x, y], 1)
        input = torch.cat([input, coord_feat], 1)
        return input

    def forward(self, x: Tensor) -> Tensor:
        x = self.add_coord(x)
        x = self.conv(x)
        return x


def _conv_layer(
    input_dim: int,
    output_dim: int,
    kernel_size: int,
    padding: int,
    device: device,
) -> nn.Sequential:
    module = nn.Sequential(
        nn.Conv2d(
            in_channels=input_dim,
            out_channels=output_dim,
            kernel_size=kernel_size,
            padding=padding,
            device=device,
        ),
        nn.BatchNorm2d(output_dim, device=device),
        nn.ReLU(),
    )
    nn.init.xavier_uniform_(module[0].weight)
    return module


In [None]:

class Decoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        img_size: int,
        clip_ctx_length: int,
        nheads: int,
        nlayers: int,
        dim_feedforward: int,
    ) -> None:
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.d_model = d_model
        self.pos_embedding_1d = PositionalEncoding1D(d_model, clip_ctx_length).to(
            self.device
        )
        self.pos_embedding_2d = PositionalEncoding2D(d_model, img_size, img_size).to(
            self.device
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(
                d_model=d_model,
                nhead=nheads,
                dim_feedforward=dim_feedforward,
                batch_first=True,
                norm_first=True,
                device=self.device,
            ),
            num_layers=nlayers,
            norm=nn.LayerNorm(d_model, device=self.device),
        )
        self.reg_token = nn.Parameter(
            torch.randn((1, 1, d_model), requires_grad=True)
        ).to(self.device)
        nn.init.kaiming_normal_(self.reg_token, nonlinearity="relu", mode="fan_out")

    def forward(self, vis: Tensor, text: Tensor) -> Tensor:
        text_features: Tensor = self.pos_embedding_1d(text)

        visual_features: Tensor = self.pos_embedding_2d(vis)

        visual_features = visual_features.flatten(2).permute(0, 2, 1)  # B HW D

        visual_features = torch.cat(
            [self.reg_token.expand((vis.shape[0], -1, -1)), visual_features], dim=1
        )
        x: Tensor = self.decoder(visual_features, text_features)
        reg_token: Tensor = x[:, 0, :]
        return reg_token


# Positional encodings implemented in separate classes if we want to change them and use learnable positional encodings instead
# Dropout added following the original transformer implementation
# https://github.com/wzlxjtu/PositionalEncoding2D
class PositionalEncoding1D(nn.Module):
    def __init__(self, d_model: int, window_len: int) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.dropout = nn.Dropout(0.1).to(self.device)
        self.pos_encoding = torch.zeros(window_len, d_model, device=self.device)
        position = torch.arange(0, window_len, device=self.device).unsqueeze(1)
        div_term = torch.exp(
            (
                torch.arange(0, d_model, 2, dtype=torch.float, device=self.device)
                * -(math.log(10000.0) / d_model)
            )
        )
        self.pos_encoding[:, 0::2] = torch.sin(position.float() * div_term)
        self.pos_encoding[:, 1::2] = torch.cos(position.float() * div_term)

        self.register_buffer("text_pos_encoding", self.pos_encoding)

    def forward(self, token_embedding: Tensor) -> Tensor:
        out = self.dropout(
            token_embedding + self.pos_encoding[: token_embedding.size(1), :]
        )
        return out


# First half of the encodings are used for the height and the second half for the width
class PositionalEncoding2D(nn.Module):
    def __init__(self, d_model: int, width: int, height: int) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.dropout = nn.Dropout(0.1).to(self.device)
        self.pe = torch.zeros(d_model, height, width, device=self.device)
        # Each dimension use half of d_model
        d_model = int(d_model / 2)
        div_term = torch.exp(
            torch.arange(0.0, d_model, 2, device=self.device)
            * -(math.log(10000.0) / d_model)
        )
        pos_w = torch.arange(0.0, width, device=self.device).unsqueeze(1)
        pos_h = torch.arange(0.0, height, device=self.device).unsqueeze(1)
        self.pe[0:d_model:2, :, :] = (
            torch.sin(pos_w * div_term)  # H d_model/4
            .transpose(0, 1)
            .unsqueeze(1)
            .repeat(1, height, 1)
        )  # d_model/4 H H
        self.pe[1:d_model:2, :, :] = (
            torch.cos(pos_w * div_term)
            .transpose(0, 1)
            .unsqueeze(1)
            .repeat(1, height, 1)
        )
        self.pe[d_model::2, :, :] = (
            torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        )  # d_model/4 W W
        self.pe[d_model + 1 :: 2, :, :] = (
            torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        )

        self.register_buffer("visual_pos_encoding", self.pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


In [None]:
class VGModel(nn.Module):
    def __init__(
        self,
        cfg: Config,
    ) -> None:
        super().__init__()
        self.cfg: Config = cfg
        embed_dim: int = cfg.model.embed_dim
        mlp_hidden_dim: int = cfg.model.mlp_hidden_dim

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.visual_encoder = VisualEncoder()

        self.text_encoder = TextEncoder(
            cfg.train.batch_size, cfg.model.clip_ctx_length, embed_dim
        )

        self.fusion_module: FusionModule = FusionModule(
            embed_dim, cfg.model.clip_embed_dim, cfg.model.proj_img_size
        ).to(self.device)

        self.decoder: Decoder = Decoder(
            embed_dim,
            cfg.model.proj_img_size,
            cfg.model.clip_ctx_length,
            cfg.model.decoder_heads,
            cfg.model.decoder_layers,
            cfg.model.decoder_dim_feedforward,
        ).to(self.device)

        self.reg_head: MLP = MLP(
            input_dim=embed_dim, output_dim=4, hidden_dim_1=mlp_hidden_dim
        ).to(self.device)

    def forward(self, batch: List[BatchSample]) -> Tensor:
        # Get text features
        text_sequence, global_text_features = self.text_encoder(
            torch.stack([sample.caption for sample in batch]).squeeze(1).to(self.device)
        )

        # Get image features
        visual_features: OrderedDict[str, Tensor] = self.visual_encoder(
            torch.stack([sample.image for sample in batch]).to(self.device)
        )

        # Fuse features
        fused_visual_features: Tensor = self.fusion_module(
            visual_features, global_text_features
        )

        # Transformer decoder
        reg_token: Tensor = self.decoder(fused_visual_features, text_sequence)

        # Regression head
        out: Tensor = self.reg_head(reg_token)
        return out


class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim_1: int) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim_1),
            nn.BatchNorm1d(hidden_dim_1),
            nn.ReLU(),
            nn.Linear(hidden_dim_1, output_dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.mlp(x)

In [None]:
class Loss:
    def __init__(self, l1: float, l2: float) -> None:
        self.l1_loss = nn.SmoothL1Loss(reduction="mean")
        self.giou_loss = generalized_box_iou_loss
        self.l1: float = l1
        self.l2: float = l2
        self.loss: Tensor

    # Both bounding boxex tensors are in xyxy format
    def compute(self, prediction: Tensor, gt_bbox: Tensor) -> Tensor:
        self.loss = self.l1 * self.l1_loss(
            gt_bbox, prediction
        ) + self.l2 * self.giou_loss(gt_bbox, prediction, reduction="mean")
        return self.loss

    def to_float(self) -> float:
        return self.loss.item()

In [None]:
def plot_grad_flow(named_parameters: Iterator[Tuple[str, nn.Parameter]]) -> None:
    """Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.

    Usage: Plug this function in Trainer class after loss.backwards() as
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow"""
    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        if (p.requires_grad) and ("bias" not in n):
            if p.grad is None:
                print(f"None gradient for {n}")
            else:
                layers.append(n)
                ave_grads.append(p.grad.abs().mean().item())
                max_grads.append(p.grad.abs().max().item())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom=-0.001, top=0.02)  # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend(
        [
            Line2D([0], [0], color="c", lw=4),
            Line2D([0], [0], color="b", lw=4),
            Line2D([0], [0], color="k", lw=4),
        ],
        ["max-gradient", "mean-gradient", "zero-gradient"],
    )
    plt.tight_layout()
    plt.show()

In [None]:
def train_one_epoch(
    epoch: int,
    dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    model: VGModel,
    loss: Loss,
    optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler.OneCycleLR,
    device: device,
    cfg: Config,
) -> Dict[str, float]:
    model.train()
    loss_list: List[Tensor] = []
    iou_list: List[Tensor] = []
    acc_50: List[Tensor] = []
    acc_75: List[Tensor] = []
    acc_90: List[Tensor] = []

    for idx, (batch, bbox) in enumerate(tqdm(dataloader, desc="Batches")):
        optimizer.zero_grad()
        # Move to gpu
        for sample in batch:
            sample = sample.to(device)
        bbox = bbox.to(device)

        # Forward pass
        out: Tensor = model(batch)

        # Loss and metrics
        out_xyxy = box_convert(out, in_fmt="xywh", out_fmt="xyxy")
        bbox_xyxy = box_convert(bbox, in_fmt="xywh", out_fmt="xyxy")
        batch_loss: Tensor = loss.compute(out_xyxy, bbox_xyxy)

        # Backward pass
        batch_loss.backward()
        plot_grad_flow(model.named_parameters())
        optimizer.step()
        scheduler.step()

        out_xyxy = out_xyxy.detach()
        bbox_xyxy = bbox_xyxy.detach()
        batch_iou: Tensor = torch.diagonal(box_iou(out_xyxy, bbox_xyxy))

        loss_list.append(batch_loss.detach())
        iou_list.append(batch_iou.mean())
        acc_50.append(accuracy(batch_iou, 0.5))
        acc_75.append(accuracy(batch_iou, 0.75))
        acc_90.append(accuracy(batch_iou, 0.9))

        if (idx * len(batch)) % 4096 == 0:
            report: Dict[str, float] = {
                "Train loss": batch_loss.detach().item(),
                "Train accurracy": batch_iou.mean().item(),
            }
            pprint(f"Batches: {idx}, {report}")

    return {
        Metric.LOSS.value: torch.stack(loss_list).mean().item(),
        Metric.IOU.value: torch.stack(iou_list).mean().item(),
        Metric.ACCURACY_50.value: torch.stack(acc_50).mean().item(),
        Metric.ACCURACY_75.value: torch.stack(acc_75).mean().item(),
        Metric.ACCURACY_90.value: torch.stack(acc_90).mean().item(),
    }


@torch.no_grad()
def validate(
    dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    model: VGModel,
    loss: Loss,
    device: torch.device,
) -> Dict[str, float]:
    # As accuracy we take the average IoU
    model.eval()
    loss_list: List[Tensor] = []
    iou_list: List[Tensor] = []
    acc_50: List[Tensor] = []
    acc_75: List[Tensor] = []
    acc_90: List[Tensor] = []

    for batch, bbox in tqdm(dataloader, desc="Batches"):
        # Move to gpu
        for sample in batch:
            sample.to(device)
        bbox = bbox.to(device)

        # Forward pass
        out: Tensor = model(batch)

        out = box_convert(out, in_fmt="xywh", out_fmt="xyxy").detach()
        bbox = box_convert(bbox, in_fmt="xywh", out_fmt="xyxy").detach()

        batch_loss: Tensor = loss.compute(out, bbox).detach()
        batch_iou: Tensor = torch.diagonal(box_iou(out, bbox)).detach()

        loss_list.append(batch_loss)
        iou_list.append(batch_iou.mean())
        acc_50.append(accuracy(batch_iou, 0.5))
        acc_75.append(accuracy(batch_iou, 0.75))
        acc_90.append(accuracy(batch_iou, 0.9))

    return {
        Metric.LOSS.value: torch.stack(loss_list).mean().item(),
        Metric.IOU.value: torch.stack(iou_list).mean().item(),
        Metric.ACCURACY_50.value: torch.stack(acc_50).mean().item(),
        Metric.ACCURACY_75.value: torch.stack(acc_75).mean().item(),
        Metric.ACCURACY_90.value: torch.stack(acc_90).mean().item(),
    }


def accuracy(iou: Tensor, threshold: float) -> Tensor:
    return torch.tensor(len(iou[iou >= threshold]) / len(iou))


In [None]:
def train(
    train_dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    val_dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    device: torch.device,
    cfg: Config,
) -> Tuple[MetricsLogger, MetricsLogger]:
    train_metrics: MetricsLogger = MetricsLogger()
    val_metrics: MetricsLogger = MetricsLogger()

    # Loss is the weighted sum of the smooth l1 loss and the GIoU
    loss_func = Loss(cfg.train.l1, cfg.train.l2)

    model: VGModel = VGModel(cfg).train()

    # Separate parameters to train
    params: List[nn.Parameter] = [p for p in model.parameters() if p.requires_grad]

    optimizer = optim.AdamW(
        params=[
            {"params": params, "lr": cfg.train.lr, "weight_decay": 1e-4},
        ]
    )
    lr_scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=[cfg.train.lr],
        epochs=cfg.epochs,
        steps_per_epoch=len(train_dataloader),
    )

    if cfg.logging.wandb:
        wandb.watch(model, loss_func, log="all", log_freq=100, log_graph=True)

    for epoch in tqdm(range(cfg.epochs), desc="Epochs"):
        print("-------------------- Training --------------------------")
        epoch_train_metrics: Dict[str, float] = train_one_epoch(
            epoch=epoch,
            dataloader=train_dataloader,
            model=model,
            loss=loss_func,
            optimizer=optimizer,
            scheduler=lr_scheduler,
            device=device,
            cfg=cfg,
        )
        train_metrics.update_metric(epoch_train_metrics)
        print("Training metrics at epoch ", epoch)
        print(epoch_train_metrics)

        # Evaluate on validation set for hyperparameter tuning
        print("-------------------- Validation ------------------------")
        epoch_val_metrics: Dict[str, float] = validate(
            val_dataloader, model, loss_func, device
        )
        val_metrics.update_metric(epoch_val_metrics)
        print("Validation metrics at epoch ", epoch)
        print(epoch_val_metrics)

        # Log metrics to wandb putting train and val metrics together
        if cfg.logging.wandb:
            wandb.log(
                {
                    "Loss": {
                        "train": epoch_train_metrics[Metric.LOSS.value],
                        "val": epoch_val_metrics[Metric.LOSS.value],
                    },
                    "Average IOU": {
                        "train": epoch_train_metrics[Metric.IOU.value],
                        "val": epoch_val_metrics[Metric.IOU.value],
                    },
                    "Accuracy@50": {
                        "train": epoch_train_metrics[Metric.ACCURACY_50.value],
                        "val": epoch_val_metrics[Metric.ACCURACY_50.value],
                    },
                    "Accuracy@75": {
                        "train": epoch_train_metrics[Metric.ACCURACY_75.value],
                        "val": epoch_val_metrics[Metric.ACCURACY_75.value],
                    },
                    "Accuracy@90": {
                        "train": epoch_train_metrics[Metric.ACCURACY_90.value],
                        "val": epoch_val_metrics[Metric.ACCURACY_90.value],
                    },
                },
                commit=True,
            )

        # Save model after each epoch
        if cfg.logging.save:
            dir: str = cfg.logging.path
            if not os.path.exists(dir):
                os.makedirs(dir)
            torch.save(
                obj={
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "lr_scheduler_state_dict": lr_scheduler.state_dict(),
                    "loss": epoch_train_metrics[Metric.LOSS.value],
                },
                f=f"{dir}model{epoch}.pth",
            )

        torch.cuda.empty_cache()
        gc.collect()

    return train_metrics, val_metrics


def initialize_run(sweep: bool = True) -> None:
    config = Config()
    if sweep:
        wandb.login(key=os.getenv("WANDB_API_KEY"))
        wandb.init(project="vgproject")
        wandb_cfg = wandb.config
        config.update(wandb_cfg)
    else:
        if config.logging.wandb:
            wandb.login(key=os.getenv("WANDB_API_KEY"))
            wandb.init(project="vgproject", config=config.as_dict())

    train_dataset: VGDataset = VGDataset(
        dir_path=config.dataset_path,
        split=Split.TRAIN,
        output_bbox_type=BboxType.XYWH,
        augment=True,
        preprocessed=True,
    )
    print("Train dataset created. Dataset length ", len(train_dataset))

    val_dataset: VGDataset = VGDataset(
        dir_path=config.dataset_path,
        split=Split.VAL,
        output_bbox_type=BboxType.XYWH,
        augment=False,
        preprocessed=True,
    )
    print("Validation dataset created. Dataset length: ", len(val_dataset))

    train_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=train_dataset,
        batch_size=config.train.batch_size,
        collate_fn=custom_collate,
        num_workers=2,
        shuffle=True,
        drop_last=True,
    )

    val_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=val_dataset,
        batch_size=config.train.batch_size,
        collate_fn=custom_collate,
        shuffle=True,
        drop_last=True,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_metrics, val_metrics = train(train_dataloader, val_dataloader, device, config)

    json.dump(train_metrics.metrics, open("../train_metrics.json", "w"))
    json.dump(val_metrics.metrics, open("../val_metrics.json", "w"))

    if config.logging.wandb:
        wandb.finish()


def main() -> None:
    init_torch()
    cfg = Config()
    if cfg.train.sweep:
        sweep_configuration: Dict[str, Any] = json.load(
            open("../sweep_config.json", "r")
        )
        sweep: str = wandb.sweep(sweep_configuration, project="vgproject")
        wandb.agent(sweep, function=initialize_run, count=10)
    else:
        initialize_run(cfg.train.sweep)

In [None]:
main()

In [None]:
# To clean environment
wandb.finish()
torch.cuda.empty_cache()
gc.collect()