# PGAN data

In [None]:
import pymongo
from pymongo import MongoClient
import gridfs
import json
from bson.binary import Binary
import pickle
import numpy as np
from fuel.datasets.hdf5 import H5PYDataset
from tqdm import tqdm
from copy import deepcopy
import imageio


In [None]:
_host_adress = 'mongo_m'
_host_adress_p = 'mongo_p'
_host_port = 27017
_db_name = 'ssense_items'
_items_collection = 'items'
_db_photos = 'ssense_photo'
_photos_collection = 'photo'

In [None]:
def get_itemsdb():
    """Utility function to retrieve the collection of models

    Returns:
        the collection of models"""
    client = MongoClient(_host_adress, _host_port)
    items = client[_db_name]
    return items[_items_collection]


def get_photosdb():
    """Utility function to retrieve the collection of models

    Returns:
        the collection of models"""
    client = MongoClient(_host_adress_p, _host_port)
    photos = client[_db_photos]
    return gridfs.GridFS(photos)

In [None]:
items_db = get_itemsdb()

In [None]:
photos_db = get_photosdb()

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
for i, file in enumerate(photos_db.find()):
    if i % 100000 == 0:
        raw_binary = file.read()
        test = pickle.loads(raw_binary)
        plt.imshow(test)
        plt.show()

In [None]:
for i, item in enumerate(items_db.find()):
    if i == 10:
        break
item

In [None]:
def fill(dtype):
    if 'S' in dtype:
        return ''
    if 'int' in dtype:
        return '99999999999999999999999'
    if 'float' in dtype:
        return 'nan'

In [None]:
def get_features(item_keys, dtypes, item, photos_db, size=512):
    all_items = []
    dict_output = dict()
    for key, dtype in zip(item_keys, dtypes):
        if isinstance(item[key], str):
            dict_output[key] = np.array([item[key].encode('latin')], dtype=dtype)[:, None]
        elif isinstance(item[key], list):
            list_str = [str(el) for el in item[key]]
            dict_output[key] = np.array([','.join(list_str)], dtype=dtype)[:, None]
        elif np.isnan(item[key]):
            filling = fill(dtype)
            dict_output[key] = np.array([filling.encode('latin')], dtype=dtype)[:, None]
        else:
            print(item)
            print(key, item[key])
    for key in item:
        if 'gridfs'in key:
            if item[key] != 404:
                im_file = photos_db.find_one({'_id': item[key]})
                im = pickle.loads(im_file.read())
                if len(im.shape) == 2:
                    im = to_rgb1a(im)
                if im.shape[2] > 3:
                    im = im[:, :, :3]
                if size is not None:
                    im = imresize(im, (size, size))
                dict_output['image'] = im[None, ...]
                dict_output['pose'] = np.array([key],  dtype='S40')[:, None]
                all_items.append(deepcopy(dict_output))
    return all_items

In [None]:
from scipy.misc import imresize

In [None]:
item_keys = ['productID', 'description', 'name', 
             'brand', 'category', 
             'composition', 'department', 
             'gender', 'msrpUSD', 
             'season', 'subcategory', 'concat_description', 'matchedProductID']
dtypes = ['int32', 'S400', 'S100', 
          'S100', 'S100', 
          'S200', 'S100', 
          'S30', 'float32', 
          'S10', 'S100', 'S800', 'S100']

In [None]:
dict_test = get_features(item_keys, dtypes, item, photos_db)

In [None]:
dict_test

In [None]:
def dump_object(item):
    item, directory = item
    file_name = '_'.join([str(item['productID'][0][0]), str(item['pose'][0][0]).replace("'", "")])
    file_name = os.path.join(directory, file_name)
    if not os.path.exists(file_name + '.png'):
        imageio.imwrite(file_name + '.png', item['image'][0])
    new_dict = dict()
    for k, v in item.items():
        v_el = v[0][0]
        if k != 'image':
            new_dict[k] = str(v_el)
    file_name = str(item['productID'][0][0])
    file_name = os.path.join(directory, file_name)
    if not os.path.exists(file_name + '.json'):
        with open(file_name + '.json', 'w') as f:
            json.dump(new_dict, f)
    return 1

In [None]:
from multiprocessing import Pool

def build_image_dump(photos, items_db, directory='/data/images_png_dump',
             from_point=0, max_items=100, size=1024, nb_processes=10):
    
    nb_inserted = 0
    pool = Pool(nb_processes)
    items_buffer = []
    for i, el in tqdm(enumerate(items_db.find(no_cursor_timeout=True)), total=items_db.count()):
        if i >= from_point and i <= max_items:
            try:
                items = get_features(item_keys,
                                     dtypes,
                                     el, photos, 
                                     size=size)
            except Exception as e:
                print(el)
                print(e)
            if len(items) > 0:
                for item in items:
                    items_buffer.append((item, directory))
            if len(items_buffer) >= nb_processes * 2:
                inserted = pool.map(dump_object, items_buffer)
                nb_inserted += np.sum(inserted)
                items_buffer = []
            if i % 1000 == 0:
                for item in items:
                    print(item['name'])
                    plt.imshow(item['image'][0])
                    plt.show()
    return nb_inserted

In [None]:
build_image_dump(photos_db, items_db, max_items=10000000, nb_processes=10)