In [None]:
import os
from typing import Literal, Optional

import random
import pandas as pd
import PIL
import torch
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torchvision.transforms import Compose

PIL.Image.MAX_IMAGE_PIXELS = 1000000000


class Flickr30kDataset(Dataset):

    def __init__(
        self,
        path: str,
        csv_file: str,
        images_folder: str,
        transform: Optional[Compose] = None,
    ) -> None:
        self.path = path
        self.images_folder = images_folder
        self.transform = transform
        
        df = pd.read_csv(os.path.join(path, csv_file), on_bad_lines='skip', sep='|')
        df.columns = list(map(lambda x: x.strip(), df.columns))
        df = df[df['comment_number'] == ' 0']
        df = df.drop('comment_number', axis=1)
        df = df[df['comment'].apply(lambda x: len(x.split())) <= 50]
        self.df = df
        

    def __len__(self) -> int:
        return len(self.df)
            
        
    def __getitem__(self, idx: int) -> dict:
        row = self.df.iloc[idx]
#         if len(row["comment"].split()) > 77:
#             return self.__getitem__(0)
        target_img_name = os.path.join(self.path, self.images_folder, row["image_name"])
        
        target_image = Image.open(target_img_name).convert("RGB")
        if self.transform:
            target_image = self.transform(target_image)

        candidate_images = []
        for _ in range(9):
            num = idx
            while num != idx:
                num = random.randint(0, self.df.shape[0]-1)
            random_row = self.df.iloc[num]
            img_name = os.path.join(self.path, self.images_folder, random_row["image_name"])
            image = Image.open(img_name).convert("RGB")
            if self.transform:
                image = self.transform(image)
            candidate_images.append(image)
        candidate_images = torch.stack(candidate_images)

        sample = {
            "context": row["comment"],
            "target": torch.Tensor(target_image),
            "candidate_images": candidate_images,
        }
        return sample


In [None]:
import os
import random
from typing import Literal

import numpy as np
import torch
from sklearn.metrics import f1_score, precision_score, recall_score
from torch.utils.data import DataLoader
from torchvision.transforms import (
    CenterCrop,
    Compose,
    InterpolationMode,
    Resize,
    ToTensor,
)

