In [2]:
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import os

In [17]:

class Flickr8kDataset(Dataset):
    def __init__(self, image_dir, captions_file, transform=None):
        self.image_dir = image_dir
        self.captions_file = captions_file
        self.transform = transform
        self.annotations = self.load_annotations()

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.annotations[idx][0])
        image = Image.open(img_name).convert("RGB")
        if self.transform:
            image = self.transform(image)
        caption = self.annotations[idx][1]
        return image, caption

    def load_annotations(self):
        annotations = []
        with open(self.captions_file, "r") as f:
            for line in f:
                parts = line.strip().split(",")
                if len(parts) == 2:
                    image_name, caption = parts
                    annotations.append((image_name, caption))
        return annotations


image_dir = r"archive\Images\\"
captions_file = r"archive\captions.txt"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


flickr8k_dataset = Flickr8kDataset(image_dir=image_dir, captions_file=captions_file, transform=transform)

print(len(flickr8k_dataset))


flickr8k_loader = DataLoader(flickr8k_dataset, batch_size=32, shuffle=True)

for images, captions in flickr8k_loader:
    print("Batch Images shape", images.shape)
    print("Batch captions shape:", len(captions))
    break  

38009
Batch Images shape torch.Size([32, 3, 224, 224])
Batch captions shape: 32
