In [5]:
import json
import os
import os.path
import re

from PIL import Image
#import h5py
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

In [6]:
%run utils.ipynb
#from utils import *
%run config.ipynb
#from config import *

In [7]:
# category_dict = {'1':'car', '2':'airplane',...}
#with open('categories_dict.json', 'r') as fd:
#    category_dict = json.load(fd)

# Create dataset instance and Generate Dataloader

In [5]:
def get_loader(train=False, val=False, test=False, image=False, sequence=False):
    """ Returns a data loader for the desired split """
    # Joint Model Dataloader
    if image and sequence:
        split = JointSet(
            path_for(train=train, val=val, test=test, label=True, image=True),
            path_for(train=train, val=val, test=test, data=True, image=True),
            path_for(train=train, val=val, test=test, data=True, sequence=True),
        )
    # Image Model Dataloader
    elif image:
        split = ImageSet(
            path_for(train=train, val=val, test=test, label=True, image=True),
            path_for(train=train, val=val, test=test, data=True, image=True),
        )
    # Sequence Model Dataloader
    elif sequence:
        split = SequenceSet(
            path_for(train=train, val=val, test=test, label=True, sequence=True),
            path_for(train=train, val=val, test=test, data=True, sequence=True),
        )
    loader = torch.utils.data.DataLoader(
        split,
        batch_size=batch_size,
        shuffle=train,  # only shuffle the data in training
        pin_memory=True,
        num_workers=data_workers,
    )
    return loader

In [6]:
class ImageSet(data.Dataset):
    """ VQA dataset, open-ended """
    def __init__(self, label_path, data_path):
        super(ImageSet, self).__init__()
        with open(label_path, 'r') as fd:
            self.label_json = json.load(fd)
        self.image_data = np.load(data_path)
        
    def __getitem__(self, item):
        l = int(self.label_json[item])
        image = self.image_data[item].reshape(28,28).astype(np.float32)
        return image, 0, l

    def __len__(self):
        return len(self.label_json)

In [7]:
class SequenceSet(data.Dataset):
    def __init__(self, label_path, data_path):
        super(SequenceSet, self).__init__()
        with open(label_path, 'r') as fd:
            self.label_json = json.load(fd)
        # shape [batch, MAX_LENGTH, (x,y)]
        # value [1-254], 0: default, 255: end of stroke
        self.sequence_data = np.load(data_path)
        self.MAX_LENGTH = 200
        
    def __getitem__(self, item):
        l = int(self.label_json[item])
        sequence = self.sequence_data[item].reshape(self.MAX_LENGTH,3).astype(np.int64)
        return 0, sequence, l
    
    def __len__(self):
        return len(self.label_json)

In [None]:
class JointSet(data.Dataset):
    def __init__(self, label_path, img_path, sequence_path):
        super(JointSet, self).__init__()
        with open(label_path, 'r') as fd:
            self.label_json = json.load(fd)
        # shape [batch, MAX_LENGTH, (x,y)]
        # value [1-254], 0: default, 255: end of stroke
        self.image_data = np.load(img_path)
        self.sequence_data = np.load(sequence_path)
        self.MAX_LENGTH = 200
        
    def __getitem__(self, item):
        l = int(self.label_json[item])
        img = self.image_data[item].reshape(28,28).astype(np.float32)
        sequence = self.sequence_data[item].reshape(self.MAX_LENGTH,3).astype(np.int64)
        return img, sequence, l
    
    def __len__(self):
        return len(self.label_json)