transform = Compose(
    [
        Resize(224, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(224),
        ToTensor(),
    ]
)


def get_loaders(
    path: str,
    csv_file: str,
    images_folder: str,
    transform: Compose = transform,
    batch_size: int = 1,
    num_workers: int = 0,
    shuffle: bool = True,
) -> DataLoader:
    eval_dataset = Flickr30kDataset(
        path=path,
        csv_file=csv_file,
        images_folder=images_folder,
        transform=transform,
    )
    eval_loader = DataLoader(
        eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    return eval_loader


def get_metrics(targets: list, ranks: list) -> tuple[float]:
    accuracy = sum(targets) / len(targets)

    true_targets = [1] * len(targets)

    f1 = f1_score(true_targets, targets)
    prec = precision_score(true_targets, targets)
    rec = recall_score(true_targets, targets)

    mrr = np.mean([1 / rank for rank in ranks])

    return accuracy, f1, prec, rec, mrr

In [None]:
from abc import ABC, abstractmethod

import torch
import torch.nn as nn


class BaseModel(ABC, nn.Module):
    """
    An abstract base class for models for Visual-WSD dataset.

    Attributes:
        model: pretrained model.
        processor: wrapped model's image processor and tokenizer into a single processor.
    """

    def __init__(self) -> None:
        super().__init__()
        self.model = None
        self.processor = None

    @abstractmethod
    def process_image(self, images: torch.Tensor) -> torch.Tensor:
        """
        Process the images.

        Args:
            images (torch.Tensor): A tensor containing the one image or stacked multiple images.

        Returns:
            torch.Tensor: The processed images.
        """
        pass

    @abstractmethod
    def process_text(self, texts: list[str]) -> torch.Tensor:
        """
        Process the textual input.

        Args:
            texts (list[str]): textual content (descriptions of images)

        Returns:
            torch.Tensor: The processed text.
        """
        pass

    @abstractmethod
    def forward(self, images: torch.Tensor, texts: list[str]) -> torch.Tensor:
        """
        The forward pass of the model. Should handle both text and image data, and return a tensor of logits,
        where on first place would be logit for target.

        Args:
            images (torch.Tensor): visual content
            texts (list[str]): textual content

        Returns:
            torch.Tensor: A tensor of logits of size [batch_size, 10].
        """
        pass

In [None]:
import torch
from PIL import Image
from torchvision import transforms
from transformers import AlignProcessor, AlignModel


class ALIGNMODEL(BaseModel):
    """
    https://huggingface.co/docs/transformers/model_doc/align
    """

    def __init__(self, model_name):
        super().__init__()
        self.model = AlignModel.from_pretrained(model_name)
        self.processor = AlignProcessor.from_pretrained(model_name)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = self.model.to(self.device)

    def process_image(self, images: torch.Tensor) -> torch.Tensor:
        processed_images = self.processor(images=images, return_tensors="pt").to(
            self.device
        )
        return processed_images

    def process_text(self, texts: list[str]) -> torch.Tensor:
        processed_texts = self.processor(text=texts, return_tensors="pt").to(
            self.device
        )
        return processed_texts

    def forward(self, images: torch.Tensor, texts: list[str]) -> torch.Tensor:
        images = images.to(self.device)
        logits = torch.zeros(images.shape[0], images.shape[1])

        for idx, sample_images in enumerate(images):
            processed_sample_images = self.process_image(sample_images)
            processed_phrase = self.process_text(texts[idx])

            output = self.model(
                input_ids=processed_phrase.input_ids,
                pixel_values=processed_sample_images.pixel_values,
                return_dict=True,
            )
            logits[idx] = output.logits_per_image.squeeze(1)

        return logits

In [None]:
import time

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

# from utils import get_metrics


def evaluate_model(
    model: torch.nn.Module, data_loader: DataLoader
) -> dict[str, float | list]:
    start_time = time.time()
    model.eval()

    predicted_images = []  # store which image was predicted
    correct_preds = []  # store whether the target was correctly predicted (1) or (0)
    all_target_ranks = []  # store the rank of the target in each prediction
    phrases = []  # store input phrases for further analysis
    all_probs = []  # store the probabilities for further analysis

    loop = tqdm(enumerate(data_loader), total=len(data_loader))

    with torch.no_grad():
        for idx, batch in loop:
#             phrases.extend(list(batch["context"]))
            texts = batch["context"]

            target, candidate_images = batch["target"], batch["candidate_images"]
            images = torch.cat([target.unsqueeze(1), candidate_images], dim=1)

            logits = model(images, texts)
            probs = F.softmax(logits, dim=1)

            top_prob, top_indices = torch.max(probs, dim=1)
#             predicted_images.extend([pred.item() for pred in top_indices])

            for i in range(len(top_indices)):
                correct_target = 1 if top_indices[i] == 0 else 0
                correct_preds.append(correct_target)

                rank = (probs[i].sort(descending=True)[1] == 0).nonzero(as_tuple=True)[
                    0
                ].item() + 1
                all_target_ranks.append(rank)

                all_probs.append(probs[i].tolist())


    accuracy, f1, precision, recall, mrr = get_metrics(correct_preds, all_target_ranks)

    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "mrr": mrr,
        "time": time.time() - start_time,
#         "phrases": phrases,
#         "predictions": predicted_images,
    }

In [None]:
model_name = "kakaobrain/align-base"
model = ALIGNMODEL(model_name=model_name)

print(model.device)

In [None]:
loader = get_loaders(
    path = '/kaggle/input/flickr-image-dataset/flickr30k_images',
    csv_file = 'results.csv',
    images_folder = 'flickr30k_images',
    transform = transform,
    batch_size = 128,
    num_workers = 2
)

In [None]:
res = evaluate_model(model, loader)

In [None]:
res