# Deep Learning Project

In [None]:
%pip install ftfy regex tqdm ultralytics wandb albumentations

In [None]:
import gc
import gzip
import html
import json
import math
import os
import pickle
import random
import re
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from pprint import pprint
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Union

import albumentations as A
import cv2
import ftfy
import gdown
import matplotlib.pyplot as plt
import numpy as np
import pkg_resources as p
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch import Tensor, device, tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.ops import box_convert, box_iou, generalized_box_iou_loss
from tqdm.notebook import tqdm

import wandb

In [None]:
%cd /content/
# Download dataset and save under data/raw/ only if not already downloaded
url = "https://drive.google.com/uc?id=1xijq32XfEm6FPhUb7RsZYWHc2UuwVkiq"
if not os.path.exists("data/raw/refcocog.tar.gz"):
    print("Downloading dataset...")
    gdown.download(url=url, output="data/raw/", quiet=False, resume=True)
if not os.path.exists("data/raw/refcocog/"):
    print("Extracting dataset...")
    !tar -xf data/raw/refcocog.tar.gz -C data/raw

The preprocessed samples can be downloaded from Google Drive by executing the following cell. Otherwise, the preprocessing wil be done saving the file only temporarly in the Colab environment. 

In [None]:
# Download preprocessed dataset
url = "https://drive.google.com/drive/folders/1jaJV40dneOckZn7WHMQyd2jBh7A8534N"
gdown.download_folder(url=url, output="data/", quiet=False)

In [None]:
!wget https://raw.githubusercontent.com/ManuelaCorte/DLProject/decoder/config.json
!wget https://raw.githubusercontent.com/ManuelaCorte/DLProject/decoder/sweep_config.json
!wget https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt
!wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz

In [None]:
%mkdir src
%cd src

In [None]:
class Sample:
    def __init__(self, image_path: str, caption: str, bounding_box: Tensor) -> None:
        self.image_path = image_path
        self.caption = caption
        self.bounding_box = bounding_box

    def as_dict(self) -> dict[str, Any]:
        return {
            "image_path": self.image_path,
            "caption": self.caption,
            "bounding_box": self.bounding_box.tolist(),
        }

    @staticmethod
    def fromJSON(json: dict[str, Any]) -> Any:
        return Sample(json["image_path"], json["caption"], Tensor(json["bounding_box"]))


class BatchSample:
    def __init__(self, image: Tensor, caption: Tensor) -> None:
        self.image: Tensor = image
        self.caption: Tensor = caption

    def to(self, device: device | str) -> Any:
        return self.__class__(self.image.to(device), self.caption.to(device))

    def __str__(self) -> str:
        return f"BatchSample(image={self.image.shape}, caption={self.caption.shape})"


