In [None]:
import os
import time
import torch
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor
from PIL import Image
from abc import ABC, abstractmethod
import torch.nn.functional as F

import torch
import torch.nn as nn

from transformers import CLIPModel, CLIPProcessor
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score

In [None]:
class Flickr8kDataset(Dataset):
    def __init__(self, image_folder, captions_file, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.df = pd.read_csv(captions_file)
        self.df['image'] = self.df['image'].apply(lambda x: os.path.join(self.image_folder, x))
        print(self.df)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path, caption = row['image'], row['caption']
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, caption

In [None]:
def get_metrics(targets: list, ranks: list) -> tuple[float]:
    accuracy = sum(targets) / len(targets)

    true_targets = [1] * len(targets)

    f1 = f1_score(true_targets, targets)
    prec = precision_score(true_targets, targets)
    rec = recall_score(true_targets, targets)

    mrr = np.mean([1 / rank for rank in ranks])

    return accuracy, f1, prec, rec, mrr


def seed_everything(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
class BaseModel(ABC, nn.Module):
    """
    An abstract base class for models for Visual-WSD dataset.

    Attributes:
        model: pretrained model.
        processor: wrapped model's image processor and tokenizer into a single processor.
    """

    def __init__(self) -> None:
        super().__init__()
        self.model = None
        self.processor = None

    @abstractmethod
    def process_image(self, images: torch.Tensor) -> torch.Tensor:
        """
        Process the images.

        Args:
            images (torch.Tensor): A tensor containing the one image or stacked multiple images.

        Returns:
            torch.Tensor: The processed images.
        """
        pass

    @abstractmethod
    def process_text(self, texts: list[str]) -> torch.Tensor:
        """
        Process the textual input.

        Args:
            texts (list[str]): textual content (descriptions of images)

        Returns:
            torch.Tensor: The processed text.
        """
        pass

    @abstractmethod
    def forward(self, images: torch.Tensor, texts: list[str]) -> torch.Tensor:
        """
        The forward pass of the model. Should handle both text and image data, and return a tensor of logits,
        where on first place would be logit for target.

        Args:
            images (torch.Tensor): visual content
            texts (list[str]): textual content

        Returns:
            torch.Tensor: A tensor of logits of size [batch_size, 10].
        """
        pass


In [None]:
class CLIPMODEL(BaseModel):
    """
    https://huggingface.co/docs/transformers/model_doc/clip
    """

    def __init__(self, model_name):
        super().__init__()
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name, do_rescale=False)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = self.model.to(self.device)

    def process_image(self, images: torch.Tensor) -> torch.Tensor:
        processed_images = self.processor(images=images, return_tensors="pt", dim=2).to(self.device)
        processed_images.pixel_values = processed_images.pixel_values.float()
        return processed_images

    def process_text(self, texts: list[str]) -> torch.Tensor:
        processed_texts = self.processor(
            text=texts, return_tensors="pt", padding=True
        ).to(self.device)
        return processed_texts

    def forward(self, images: torch.Tensor, texts: list[str]) -> torch.Tensor:
        images = images.to(self.device)
        logits = torch.zeros(images.shape[0], images.shape[1], dtype=torch.float32).to(self.device)

        for idx, sample_images in enumerate(images):
            processed_sample_images = self.process_image(sample_images)
            processed_phrase = self.process_text(texts[idx])

            output = self.model(
                input_ids=processed_phrase.input_ids,
                pixel_values=processed_sample_images.pixel_values,
                return_dict=True,
            )

            logits[idx] = output.logits_per_image.squeeze(1)
        return logits

In [None]:
def evaluate_model(
    model: torch.nn.Module, data_loader: DataLoader
) -> dict[str, float | list]:
    
    start_time = time.time()

    model.eval()

    predicted_images = []  # store which image was predicted
    correct_preds = []  # store whether the target was correctly predicted (1) or (0)
    all_target_ranks = []  # store the rank of the target in each prediction
    phrases = []  # store input phrases for further analysis
    all_probs = []  # store the probabilities for further analysis


    with torch.no_grad():
        for images, captions in tqdm(data_loader):
            phrases.extend(list(captions))

            logits = model(images, captions)
            probs = F.softmax(logits, dim=1)

            _, top_indices = torch.max(probs, dim=1)
            predicted_images.extend([pred.item() for pred in top_indices])

            for i in range(len(top_indices)):
                correct_target = 1 if top_indices[i] == 0 else 0
                correct_preds.append(correct_target)

                rank = (probs[i].sort(descending=True)[1] == 0).nonzero(as_tuple=True)[
                    0
                ].item() + 1
                all_target_ranks.append(rank)

                all_probs.append(probs[i].tolist())

    accuracy, f1, precision, recall, mrr = get_metrics(correct_preds, all_target_ranks)

    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "mrr": mrr,
        "time": time.time() - start_time,
        "phrases": phrases,
        "predictions": predicted_images,
    }


In [None]:
model_name = "openai/clip-vit-base-patch32"
model = CLIPMODEL(model_name=model_name)

In [None]:
transform = transforms.Compose([
    Resize((224, 224)),
    ToTensor(),
])

image_folder = '../data/flickr8k/Images/'
captions_file = '../data/flickr8k/captions.txt'
flickr_dataset = Flickr8kDataset(image_folder, captions_file, transform=transform)
data_loader = DataLoader(flickr_dataset, batch_size=32, shuffle=True)

In [None]:
res = evaluate_model(model, data_loader)

In [None]:
res