In [None]:
#import requests
from PIL import Image, ImageOps
from transformers import BlipProcessor, BlipForConditionalGeneration
import io
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [None]:
import json
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.spice.spice import Spice

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = BlipProcessor.from_pretrained("DinoDave/BLIP_finetuned_spatial_relations")
model = BlipForConditionalGeneration.from_pretrained("DinoDave/BLIP_finetuned_spatial_relations").to(device)

print("Sucessfully loaded processor and model")

In [None]:
dataset = load_dataset("DinoDave/SpatialRelationsTennis_masked")

train_test_split = dataset['train'].train_test_split(test_size=0.1)

# Separate train and test sets
train_dataset_raw = train_test_split['train']
test_dataset_raw = train_test_split['test']

print("Number of training examples:", len(train_dataset_raw))
print("Number of testing examples:", len(test_dataset_raw))

In [None]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor, resize_to=(640, 640)):
        self.dataset = dataset
        self.processor = processor
        self.resize_to = resize_to

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]

        # Resize the image
        if self.resize_to:
            image = image.resize(self.resize_to, Image.LANCZOS)

        encoding = self.processor(images=image, text=item["text"], padding="max_length", return_tensors="pt")
        # Remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        return encoding

In [None]:
train_dataset = ImageCaptioningDataset(train_dataset_raw, processor)
test_dataset = ImageCaptioningDataset(test_dataset_raw, processor)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=True)

In [None]:
generated_captions = []
reference_captions = []

# Create a figure with subplots (one column, multiple rows)
fig, axes = plt.subplots(len(test_dataset), 1, figsize=(5, 5 * len(test_dataset)))

for ax, id in zip(axes, range(len(test_dataset))):
    image_raw = test_dataset_raw[id]["image"]

    image = image_raw.resize((640, 640), Image.Resampling.LANCZOS)
    inputs = processor(image, return_tensors="pt").to(device)

    out = model.generate(**inputs, max_length=50)

    caption = processor.decode(out[0], skip_special_tokens=True)
    caption_ref = test_dataset_raw[id]["text"]

    reference_captions.append([str(caption_ref)])
    generated_captions.append(str(caption))

    ax.imshow(image)
    ax.set_title(caption)  # Display the filename as the title
    ax.axis('off')  # Hide the axes

plt.tight_layout()
plt.show()

In [None]:
gts = {i: reference_captions[i] for i in range(len(reference_captions))}
res = {i: [generated_captions[i]] for i in range(len(generated_captions))}

scorers = [(Bleu(4), "BLEU"), (Meteor(), "METEOR"), (Rouge(), "ROUGE"), (Cider(), "CIDEr"), (Spice(), "SPICE")]

# Compute and print scores
for scorer, method in scorers:
    score, scores = scorer.compute_score(gts, res)
    if isinstance(score, list):
        score_str = ", ".join([f"{s:.4f}" for s in score])
    else:
        score_str = f"{score:.4f}"
    print(f"{method}: {score_str}")