## imports

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
import itertools
import pandas as pd
import _pickle as cPickle
from pandas import DataFrame
from pandas import Series
import os

## function definitions

In [None]:
class Loader(object):
    """
    interface for data loader wrappers
    """
    def load_data(self, file_name):
        """
        loads the data into memory
        :param file_name: name of the file to load
        :return: pandas object
        """
        raise NotImplementedError

In [None]:
import gzip
class JSONLoader(Loader):
    def __init__(self, read_mode='rb'):
        self.read_mode = read_mode

    def load_data(self, file_name):
        """
        :param file_name: complete path to open
        :return: pandas dataframe
        """
        try:
            i = 0
            df = {}
            for d in self._parse(file_name):
                df[i] = d
                i += 1
            return pd.DataFrame.from_dict(df, orient='index')
        except Exception as e:
            raise e

    def _parse(self, file_name):
        g = gzip.open(file_name, self.read_mode)
        for l in g:
            yield eval(l)

In [None]:
def flatten(l):
    return [item for sublist in l for item in sublist]

In [None]:
def get_cat(l):
    return l[-2]

## data loading

In [None]:
# file_path corresponds to the file of the .gz file which contains the JSON file. 
product_path = '/mnt/share/datasets/product-classification/meta_Electronics.json.gz'
#product_path = 'D:\\TUM\\courses\\sem_3\\practical DM\\datasets\\meta_Electronics.json.gz'
loader = JSONLoader()
product = loader.load_data(product_path)

In [None]:
product.head()

In [None]:
product.categories = product.categories.apply(flatten)
product.categories = product.categories.apply(get_cat)

# code that did the magic of running in sub-seconds complexity.
#sin_cat_dict = Series(product.categories.values,index=product.asin).to_dict()

In [None]:
product.categories.value_counts()

## creating dataset

In [None]:
# change the threshold to experiment
threshold = 7000
percent = []
cats = []
counter = 0
counts = product.categories.value_counts()
for key, val in counts.iteritems():
    if val >= threshold:
        counter+=1
        percent.append( (val/product.shape[0]) * 100)
        cats.append(key)

In [None]:
product_cat_subset = product[product.categories.isin(cats)]
product_cat_subset.categories.value_counts()

In [None]:
# number of categories
len(cats)

In [None]:
# percent of data we are using
sum(percent)

## Downloading

In [None]:
# change the project path to root of the repository. make sure that datasets folder is added to .gitignore
project_path = '.'
datasets_path = os.path.join(project_path, 'datasets')

In [None]:
if not os.path.exists(datasets_path):
    os.makedirs(datasets_path)

In [None]:
import wget
import random
import time
for cat in cats:
    
    # switch off the download_cutoff_activate if you want to download all images in the category.
    download_cutoff_activate = True
    
    # change the download cutoff if required, minimum it should be 7k
    download_cutoff = 10000
    product_cat_subset_subset = product_cat_subset[product_cat_subset.categories == cat]
    cat_path = os.path.join(datasets_path, cat)
    if not os.path.exists(cat_path):
        os.makedirs(cat_path)
    os.chdir(cat_path)
    imurls = product_cat_subset_subset.imUrl.tolist()
    imurls = list(set(imurls))
    if download_cutoff_activate and len(imurls) > download_cutoff:
        # randomly sample 10k urls from categories that contain more than 10k images to reduce download time
        imurls = random.sample(imurls, download_cutoff)
        
        # sanity check that sampling worked correctly
        assert len(imurls) == download_cutoff
        
    print('number of urls to be downloaded for category: ' + cat + ' is: ' + str(len(imurls)))
    for idx, url in enumerate(imurls):
        try:
            # logging
            if idx % 1000 == 0:
                print('images downloaded: ' + str(idx))
            
            # download the image using wget in cat_path
            file = wget.download(url)
        except Exception as e:
            pass