In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
%matplotlib inline

In [2]:
train_df = pd.read_csv('../data/raw/train.csv')

In [3]:
train_photo_to_biz = pd.read_csv('../data/raw/train_photo_to_biz_ids.csv')

In [4]:
test_photo_to_biz = pd.read_csv('../data/raw/test_photo_to_biz.csv')

In [21]:
attribute_id_to_label = {
    0: 'good_for_lunch',
    1: 'good_for_dinner',
    2: 'takes_reservations',
    3: 'outdoor_seating',
    4: 'restaurant_is_expensive',
    5: 'has_alcohol',
    6: 'has_table_service',
    7: 'ambience_is_classy',
    8: 'good_for_kids'}

In [5]:
train_photo_id_to_biz_id = dict(zip(train_photo_to_biz.photo_id, train_photo_to_biz.business_id))

In [6]:
train_df_cleaned = train_df.dropna()

In [9]:
biz_id_to_labels_str = dict(zip(train_df_cleaned.business_id, train_df_cleaned['labels']))

In [10]:
biz_id_to_labels = dict()
for biz_id, labels_str in biz_id_to_labels_str.items():
    biz_id_to_labels[biz_id] = [int(l) for l in labels_str.split()]

# Split in train/val

We want to predict label for biz id, so I will also split the traning set on buisiness id.

In [13]:
seed = 42

In [14]:
np.random.seed(42)

In [15]:
train_biz_id = set(train_photo_to_biz.business_id)

In [16]:
len(train_biz_id)

2000

In [17]:
train_ratio = 0.7

In [18]:
train_biz_id_permuted = np.random.permutation(list(train_biz_id))

In [19]:
n_train_biz_id = int(len(train_biz_id) * train_ratio)

In [20]:
train_biz_id_cv = set(train_biz_id_permuted[:n_train_biz_id])
val_biz_id_cv = set(train_biz_id_permuted[n_train_biz_id:])

In [21]:
train_photos_ids_cv = {photo_id for photo_id, biz_id in train_photo_id_to_biz_id.items() if biz_id in train_biz_id_cv}
val_photos_ids_cv = {photo_id for photo_id, biz_id in train_photo_id_to_biz_id.items() if biz_id in val_biz_id_cv}

# Create caffe file for each attribute

In [37]:
def create_caffe_file(photos_ids, att, output_filename):
    with open(output_filename, 'w') as cfile:
        for photo_id in photos_ids:
            biz_id = train_photo_id_to_biz_id[photo_id]
            if biz_id in biz_id_to_labels:
                labels = biz_id_to_labels[biz_id]
                r = 0
                if att in labels:
                    r = 1
                cfile.write("{id}.jpg {label}\n".format(id=photo_id, label=r))

In [38]:
output_dir = '../data/db/'
for att in range(9):
    train_caffe_filename = os.path.join(output_dir, 'caffe_train_{att}.txt'.format(att=att))
    create_caffe_file(train_photos_ids_cv, att, train_caffe_filename)
    val_caffe_filename = os.path.join(output_dir, 'caffe_val_{att}.txt'.format(att=att))
    create_caffe_file(val_photos_ids_cv, att, val_caffe_filename)

In [41]:
!wc -l ../data/db/caffe_train_0.txt

  162905 ../data/db/caffe_train_0.txt


In [42]:
!wc -l ../data/db/caffe_val_0.txt

   71640 ../data/db/caffe_val_0.txt


In [45]:
!head -n 20 ../data/db/caffe_train_0.txt

262144.jpg 0
462639.jpg 0
2.jpg 0
436907.jpg 0
460541.jpg 0
5.jpg 0
462636.jpg 0
8.jpg 0
262154.jpg 0
392807.jpg 0
12.jpg 1
13.jpg 0
262158.jpg 0
262159.jpg 1
262160.jpg 0
262161.jpg 1
262162.jpg 0
21.jpg 1
262166.jpg 0
349529.jpg 0
