In [None]:
%pip install ftfy regex tqdm spacy ultralytics
%pip install git+https://github.com/openai/CLIP.git

In [1]:
from typing import Any, Dict, List, Tuple, OrderedDict, Callable
from dataclasses import dataclass
from enum import Enum
import pickle
import json
import gdown
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor, tensor, device
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import torchvision.transforms as T
from torchvision.ops import box_iou, box_convert, generalized_box_iou_loss
from torchvision.io import read_image

from clip import clip
from clip.model import CLIP, ModifiedResNet

import numpy as np
from PIL import Image
import spacy
from spacy.tokens import Doc, Span
from tqdm.notebook import tqdm
import optuna
from optuna.trial import Trial
from optuna.visualization import plot_optimization_history

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Download dataset and save under data/raw/ only if not already downloaded
url = "https://drive.google.com/uc?id=1xijq32XfEm6FPhUb7RsZYWHc2UuwVkiq"
if not os.path.exists("data/refcocog.tar.gz"):
    print("Downloading dataset...")
    gdown.download(url=url, output="data/raw/", quiet=False, resume=True)
if not os.path.exists("data/refcocog/"):
    print("Extracting dataset...")
    !tar -xf data/raw/refcocog.tar.gz -C data/raw/ --verbose

Utility classes definition

In [2]:
@dataclass(frozen=True)
class Sample:
    image_path: str
    caption: str
    bounding_box: Tensor


class BatchSample:
    def __init__(self, image: Tensor, caption: Tensor) -> None:
        self.image: Tensor = image
        self.caption: Tensor = caption

    def to(self, device: device | str) -> Any:
        return self.__class__(self.image.to(device), self.caption.to(device))

    def __str__(self) -> str:
        return f"BatchSample(image={self.image.shape}, caption={self.caption.shape})"


@dataclass(frozen=True)
class Split(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"


@dataclass(frozen=True)
class Result:
    bounding_box: Tensor
    score: Tensor


@dataclass(frozen=True)
class BboxType(Enum):
    XYXY = "xyxy"
    XYWH = "xywh"
    CXCWH = "cxcwh"

    def __str__(self) -> str:
        return super().__str__()

In [None]:
class Singleton:
    def __init__(self, decorated_class: Any) -> None:
        self._decorated = decorated_class

    def get_instance(self) -> Any:
        """
        Returns the singleton instance. Upon its first call, it creates a
        new instance of the decorated class and calls its `__init__` method.
        On all subsequent calls, the already created instance is returned.
        """
        try:
            return self._instance  # type: ignore
        except AttributeError:
            self._instance = self._decorated()
            return self._instance

    def __call__(self) -> None:
        raise TypeError("Singletons must be accessed through get_instance() method.")

    def __instancecheck__(self, inst: Any) -> bool:
        return isinstance(inst, self._decorated)


# All configurations are stored ina json file and loaded here only once
# All other times the class is called the same instance is returned


@Singleton
class Config:
    def __init__(self) -> None:
        with open(file="config.json", mode="r") as fp:
            cfg: Dict[str, Any] = json.load(fp=fp)
        for k, v in cfg.items():
            setattr(self, k, v)
        # self.__dict__.update(cfg)

In [None]:
# Each item of the dataset is a tuple of (BatchSample, Tensor) where the tensor is the ground truth bounding box
class VGDataset(Dataset[Tuple[BatchSample, Tensor]]):
    def __init__(
        self,
        dir_path: str,
        split: Split,
        output_bbox_type: BboxType,
        transform_image: Any = None,
        dependencies: bool = False,
    ) -> None:
        super().__init__()
        self.dir_path: str = dir_path
        self.split: Split = split
        self.output_bbox_type: BboxType = output_bbox_type
        self.transform_image = transform_image
        self.text_processor = spacy.load(name="en_core_web_lg")
        self.device: device = torch.device(
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
        self.samples: List[Sample] = self.get_samples(dependencies=dependencies)

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, ref_id: int) -> Tuple[BatchSample, Tensor]:
        caption: Tensor = clip.tokenize(self.samples[ref_id].caption)

        if self.transform_image is not None:
            image_trans, bbox_trans = self.transform_image(
                Image.open(self.samples[ref_id].image_path),
                self.samples[ref_id].bounding_box,
                device=self.device,
            )
            sample_trans: BatchSample = BatchSample(image_trans, caption).to(
                self.device
            )
            return sample_trans, bbox_trans
        else:
            image: Tensor = read_image(self.samples[ref_id].image_path)
            bbox: Tensor = self.samples[ref_id].bounding_box.to(device=self.device)
            sample: BatchSample = BatchSample(image, caption).to(self.device)
            return sample, bbox

    def get_samples(self, dependencies: bool = False) -> List[Sample]:
        with open(self.dir_path + "annotations/instances.json", "r") as inst, open(
            self.dir_path + "annotations/refs(umd).p", "rb"
        ) as refs:
            instances = json.load(inst)
            references = pickle.load(refs)
        samples: List[Sample] = []
        for ref in references:
            if self.split.value == ref["split"]:
                image_path = self.get_image_path(ref["image_id"], instances)
                caption = self.get_caption(ref["sentences"], dependencies)
                bbox = self.get_bounding_box(ref["ann_id"], instances)
                samples.append(Sample(image_path, caption, bbox))
        return samples

    def get_image_path(self, img_id: int, instances: Dict[str, Any]) -> str:
        image_name = next(
            image["file_name"] for image in instances["images"] if image["id"] == img_id
        )
        path = self.dir_path + "images/" + image_name
        return path

    # Longest caption inserted into the prompt
    def get_caption(self, captions: List[Dict[str, Any]], dependencies: bool) -> str:
        longest_caption = captions[0]
        for caption in captions:
            if len(caption["sent"]) > len(longest_caption["sent"]):
                longest_caption = caption
        if dependencies:
            return self.get_relevant_caption(
                doc=self.text_processor(longest_caption["sent"])
            )
        return f"find the region that corresponds to the description {longest_caption['sent']}"

    # Bounding boxed converted to format compatible with yolo or torchvision
    def get_bounding_box(self, ann_id: int, instances: Dict[str, Any]) -> Tensor:
        bbox = next(
            ann["bbox"] for ann in instances["annotations"] if ann["id"] == ann_id
        )
        bounding_box: Tensor = tensor([])
        match self.output_bbox_type:
            case BboxType.XYXY:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYXY.value
                )
            case BboxType.XYWH:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYWH.value
                )
            case BboxType.CXCWH:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.CXCWH.value
                )

        return bounding_box

    # We tried extracting the most relevant part of the caption to use in the baseline implementation
    # However, the results were not better than simply using the whole caption
    def get_relevant_caption(self, doc: Doc) -> str:
        # for chunck in doc.noun_chunks:
        #         return chunck.text

        # Mainly -ing verbs
        for token in doc:
            if "acl" in token.dep_:
                subtree = list(token.subtree)
                end = subtree[0].i
                sent = doc[0:end]
                if len(sent) > 1:
                    return str(sent)

        # subject which/that something
        for token in doc:
            if "relcl" in token.dep_:
                subtree = list(token.subtree)
                end: int = subtree[0].i
                sent: Span = doc[0:end]
                if len(sent) > 1:
                    return str(sent)

        # Subjects
        for token in doc:
            if "subj" in token.dep_:
                subtree = list(token.subtree)
                start: int = subtree[0].i
                end = subtree[-1].i + 1
                sent_subj: Span = doc[start:end]
                if len(sent_subj) > 1:
                    return str(sent_subj)
        return str(doc)

