In [1]:
import json
import os
import os.path
import re

from PIL import Image
#import h5py
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

In [2]:
%run utils.ipynb
#from utils import *
%run config.ipynb
#from config import *

In [3]:
# category_dict = {'1':'car', '2':'airplane',...}
with open('categories_dict.json', 'r') as fd:
    category_dict = json.load(fd)

# Create dataset instance and Generate Dataloader

In [4]:
def get_loader(train=False, val=False, test=False, image=False, sequence=False):
    """ Returns a data loader for the desired split """
    if image:
        dataLink = ImageSet
    elif sequence:
        dataLink = SequenceSet
    split = dataLink(
        path_for(train=train, val=val, test=test, label=True, image=image, sequence=sequence),
        path_for(train=train, val=val, test=test, data=True, image=image, sequence=sequence),
    )
    loader = torch.utils.data.DataLoader(
        split,
        batch_size=batch_size,
        shuffle=train,  # only shuffle the data in training
        pin_memory=True,
        num_workers=data_workers,
    )
    return loader

In [5]:
class ImageSet(data.Dataset):
    """ VQA dataset, open-ended """
    def __init__(self, label_path, data_path):
        super(ImageSet, self).__init__()
        with open(label_path, 'r') as fd:
            self.label_json = json.load(fd)
        self.image_data = np.load(data_path)
        
    def __getitem__(self, item):
        l = int(self.label_json[item])
        image = self.image_data[item].reshape(28,28).astype(np.float32)
        return image, l

    def __len__(self):
        return len(self.label_json)

In [None]:
class SequenceSet(data.Dataset):
    def __init__(self, label_path, data_path):
        super(SequenceSet, self).__init__()
        with open(label_path, 'r') as fd:
            self.label_json = json.load(fd)
        # shape [batch, MAX_LENGTH, (x,y)]
        # value [1-254], 0: default, 255: end of stroke
        self.sequence_data = np.load(data_path)
        self.MAX_LENGTH = 200
        
    def __getitem__(self, item):
        l = int(self.label_json[item])
        sequence = self.sequence_data[item].reshape(self.MAX_LENGTH,2).astype(np.int64)
        return sequence, l
    
    def __len__(self):
        return len(self.label_json)