In [6]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import json
from tqdm.notebook import tqdm


In [7]:
class CocoCaptionDataset(Dataset):
    def __init__(self,img_folder,annotation_file,transform=None,vocab=None):
        self.img_folder =img_folder
        self.transform = transform

        with open(annotation_file, 'r') as f:
            self.coco_data = json.load(f)

        # Id to file mapping
        self.img_id_to_filename = {img['id']: img['file_name'] for img in self.coco_data['images']}

        # All captions with image ids
        self.captions = self.coco_data['annotations']

        # vocab for tokenizing captions
        self.vocab = vocab

    def __len__(self):
        return len(self.captions)
    

    def __get_item__(self,idx):
        caption = self.captions[idx]['caption']
        img_id = self.captions[idx]['image_id']
        img_filename = self.img_id_to_filename[img_id]
        img_path = os.path.join(self.img_folder, img_filename)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        if self.vocab:
            caption_tokens = [self.vocab['<start>']] + [self.vocab.get(word, self.vocab['<unk>']) for word in caption.lower().split()] + [self.vocab['<end>']]
            caption_tokens = torch.tensor(caption_tokens)
            return image, caption_tokens
        else:
            return image, caption