# Deep Learning Project

In [None]:
%pip install ftfy regex tqdm ultralytics optuna albumentations torchviz
%pip install git+https://github.com/openai/CLIP.git

In [None]:
import json
import os
import pickle
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, OrderedDict, Tuple

import albumentations as A
import cv2
import gdown
import matplotlib.pyplot as plt
import numpy as np
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from albumentations.pytorch import ToTensorV2
from clip import clip
from clip.model import CLIP, ModifiedResNet
from optuna.trial import Trial
from optuna.visualization import plot_optimization_history
from PIL import Image
from torch import Tensor, device, tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.ops import box_convert, box_iou, generalized_box_iou_loss
from tqdm.notebook import tqdm

In [None]:
# Download dataset and save under data/raw/ only if not already downloaded
%cd /content/
url = "https://drive.google.com/uc?id=1xijq32XfEm6FPhUb7RsZYWHc2UuwVkiq"
if not os.path.exists("data/raw/refcocog.tar.gz"):
    print("Downloading dataset...")
    gdown.download(url=url, output="data/raw/", quiet=False, resume=True)
if not os.path.exists("data/raw/refcocog/"):
    print("Extracting dataset...")
    !tar -xf data/raw/refcocog.tar.gz -C data/raw

The preprocessed samples can be downloaded from Google Drive by executing the following cell. Otherwise, the preprocessing wil be done saving the file only temporarly in the Colab environment. 

In [None]:
# Download preprocessed dataset
url = "https://drive.google.com/drive/folders/1jaJV40dneOckZn7WHMQyd2jBh7A8534N"
gdown.download_folder(url=url, output="data/", quiet=False)

In [None]:
# Copy configuration file from repository
%cd /content/
!wget https://raw.githubusercontent.com/ManuelaCorte/DLProject/master/config.json

In [None]:
# Move forlder simply for consistentcy with repository implementation
%mkdir src
%cd src

Utility classes definition

In [None]:
class Sample:
    def __init__(self, image_path: str, caption: str, bounding_box: Tensor) -> None:
        self.image_path = image_path
        self.caption = caption
        self.bounding_box = bounding_box

    def as_dict(self) -> dict[str, Any]:
        return {
            "image_path": self.image_path,
            "caption": self.caption,
            "bounding_box": self.bounding_box.tolist(),
        }

    @staticmethod
    def fromJSON(json: dict[str, Any]) -> Any:
        return Sample(json["image_path"], json["caption"], Tensor(json["bounding_box"]))


class BatchSample:
    def __init__(self, image: Tensor, caption: Tensor) -> None:
        self.image: Tensor = image
        self.caption: Tensor = caption

    def to(self, device: device | str) -> Any:
        return self.__class__(self.image.to(device), self.caption.to(device))

    def __str__(self) -> str:
        return f"BatchSample(image={self.image.shape}, caption={self.caption.shape})"


