In [1]:
from __future__ import print_function
from PIL import Image
import torch.utils.data as data
import os
import sys
import pickle
import numpy as np
import lmdb
import torch


In [2]:
def default_loader(path):
    try:
        im = Image.open(path).convert('RGB')
        print ("Hello") 
        return im
    except:
        print("...", file=sys.stderr)
        return Image.new('RGB', (224, 224), 'white')
        template = "An exception of type {0} occurred. Arguments:\n{1!r}"
        message = template.format(type(ex).__name__, ex.args)
        print(message)
        print('Here\'s the path {}'.format(path))
        #pdb.set_trace()
        return Image.new('RGB', (224, 224), 'white')

In [3]:
class ImageLoader(data.Dataset):
    def __init__(self, img_path, transform=None, target_transform=None,
                 loader=default_loader, square=False, data_path=None, partition=None):
        if data_path == None:
            raise Exception('No data path specified.')

        if partition is None:
            raise Exception('Unknown partition type %s.' % partition)
        else:
            self.partition = partition
        # Open the LMDB files.
        self.env = lmdb.open(os.path.join(data_path, partition + '_lmdb'), max_readers=1, readonly=True, lock=False,
                             readahead=False, meminit=False)
        with open(os.path.join(data_path, partition + '_keys.pkl'), 'rb') as f:
            self.ids = pickle.load(f)
        
        self.square = square
        self.imgPath = img_path
        self.mismtch = 0.8
        self.maxInst = 20
        
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        
    def __getitem__(self,index):
        recipId = self.ids[index]
        # we force 80 percent of them to be a mismatch
        if self.partition == 'train':
            match = np.random.uniform() > self.mismtch
        elif self.partition == 'val' or self.partition == 'test':
            match = True
        else:
            raise 'Partition name not well defined'
        
        target = match and 1 or -1
        
        with self.env.begin(write=False) as txn:
            serialized_sample = txn.get(self.ids[index])
        sample = pickle.loads(serialized_sample)
        imgs = sample['imgs']
    
        # image
        if target == 1:
            if self.partition == 'train':
                # We do only use the first five images per recipe during training
                imgIdx = np.random.choice(range(min(5, len(imgs))))
            else:
                imgIdx = 0
            loader_path = [imgs[imgIdx]['id'][i] for i in range(4)]
            loader_path = os.path.join(*loader_path)
            path = os.path.join(self.imgPath, self.partition, loader_path, imgs[imgIdx]['id'])
        else:
            # we randomly pick one non-matching image
            all_idx = range(len(self.ids))
            rndindex = np.random.choice(all_idx)
            while rndindex == index:
                rndindex = np.random.choice(all_idx)  # pick a random index

            with self.env.begin(write=False) as txn:
                serialized_sample = txn.get(self.ids[rndindex])

            rndsample = pickle.loads(serialized_sample)
            rndimgs = rndsample['imgs']

            if self.partition == 'train':  # if training we pick a random image
                # We do only use the first five images per recipe during training
                imgIdx = np.random.choice(range(min(5, len(rndimgs))))
            else:
                imgIdx = 0

            path = self.imgPath + rndimgs[imgIdx]['id']
        
       
        # instructions
        
        instrs = sample['instruct']
        instr_vec_sent = instrs[0]
        instr_vec_word = instrs[1]
        
#         t_inst = np.zeros((self.maxInst, np.shape(instrs)[1]), dtype=np.float32)
#         t_inst[:itr_ln][:] = instrs
        instr_vec_sent = torch.FloatTensor(instr_vec_sent)
        instr_vec_word = torch.FloatTensor(instr_vec_word)
        print (instr_vec_sent)
        print (instr_vec_word)
        return
        # ingredients
        ingrs = sample['ingrs'].astype(int)
        ingrs = torch.LongTensor(ingrs)
        igr_ln = max(np.nonzero(sample['ingrs'])[0]) + 1

        # load image
        img = self.loader(path)

        if self.square:
            img = img.resize(self.square)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        rec_class = sample['classes'] - 1
        rec_id = self.ids[index]

        if target == -1:
            img_class = rndsample['classes'] - 1
            img_id = self.ids[rndindex]
        else:
            img_class = sample['classes'] - 1
            img_id = self.ids[index]
        
        if self.partition == 'train':
            return [img, ingrs, igr_ln], [target]
            #return [img, instrs, itr_ln, ingrs, igr_ln], [target]
        else:
            return [img, ingrs, igr_ln], [target, img_id, rec_id]
            #return [img, instrs, itr_ln, ingrs, igr_ln], [target, img_id, rec_id]

    def __len__(self):
        return len(self.ids)

In [4]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
IMG_PATH = '/home/yifu/Documents/Mycode/python/hierarchicalRNN/jasha/'
# GIve the path for the LMDB files that were created.
DATA_PATH = '/home/yifu/Documents/Mycode/python/hierarchicalRNN/jasha/lmdb/'
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
WORKERS = 30
BATCH_SIZE = 160
i1 = ImageLoader(IMG_PATH,
            transforms.Compose([
            transforms.Scale(256), # rescale the image keeping the original aspect ratio
            transforms.CenterCrop(256), # we get only the center of that rescaled
            transforms.RandomCrop(224), # random crop within the center crop
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        data_path=DATA_PATH,
        partition='train')
train_loader = torch.utils.data.DataLoader(
        i1,
        batch_size=BATCH_SIZE, shuffle=True,
        num_workers=WORKERS, pin_memory=True)




In [5]:
#############Testing#########################
data_path = '/home/yifu/Documents/Mycode/python/hierarchicalRNN/jasha/lmdb/'
partition = 'train'
env = lmdb.open(os.path.join(data_path, partition + '_lmdb'), max_readers=1, readonly=True, lock=False,
                             readahead=False, meminit=False)
with open(os.path.join(data_path, partition + '_keys.pkl'), 'rb') as f:
    ids = pickle.load(f)
loader = default_loader
mismtch = 0.8
index = 100
recipId = ids[index]
        # we force 80 percent of them to be a mismatch
if partition == 'train':
    match = np.random.uniform() > mismtch
target = match and 1 or -1    

In [7]:
env

<Environment at 0x7fbae872c030>