In [32]:
import torch
import os
import numpy as np
import json
import lmdb
import random

from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from io import BytesIO

In [23]:
def get_train_transforms(mean, std):
    return transforms.Compose([
        transforms.Resize(256),
        transforms.ColorJitter(hue=.05, saturation=.05),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(20, resample=Image.BILINEAR),
        transforms.RandomCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

def get_val_transforms(mean, std):
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])


# https://discuss.pytorch.org/t/how-can-torchvison-models-deal-with-image-whose-size-is-not-224-224/51077/3
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = get_train_transforms(mean, std)
val_transform = get_val_transforms(mean, std)
ROOT = '/common/users/as3503/courses/cs536'

In [24]:
! pwd

/common/users/as3503/courses/cs536/final_project/test_dir


In [37]:
class Recipe1MDataset(Dataset):
    def __init__(
        self, 
        lmdb_file=f'{ROOT}/dataset/Recipe1M.lmdb',
        part='', food_type='',
        transform=None, resolution=256, return_image=True):

        assert part in ['', 'train', 'val', 'test'], "part has to be in ['', 'train', 'val', 'test']"
        assert food_type in ['', 'salad', 'cookie', 'muffin'], "part has to be in ['', 'salad', 'cookie', 'muffin']"

        dirname = os.path.dirname(lmdb_file)
        path = os.path.join(dirname, 'keys.json')

        with open(path, 'r') as f:
            self.keys = json.load(f)
        
        self.all_keys = self.keys
        self.keys = [x for x in self.keys if x['with_image']]
        if part:
            self.keys = [x for x in self.keys if x['partition']==part]
        if food_type:
            self.keys = [x for x in self.keys if food_type.lower() in x['title'].lower()]

        self.env = lmdb.open(
            lmdb_file,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', lmdb_file)

        # with self.env.begin(write=False) as txn:
        #     self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

        self.resolution = resolution

        assert transform!=None, 'transform can not be None!'
        self.transform = transform

        self.return_image = return_image

    def __len__(self):
        return len(self.keys)

    def _load_recipe(self, rcp):
        rcp_id = rcp['id']

        with self.env.begin(write=False) as txn:
            # print("Loading recipe")
            key = f'title-{rcp_id}'.encode('utf-8')
            title = txn.get(key).decode('utf-8')

            key = f'ingredients-{rcp_id}'.encode('utf-8')
            ingredients = txn.get(key).decode('utf-8')

            key = f'instructions-{rcp_id}'.encode('utf-8')
            instructions = txn.get(key).decode('utf-8')

            key = f'{self.resolution}-{rcp_id}'.encode('utf-8')
            img_bytes = txn.get(key)

        txt = title
        txt += '\n'
        txt += ingredients
        txt += '\n'
        txt += instructions

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img = self.transform(img)
        return txt, img

    def __getitem__(self, index):
        rcp_key = self.keys[index]
        txt, img = self._load_recipe(rcp_key)
        return txt, img


In [38]:
train = Recipe1MDataset(transform=train_transform, part='train')
val = Recipe1MDataset(transform=val_transform, part='val')
test = Recipe1MDataset(transform=val_transform, part='test')

In [39]:
len(train), len(val), len(test)

(452322, 97179, 97612)

In [40]:
len(train.all_keys), (len(train) + len(val) + len(test))

(1274073, 647113)