@dataclass(frozen=True)
class Split(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"


# Used in the baseline implementation
@dataclass(frozen=True)
class Result:
    bounding_box: Tensor
    score: Tensor


# XYXY: top left and bottom right corners
# XYWH: top left corner, width and height
# CXCWH: center coordinates, width and height
@dataclass(frozen=True)
class BboxType(Enum):
    XYXY = "xyxy"
    XYWH = "xywh"
    CXCWH = "cxcwh"

    def __str__(self) -> str:
        return super().__str__()

In [None]:
class Singleton:
    def __init__(self, decorated_class: Any) -> None:
        self._decorated = decorated_class

    def get_instance(self) -> Any:
        """
        Returns the singleton instance. Upon its first call, it creates a
        new instance of the decorated class and calls its `__init__` method.
        On all subsequent calls, the already created instance is returned.
        """
        try:
            return self._instance  # type: ignore
        except AttributeError:
            self._instance = self._decorated()
            return self._instance

    def __call__(self) -> None:
        raise TypeError("Singletons must be accessed through get_instance() method.")

    def __instancecheck__(self, inst: Any) -> bool:
        return isinstance(inst, self._decorated)


# All configurations are stored in a json file and loaded here only once
# All other times the class is called the same instance is returned
@Singleton
class Config:
    def __init__(self) -> None:
        with open(file="../config.json", mode="r") as fp:
            cfg: Dict[str, Any] = json.load(fp=fp)
        for k, v in cfg.items():
            setattr(self, k, v)
        # self.__dict__.update(cfg)

Dataset preprocessing

In [None]:
def get_samples(dir_path: str) -> Tuple[List[Sample], List[Sample], List[Sample]]:
    with open(dir_path + "annotations/instances.json", "r") as inst, open(
        dir_path + "annotations/refs(umd).p", "rb"
    ) as refs:
        instances = json.load(inst)
        references = pickle.load(refs)
    train_samples: List[Sample] = []
    val_samples: List[Sample] = []
    test_samples: List[Sample] = []
    for ref in tqdm(references, desc=f"Processing dataset"):
        image_path = get_image_path(dir_path, ref["image_id"], instances)
        caption = get_caption(ref["sentences"])
        bbox = get_bounding_box(ref["ann_id"], instances)
        split = ref["split"]
        # print(split)
        match split:
            case Split.TRAIN.value:
                train_samples.append(Sample(image_path, caption, bbox))
            case Split.VAL.value:
                val_samples.append(Sample(image_path, caption, bbox))
            case Split.TEST.value:
                test_samples.append(Sample(image_path, caption, bbox))
            case _:
                raise ValueError(f"Invalid split: {split}")
    return train_samples, val_samples, test_samples


def get_image_path(dir_path: str, img_id: int, instances: Dict[str, Any]) -> str:
    image_name = next(
        image["file_name"] for image in instances["images"] if image["id"] == img_id
    )
    path = dir_path + "images/" + image_name
    return path


def get_caption(captions: List[Dict[str, Any]]) -> str:
    longest_caption = captions[0]
    for caption in captions:
        if len(caption["sent"]) > len(longest_caption["sent"]):
            longest_caption = caption
    return longest_caption["sent"]


# Bounding boxed converted to format compatible with yolo or torchvision
def get_bounding_box(ann_id: int, instances: Dict[str, Any]) -> Tensor:
    bbox = next(ann["bbox"] for ann in instances["annotations"] if ann["id"] == ann_id)
    bounding_box: Tensor = tensor([])
    bounding_box = box_convert(
        tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYXY.value
    )
    return bounding_box


# If the files already exist, don't preprocess again
def preprocess(in_path: str, out_path: str) -> None:
    if (
        os.path.exists(f"{out_path}train_samples.json")
        and os.path.exists(f"{out_path}val_samples.json")
        and os.path.exists(f"{out_path}test_samples.json")
    ):
        return
    train_samples, val_samples, test_samples = get_samples(in_path)

    json.dump(
        train_samples,
        open(f"{out_path}train_samples.json", "w"),
        default=Sample.as_dict,
    )

    json.dump(
        val_samples,
        open(f"{out_path}val_samples.json", "w"),
        default=Sample.as_dict,
    )

    json.dump(
        test_samples,
        open(f"{out_path}test_samples.json", "w"),
        default=Sample.as_dict,
    )

In [None]:
# The Dataset contains samples with an image with a bounding box and a caption associated with the bounding box.


class VGDataset(Dataset[Tuple[BatchSample, Tensor]]):
    def __init__(
        self,
        dir_path: str,
        split: Split,
        output_bbox_type: BboxType,
        augment: bool,
        transform: bool = True,
        preprocessed: bool = False,
        preprocessed_path: str = "../data/processed/",
    ) -> None:
        super().__init__()
        self.dir_path: str = dir_path
        self.split: Split = split
        self.output_bbox_type: BboxType = output_bbox_type
        self.augment: bool = augment
        self.transform: bool = transform
        self.device: device = torch.device(
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
        if preprocessed:
            preprocess(dir_path, preprocessed_path)
            with open(
                preprocessed_path + f"{self.split.value}_samples.json", "rb"
            ) as samples:
                self.samples: List[Sample] = json.load(
                    samples, object_hook=Sample.fromJSON
                )
        else:
            self.samples: List[Sample] = self.get_samples()  # type: ignore

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, ref_id: int) -> Tuple[BatchSample, Tensor]:
        # extended_caption = f"find the region that corresponds to the description {caption}"
        caption: Tensor = clip.tokenize(self.samples[ref_id].caption)  # type: ignore
        if self.transform:
            image, bbox = transform_sample(
                Image.open(self.samples[ref_id].image_path),
                self.samples[ref_id].bounding_box,
                self.augment,
                device=self.device,
            )
        else:
            image = read_image(self.samples[ref_id].image_path)
            bbox = torch.tensor([self.samples[ref_id].bounding_box])
        return BatchSample(image, caption), bbox

    def get_samples(self) -> List[Sample]:
        with open(self.dir_path + "annotations/instances.json", "r") as inst, open(
            self.dir_path + "annotations/refs(umd).p", "rb"
        ) as refs:
            instances = json.load(inst)
            references = pickle.load(refs)
        samples: List[Sample] = []
        for ref in references:
            if self.split.value == ref["split"]:
                image_path = self.get_image_path(ref["image_id"], instances)
                caption = self.get_caption(ref["sentences"])
                bbox = self.get_bounding_box(ref["ann_id"], instances)
                samples.append(Sample(image_path, caption, bbox))
        return samples

    def get_image_path(self, img_id: int, instances: Dict[str, Any]) -> str:
        image_name = next(
            image["file_name"] for image in instances["images"] if image["id"] == img_id
        )
        path = self.dir_path + "images/" + image_name
        return path

    def get_caption(self, captions: List[Dict[str, Any]]) -> str:
        longest_caption = captions[0]
        for caption in captions:
            if len(caption["sent"]) > len(longest_caption["sent"]):
                longest_caption = caption
        return f"find the region that corresponds to the description {longest_caption['sent']}"

    # Bounding boxed converted to format compatible with yolo or torchvision
    def get_bounding_box(self, ann_id: int, instances: Dict[str, Any]) -> Tensor:
        bbox = next(
            ann["bbox"] for ann in instances["annotations"] if ann["id"] == ann_id
        )
        bounding_box: Tensor = tensor([])
        match self.output_bbox_type:
            case BboxType.XYXY:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYXY.value
                )
            case BboxType.XYWH:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.XYWH.value
                )
            case BboxType.CXCWH:
                bounding_box = box_convert(
                    tensor([bbox]), in_fmt="xywh", out_fmt=BboxType.CXCWH.value
                )

        return bounding_box

