In [2]:
#import requests
from PIL import Image, ImageOps
from transformers import BlipProcessor, BlipForConditionalGeneration
import io
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = BlipProcessor.from_pretrained("DinoDave/BLIP_finetuned_spatial_relations")
model = BlipForConditionalGeneration.from_pretrained("DinoDave/BLIP_finetuned_spatial_relations").to(device)

print("Sucessfully loaded processor and model")

Sucessfully loaded processor and model


In [3]:
dataset = load_dataset("DinoDave/SpatialRelationsTennis_masked")

train_test_split = dataset['train'].train_test_split(test_size=0.1)

# Separate train and test sets
train_dataset_raw = train_test_split['train']
test_dataset_raw = train_test_split['test']

print("Number of training examples:", len(train_dataset_raw))
print("Number of testing examples:", len(test_dataset_raw))

Downloading readme: 100%|██████████| 24.0/24.0 [00:00<00:00, 93.2kB/s]
Downloading data: 100%|██████████| 267/267 [00:11<00:00, 22.44files/s]
Generating train split: 100%|██████████| 266/266 [00:00<00:00, 3722.34 examples/s]

Number of training examples: 239
Number of testing examples: 27





In [4]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor, resize_to=(640, 640)):
        self.dataset = dataset
        self.processor = processor
        self.resize_to = resize_to

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]

        # Resize the image
        if self.resize_to:
            image = image.resize(self.resize_to, Image.LANCZOS)

        encoding = self.processor(images=image, text=item["text"], padding="max_length", return_tensors="pt")
        # Remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        return encoding

In [11]:
train_dataset = ImageCaptioningDataset(train_dataset_raw, processor)
test_dataset = ImageCaptioningDataset(test_dataset_raw, processor)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=True)

In [12]:
# Create a figure with subplots (one column, multiple rows)
fig, axes = plt.subplots(len(test_dataset), 1, figsize=(5, 5 * len(test_dataset)))

for ax, id in zip(axes, range(len(test_dataset))):
    image_raw = test_dataset_raw[id]["image"]
    #
    image = image_raw.resize((640, 640), Image.Resampling.LANCZOS)
    inputs = processor(image, return_tensors="pt").to(device)

    out = model.generate(**inputs, max_length=50)

    ax.imshow(image)
    ax.set_title(processor.decode(out[0], skip_special_tokens=True))  # Display the filename as the title
    ax.axis('off')  # Hide the axes

plt.tight_layout()
plt.show()

: 