In [None]:
%cd /content/
!rm -rf DLProject/
!git clone https://github.com/ManuelaCorte/DLProject.git

In [None]:
%cd /content/DLProject/

In [None]:
%pip install ftfy regex tqdm ultralytics
%pip install git+https://github.com/openai/CLIP.git

In [None]:
import sys

sys.path.append("DLProject/src/")

In [None]:
import os

import gdown

# Download dataset and save under data/raw/ only if not already downloaded
url = "https://drive.google.com/uc?id=1xijq32XfEm6FPhUb7RsZYWHc2UuwVkiq"
if not os.path.exists("data/raw/refcocog.tar.gz"):
    print("Downloading dataset...")
    gdown.download(url=url, output="data/raw/", quiet=False, resume=True)
if not os.path.exists("data/raw/refcocog/"):
    print("Extracting dataset...")
    !tar -xf data/raw/refcocog.tar.gz -C data/raw/ --verbose

In [None]:
%cd src/

In [None]:
from typing import Any

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.ops import box_iou
from tqdm import tqdm
from vgproject.data.dataset import VGDataset
from vgproject.models.baseline import Baseline
from vgproject.utils.config import Config
from vgproject.utils.data_types import BboxType, Split
from vgproject.utils.misc import custom_collate

cfg = Config()
test_data = VGDataset(
    dir_path=cfg.dataset_path, split=Split.TEST, output_bbox_type=BboxType.XYXY
)

dataloader: DataLoader[Any] = DataLoader(
    test_data,
    batch_size=cfg.train.batch_size,
    shuffle=False,
    collate_fn=custom_collate,
    drop_last=True,
)

baseline = Baseline()

batches_acc = []
for batch, bboxes in tqdm(dataloader):
    prediction = baseline.predict(batch)
    bbox_pred = torch.stack([p.bounding_box for p in prediction]).to(baseline.device)
    bbox_gt = bboxes.clone().detach().squeeze(1).to(baseline.device)
    # print(bbox_pred.shape, bbox_gt.shape)
    iou = box_iou(bbox_pred, bbox_gt).to(baseline.device)
    acc: Tensor = torch.mean(torch.diagonal(iou))
    batches_acc.append(acc)
    # print('Accuracy: ', acc)

accuracy: float = torch.mean(torch.stack(batches_acc)).cpu().item()
print("Iou: ", accuracy)