Dataset / Dataloader utility functions

In [None]:
def custom_collate(
    batch: List[Tuple[BatchSample, torch.Tensor]]
) -> Tuple[List[BatchSample], torch.Tensor]:
    bboxes: List[torch.Tensor] = []
    samples: List[BatchSample] = []
    for sample, bbox in batch:
        samples.append(BatchSample(sample.image, sample.caption))
        bboxes.append(bbox)
    return samples, torch.stack(bboxes)


# Transform image according to CLIP preprocess function
# Normalize bounding box coordinates to be independent of image size
def transform_sample(
    image: Image.Image,
    box: Tensor,
    augment: bool,
    device: device,
    target_size: int = 224,
) -> Tuple[Tensor, Tensor]:
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Same transformation as in the CLIP preprocess function
    if augment:
        trans = A.Compose(
            transforms=[
                A.Resize(target_size, target_size, interpolation=cv2.INTER_CUBIC, p=1),
                A.CenterCrop(
                    target_size,
                    target_size,
                    always_apply=True,
                ),
                A.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                    max_pixel_value=255.0,
                    always_apply=True,
                ),
                A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0),
                A.GaussianBlur(p=1),
                A.PixelDropout(dropout_prob=0.02),
                A.Rotate(limit=20),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format="pascal_voc", label_fields=[]),
        )
    else:
        trans = A.Compose(
            transforms=[
                A.Resize(target_size, target_size, interpolation=cv2.INTER_CUBIC, p=1),
                A.CenterCrop(
                    target_size,
                    target_size,
                    always_apply=True,
                ),
                A.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                    max_pixel_value=255.0,
                ),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format="pascal_voc", label_fields=[]),
        )

    transformed_sample: Dict[str, Any] = trans(
        image=np.array(image), bboxes=box.tolist()
    )

    bbox_tensor: Tensor = (
        torch.tensor(transformed_sample["bboxes"][0], requires_grad=True) / target_size
    )
    # print(bbox_tensor)
    return transformed_sample["image"], bbox_tensor.to(torch.float32)

In [None]:
# Simple forward pass on the pretrained CLIP text encoder


class TextEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pretrained_model: CLIP = clip.load("RN50", device=self.device)[0]
        self.pretrained_model.float()

        # Freeze the backbone
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def forward(self, tokenized_caption: Tensor) -> Tensor:
        out: Tensor = self.pretrained_model.encode_text(tokenized_caption).to(
            self.device
        )
        # .unsqueeze(1)
        return out.requires_grad_(True)

In [None]:
# Class that gets output for all layes of the backbone
# CLIP backbone is a modified ResNet with an attention layer for global pooling


class VisualEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pretrained_model: ModifiedResNet = clip.load("RN50", device=self.device)[
            0
        ].visual  # type: ignore
        self.pretrained_model.float()
        assert isinstance(self.pretrained_model, ModifiedResNet)

        # Freeze the backbone
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        # Register hooks to get the output of all layers
        self.layers_outputs: OrderedDict[str, Tensor] = OrderedDict()
        self.pretrained_model.layer1.register_forward_hook(self.hook_fn("layer1"))  # type: ignore
        self.pretrained_model.layer2.register_forward_hook(self.hook_fn("layer2"))  # type: ignore
        self.pretrained_model.layer3.register_forward_hook(self.hook_fn("layer3"))  # type: ignore
        self.pretrained_model.layer4.register_forward_hook(self.hook_fn("layer4"))  # type: ignore

        # Project the output of each layer to the same dimensionality as the text features
        cfg = Config.get_instance().visual_encoder  # type: ignore
        resnet_resolution: int = cfg["resnet_resolution"]
        resnet_channels: int = cfg["resnet_channels"]

        self.layers_projections = nn.ModuleList()
        for _ in range(4):
            resnet_resolution //= 2
            in_features: int = resnet_channels * resnet_resolution * resnet_resolution
            resnet_channels *= 2
            layer_projection: nn.Sequential = nn.Sequential(
                nn.AdaptiveAvgPool2d(resnet_resolution),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, cfg["output_dim"], device=self.device),
                # nn.LayerNorm(
                #     cfg["output_dim"], eps=1e-3, device=self.device
                # ),
                # nn.ReLU(),
            )
            self.layers_projections.append(layer_projection)

    @torch.no_grad()
    def forward(self, batch: Tensor) -> OrderedDict[str, Tensor]:
        # Reset the dictionary
        self.layers_outputs = OrderedDict()

        out: Tensor = self.pretrained_model(batch)
        # .unsqueeze(1)

        for idx, (layer_name, layer_output) in enumerate(self.layers_outputs.items()):
            self.layers_outputs[layer_name] = self.layers_projections[idx](layer_output)
            # .unsqueeze(1)
        self.layers_outputs["output"] = out

        return self.layers_outputs

    def hook_fn(self, layer: str) -> Callable[[nn.Module, Tensor, Tensor], None]:
        def hook(module: nn.Module, input: Tensor, output: Tensor) -> None:
            # print(f"Module: {[module for  module in module.modules()]}")
            self.layers_outputs[layer] = output.requires_grad_(True)

        return hook

In [None]:
# The model is composed of the visual encoder and the text encoder described above
# The attention layer is used to compute the attention between the the text and each layer of visual features


class VGModel(nn.Module):
    def __init__(
        self,
        mlp_hidden_dim_1: int,
        mlp_hidden_dim_2: int,
    ) -> None:
        super().__init__()
        cfg = Config.get_instance().model  # type: ignore
        emb_dim: int = cfg["emb_dim"]
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.visual_backbone: VisualEncoder = VisualEncoder().to(self.device)
        self.text_encoder: TextEncoder = TextEncoder().to(self.device)
        self.attention_layers: nn.ModuleList = nn.ModuleList(
            [
                nn.MultiheadAttention(
                    embed_dim=emb_dim,
                    num_heads=4,
                    batch_first=True,
                    device=self.device,
                )
                for _ in range(5)
            ]
        )
        self.reg_head: MLP = MLP(
            emb_dim * 5, 4, hidden_dim_1=mlp_hidden_dim_1, hidden_dim_2=mlp_hidden_dim_2
        ).to(self.device)

    def forward(self, batch: List[BatchSample]) -> Tensor:
        captions: Tensor = (
            torch.stack([sample.caption for sample in batch]).squeeze(1).to(self.device)
        )
        text_features: Tensor = self.text_encoder(captions).unsqueeze(1)

        images: Tensor = torch.stack([sample.image for sample in batch]).to(self.device)
        visual_features: OrderedDict[str, Tensor] = self.visual_backbone(images)

        attended_features: List[Tensor] = []
        for i, visual_feature in enumerate(visual_features.values()):
            visual_feature: Tensor = visual_feature.unsqueeze(1)
            # print(visual_feature.shape, text_features.shape)
            attended_feature: Tensor = self.attention_layers[i](
                query=text_features, key=visual_feature, value=visual_feature
            )[0].squeeze(1)
            attended_features.append(attended_feature)

        aggregated_features: Tensor = torch.cat(attended_features, dim=1)

        return self.reg_head(aggregated_features)