Dataset / Dataloader utility functions

In [None]:
# Function to create the batches
# Each batch is a tuple with a list of BatchSample and a tensor with the bounding boxes
# (since the tensors have all the same size we can stack them in a tensor instead of a list)
def custom_collate(
    batch: List[Tuple[BatchSample, torch.Tensor]]
) -> Tuple[List[BatchSample], torch.Tensor]:
    bboxes: List[torch.Tensor] = []
    samples: List[BatchSample] = []
    for sample, bbox in batch:
        samples.append(BatchSample(sample.image, sample.caption))
        bboxes.append(bbox)
    return samples, torch.stack(bboxes)


# Function to resize the image and the corresponding bounding box correctly
# Assumes bounding box is in format xyxy format


def transform_sample(
    image: Image.Image,
    box: Tensor,
    target_size: int = 224,
    device: device = torch.device("cpu"),
) -> Tuple[Tensor, Tensor]:
    x, y = image.size[0], image.size[1]

    x_scale = target_size / x
    y_scale = target_size / y

    trans = T.Compose(
        transforms=[
            T.Resize((target_size, target_size)),
            T.CenterCrop(target_size),
            T.PILToTensor(),
        ]
    )
    image_tensor: Tensor = trans(image).to(device)  # type: ignore

    xmin, ymin, xmax, ymax = box.squeeze(0)

    xmin = np.round(xmin * x_scale)
    ymin = np.round(ymin * y_scale)
    xmax = np.round(xmax * x_scale)
    ymax = np.round(ymax * y_scale)

    bbox_tensor: Tensor = torch.tensor([xmin, ymin, xmax, ymax], device=device)
    return image_tensor, bbox_tensor

In [None]:
# Simple forward pass on the pretrained CLIP text encoder
class TextEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pretrained_model: CLIP = clip.load("RN50", device=self.device)[0]

    @torch.no_grad()
    def forward(self, tokenized_caption: Tensor) -> Tensor:
        tokenized_caption = tokenized_caption.int()
        out: Tensor = self.pretrained_model.encode_text(
            tokenized_caption.to(torch.IntTensor())
        )
        return out

