# 1. Load Data

In [43]:
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import math


In [44]:
hyperparameters = {
    'batch_size': 16,
    'embedding_dim': 256,
    'lstm_out_dim': 512,
    'epochs': 150,
    'learning_rate': 0.001
}

In [45]:
BASE_PATH = '/scratch/lt2316-h18-resources/coco/'

In [3]:
coco_captions = COCO(BASE_PATH + 'annotations/captions_train2017.json')
coco_instances = COCO(BASE_PATH + 'annotations/instances_train2017.json')

loading annotations into memory...
Done (t=1.18s)
creating index...
index created!
loading annotations into memory...
Done (t=21.90s)
creating index...
index created!


In [42]:
list(coco_captions.imgs.values())[:1]

[{'license': 3,
  'file_name': '000000391895.jpg',
  'coco_url': 'http://images.cocodataset.org/train2017/000000391895.jpg',
  'height': 360,
  'width': 640,
  'date_captured': '2013-11-14 11:18:45',
  'flickr_url': 'http://farm9.staticflickr.com/8186/8119368305_4e622c8349_z.jpg',
  'id': 391895}]

In [55]:
class Sampler():
    def __init__(self, coco_captions, number_of_samples=100, train_split=0.8, val_split=0.05, test_split=0.15) -> None:
        samples = []
        random_images = random.sample(list(coco_captions.imgs.values()), number_of_samples)
        for image_info in random_images:
            image = Image.open(BASE_PATH + 'train2017/' + image_info['file_name'])
            samples.append({
                'image': list(image.getdata()),
                'annotations': [annotation['caption']
                                for annotation in coco_captions.imgToAnns[image_info['id']]]
            })

        train_border = int(train_split * number_of_samples)
        val_border = int((train_split + val_split) * number_of_samples)

        self.train_samples = samples[:train_border]
        self.val_samples = samples[train_border:val_border]
        self.test_samples = samples[val_border:]


In [56]:
class COCO_Dataset(Dataset):
    def __init__(self, samples) -> None:
        super().__init__()
        self.samples = samples

    def __getitem__(self, index):
        return self.samples[index]
    
    def __len__(self):
        return len(self.samples)

In [57]:
sampler = Sampler(coco_captions)

train_dataloader = DataLoader(COCO_Dataset(sampler.train_samples),
                              batch_size=hyperparameters['batch_size'],
                              shuffle=True)
val_dataloader = DataLoader(COCO_Dataset(sampler.val_samples),
                            batch_size=hyperparameters['batch_size'],
                            shuffle=True)
test_dataloader = DataLoader(COCO_Dataset(sampler.test_samples),
                             batch_size=hyperparameters['batch_size'],
                             shuffle=True)