class MLP(nn.Module):
    def __init__(
        self, input_dim: int, output_dim: int, hidden_dim_1: int, hidden_dim_2: int
    ) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim_1),
            nn.ReLU(),
            nn.Linear(hidden_dim_1, hidden_dim_2),
            nn.ReLU(),
            nn.Linear(hidden_dim_2, output_dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.mlp(x)

In [None]:
class Loss:
    def __init__(self, l: float) -> None:
        self.l1_loss = nn.SmoothL1Loss(reduction="mean")
        self.giou_loss = generalized_box_iou_loss
        self.l: float = l
        self.loss: Tensor

    def compute(self, out: Tensor, bbox: Tensor) -> Tensor:
        self.loss = self.giou_loss(out, bbox, reduction="mean") + self.l * self.l1_loss(
            out, bbox
        )
        return self.loss

    def to_float(self) -> float:
        return self.loss.item()

Visualization utilities

In [None]:
def visualize(samples: List[Sample], predictions: Tensor) -> None:
    ncols = 3
    nrows = int(len(samples) / ncols)
    print(nrows, ncols)
    _, ax = plt.subplots(nrows, ncols, figsize=(24, 24))
    for i, sample in enumerate(samples):
        img: Tensor = read_image(sample.image_path)
        # bboxes: Tensor = torch.stack(
        #     [
        #         unnormalize_bbox(img, sample.bounding_box),
        #         unnormalize_bbox(img, predictions[i]),
        #     ]
        # )
        bboxes: Tensor = torch.stack(
            [
                sample.bounding_box,
                predictions[i],
            ]
        ).squeeze(1)
        result: Tensor = draw_bounding_boxes(img, bboxes, width=2, colors=(255, 0, 0))
        ax[i // ncols, i % ncols].imshow(result.permute(1, 2, 0))
        ax[i // ncols, i % ncols].set_title(sample.caption)
        ax[i // ncols, i % ncols].axis("off")
    plt.axis("off")
    plt.show()


def unnormalize_bbox(image: Tensor, bbox: Tensor) -> Tensor:
    x: int
    y: int
    y, x = image.shape[1], image.shape[2]
    xmin, ymin, xmax, ymax = bbox.squeeze(0)
    xmin_unnorm: float = xmin.item() * x
    ymin_unnorm: float = ymin.item() * y
    xmax_unnorm: float = xmax.item() * x
    ymax_unnorm: float = ymax.item() * y
    return torch.tensor([xmin_unnorm, ymin_unnorm, xmax_unnorm, ymax_unnorm])


def visualize_network(model: torch.nn.Module, batch: List[BatchSample]) -> None:
    output: Tensor = model(batch)
    make_dot(
        output.mean(), params=dict(model.named_parameters()), show_attrs=True
    ).render("model_graph", directory="../runs", format="png")

Main training code

In [None]:
def train_one_epoch(
    dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    model: VGModel,
    loss: Loss,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
) -> Tensor:
    # As loss we take smooth_l1 + GIoU
    epoch_loss_list: List[Tensor] = []

    for batch, bbox in tqdm(dataloader, desc="Batches"):
        # Move to gpu
        for sample in batch:
            sample.to(device)
        bbox = bbox.to(device)

        # Forward pass
        out: Tensor = model(batch)

        # Loss and metrics
        batch_loss: Tensor = loss.compute(out, bbox)
        epoch_loss_list.append(batch_loss)

        # Backward pass
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    return torch.stack(epoch_loss_list).mean()


@torch.no_grad()
def validate(
    dataloader: DataLoader[Tuple[BatchSample, Tensor]],
    model: VGModel,
    device: torch.device,
) -> float:
    # As accuracy we take the average IoU
    model.eval()
    accuracy_list: List[Tensor] = []
    for batch, bbox in tqdm(dataloader, desc="Batches"):
        # Move to gpu
        for sample in batch:
            sample.to(device)
        bbox = bbox.to(device)

        # Forward pass
        out: Tensor = model(batch)

        accuracy_list.append(torch.diagonal(box_iou(out, bbox)).mean())

    return torch.stack(accuracy_list).mean().item()

In [None]:
def objective(trial: Trial) -> float:
    cfg = Config.get_instance()  # type: ignore
    train_dataset: VGDataset = VGDataset(
        dir_path=cfg.dataset_path,
        split=Split.TRAIN,
        output_bbox_type=BboxType.XYXY,
        augment=True,
        preprocessed=True,
    )
    print("Train dataset created. Dataset length ", len(train_dataset))

    val_dataset: VGDataset = VGDataset(
        dir_path=cfg.dataset_path,
        split=Split.VAL,
        output_bbox_type=BboxType.XYXY,
        augment=False,
        preprocessed=True,
    )
    print("Validation dataset created. Dataset length: ", len(val_dataset))

    batch_size = trial.suggest_int(
        "batch_size",
        1,
        10,
    )
    train_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        collate_fn=custom_collate,
        shuffle=True,
        drop_last=True,
    )

    val_dataloader: DataLoader[Tuple[BatchSample, Tensor]] = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        collate_fn=custom_collate,
        shuffle=True,
        drop_last=True,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Loss is the weighted sum of the smooth l1 loss and the GIoU
    l = trial.suggest_float("l", 0.0, 1.0)
    loss_func = Loss(l)
    losses_list: List[float] = []
    accuracies_list: List[float] = []

    hidden_dim_1 = trial.suggest_int("hidden_dim_1", 512, 2048)
    hidden_dim_2 = trial.suggest_int("hidden_dim_2", 128, 512)
    if cfg.logging["resume"]:
        checkpoint: Dict[str, Any] = torch.load(cfg.logging["path"] + "model.pth")
        model = VGModel(hidden_dim_1, hidden_dim_2).to(device)
        model.load_state_dict(checkpoint["model_state_dict"])
        lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
        optimizer = optim.AdamW(model.parameters(), lr=lr)
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=cfg.model["gamma"]
        )
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
        start_epoch: int = checkpoint["epoch"]
        losses_list.append(checkpoint["loss"])
    else:
        model = VGModel(hidden_dim_1, hidden_dim_2).train()
        lr = trial.suggest_float("lr", 1e-5, 1e-2)
        optimizer = optim.AdamW(model.parameters(), lr=lr)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=cfg.model["gamma"]
        )
        start_epoch = 0

    for epoch in tqdm(range(start_epoch, cfg.epochs), desc="Epochs"):
        print("-------------------- Training --------------------------")
        epoch_loss = train_one_epoch(
            train_dataloader, model, loss_func, optimizer, device
        )
        losses_list.append(epoch_loss.item())
        lr_scheduler.step()

        # Evaluate on validation set for hyperparameter tuning
        print("-------------------- Validation ------------------------")
        accuracy = validate(val_dataloader, model, device)
        accuracies_list.append(accuracy)
        trial.report(accuracy, epoch)
        print(f"Accuracy: {accuracy} at epoch {epoch}")

        # Early stopping for non promising trials
        if trial.should_prune():
            raise optuna.TrialPruned()

        # Save model after each epoch
        if cfg.logging["save_model"]:
            dir: str = cfg.logging["path"]
            if not os.path.exists(dir):
                os.makedirs(dir)

            torch.save(
                obj={
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "lr_scheduler_state_dict": lr_scheduler.state_dict(),
                    "loss": epoch_loss,
                },
                f=f"{cfg.logging['path']}model.pth",
            )

        torch.cuda.empty_cache()
        gc.collect()

    return sum(accuracies_list) / len(accuracies_list)


def main() -> None:
    # optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
    study_name = "train.db"
    study = optuna.create_study(
        study_name=study_name,
        storage=f"sqlite:///{study_name}",
        direction="maximize",
        load_if_exists=True,
    )
    study.optimize(objective, n_trials=1, timeout=600)

    trial = study.best_trial
    print(f"Best hyperparameters: {trial.params}")
    fig = plot_optimization_history(study)
    fig.show()

In [None]:
main()