@dataclass(frozen=True)
class Split(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"


@dataclass(frozen=True)
class Result:
    bounding_box: Tensor
    score: Tensor


# XYXY: top left and bottom right corners
# XYWH: top left corner, width and height
# CXCWH: center coordinates, width and height
@dataclass(frozen=True)
class BboxType(Enum):
    XYXY = "xyxy"
    XYWH = "xywh"
    CXCWH = "cxcwh"

    def __str__(self) -> str:
        return super().__str__()


@dataclass
class Model:
    clip_embed_dim: int
    clip_ctx_length: int
    embed_dim: int
    mlp_hidden_dim: int
    img_size: int
    proj_img_size: int
    decoder_layers: int
    decoder_heads: int
    decoder_dim_feedforward: int


@dataclass
class Train:
    batch_size: int
    lr: float
    lr_backbone: float
    gamma: float
    l1: float
    l2: float
    sweep: bool


@dataclass
class Logging:
    path: str
    save: bool
    resume: bool
    wandb: bool


class Config:
    def __init__(self) -> None:
        cfg: Dict[str, Any] = json.load(open("../config.json", "r"))
        self.dataset_path: str = cfg["dataset_path"]
        self.epochs: int = cfg["epochs"]
        self.model = Model(**cfg["model"])
        self.train = Train(**cfg["train"])
        self.logging = Logging(**cfg["logging"])

    def as_dict(self) -> Dict[str, Any]:
        return {
            "dataset_path": self.dataset_path,
            "epochs": self.epochs,
            "model": self.model.__dict__,
            "train": self.train.__dict__,
        }

    # if in other dict there are keys equal to the keys in self, update them
    def update(self, other: Dict[str, Any]):
        for k, v in other.items():
            if k in self.__dict__:
                self.__dict__[k] = v
            if k in self.model.__dict__:
                self.model.__dict__[k] = v
            if k in self.train.__dict__:
                self.train.__dict__[k] = v
            if k in self.logging.__dict__:
                self.logging.__dict__[k] = v

In [None]:
@dataclass(frozen=True)
class Metric(Enum):
    LOSS = "loss"
    ACCURACY_50 = "accuracy"  # IoU > 0.5 -> 1 else 0
    ACCURACY_75 = "accuracy75"  # IoU > 0.75 -> 1 else 0
    ACCURACY_90 = "accuracy90"  # IoU > 0.9 -> 1 else 0
    IOU = "iou"
    COSINE_SIMILARITY = "cosine_similarity"


@dataclass(frozen=True)
class Reduction(Enum):
    MEAN = "mean"
    SUM = "sum"
    NONE = "none"


class MetricsLogger:
    def __init__(self, metrics: Dict[Metric, List[float]] | None = None) -> None:
        self.metrics: Dict[Metric, List[float]] = {}
        if metrics is None:
            for metric in Metric:
                self.metrics[metric] = []
        else:
            self.metrics: Dict[Metric, List[float]] = metrics

    def update_metric(self, metrics: Dict[Metric, float]) -> None:
        for metric, value in metrics.items():
            self.metrics[metric].append(value)

    def get_metric(
        self, metric: Metric, red: Reduction = Reduction.NONE
    ) -> float | List[float]:
        values: List[float] = self.metrics[metric]
        match red.name:
            case Reduction.MEAN.name:
                return sum(values) / len(values)
            case Reduction.SUM.name:
                return sum(values)
            case Reduction.NONE.name:
                return values
            case _:
                raise ValueError(f"Reduction {red.name} doesn't exists")

    def __str__(self) -> str:
        res = "Metrics:\n"
        for metric, values in self.metrics.items():
            res += f"{metric.value}: {sum(values) / len(values)}\n"
        return res

In [None]:
@lru_cache()
def default_bpe():
    return os.path.join(globals()["_dh"][0], "bpe_simple_vocab_16e6.txt.gz")


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = (
        list(range(ord("!"), ord("~") + 1))
        + list(range(ord("¡"), ord("¬") + 1))
        + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r"\s+", " ", text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe()):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges_lst = gzip.open(bpe_path).read().decode("utf-8").split("\n")
        merges_lst = merges_lst[1 : 49152 - 256 - 2 + 1]
        merges = [tuple(merge.split()) for merge in merges_lst]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v + "</w>" for v in vocab]
        for merge in merges:
            vocab.append("".join(merge))
        vocab.extend(["<|startoftext|>", "<|endoftext|>"])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {
            "<|startoftext|>": "<|startoftext|>",
            "<|endoftext|>": "<|endoftext|>",
        }
        self.pat = re.compile(
            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\s\\p{L}\\p{N}]+""",
            re.IGNORECASE,
        )

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + (token[-1] + "</w>",)
        pairs = get_pairs(word)

        if not pairs:
            return token + "</w>"

        while True:
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = " ".join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
            bpe_tokens.extend(
                self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
            )
        return bpe_tokens

    def decode(self, tokens):
        text = "".join([self.decoder[token] for token in tokens])
        text = (
            bytearray([self.byte_decoder[c] for c in text])
            .decode("utf-8", errors="replace")
            .replace("</w>", " ")
        )
        return text

In [None]:
_tokenizer = SimpleTokenizer()


def tokenize(
    texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
) -> Union[torch.IntTensor, torch.LongTensor]:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
    We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]

    result: Union[torch.IntTensor, torch.LongTensor]
    if p.parse_version(torch.__version__) < p.parse_version("1.8.0"):
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)  # type: ignore
    else:
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)  # type: ignore

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(
                    f"Input {texts[i]} is too long for context length {context_length}"
                )
        result[i, : len(tokens)] = torch.tensor(tokens)

    return result

In [None]:
# The Dataset contains samples with an image with a bounding box and a caption associated with the bounding box.


class VGDataset(Dataset[Tuple[BatchSample, Tensor]]):
    def __init__(
        self,
        dir_path: str,
        split: Split,
        output_bbox_type: BboxType,
        augment: bool,
        transform: bool = True,
        preprocessed: bool = False,
        preprocessed_path: str = "../data/processed/",
    ) -> None:
        super().__init__()
        self.dir_path: str = dir_path
        self.split: Split = split
        self.output_bbox_type: BboxType = output_bbox_type
        self.augment: bool = augment
        self.transform: bool = transform
        self.device: device = torch.device(
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
        if preprocessed:
            preprocess(dir_path, preprocessed_path, output_bbox_type)
            with open(
                preprocessed_path + f"{self.split.value}_samples.json", "rb"
            ) as samples:
                self.samples: List[Sample] = json.load(
                    samples, object_hook=Sample.fromJSON
                )
        else:
            self.samples: List[Sample] = self.get_samples()  # type: ignore

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, ref_id: int) -> Tuple[BatchSample, Tensor]:
        # extended_caption: str = f"find the region that corresponds to the description {self.samples[ref_id].caption}"
        caption: Tensor = tokenize(self.samples[ref_id].caption, truncate=True)  # type: ignore
        if self.transform:
            image, bbox = transform_sample(
                Image.open(self.samples[ref_id].image_path),
                self.samples[ref_id].bounding_box,
                self.augment,
            )
        else:
            image = read_image(self.samples[ref_id].image_path)
            bbox = torch.tensor([self.samples[ref_id].bounding_box])
        return BatchSample(image, caption), bbox

    def get_samples(self) -> List[Sample]:
        with open(self.dir_path + "annotations/instances.json", "r") as inst, open(
            self.dir_path + "annotations/refs(umd).p", "rb"
        ) as refs:
            instances = json.load(inst)
            references = pickle.load(refs)
        samples: List[Sample] = []
        for ref in references:
            if self.split.value == ref["split"]:
                image_path = self.get_image_path(ref["image_id"], instances)
                caption = self.get_caption(ref["sentences"])
                bbox = self.get_bounding_box(ref["ann_id"], instances)
                samples.append(Sample(image_path, caption, bbox))
        return samples

    def get_image_path(self, img_id: int, instances: Dict[str, Any]) -> str:
        image_name = next(
            image["file_name"] for image in instances["images"] if image["id"] == img_id
        )
        path = self.dir_path + "images/" + image_name
        return path

    def get_caption(self, captions: List[Dict[str, Any]]) -> str:
        longest_caption = captions[0]
        for caption in captions:
            if len(caption["sent"]) > len(longest_caption["sent"]):
                longest_caption = caption
        return f"find the region that corresponds to the description {longest_caption['sent']}"

    # Bounding boxed converted to format compatible with yolo or torchvision
    def get_bounding_box(self, ann_id: int, instances: Dict[str, Any]) -> Tensor:
        bbox = next(
            ann["bbox"] for ann in instances["annotations"] if ann["id"] == ann_id
        )
        bounding_box: Tensor = tensor([])
        match self.output_bbox_type.name:
            case BboxType.XYXY.name:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYXY.value
                )
            case BboxType.XYWH.name:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYWH.value
                )
            case BboxType.CXCWH.name:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.CXCWH.value
                )

        return bounding_box

In [None]:
def get_samples(
    dir_path: str, bbox_type: BboxType
) -> Tuple[List[Sample], List[Sample], List[Sample]]:
    with open(dir_path + "annotations/instances.json", "r") as inst, open(
        dir_path + "annotations/refs(umd).p", "rb"
    ) as refs:
        instances = json.load(inst)
        references = pickle.load(refs)
    train_samples: List[Sample] = []
    val_samples: List[Sample] = []
    test_samples: List[Sample] = []
    for ref in tqdm(references, desc=f"Processing dataset"):
        image_path = get_image_path(dir_path, ref["image_id"], instances)
        caption = get_caption(ref["sentences"])
        bbox = get_bounding_box(ref["ann_id"], instances, bbox_type)
        split = ref["split"]
        # print(split)
        match split:
            case Split.TRAIN.value:
                train_samples.append(Sample(image_path, caption, bbox))
            case Split.VAL.value:
                val_samples.append(Sample(image_path, caption, bbox))
            case Split.TEST.value:
                test_samples.append(Sample(image_path, caption, bbox))
            case _:
                raise ValueError(f"Invalid split: {split}")
    return train_samples, val_samples, test_samples


def get_image_path(dir_path: str, img_id: int, instances: Dict[str, Any]) -> str:
    image_name = next(
        image["file_name"] for image in instances["images"] if image["id"] == img_id
    )
    path = dir_path + "images/" + image_name
    return path


def get_caption(captions: List[Dict[str, Any]]) -> str:
    longest_caption = captions[0]
    for caption in captions:
        if len(caption["sent"]) > len(longest_caption["sent"]):
            longest_caption = caption
    return longest_caption["sent"]


# Bounding boxed converted to format compatible with yolo or torchvision
def get_bounding_box(
    ann_id: int, instances: Dict[str, Any], bbox_type: BboxType
) -> Tensor:
    bbox = next(ann["bbox"] for ann in instances["annotations"] if ann["id"] == ann_id)
    bounding_box: Tensor = tensor([])
    match bbox_type.name:
        case BboxType.XYXY.name:
            bounding_box = box_convert(
                tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYXY.value
            )
        case BboxType.XYWH.name:
            bounding_box = box_convert(
                tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYWH.value
            )
        case BboxType.CXCWH.name:
            bounding_box = box_convert(
                tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.CXCWH.value
            )

    return bounding_box


# If the files already exist, don't preprocess again
def preprocess(in_path: str, out_path: str, bbox_type: BboxType) -> None:
    if (
        os.path.exists(f"{out_path}train_samples.json")
        and os.path.exists(f"{out_path}val_samples.json")
        and os.path.exists(f"{out_path}test_samples.json")
    ):
        return
    train_samples, val_samples, test_samples = get_samples(in_path, bbox_type)

    json.dump(
        train_samples,
        open(f"{out_path}train_samples.json", "w"),
        default=Sample.as_dict,
    )

    json.dump(
        val_samples,
        open(f"{out_path}val_samples.json", "w"),
        default=Sample.as_dict,
    )

    json.dump(
        test_samples,
        open(f"{out_path}test_samples.json", "w"),
        default=Sample.as_dict,
    )


if __name__ == "__main__":
    preprocess("../data/raw/refcocog/", "../data/processed/", BboxType.XYWH)

    train: List[Sample] = json.load(
        open("../data/processed/train_samples.json", "r"), object_hook=Sample.fromJSON
    )
    val: List[Sample] = json.load(
        open("../data/processed/val_samples.json", "r"), object_hook=Sample.fromJSON
    )
    test: List[Sample] = json.load(
        open("../data/processed/test_samples.json", "r"), object_hook=Sample.fromJSON
    )
    print(len(train), len(val), len(test))
    print(train[0].image_path, train[0].caption, train[0].bounding_box.shape)

In [None]:
def custom_collate(
    batch: List[Tuple[BatchSample, torch.Tensor]]
) -> Tuple[List[BatchSample], torch.Tensor]:
    bboxes: List[torch.Tensor] = []
    samples: List[BatchSample] = []
    for sample, bbox in batch:
        samples.append(BatchSample(sample.image, sample.caption))
        bboxes.append(bbox)
    return samples, torch.stack(bboxes)


# Transform image according to CLIP preprocess function
# Normalize bounding box coordinates to be independent of image size
def transform_sample(
    image: Image.Image,
    box: Tensor,
    augment: bool,
    target_size: int = 224,
) -> Tuple[Tensor, Tensor]:
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Same transformation as in the CLIP preprocess function
    if augment:
        trans = A.Compose(
            transforms=[
                A.Resize(target_size, target_size, interpolation=cv2.INTER_CUBIC, p=1),
                A.CenterCrop(
                    target_size,
                    target_size,
                    always_apply=True,
                ),
                A.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                    max_pixel_value=255.0,
                    always_apply=True,
                ),
                A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0),
                A.GaussianBlur(p=1),
                A.PixelDropout(dropout_prob=0.02),
                A.Rotate(limit=20),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format="coco", label_fields=[]),
        )
    else:
        trans = A.Compose(
            transforms=[
                A.Resize(target_size, target_size, interpolation=cv2.INTER_CUBIC, p=1),
                A.CenterCrop(
                    target_size,
                    target_size,
                    always_apply=True,
                ),
                A.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                    max_pixel_value=255.0,
                ),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format="coco", label_fields=[]),
        )

    transformed_sample: Dict[str, Any] = trans(
        image=np.array(image), bboxes=box.tolist()
    )

    bbox_tensor: Tensor = torch.tensor(transformed_sample["bboxes"][0]) / target_size
    # print(bbox_tensor)
    return transformed_sample["image"], bbox_tensor.to(torch.float32)


def init_torch(seed: int = 41) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)

In [None]:
from collections import OrderedDict
from typing import Any, List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes: int, planes: int, stride: int = 1) -> None:
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu3 = nn.ReLU(inplace=True)

        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(
                OrderedDict(
                    [
                        ("-1", nn.AvgPool2d(stride)),
                        (
                            "0",
                            nn.Conv2d(
                                inplanes,
                                planes * self.expansion,
                                1,
                                stride=1,
                                bias=False,
                            ),
                        ),
                        ("1", nn.BatchNorm2d(planes * self.expansion)),
                    ]
                )
            )

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu3(out)
        return out


class AttentionPool2d(nn.Module):
    def __init__(
        self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int
    ) -> None:
        super().__init__()
        self.spacial_dim = spacial_dim
        self.positional_embedding = nn.Parameter(
            torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
        )
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads
        # residual
        self.connect = nn.Sequential(
            nn.Conv2d(embed_dim, output_dim, 1),
            nn.BatchNorm2d(output_dim),
        )

    def resize_pos_embed(self, pos_embed, input_shape):
        pos_h = pos_w = self.spacial_dim
        # Skip the first position embedding as it is the CLS token
        pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :, :]  # 1 HW C
        pos_embed_weight = pos_embed_weight.permute(0, 2, 1)
        return pos_embed_weight

        # pos_embed_weight = pos_embed_weight.reshape(
        #     1, pos_h, pos_w, pos_embed.shape[2]
        # ).permute(
        #     0, 3, 1, 2
        # )  # 1 C H W
        # # pos_embed_weight_int = F.interpolate(
        # #     pos_embed_weight, size=input_shape, align_corners=False, mode="bicubic"
        # # )
        # pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)  # 1 HW C
        # return pos_embed_weight.transpose(-2, -1)  # 1 C HW

    def forward(self, x):
        B, C, H, W = x.size()
        # Residual connection
        residual = self.connect(x)

        x = x.reshape(B, C, -1)  # B C HW
        pos_embed = self.positional_embedding.unsqueeze(0)
        pos_embed = self.resize_pos_embed(pos_embed, (H, W))  # 1 C HW
        x = x + pos_embed.to(x.dtype)  # B C HW
        x = x.permute(2, 0, 1)  # HW B C (seq_len, batch, embed_dim)
        x, _ = F.multi_head_attention_forward(
            query=x,
            key=x,
            value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat(
                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
            ),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False,
        )
        x = x.permute(1, 2, 0).reshape(B, -1, H, W)  # B C H W
        x = x + residual
        x = F.relu(x, True)

        return x


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(
        self,
        layers: List[int] | Tuple[int, int, int, int],
        output_dim: int,
        heads: int,
        input_resolution: int = 224,
        width: int = 64,
    ) -> None:
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution

        # the 3-layer stem
        self.conv1 = nn.Conv2d(
            3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            width // 2, width // 2, kernel_size=3, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(
            input_resolution // 32, embed_dim, heads, output_dim
        )

    def _make_layer(self, planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        with torch.no_grad():

            def stem(x: Tensor) -> Tensor:
                x = self.relu1(self.bn1(self.conv1(x)))
                x = self.relu2(self.bn2(self.conv2(x)))
                x = self.relu3(self.bn3(self.conv3(x)))
                x = self.avgpool(x)
                return x

            x = x.type(self.conv1.weight.dtype)
            x_stem: Tensor = stem(x)
            x1: Tensor = self.layer1(x_stem)
            x2: Tensor = self.layer2(x1)
            x3: Tensor = self.layer3(x2)
            x4: Tensor = self.layer4(x3)

        x_pooled: Tensor = self.attnpool(x4)

        return (
            x2,
            x3,
            x_pooled,
        )  # B 512 H/8 W/8 (B, 1024, H/16, W/16) (B, 1024, H/32, W/32)


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, input: Tensor) -> Tensor:
        orig_type = input.dtype
        ret = super().forward(input.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(
        self, d_model: int, n_head: int, attn_mask: Optional[torch.Tensor] = None
    ) -> None:
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(
            OrderedDict(
                [
                    ("c_fc", nn.Linear(d_model, d_model * 4)),
                    ("gelu", QuickGELU()),
                    ("c_proj", nn.Linear(d_model * 4, d_model)),
                ]
            )
        )
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = (
            self.attn_mask.to(dtype=x.dtype, device=x.device)
            if self.attn_mask is not None
            else None
        )
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        attn_mask: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(
            *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
        )

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)


class CLIP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        # vision
        image_resolution: int,
        vision_layers: Tuple[int, int, int, int],
        vision_width: int,
        # text
        context_length: int,
        vocab_size: int,
        transformer_width: int,
        transformer_heads: int,
        transformer_layers: int,
    ):
        super().__init__()

        self.context_length = context_length

        vision_heads = vision_width * 32 // 64
        self.visual = ModifiedResNet(
            layers=vision_layers,
            output_dim=embed_dim,
            heads=vision_heads,
            input_resolution=image_resolution,
            width=vision_width,
        )

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask(),
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(
            torch.empty(self.context_length, transformer_width)
        )
        self.ln_final = LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        if isinstance(self.visual, ModifiedResNet):
            if self.visual.attnpool is not None:
                std = self.visual.attnpool.c_proj.in_features**-0.5
                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
                # nn.init.uniform_(self.visual.attnpool.connect.weight)

            for resnet_block in [
                self.visual.layer1,
                self.visual.layer2,
                self.visual.layer3,
                self.visual.layer4,
            ]:
                for name, param in resnet_block.named_parameters():
                    if name.endswith("bn3.weight"):
                        nn.init.zeros_(param)

        proj_std = (self.transformer.width**-0.5) * (
            (2 * self.transformer.layers) ** -0.5
        )
        attn_std = self.transformer.width**-0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    @property
    def dtype(self):
        return self.visual.conv1.weight.dtype

    def encode_image(self, image) -> Tuple[Tensor, Tensor, Tensor]:
        return self.visual(image.type(self.dtype))

    @torch.no_grad()
    def encode_text(self, text) -> Tuple[Tensor, Tensor]:
        x = self.token_embedding(text).type(
            self.dtype
        )  # B L D (batch size, sequence length=77, embed dim=512)

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # L B D
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # B L D
        x = self.ln_final(x).type(self.dtype)  # B L D

        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # Embedding corresponding to the higher value token for each element in the sequence (eot_token)
        global_repr: Tensor = (
            x[torch.arange(x.shape[0]), text.argmax(dim=-1), :] @ self.text_projection
        )  # [B, d_model=1024] @ [d_model, D] = [B, D]

        # Transformer output, EOT embeddings
        return x, global_repr

    def forward(self, image, text) -> Tuple[Tensor, Tensor]:
        image_features: Tensor = self.encode_image(image)[2]
        text_features: Tensor = self.encode_text(text)[1]

        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text


def convert_weights(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l: nn.Module) -> None:
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [
                *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
                "in_proj_bias",
                "bias_k",
                "bias_v",
            ]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr_module: nn.Module = getattr(l, name)
                if attr_module is not None:
                    attr_module.data = attr_module.data.half()

    model.apply(_convert_weights_to_fp16)


def build_model(state_dict: dict[str, Any]):
    """Build a CLIP model from a state dict"""
    counts: List[int] = [
        len(
            set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))
        )
        for b in [1, 2, 3, 4]
    ]
    vision_layers: Tuple[int, int, int, int] = tuple(counts)  # type: ignore
    vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
    output_width = round(
        (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
    )
    assert (
        output_width**2 + 1
        == state_dict["visual.attnpool.positional_embedding"].shape[0]
    )
    image_resolution = output_width * 32

    embed_dim: int = state_dict["text_projection"].shape[1]
    context_length: int = state_dict["positional_embedding"].shape[0]
    vocab_size: int = state_dict["token_embedding.weight"].shape[0]
    transformer_width: int = state_dict["ln_final.weight"].shape[0]
    transformer_heads: int = transformer_width // 64
    transformer_layers: int = len(
        set(
            k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")
        )
    )

    model = CLIP(
        embed_dim,
        image_resolution,
        vision_layers,
        vision_width,
        context_length,
        vocab_size,
        transformer_width,
        transformer_heads,
        transformer_layers,
    )

    # Add the new positional embedding to the state dict
    state_dict.update(
        model.visual.attnpool.connect.state_dict(prefix="visual.attnpool.connect.")
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    convert_weights(model)

    model.load_state_dict(state_dict)
    return model.eval()

In [None]:
class FusionModule(nn.Module):
    def __init__(self, emb_dim: int, clip_emb_dim: int, proj_img_size: int) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.text_projection = nn.Sequential(
            nn.Linear(in_features=clip_emb_dim, out_features=clip_emb_dim),
            nn.BatchNorm1d(clip_emb_dim, device=self.device),
        ).to(self.device)
        self.vis_l4_projection = _conv_layer(
            input_dim=clip_emb_dim,
            output_dim=clip_emb_dim,
            kernel_size=1,
            padding=0,
            device=self.device,
        )[
            :2
        ]  # Remove ReLU
        self.norm_layer = nn.Sequential(
            nn.BatchNorm2d(
                clip_emb_dim,
                device=self.device,
            ),
            nn.ReLU(),
        ).to(self.device)
        self.vis_l3_projection = _conv_layer(
            input_dim=clip_emb_dim + clip_emb_dim,
            output_dim=emb_dim,
            kernel_size=3,
            padding=1,
            device=self.device,
        )
        self.vis_l2_projection = _conv_layer(
            input_dim=emb_dim + emb_dim,
            output_dim=emb_dim,
            kernel_size=3,
            padding=1,
            device=self.device,
        )
        self.aggregation = _conv_layer(
            input_dim=clip_emb_dim + emb_dim + emb_dim,
            output_dim=emb_dim,
            kernel_size=1,
            padding=0,
            device=self.device,
        )

    def forward(
        self, visual_features: Tuple[Tensor, Tensor, Tensor], text_features: Tensor
    ) -> Tensor:
        visual_l2_features, visual_l3_features, visual_l4_features = visual_features
        # Visual and text features projection
        text_features_proj: Tensor = (
            self.text_projection(text_features).unsqueeze(-1).unsqueeze(-1)
        )  # B 1024 1 1
        visual_l4_features_proj: Tensor = self.vis_l4_projection(
            visual_l4_features
        )  # B 1024 7 7

        # First fusion l4 (B 1024 7 7) and text (B 1024)
        fused_l4: Tensor = self.norm_layer(
            visual_l4_features_proj * text_features_proj
        )  # B 1024 7 7

        # Second fusion l3 (B 512 14 14) and l4 (B 1024 7 7)
        fused_l4_upsample: Tensor = nn.Upsample(scale_factor=2, mode="nearest")(
            fused_l4
        )  # B 1024 14 14
        cat_features: Tensor = torch.cat([visual_l3_features, fused_l4_upsample], dim=1)
        fused_l3: Tensor = self.vis_l3_projection(cat_features)  # B 512 14 14

        # Third fusion l2 (B 512 28 28) and l3 (B 512 14 14)
        visual_l2_pooling: Tensor = nn.MaxPool2d(kernel_size=2, stride=2)(
            visual_l2_features
        )  # B 512 14 14
        fused_l2: Tensor = self.vis_l2_projection(
            torch.cat([fused_l3, visual_l2_pooling], dim=1)
        )  # B 512 14 14

        # Aggregate features
        cat_visual_features: Tensor = torch.cat(
            [fused_l2, fused_l3, fused_l4_upsample], dim=1
        )  # B 2048 14 14
        aggregated_features: Tensor = self.aggregation(
            cat_visual_features
        )  # B 512 14 14
        # TODO: Add spatial coords?
        return aggregated_features


def _conv_layer(
    input_dim: int,
    output_dim: int,
    kernel_size: int,
    padding: int,
    device: device,
) -> nn.Sequential:
    module = nn.Sequential(
        nn.Conv2d(
            in_channels=input_dim,
            out_channels=output_dim,
            kernel_size=kernel_size,
            padding=padding,
            device=device,
        ),
        nn.BatchNorm2d(output_dim, device=device),
        nn.ReLU(),
    )
    nn.init.xavier_uniform_(module[0].weight)
    return module

In [None]:
class Decoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        img_size: int,
        clip_ctx_length: int,
        nheads: int,
        nlayers: int,
        dim_feedforward: int,
    ) -> None:
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.d_model = d_model
        self.pos_embedding_1d = PositionalEncoding1D(d_model, clip_ctx_length).to(
            self.device
        )
        self.pos_embeddinf_2d = PositionalEncoding2D(d_model, img_size, img_size).to(
            self.device
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(
                d_model=d_model,
                nhead=nheads,
                dim_feedforward=dim_feedforward,
                batch_first=True,
                norm_first=True,  # Less prone to vanishing gradients??
                device=self.device,
            ),
            num_layers=nlayers,
            norm=nn.LayerNorm(d_model, device=self.device),
        )
        self.reg_token = nn.Parameter(torch.randn(1, 1, d_model)).to(self.device)
        nn.init.kaiming_normal_(self.reg_token, nonlinearity="relu", mode="fan_out")

    def forward(self, vis: Tensor, text: Tensor) -> Tensor:
        text_features: Tensor = self.pos_embedding_1d(text)

        visual_features: Tensor = self.pos_embeddinf_2d(vis)

        visual_features = visual_features.flatten(2).permute(0, 2, 1)  # B HW D

        visual_features = torch.cat(
            [self.reg_token.expand((vis.shape[0], -1, -1)), visual_features], dim=1
        )
        x = self.decoder(visual_features, text_features)
        return x[:, 0, :]


# Positional encodings implemented in separate classes if we want to change them and use learnable positional encodings instead
# Dropout added following the original transformer implementation
# https://github.com/wzlxjtu/PositionalEncoding2D
class PositionalEncoding1D(nn.Module):
    def __init__(self, d_model: int, window_len: int) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.dropout = nn.Dropout(0.1).to(self.device)
        self.pos_encoding = torch.zeros(window_len, d_model, device=self.device)
        position = torch.arange(0, window_len, device=self.device).unsqueeze(1)
        div_term = torch.exp(
            (
                torch.arange(0, d_model, 2, dtype=torch.float, device=self.device)
                * -(math.log(10000.0) / d_model)
            )
        )
        self.pos_encoding[:, 0::2] = torch.sin(position.float() * div_term)
        self.pos_encoding[:, 1::2] = torch.cos(position.float() * div_term)

        self.register_buffer("text_pos_encoding", self.pos_encoding)

    @torch.no_grad()
    def forward(self, token_embedding: Tensor) -> Tensor:
        out = self.dropout(
            token_embedding + self.pos_encoding[: token_embedding.size(1), :]
        )
        return out


# First half of the encodings are used for the height and the second half for the width
class PositionalEncoding2D(nn.Module):
    def __init__(self, d_model: int, width: int, height: int) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.dropout = nn.Dropout(0.1).to(self.device)
        self.pe = torch.zeros(d_model, height, width, device=self.device)
        # Each dimension use half of d_model
        d_model = int(d_model / 2)
        div_term = torch.exp(
            torch.arange(0.0, d_model, 2, device=self.device)
            * -(math.log(10000.0) / d_model)
        )
        pos_w = torch.arange(0.0, width, device=self.device).unsqueeze(1)
        pos_h = torch.arange(0.0, height, device=self.device).unsqueeze(1)
        self.pe[0:d_model:2, :, :] = (
            torch.sin(pos_w * div_term)  # H d_model/4
            .transpose(0, 1)
            .unsqueeze(1)
            .repeat(1, height, 1)
        )  # d_model/4 H H
        self.pe[1:d_model:2, :, :] = (
            torch.cos(pos_w * div_term)
            .transpose(0, 1)
            .unsqueeze(1)
            .repeat(1, height, 1)
        )
        self.pe[d_model::2, :, :] = (
            torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        )  # d_model/4 W W
        self.pe[d_model + 1 :: 2, :, :] = (
            torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        )
        self.register_buffer("visual_pos_encoding", self.pe)

    @torch.no_grad()
    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)

In [None]:
class VGModel(nn.Module):
    def __init__(
        self,
        cfg: Config,
    ) -> None:
        super().__init__()
        self.cfg: Config = cfg
        embed_dim: int = cfg.model.embed_dim
        mlp_hidden_dim: int = cfg.model.mlp_hidden_dim

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.clip: CLIP = torch.jit.load("../RN50.pt", map_location="cpu").eval()
        self.pretrained_model: CLIP = build_model(self.clip.state_dict()).to(
            self.device
        )
        self.pretrained_model.float()
        del self.clip

        # Freeze all clip parameters except the attention pooling layer
        for param in self.pretrained_model.parameters():
            param.requires_grad = False
        self.pretrained_model.visual.attnpool.requires_grad_(True)

        self.fusion_module: FusionModule = FusionModule(
            embed_dim, cfg.model.clip_embed_dim, cfg.model.proj_img_size
        ).to(self.device)
        self.decoder: Decoder = Decoder(
            embed_dim,
            cfg.model.proj_img_size,
            cfg.model.clip_ctx_length,
            cfg.model.decoder_heads,
            cfg.model.decoder_layers,
            cfg.model.decoder_dim_feedforward,
        ).to(self.device)

        self.reg_head: MLP = MLP(
            input_dim=embed_dim, output_dim=4, hidden_dim_1=mlp_hidden_dim
        ).to(self.device)

    def forward(self, batch: List[BatchSample]) -> Tensor:
        # Get text features
        text_sequence, global_text_features = self.pretrained_model.encode_text(
            torch.stack([sample.caption for sample in batch]).squeeze(1).to(self.device)
        )

        # Get image features
        visual_features = self.pretrained_model.encode_image(
            torch.stack([sample.image for sample in batch]).to(self.device)
        )

        # Fuse features
        fused_visual_features: Tensor = self.fusion_module(
            visual_features, global_text_features
        )

        # Transformer decoder
        reg_token = self.decoder(fused_visual_features, text_sequence)

        # Regression head
        out = self.reg_head(reg_token)
        return out


class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim_1: int) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim_1),
            nn.ReLU(),
            nn.Linear(hidden_dim_1, output_dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.mlp(x)

In [None]:
class Loss:
    def __init__(self, l1: float, l2: float) -> None:
        self.l1_loss = nn.SmoothL1Loss(reduction="mean")
        self.giou_loss = generalized_box_iou_loss
        self.l1: float = l1
        self.l2: float = l2
        self.loss: Tensor

    def compute(self, prediction: Tensor, gt_bbox: Tensor) -> Tensor:
        prediction = box_convert(prediction, in_fmt="xywh", out_fmt="xyxy")
        gt_bbox = box_convert(gt_bbox, in_fmt="xywh", out_fmt="xyxy")
        self.loss = self.l1 * self.l1_loss(
            gt_bbox, prediction
        ) + self.l2 * self.giou_loss(gt_bbox, prediction, reduction="mean")
        return self.loss

    def to_float(self) -> float:
        return self.loss.item()

In [None]:
def train_one_epoch(
    epoch: int,
    dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    model: VGModel,
    loss: Loss,
    optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler.OneCycleLR,
    device: device,
    cfg: Config,
) -> Dict[Metric, float]:
    model.train()
    loss_list: List[Tensor] = []
    iou_list: List[Tensor] = []
    acc_50: List[Tensor] = []
    acc_75: List[Tensor] = []
    acc_90: List[Tensor] = []

    for idx, (batch, bbox) in enumerate(tqdm(dataloader, desc="Batches")):
        optimizer.zero_grad()
        # Move to gpu
        for sample in batch:
            sample = sample.to(device)
        bbox = bbox.to(device)

        # Forward pass
        out: Tensor = model(batch)

        # Loss and metrics
        batch_loss: Tensor = loss.compute(out, bbox)

        # Backward pass
        batch_loss.backward()
        optimizer.step()
        scheduler.step()

        out = box_convert(out, in_fmt="xywh", out_fmt="xyxy").detach()
        bbox = box_convert(bbox, in_fmt="xywh", out_fmt="xyxy").detach()
        batch_iou: Tensor = torch.diagonal(box_iou(out, bbox))

        loss_list.append(batch_loss.detach())
        iou_list.append(batch_iou.mean())
        acc_50.append(accuracy(batch_iou, 0.5))
        acc_75.append(accuracy(batch_iou, 0.75))
        acc_90.append(accuracy(batch_iou, 0.9))

        if (idx * len(batch)) % 4096 == 0:
            report: Dict[str, float] = {
                "Train loss": batch_loss.detach().item(),
                "Train accurracy": batch_iou.mean().item(),
            }
            pprint(f"Batches: {idx}, {report}")

    return {
        Metric.LOSS: torch.stack(loss_list).mean().item(),
        Metric.IOU: torch.stack(iou_list).mean().item(),
        Metric.ACCURACY_50: torch.stack(acc_50).mean().item(),
        Metric.ACCURACY_75: torch.stack(acc_75).mean().item(),
        Metric.ACCURACY_90: torch.stack(acc_90).mean().item(),
    }


@torch.no_grad()
def validate(
    dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    model: VGModel,
    loss: Loss,
    device: torch.device,
) -> Dict[Metric, float]:
    # As accuracy we take the average IoU
    model.eval()
    loss_list: List[Tensor] = []
    iou_list: List[Tensor] = []
    acc_50: List[Tensor] = []
    acc_75: List[Tensor] = []
    acc_90: List[Tensor] = []

    for batch, bbox in tqdm(dataloader, desc="Batches"):
        # Move to gpu
        for sample in batch:
            sample.to(device)
        bbox = bbox.to(device)

        # Forward pass
        out: Tensor = model(batch)

        out = box_convert(out, in_fmt="xywh", out_fmt="xyxy").detach()
        bbox = box_convert(bbox, in_fmt="xywh", out_fmt="xyxy").detach()

        batch_loss: Tensor = loss.compute(out, bbox).detach()
        batch_iou: Tensor = torch.diagonal(box_iou(out, bbox)).detach()

        loss_list.append(batch_loss)
        iou_list.append(batch_iou.mean())
        acc_50.append(accuracy(batch_iou, 0.5))
        acc_75.append(accuracy(batch_iou, 0.75))
        acc_90.append(accuracy(batch_iou, 0.9))

    return {
        Metric.LOSS: torch.stack(loss_list).mean().item(),
        Metric.IOU: torch.stack(iou_list).mean().item(),
        Metric.ACCURACY_50: torch.stack(acc_50).mean().item(),
        Metric.ACCURACY_75: torch.stack(acc_75).mean().item(),
        Metric.ACCURACY_90: torch.stack(acc_90).mean().item(),
    }


def accuracy(iou: Tensor, threshold: float) -> Tensor:
    return torch.tensor(len(iou[iou >= threshold]) / len(iou))

In [None]:
def train(
    train_dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    val_dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    device: torch.device,
    cfg: Config,
) -> Tuple[MetricsLogger, MetricsLogger]:
    train_metrics: MetricsLogger = MetricsLogger()
    val_metrics: MetricsLogger = MetricsLogger()

    # Loss is the weighted sum of the smooth l1 loss and the GIoU
    loss_func = Loss(cfg.train.l1, cfg.train.l2)

    model: VGModel = VGModel(cfg).train()

    # Separate parameters to train
    backbone_params: List[nn.Parameter] = [
        p for p in model.pretrained_model.parameters() if p.requires_grad
    ]

    # All parameters except the backbone
    non_frozen_params: List[nn.Parameter] = []
    non_frozen_params.extend(model.fusion_module.parameters())
    non_frozen_params.extend(model.decoder.parameters())
    non_frozen_params.extend(model.reg_head.parameters())
    print(len(backbone_params), len(non_frozen_params))
    optimizer = optim.AdamW(
        params=[
            {"params": backbone_params, "lr": cfg.train.lr_backbone, "weight_decay": 0},
            {"params": non_frozen_params, "lr": cfg.train.lr, "weight_decay": 1e-4},
        ]
    )
    lr_scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=[cfg.train.lr_backbone, cfg.train.lr],
        epochs=cfg.epochs,
        steps_per_epoch=len(train_dataloader),
    )

    if cfg.logging.wandb:
        wandb.watch(model, loss_func, log="all", log_freq=100, log_graph=True)

    for epoch in tqdm(range(cfg.epochs), desc="Epochs"):
        print("-------------------- Training --------------------------")
        epoch_train_metrics: Dict[Metric, float] = train_one_epoch(
            epoch=epoch,
            dataloader=train_dataloader,
            model=model,
            loss=loss_func,
            optimizer=optimizer,
            scheduler=lr_scheduler,
            device=device,
            cfg=cfg,
        )
        train_metrics.update_metric(epoch_train_metrics)
        print("Training metrics at epoch ", epoch)
        pprint(epoch_train_metrics)

        # Evaluate on validation set for hyperparameter tuning
        print("-------------------- Validation ------------------------")
        epoch_val_metrics: Dict[Metric, float] = validate(
            val_dataloader, model, loss_func, device
        )
        val_metrics.update_metric(epoch_val_metrics)
        print("Validation metrics at epoch ", epoch)
        pprint(epoch_val_metrics)

        # Log metrics to wandb putting train and val metrics together
        if cfg.logging.wandb:
            wandb.log(
                {
                    "Loss": {
                        "train": epoch_train_metrics[Metric.LOSS],
                        "val": epoch_val_metrics[Metric.LOSS],
                    },
                    "Average IOU": {
                        "train": epoch_train_metrics[Metric.IOU],
                        "val": epoch_val_metrics[Metric.IOU],
                    },
                    "Accuracy@50": {
                        "train": epoch_train_metrics[Metric.ACCURACY_50],
                        "val": epoch_val_metrics[Metric.ACCURACY_50],
                    },
                    "Accuracy@75": {
                        "train": epoch_train_metrics[Metric.ACCURACY_75],
                        "val": epoch_val_metrics[Metric.ACCURACY_75],
                    },
                    "Accuracy@90": {
                        "train": epoch_train_metrics[Metric.ACCURACY_90],
                        "val": epoch_val_metrics[Metric.ACCURACY_90],
                    },
                },
                commit=True,
            )

        # Save model after each epoch
        if cfg.logging.save:
            dir: str = cfg.logging.path
            if not os.path.exists(dir):
                os.makedirs(dir)
            torch.save(
                obj={
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "lr_scheduler_state_dict": lr_scheduler.state_dict(),
                    "loss": epoch_train_metrics[Metric.LOSS],
                },
                f=f"{dir}model{epoch}.pth",
            )

        torch.cuda.empty_cache()
        gc.collect()

    return train_metrics, val_metrics


def initialize_run(sweep: bool = True) -> None:
    config = Config()
    if sweep:
        wandb.login(key=os.getenv("WANDB_API_KEY"))
        wandb.init(project="vgproject")
        wandb_cfg = wandb.config
        config.update(wandb_cfg)
    else:
        if config.logging.wandb:
            wandb.login(key=os.getenv("WANDB_API_KEY"))
            wandb.init(project="vgproject", config=config.as_dict())

    train_dataset: VGDataset = VGDataset(
        dir_path=config.dataset_path,
        split=Split.TRAIN,
        output_bbox_type=BboxType.XYWH,
        augment=True,
        preprocessed=True,
    )
    print("Train dataset created. Dataset length ", len(train_dataset))

    val_dataset: VGDataset = VGDataset(
        dir_path=config.dataset_path,
        split=Split.VAL,
        output_bbox_type=BboxType.XYWH,
        augment=False,
        preprocessed=True,
    )
    print("Validation dataset created. Dataset length: ", len(val_dataset))

    train_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=train_dataset,
        batch_size=config.train.batch_size,
        collate_fn=custom_collate,
        num_workers=2,
        shuffle=True,
        drop_last=True,
    )

    val_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=val_dataset,
        batch_size=config.train.batch_size,
        collate_fn=custom_collate,
        shuffle=True,
        drop_last=True,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_metrics, val_metrics = train(train_dataloader, val_dataloader, device, config)

    json.dump(train_metrics.metrics, open("../train_metrics.json", "w"))
    json.dump(val_metrics.metrics, open("../val_metrics.json", "w"))

    if config.logging.wandb:
        wandb.finish()


def main() -> None:
    init_torch()
    cfg = Config()
    if cfg.train.sweep:
        sweep_configuration: Dict[str, Any] = json.load(
            open("../sweep_config.json", "r")
        )
        sweep: str = wandb.sweep(sweep_configuration, project="vgproject")
        wandb.agent(sweep, function=initialize_run, count=10)
    else:
        initialize_run(cfg.train.sweep)

In [None]:
main()