In [None]:
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from collections import namedtuple
from fuel.datasets.hdf5 import H5PYDataset


BATH_SIZE = 100
IMG_SIZE = 256

# Img resizing stuff
LR_HR_RATIO = 4
BIG_SIZE = int(IMG_SIZE * 76 / 64)
SMALL_SIZE = int(BIG_SIZE / LR_HR_RATIO)

DATA_TEMPLATE =             '/fashion/ssense_%i_%i.h5'
LANGUAGE_MODEL_FILE =       '/models/glove/glove.6B.300d.txt'
LANGUAGE_MODEL_VOCABULARY = '/models/glove/glove.6B.vocab'

FASTTEXT_DATA =             '/data/fashion/txt/fashion.dedup.txt'
FASTTEXT_DATA_TRAIN =       '/data/fashion/txt/fashion-train.txt'
FASTTEXT_DATA_CLEAN =       '/data/fashion/txt/fashion-clean.txt'

OUTPUT =                    '/data/classes_and_texts.txt'

In [None]:
data_set = H5PYDataset(DATA_TEMPLATE % (IMG_SIZE, IMG_SIZE), which_sets=('all',))

In [None]:
def list_keys(data_set): 
    return [k.split('_')[1] for k in sorted(data_set.axis_labels.keys())]

In [None]:
for i, key in enumerate(list_keys(data_set)):
    print("%d %s" % (i, key))

In [None]:
Item = namedtuple('Item', ' '.join(list_keys(data_set)))

In [None]:
import string
from collections import Counter

BLACK_LIST = string.punctuation.replace('%', '') + '\n'

def normalize(text_array, 
    black_list = BLACK_LIST, 
    vocab=None, lowercase =  True, tokenize = False):
    text = text_array[0]
    if black_list:
        text = text.translate(None, BLACK_LIST)
    if lowercase:
        text = text.lower()
    if vocab:
        text = ' '.join([word for word in text.split() if word in vocab])
    if tokenize:
        return text.split()
    return text

def encode_fast_text(label, text):
    clean_label = label.lower().replace(" ", "-")
    clean_text = ' '.join(text[0][0].split())
    return "__label__%s %s\n" % (clean_label, clean_text)

In [None]:
def dump_to_fastext_corpora(data_set, output, id2category,
                            batch_size = BATH_SIZE, 
                            limit=None):
    '''
    Dumps the dataset to be consumed by fastText.
    '''
    N = data_set.num_examples
    num_batch = N / batch_size
    
    handle = data_set.open()
    with open(output, 'wr') as f:
        processed = 0
        for i in itertools.islice(xrange(num_batch), limit):
            # fetch batch of data
            text_batch, img_batch, metadata_batch  = data_set.get_data(
                handle, slice(i*batch_size, min((i+1)*batch_size, N - processed)))
            
            # process batch
            lines = [encode_fast_text(id2category[id_], text) for 
                     text, id_ in zip(text_batch, metadata_batch[:,0])]

            # dumplines
            f.writelines(lines)
            
            # track progress 
            processed += text_batch.shape[0]   
            if i % 100 == 0:
                percent = int(((100.0 * i )/ num_batch))
                print("Processing %i batch out of %i [%i processed]" % (i +1, num_batch +1, processed))

In [None]:
def dump_for_emmbeddings(
    data_set, output, 
    batch_size = BATH_SIZE, limit=None, vocab=None):
    '''
    Dumps the hdf5 dataset to flat textfile
    '''
    N = data_set.num_examples
    num_batch = N / batch_size + 1
    
    handle = data_set.open()
    with open(output, 'wr') as f:
        processed = 0
        for i in itertools.islice(xrange(num_batch), limit):
            # fetch batch of data
            low, high = i*batch_size, min((i+1)*batch_size, data_set.num_examples)
            rows = data_set.get_data(handle, slice(low, high))
            
            # process batch
            classes = [row[0].replace(" ", "_") for row in rows[1]]
            texts = [normalize(text, vocab=vocab) for text in rows[4]]
            lines = ["%s %s\n" % (c, t) for c,t in zip(classes, texts)]
            # dumplines
            f.writelines(lines)
             
            if i % 100 == 0:
                percent = int(((100.0 * i )/ num_batch))
                print("Low: %d, high: %d" % (low, high))
                print("Processing %i batch out of %i [%i processed]" % (i, num_batch, processed))
            
            # track progress 
            processed += len(texts)  

