In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
%matplotlib inline

In [2]:
train_df = pd.read_csv('../data/raw/train.csv')

In [3]:
train_photo_to_biz = pd.read_csv('../data/raw/train_photo_to_biz_ids.csv')

In [4]:
test_photo_to_biz = pd.read_csv('../data/raw/test_photo_to_biz.csv')

In [5]:
attribute_id_to_label = {
    0: 'good_for_lunch',
    1: 'good_for_dinner',
    2: 'takes_reservations',
    3: 'outdoor_seating',
    4: 'restaurant_is_expensive',
    5: 'has_alcohol',
    6: 'has_table_service',
    7: 'ambience_is_classy',
    8: 'good_for_kids'}

In [7]:
train_photo_id_to_biz_id = dict(zip(train_photo_to_biz.photo_id, train_photo_to_biz.business_id))

In [8]:
train_df_cleaned = train_df.dropna()

In [9]:
biz_id_to_labels_str = dict(zip(train_df_cleaned.business_id, train_df_cleaned['labels']))

In [10]:
biz_id_to_labels = dict()
for biz_id, labels_str in biz_id_to_labels_str.items():
    biz_id_to_labels[biz_id] = [int(l) for l in labels_str.split()]

In [11]:
def OHE(labels, size=9):
    ohe = np.zeros((9, 1, 1))
    for l in labels:
        ohe[l][0][0] = 1
    return ohe

In [12]:
biz_id_to_ohe_labels = dict()
for biz_id, labels in biz_id_to_labels.items():
    biz_id_to_ohe_labels[biz_id] = OHE(labels)

In [13]:
train_photo_id_to_ohe_labels = dict()
for photo_id, biz_id in train_photo_id_to_biz_id.items():
    if biz_id in biz_id_to_ohe_labels:
        train_photo_id_to_ohe_labels[photo_id] = biz_id_to_ohe_labels[biz_id]

# Create LMDB

In [15]:
import caffe
import lmdb

In [26]:
train_photo_ids = train_photo_id_to_ohe_labels.keys()

In [27]:
train_photo_ids_perm = np.random.permutation(train_photo_ids)

In [29]:
def create_labels_lmdb(photo_ids_perm, photo_id_to_ohe_labels, lmdb_filename):
    counter_label = 0
    db_labels = lmdb.open(lmdb_filename, map_size=1e12)
    with db_labels.begin(write=True) as txn_label:
        for pid in photo_ids_perm:
            ohe = photo_id_to_ohe_labels[pid]
            datum = caffe.io.array_to_datum(ohe)
            txn_label.put("{:0>10d}".format(counter_label), datum.SerializeToString())
            counter_label += 1   
    db_labels.close()

In [31]:
create_labels_lmdb(train_photo_ids_perm, train_photo_id_to_ohe_labels, '../data/db/all_train_labels_lmdb')

In [36]:
def create_caffe_image_file(photo_ids_perm, filename):
    with open(filename, 'w') as cfile:
        for photo_id in photo_ids_perm:
            cfile.write("{id}.jpg 0\n".format(id=photo_id))

In [37]:
create_caffe_image_file(train_photo_ids_perm, '../data/db/all_train_caffe.txt')

In [38]:
!head -n 5 ../data/db/all_train_caffe.txt

420948.jpg 0
271159.jpg 0
284325.jpg 0
393119.jpg 0
58272.jpg 0


In [39]:
!wc -l ../data/db/all_train_caffe.txt

  234545 ../data/db/all_train_caffe.txt