In [None]:
# CLIP backbone is a modified ResNet with an attention layer for global pooling
# We get all layers as outputs and project them to the same dimension as the text encodings
class VisualEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pretrained_model: ModifiedResNet = clip.load("RN50", device=self.device)[
            0
        ].visual  # type: ignore
        assert isinstance(self.pretrained_model, ModifiedResNet)

        # Freeze the backbone
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        # Register hooks to get the output of all layers
        self.layers_outputs: OrderedDict[str, Tensor] = OrderedDict()
        self.pretrained_model.layer1.register_forward_hook(self.hook_fn("layer1"))  # type: ignore
        self.pretrained_model.layer2.register_forward_hook(self.hook_fn("layer2"))  # type: ignore
        self.pretrained_model.layer3.register_forward_hook(self.hook_fn("layer3"))  # type: ignore
        self.pretrained_model.layer4.register_forward_hook(self.hook_fn("layer4"))  # type: ignore

        # Project the output of each layer to the same dimensionality as the text features
        cfg = Config.get_instance().visual_encoder  # type: ignore
        resnet_resolution = cfg["resnet_resolution"]

        # Each layer is projected through a different sequence of layers since the initial dimensionality is different
        self.layers_projections: List[nn.Sequential] = []
        for _ in range(4):
            resnet_resolution //= 2
            layer_projection: nn.Sequential = nn.Sequential(
                nn.AdaptiveAvgPool2d(resnet_resolution),
                nn.Flatten(start_dim=1),
                nn.LazyLinear(cfg["output_dim"], device=self.device),
                nn.ReLU(),
            )
            self.layers_projections.append(layer_projection)

    def forward(self, batch: Tensor) -> OrderedDict[str, Tensor]:
        # Reset the dictionary
        self.layers_outputs = OrderedDict()

        with torch.no_grad():
            out: Tensor = self.pretrained_model(batch)

        for idx, (layer_name, layer_output) in enumerate(self.layers_outputs.items()):
            self.layers_outputs[layer_name] = self.layers_projections[idx](layer_output)
        self.layers_outputs["output"] = out

        return self.layers_outputs

    def hook_fn(self, layer: str) -> Callable[[nn.Module, Tensor, Tensor], None]:
        def hook(module: nn.Module, input: Tensor, output: Tensor) -> None:
            self.layers_outputs[layer] = output

        return hook

In [None]:
# The model is composed of the visual encoder and the text encoder described above
# The attention layer is used to compute the attention between the the text and each layer of visual features
# The attentions are then concatenated and passed through a regression head to predict the bounding box
class VGModel(nn.Module):
    def __init__(
        self,
        mlp_hidden_dim_1: int,
        mlp_hidden_dim_2: int,
    ) -> None:
        super().__init__()
        cfg = Config.get_instance().model  # type: ignore
        emb_dim = cfg["emb_dim"]
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.visual_backbone: VisualEncoder = VisualEncoder().to(self.device)
        self.text_encoder: TextEncoder = TextEncoder().to(self.device)
        self.attention_layer = nn.MultiheadAttention(
            embed_dim=emb_dim,
            num_heads=4,
            batch_first=True,
            device=self.device,
        )
        self.reg_head: MLP = MLP(
            emb_dim * 5, 4, hidden_dim_1=mlp_hidden_dim_1, hidden_dim_2=mlp_hidden_dim_2
        ).to(self.device)

    def forward(self, batch: List[BatchSample]) -> Tensor:
        captions: Tensor = torch.stack([sample.caption for sample in batch]).squeeze(1)
        text_features: Tensor = self.text_encoder(captions)

        images: Tensor = torch.stack([sample.image for sample in batch])
        visual_features: OrderedDict[str, Tensor] = self.visual_backbone(images)

        attended_features: List[Tensor] = []
        for feature in visual_features.values():
            attention: Tensor = self.attention_layer(
                feature, text_features, text_features
            )
            attended_features.append(attention[0])

        aggregated_features: Tensor = torch.cat(attended_features, dim=1)
        return self.reg_head(aggregated_features)