In [None]:
import itertools
import math

tokens = Counter()
classes = Counter()

handle = data_set.open()
num_batch = data_set.num_examples / BATH_SIZE
for i in xrange(num_batch):
    if i % 100 == 0:
        percent = int(((100.0 * i )/ num_batch))
        print("Processing %i batch out of %i [%i percent]" % (i, num_batch, percent))
        print("Number of tokens in the dictionary: %i" % len(tokens))
        print("Number of classes in the dictionary: %i" % len(classes))
    
    rows = data_set.get_data(
        handle, 
        slice(i*BATH_SIZE, min((i+1)*BATH_SIZE, data_set.num_examples))
    )
    
    tokens.update(itertools.chain(*[
        normalize(text, tokenize=True) for text in rows[4]
    ]))
    classes.update(itertools.chain([
        row[0] for row in rows[1]
    ]))

In [None]:
print("There are %s distinct words in the dataset" % len(tokens))
print("There are %s distinct classes in the dataset" % len(classes))

In [None]:
tokens.most_common(10)

In [None]:
labels, values = zip(*classes.most_common())

indexes = np.arange(len(labels))
width = 1

plt.figure(figsize=(15, 5))
plt.bar(indexes, values, width)
plt.xticks(indexes + width * 0.5, labels, rotation=90)
plt.show()

# Are we missing something?

In [None]:
with open(LANGUAGE_MODEL_FILE, 'rt') as f:
    vocab = set([line.split(" ")[0] for line in f.readlines()])

In [None]:
missing_words = Counter({ word: c[word] for word in c if word.lower() not in vocab})
print("There are %i missing words out of %i" % (len(missing_words), len(c)))
print("Most common mising words")
missing_words.most_common(10)

# Sample data

In [None]:
import scipy.misc
import random

def clean_plot(img):
    plt.imshow(img); plt.xticks([]); plt.yticks([]); plt.show()
    
def clean_plot_dpi(img, size, dpi=60):
    plt.figure().set_size_inches(float(size)/float(dpi),float(size)/float(dpi))
    plt.xticks([]); plt.yticks([]);
    plt.imshow(img);

In [None]:
i = random.randint(0, data_set.num_examples)
item = Item._make(data_set.get_data(handle, slice(i, min((i+1), data_set.num_examples))))

In [None]:
plt.title(item.category[0][0])
plt.imshow(item.image[0]);
print("Raw text: \n%s\n" % item.description[0][0])
print("Normalized text: \n%s\n" %  ' '.join(
    normalize(item.description[0], tokenize=True)))
print("Normalized text with vocab: \n%s\n" %  ' '.join(
    normalize(item.description[0], vocab=vocab, tokenize=True)))
plt.show()

## To be consistent with StackGAN

In [None]:
img = item.image[0]
img_76 = scipy.misc.imresize(img, [SMALL_SIZE, SMALL_SIZE], 'bicubic')
img_304 = scipy.misc.imresize(img, [BIG_SIZE, BIG_SIZE], 'bicubic')

In [None]:
clean_plot_dpi(img_76, SMALL_SIZE); clean_plot_dpi(img_304, BIG_SIZE);

# Text preprocessing for fastText

In [None]:
from random import sample

def clean_line(text):
    return text.translate(None, BLACK_LIST).lower()

def process(line):
    label, text = line.split(" ", 1)
    return "%s %s\n" % (label, clean_line(line))

def strip(line):
    _, text = line.split(" ", 1)
    return "%s\n" % clean_line(text)

with open(FASTTEXT_DATA, 'rt') as f:
    all_lines = f.readlines()
    lines = list(set(all_lines))

with open(FASTTEXT_DATA_TRAIN, 'wt') as f:
    f.writelines(sample([process(line) for line in lines], len(lines)))
    
with open(FASTTEXT_DATA_CLEAN, 'wt') as f:
    f.writelines([strip(line) for line in all_lines])

In [None]:
from tensorflow.python.client import device_lib

In [None]:
print(device_lib.list_local_devices())

# Dump data for text embedding with CNNN

In [None]:
dump_for_emmbeddings(data_set, OUTPUT, batch_size = BATH_SIZE, vocab=None)