In [1]:
import torch
import os
import numpy as np
import json
import lmdb
import random

from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from io import BytesIO

In [2]:
def get_train_transforms(mean, std):
    return transforms.Compose([
        transforms.Resize(256),
        transforms.ColorJitter(hue=.05, saturation=.05),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(20, resample=Image.BILINEAR),
        transforms.RandomCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

def get_val_transforms(mean, std):
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])


# https://discuss.pytorch.org/t/how-can-torchvison-models-deal-with-image-whose-size-is-not-224-224/51077/3
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = get_train_transforms(mean, std)
val_transform = get_val_transforms(mean, std)
ROOT = '/common/users/as3503/courses/cs536'

In [3]:
! pwd

/common/users/as3503/courses/cs536/final_project/example_code


In [4]:
class Recipe1MDataset(Dataset):
    def __init__(
        self, 
        lmdb_file=f'{ROOT}/dataset/Recipe1M.lmdb',
        part='', food_type='',
        transform=None, resolution=256, return_image=True):

        assert part in ['', 'train', 'val', 'test'], "part has to be in ['', 'train', 'val', 'test']"
        assert food_type in ['', 'salad', 'cookie', 'muffin'], "part has to be in ['', 'salad', 'cookie', 'muffin']"

        dirname = os.path.dirname(lmdb_file)
        path = os.path.join(dirname, 'keys.json')

        with open(path, 'r') as f:
            self.keys = json.load(f)
        
        self.all_keys = self.keys
        self.keys = [x for x in self.keys if x['with_image']]
        if part:
            self.keys = [x for x in self.keys if x['partition']==part]
        if food_type:
            self.keys = [x for x in self.keys if food_type.lower() in x['title'].lower()]

        self.env = lmdb.open(
            lmdb_file,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', lmdb_file)

        # with self.env.begin(write=False) as txn:
        #     self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

        self.resolution = resolution

        assert transform!=None, 'transform can not be None!'
        self.transform = transform

        self.return_image = return_image

    def __len__(self):
        return len(self.keys)

    def _load_recipe(self, rcp):
        rcp_id = rcp['id']

        with self.env.begin(write=False) as txn:
            # print("Loading recipe")
            key = f'title-{rcp_id}'.encode('utf-8')
            title = txn.get(key).decode('utf-8')

            key = f'ingredients-{rcp_id}'.encode('utf-8')
            ingredients = txn.get(key).decode('utf-8')

            key = f'instructions-{rcp_id}'.encode('utf-8')
            instructions = txn.get(key).decode('utf-8')

            key = f'{self.resolution}-{rcp_id}'.encode('utf-8')
            img_bytes = txn.get(key)

        txt = title
        txt += '\n'
        txt += ingredients
        txt += '\n'
        txt += instructions

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img = self.transform(img)
        return txt, img

    def __getitem__(self, index):
        rcp_key = self.keys[index]
        txt, img = self._load_recipe(rcp_key)
        return txt, img


In [39]:
data = Recipe1MDataset(transform=val_transform, part='train')

In [44]:
train_ids = list(set('-'.join(x['id'].split('-')[:-1]) for x in data.keys))

In [45]:
len(train_ids)

281598

In [30]:
muffins = Recipe1MDataset(transform=val_transform, food_type='muffin', part='train')

In [31]:
muffins.keys[0]

{'id': '000320b7ce-0',
 'partition': 'train',
 'title': 'Tropical Banana Muffins',
 'with_image': 1}

In [32]:
muffin_ids = list(set('_'.join(x['id'].split('-')[:-1]) for x in muffins.keys))

In [33]:
muffin_ids

['274048bf64',
 '76132af775',
 'bc2a104de5',
 'd5681c8cef',
 'b7e3ace201',
 '4e274aa7e7',
 '508cb9b502',
 'd72a080891',
 '3139163450',
 '8c6390ec2d',
 'dd774731aa',
 '4df2195052',
 '9fc15c4484',
 'b7138eb764',
 '9618fdb41c',
 '109d7eafa8',
 '916673fc09',
 'e258e90d30',
 'f5e4ce0d83',
 '954eb39a60',
 '5d81159b2d',
 '2d251c518e',
 '63b8484102',
 '30ff8a9e41',
 'f53659d4c7',
 'c5128dc423',
 '307d0fb88e',
 '4e63d24c36',
 '7fd9700f29',
 'a818c7fb1f',
 '06d6123ef8',
 'e6992af880',
 'f171f19aea',
 '2c49cb7ea1',
 '9e59426e7f',
 '7b7a24d8ba',
 'f13c280ea9',
 '1e981a7a87',
 'b435afe5b2',
 '6f18885d9a',
 'b8fdfc0fbf',
 '1271ca6301',
 '512a6dd744',
 '607ae5f00a',
 'c26d635328',
 'b1f063d3dd',
 'ced12b262f',
 'e992bf3825',
 'f0db6f76a6',
 '776a39634f',
 'd758e1b73d',
 '4c9f39ef1f',
 'f5355bd1f1',
 'a6c2c9eed9',
 '63e8f8f9d2',
 'cdffc309f7',
 '9ab3e72531',
 'a3d282dc65',
 '440b260313',
 'eaeb65be8a',
 '00938d4caa',
 '477c5b1639',
 '7d0607e5e2',
 '7f9f69f061',
 '9abc3de705',
 'f419267f92',
 '69de59b0

In [34]:
salads = Recipe1MDataset(transform=val_transform, food_type='salad', part='train')

In [35]:
salads.keys[0]

{'id': '0005fc89f7-0',
 'partition': 'train',
 'title': 'Shrimp and Caper Salad',
 'with_image': 1}

In [36]:
salad_ids = list(set('_'.join(x['id'].split('-')[:-1]) for x in salads.keys))

In [37]:
salad_ids

['205127872b',
 'e59638cdc2',
 '7cbb72bc75',
 '6f27a50955',
 '69850b066f',
 '66e50b582b',
 '773ac16da9',
 '151202c8b5',
 'bac54afb16',
 '60964e47fe',
 'e52124213c',
 '036ca6acff',
 'bc137d15a3',
 '3e9d11a948',
 '999903f4f8',
 'eaa6b22344',
 '22a029f05b',
 '6b51b28c93',
 '848fb435ac',
 'b9a6962534',
 '813a288a06',
 '06c5ce19ca',
 'f3b9cd5e00',
 'cc4dadd9cd',
 'b7a660aab8',
 '1ea3322a84',
 '461269b9db',
 'b597374ff9',
 '2848731780',
 'e41108816f',
 'd51c0f8d82',
 '8174ae6216',
 '0739504d23',
 'a6abf1c0a8',
 '1e7e390444',
 '238a9ec805',
 '2585192efa',
 '38fb111adc',
 '49d328da2d',
 '71d238f364',
 '10e4db39fe',
 'a702e344fa',
 '48517eff7e',
 'a2e285e886',
 '9bcdccda3f',
 '588d285ced',
 '93861893db',
 'bdf0ab9f67',
 '9d8cd9be72',
 'c21b282372',
 'd1d8bd0c6e',
 '975ab9b220',
 '13f06456e6',
 '815d2088d0',
 '95799ae010',
 'a33abf4dea',
 'cbc2301cca',
 'f059560974',
 '5a3437905a',
 '0e73c1f344',
 '1b192a19be',
 '649f224b54',
 '8b6a186cb4',
 '0520e40340',
 '2afc726d51',
 '424fc3e2ad',
 '6e183718

In [38]:
with open('muffin_ids.txt', 'w') as f:
    f.write('\n'.join(muffin_ids))
    
with open('salad_ids.txt', 'w') as f:
    f.write('\n'.join(salad_ids))

In [38]:
train = Recipe1MDataset(transform=train_transform, part='train')
val = Recipe1MDataset(transform=val_transform, part='val')
test = Recipe1MDataset(transform=val_transform, part='test')

In [39]:
len(train), len(val), len(test)

(452322, 97179, 97612)

In [40]:
len(train.all_keys), (len(train) + len(val) + len(test))

(1274073, 647113)