# Simple 2-layer MLP
class MLP(nn.Module):
    def __init__(
        self, input_dim: int, output_dim: int, hidden_dim_1: int, hidden_dim_2: int
    ) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim_1),
            nn.ReLU(),
            nn.Linear(hidden_dim_1, hidden_dim_2),
            nn.ReLU(),
            nn.Linear(hidden_dim_2, output_dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.mlp(x)


In [None]:
class Loss:
    def __init__(self, l: float) -> None:
        self.l1_loss = nn.L1Loss(reduction="mean")
        self.giou_loss = generalized_box_iou_loss
        self.l = l
        self.loss: Tensor

    def compute(self, out: Tensor, bbox: Tensor) -> Tensor:
        self.loss = self.giou_loss(out, bbox, reduction="mean") + self.l * self.l1_loss(
            out, bbox
        )
        return self.loss

    def to_float(self) -> float:
        return self.loss.detach().cpu().item()


Main training code

In [None]:
def objective(trial: Trial) -> float:
    config = Config.get_instance()  # type: ignore
    train_dataset: VGDataset = VGDataset(
        dir_path=config.dataset_path,
        split=Split.TRAIN,
        output_bbox_type=BboxType.XYXY,
        transform_image=transform_sample,
    )

    val_dataset: VGDataset = VGDataset(
        dir_path=config.dataset_path,
        split=Split.VAL,
        output_bbox_type=BboxType.XYXY,
        transform_image=transform_sample,
    )

    batch_size = trial.suggest_int(
        "batch_size",
        1,
        10,
    )
    train_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        collate_fn=custom_collate,
        shuffle=True,
        drop_last=True,
    )

    val_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        collate_fn=custom_collate,
        shuffle=True,
        drop_last=True,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Loss is the weighted sum of the smooth l1 loss and the GIoU
    l = trial.suggest_float("l", 0.0, 1.0)
    loss = Loss(l)
    losses: List[float] = []
    accuracies: List[float] = []

    hidden_dim_1 = trial.suggest_int("hidden_dim_1", 512, 2048)
    hidden_dim_2 = trial.suggest_int("hidden_dim_2", 128, 512)
    if config.logging["resume"]:
        checkpoint: Dict[str, Any] = torch.load(config.logging["path"] + "model.pth")
        model = VGModel(hidden_dim_1, hidden_dim_2).to(device)
        model.load_state_dict(checkpoint["model_state_dict"])
        lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
        optimizer = optim.AdamW(model.parameters(), lr=lr)
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=config.model["gamma"]
        )
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
        start_epoch: int = checkpoint["epoch"]
        losses.append(checkpoint["loss"])
    else:
        model = VGModel(hidden_dim_1, hidden_dim_2).train()
        lr = trial.suggest_float("lr", 1e-5, 1e-2)
        optimizer = optim.AdamW(model.parameters(), lr=lr)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=config.model["gamma"]
        )
        start_epoch = 0

    for epoch in tqdm(range(start_epoch, config.epochs), desc="Epochs"):
        print("-------------------- Training --------------------------")
        loss_epoch = train_one_epoch(train_dataloader, model, loss, optimizer)
        losses.append(loss_epoch.cpu().item())
        lr_scheduler.step()

        # Evaluate on validation set for hyperparameter tuning
        print("-------------------- Validation ------------------------")
        accuracy = validate(val_dataloader, model)
        accuracies.append(accuracy)
        trial.report(accuracy, epoch)

        # Early stopping for non promising trials
        if trial.should_prune():
            raise optuna.TrialPruned()

        # Save model after each epoch
        if config.logging["save_model"]:
            dir: str = config.logging["path"]
            if not os.path.exists(dir):
                os.makedirs(dir)

            torch.save(
                obj={
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "lr_scheduler_state_dict": lr_scheduler.state_dict(),
                    "loss": loss_epoch,
                },
                f=f"{config.logging['path']}model.pth",
            )

        torch.clear_autocast_cache()

    return sum(accuracies) / len(accuracies)


def train_one_epoch(
    dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    model: VGModel,
    loss: Loss,
    optimizer: torch.optim.Optimizer,
) -> Tensor:
    # As loss we take smooth_l1 + GIoU
    train_loss: List[Tensor] = []

    for batch, bbox in tqdm(dataloader, desc="Batches"):
        # Forward pass
        out: Tensor = model(batch)

        # Loss and metrics
        batch_loss: Tensor = loss.compute(out, bbox)
        train_loss.append(batch_loss)

        # Backward pass
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    return torch.stack(train_loss).mean()


@torch.no_grad()
def validate(
    dataloader: DataLoader[Tuple[BatchSample, Tensor]], model: VGModel
) -> float:
    # As accuracy we take the average IoU
    accuracy: List[Tensor] = []
    for batch, bbox in tqdm(dataloader, desc="Batches"):
        # Forward pass
        out: Tensor = model(batch)

        accuracy.append(torch.diagonal(box_iou(out, bbox)).mean())

    return torch.stack(accuracy).mean().cpu().item()


def main() -> None:
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
    study_name = "test"
    study = optuna.create_study(
        study_name=study_name,
        storage=f"sqlite:///{study_name}",
        direction="maximize",
        load_if_exists=True,
    )
    study.optimize(objective, n_trials=3, timeout=600)

    trial = study.best_trial
    print(f"Best hyperparameters: {trial.params}")
    fig = plot_optimization_history(study)
    fig.show()

In [None]:
main()