In [1]:
import os
from PIL import Image
from transformers import GPT2TokenizerFast
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# Load GPT-2 tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token  # Set padding token to end-of-sequence token

class ImageCaptionDataset(Dataset):
    """
    Dataset class to load image-caption pairs.
    """
    def __init__(self, data_frame, image_directory, transform):
        self.data_frame = data_frame
        self.image_directory = image_directory
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        # Load and transform image
        sample = self.data_frame.iloc[idx]
        image_path = os.path.join(self.image_directory, sample['image'])
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        
        # Process caption with end-of-text token
        caption = f"{sample['caption']}<|endoftext|>"
        input_ids = torch.tensor(tokenizer.encode(caption, truncation=True))
        labels = input_ids.clone()
        labels[:-1] = input_ids[1:]
        labels[-1] = -100  # Set the last token as ignore index for loss computation
        return image, input_ids, labels

def collate_fn(batch):
    """
    Custom collate function for padding sequences in batch.
    """
    images, input_ids, labels = zip(*batch)
    images = torch.stack(images, dim=0)  # Stack images for batch
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    return images, input_ids, labels

# Define data augmentation and normalization transformations
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize((224, 224)),  # Resizing for ViT input requirements
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
    transforms.ToTensor(),
])
valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.ToTensor(),
])