In [1]:
%load_ext autoreload
%autoreload 1

In [2]:
%aimport utils
%aimport Networks

In [3]:
import numpy as np
import pandas as pd
import random
import os
import time
import tensorflow as tf
from tqdm import tqdm
from os import path
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA
from math import ceil
from utils import load_embeddings_and_ids, User

In [4]:
# use a single GPU because we want to be nice with other people :)
os.environ["CUDA_VISIBLE_DEVICES"]="1"

###  Load pre-trained ResNet50 image embeddings

In [5]:
resnet50_embeddings,\
artwork_ids,\
artwork_id2index = load_embeddings_and_ids(
'/mnt/workspace/Ugallery/ResNet50/', 'flatten_1.npy', 'ids')

In [6]:
n_artworks = len(artwork_ids)
n_artworks

13297

###  Load transactions

In [7]:
sales_df = pd.read_csv('./valid_sales.csv')
artworks_df = pd.read_csv('./valid_artworks.csv')

In [8]:
artist_ids = np.full((n_artworks,), -1, dtype=int)
for _artworkId, _artistId in zip(artworks_df.id, artworks_df.artist_id):
    i = artwork_id2index[_artworkId]
    artist_ids[i] = _artistId

In [9]:
artistId2artworkIndexes = dict()
for i, _artistId in enumerate(artist_ids):
    if _artistId == -1:
        continue
    try:
        artistId2artworkIndexes[_artistId].append(i)
    except KeyError:
        artistId2artworkIndexes[_artistId] = [i]

### Collect transactions per user (making sure we hide the last nonfirst purchase basket per user)

#### create list of users

In [10]:
user_ids = sales_df.customer_id.unique()
user_id2index = { _id:i for i,_id in enumerate(user_ids) }
users = [User(uid) for uid in user_ids]
n_users = len(user_ids)

#### collect and sanity check transactions per user

In [11]:
sorted_sales_df = sales_df.sort_values('order_date')

In [12]:
# clear structures to prevent possible duplicate elements
for user in users:
    user.clear()

# collect transactions per user sorted by timestamp
for uid, aid, t in zip(sorted_sales_df.customer_id,
                       sorted_sales_df.artwork_id,
                       sorted_sales_df.order_date):
    users[user_id2index[uid]].append_transaction(aid,t,artwork_id2index,artist_ids)
    assert users[user_id2index[uid]]._uid == uid
    
# bin transctions with same timestamps into purchase baskets
for user in users:
    user.build_purchase_baskets()
    user.sanity_check_purchase_baskets()
    user.remove_last_nonfirst_purchase_basket(artwork_id2index, artist_ids)
    user.sanity_check_purchase_baskets()

### Compute minimun cosine distance from each user profile to each item in the dataset
\* using R200 vectors obtained with PCA(200) over ResNet50 embeddings

In [13]:
resnet50_PCA200 = PCA(n_components=200).fit_transform(resnet50_embeddings)

In [14]:
resnet50_PCA200.shape

(13297, 200)

In [15]:
distmat = squareform(pdist(resnet50_PCA200, 'cosine'))

In [16]:
user2artwork_mindist = np.empty((n_users, n_artworks))

In [17]:
for ui in tqdm(range(n_users)):
    for ai in range(n_artworks):
        user2artwork_mindist[ui][ai] = min(distmat[ai][j] for j in users[ui].artwork_idxs)

100%|██████████| 2919/2919 [00:41<00:00, 69.89it/s]


### Load images

### Generate training data

In [18]:
def hash(ui, pi, ni):
    return  ((pi * n_artworks) + ni) * n_users + ui

In [19]:
def sanity_check_instance(instance, pos_is_purchased=True, not_sharing_artist=False):
    ui, pi, ni = instance    
    try:
        assert 0 <= ui < n_users
        assert 0 <= pi < n_artworks
        assert 0 <= ni < n_artworks
        assert pi != ni
        assert user2artwork_mindist[ui][ni] > 0
        user = users[ui]
        if pos_is_purchased is True:
            assert pi in user.artwork_idxs_set
        else:
            assert pi not in user.artwork_idxs_set
        assert ni not in user.artwork_idxs_set
        if not_sharing_artist:
            assert artist_ids[ni] not in user.artist_ids_set
    except AssertionError:
        print('ui = ', ui)
        print('pi = ', pi)
        print('ni = ', ni)
        raise

In [20]:
def append_instance(container, instance, **kwargs):
    global _hash_collisions, _visual_collisions
    h = hash(*instance)
    if h in used_hashes:
        _hash_collisions += 1
        return False
    ui, ni = instance[0], instance[2]
    if user2artwork_mindist[ui][ni] == 0:
        _visual_collisions += 1
        return False
    sanity_check_instance(instance, **kwargs)
    container.append(instance)
    used_hashes.add(h)
    return True

In [21]:
used_hashes = set()
_hash_collisions = 0
_visual_collisions = 0
train_instances = []
test_instances = []

##### 1) Given a user, his purchased items should trivially be ranked higher than any of his non-purchased items

In [22]:
def sample_artwork_index__nonpurchased(purchased_artwork_idxs):
    while True:
        i = random.randint(0, n_artworks-1)
        if i not in purchased_artwork_idxs:
            return i

In [23]:
def generate_samples__rank_purchased_above_nonpurchased(instances_container, n_samples_per_user):    
    for ui, user in enumerate(users):
        profile = user.artwork_idxs
        profile_set = user.artwork_idxs_set
        for _ in range(n_samples_per_user):
            for __ in range(5):
                pi = random.choice(profile)
                ni = sample_artwork_index__nonpurchased(profile_set)
                if append_instance(instances_container, (ui, pi, ni)):
                    break

In [24]:
print('sampling train instances ...')
generate_samples__rank_purchased_above_nonpurchased(train_instances, n_samples_per_user=800)
print('sampling test instances ...')
generate_samples__rank_purchased_above_nonpurchased(test_instances, n_samples_per_user=60)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
2335200 175140
hash_collisions =  70254
visual_collisions =  3


##### 2) Given a user, any non-purchased item sharing the same artist with one of his purchased items should be ranked higher than any item of a non-purchased artist as long as ResNet50 doesn't disagree by much

In [25]:
def sample_artwork_index__notsharingartist(profile_artist_ids):
    while True:
        i = random.randint(0, n_artworks-1)
        if artist_ids[i] not in profile_artist_ids:
            return i

In [26]:
def sample_artwork_index__nonpurchased_sharingartist(artist_id, purchased_artwork_idxs):
    candidate_idxs = artistId2artworkIndexes[artist_id]
    for _ in range(10): # try at most 10 times
        i = random.choice(candidate_idxs)
        if i not in purchased_artwork_idxs:
            return i
    return None # failed to find

In [27]:
def reject_ui_pi_ni_triplet(ui, pi, ni, threshold=0.55):
    dp = user2artwork_mindist[ui][pi]
    dn = user2artwork_mindist[ui][ni]
    return (dp + dn) == 0 or dp / (dp + dn) > threshold

In [28]:
def sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, threshold):
    while True:
        ni = sample_artwork_index__notsharingartist(users[ui].artist_ids_set)
        if not reject_ui_pi_ni_triplet(ui, pi, ni, threshold=threshold):
            return ni

In [29]:
def generate_samples__rank_purchased_artist_above_nonpurchased_artist(instances_container, n_samples_per_user):
    for ui, user in enumerate(users):
        profile = user.artwork_idxs
        profile_set = user.artwork_idxs_set
        for _ in range(n_samples_per_user):
            for __ in range(5):
                aid = artist_ids[random.choice(profile)]
                assert aid != -1
                pi = sample_artwork_index__nonpurchased_sharingartist(aid, profile_set)
                if pi is None:
                    continue
                ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.55)
                if append_instance(instances_container, (ui, pi, ni), pos_is_purchased=False, not_sharing_artist=True):
                    break

In [30]:
print('sampling train instances ...')
generate_samples__rank_purchased_artist_above_nonpurchased_artist(train_instances, n_samples_per_user=800)
print('sampling test instances ...')
generate_samples__rank_purchased_artist_above_nonpurchased_artist(test_instances, n_samples_per_user=60)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
4627404 347053
hash_collisions =  78304
visual_collisions =  3


### Training Model

In [31]:
def generate_minibatches(tuples, batch_size):
    n_tuples = len(tuples)
    n_batches = ceil(n_tuples / batch_size)
    
    assert n_batches * batch_size >= n_tuples
    assert (n_batches - 1) * batch_size < n_tuples
    
    indexes = list(range(n_tuples))
    random.shuffle(indexes)
    
    print('n_tuples = ', n_tuples)
    print('n_batches = ', n_batches)
    
    user_index_batches = [None] * n_batches
    pos_index_batches = [None] * n_batches
    neg_index_batches = [None] * n_batches
    
    for i in range(n_batches):
        jmin = i * batch_size
        jmax = min(jmin + batch_size, n_tuples)
        actual_batch_size = jmax - jmin
        
        user_index_batch = np.empty((actual_batch_size,), dtype=int)
        pos_index_batch = np.empty((actual_batch_size,), dtype=int)
        neg_index_batch = np.empty((actual_batch_size,), dtype=int)
        
        for j in range(actual_batch_size):
            t = tuples[indexes[jmin+j]]
            user_index_batch[j] = t[0]
            pos_index_batch[j] = t[1]
            neg_index_batch[j] = t[2]

        user_index_batches[i] = user_index_batch
        pos_index_batches[i] = pos_index_batch
        neg_index_batches[i] = neg_index_batch
        
    return dict(
        user_index_batches = user_index_batches,
        pos_index_batches  = pos_index_batches,
        neg_index_batches  = neg_index_batches,
        n_batches               = n_batches,
    )

In [32]:
def sanity_check_minibatches(minibatches):
    user_index_batches = minibatches['user_index_batches']
    pos_index_batches = minibatches['pos_index_batches']
    neg_index_batches = minibatches['neg_index_batches']
    n_batches = minibatches['n_batches']
    assert n_batches == len(user_index_batches)
    assert n_batches == len(pos_index_batches)
    assert n_batches == len(neg_index_batches)
    assert n_batches > 0
    
    for user_index, pos_index, neg_index in zip(
        user_index_batches,
        pos_index_batches,
        neg_index_batches
    ):
        n = user_index.shape[0]
        assert n == pos_index.shape[0]
        assert n == neg_index.shape[0]
        
        for i in range(n):
            ui = user_index[i]
            pi = pos_index[i]
            ni = neg_index[i]
            assert pi != ni
            assert ni not in users[ui].artwork_idxs_set

In [33]:
MODEL_PATH = '/mnt/workspace/pamessina_models/ugallery/VBPR/v3_hidinglast/'

In [34]:
import tensorflow as tf
from Networks import VBPR_Network, TrainLogger

In [35]:
def train_network(train_minibatches, test_minibatches,
                  n_train_instances, n_test_instances, batch_size,
                  max_seconds_training=3600,
                  min_seconds_to_check_improvement=60,
                  early_stopping_checks=4,
                  learning_rates=[1e-3]):
    
    n_train_batches = train_minibatches['n_batches']
    
    print('learning_rates = ', learning_rates)
    
    with tf.Graph().as_default():
        network = VBPR_Network(
            n_users=n_users,
            n_items=n_artworks,
            user_latent_dim=128,
            item_latent_dim=64,
            item_visual_dim=64,
            pretrained_dim=2048,
        )
        with tf.Session() as sess:
            try:
                saver = tf.train.Saver()            
                saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
                print('model successfully restored from checkpoint!')
            except ValueError:
                print('no checkpoint found: initializing variables with random values')
                os.makedirs(MODEL_PATH, exist_ok=True)
                sess.run(tf.global_variables_initializer())            
            trainlogger = TrainLogger(MODEL_PATH + 'train_logs.csv')

            # ========= BEFORE TRAINING ============
            
            initial_test_acc = 0.            
            for user_index, pos_index, neg_index in zip(
                test_minibatches['user_index_batches'],
                test_minibatches['pos_index_batches'],
                test_minibatches['neg_index_batches']
            ):
                minibatch_test_acc = network.get_test_accuracy(
                    sess, resnet50_embeddings, user_index, pos_index, neg_index)
                initial_test_acc += minibatch_test_acc
            initial_test_acc /= n_test_instances

            print("Before training: test_accuracy = %f" % initial_test_acc)
            
            best_test_acc = initial_test_acc
            seconds_training = 0
            elapsed_seconds_from_last_check = 0
            checks_with_no_improvement = 0
            last_improvement_loss = None
            
            # ========= TRAINING ============
            
            print ('Starting training ...')
            n_lr = len(learning_rates)
            lr_i = 0
            train_loss_ema = None # exponential moving average
            
            while seconds_training < max_seconds_training:
                
                for train_i, (user_index, pos_index, neg_index) in enumerate(zip(
                    train_minibatches['user_index_batches'],
                    train_minibatches['pos_index_batches'],
                    train_minibatches['neg_index_batches']
                )):
                    # optimize and get traing loss
                    start_t = time.time()
                    _, minibatch_train_loss = network.optimize_and_get_train_loss(
                        sess, resnet50_embeddings, user_index, pos_index, neg_index, learning_rates[lr_i])
                    delta_t = time.time() - start_t
                    
                    # update train loss exponential moving average
                    train_loss_ema = minibatch_train_loss if train_loss_ema is None else\
                                    0.999 * train_loss_ema + 0.001 * minibatch_train_loss
                    
                    # update time tracking variables
                    seconds_training += delta_t
                    elapsed_seconds_from_last_check += delta_t
                    
                    # check for improvements using test set if it's time to do so
                    if elapsed_seconds_from_last_check >= min_seconds_to_check_improvement:
                        
                        # --- testing                        
                        test_acc = 0.
                        for _user_index, _pos_index, _neg_index in zip(
                            test_minibatches['user_index_batches'],
                            test_minibatches['pos_index_batches'],
                            test_minibatches['neg_index_batches']
                        ):
                            minibatch_test_acc = network.get_test_accuracy(
                                sess, resnet50_embeddings, _user_index, _pos_index, _neg_index)
                            test_acc += minibatch_test_acc
                        test_acc /= n_test_instances
                    
                        print(("train_i=%d, train_loss = %.12f, test_accuracy = %.5f,"
                               " check_secs = %.2f, total_secs = %.2f") % (
                                train_i, train_loss_ema, test_acc, elapsed_seconds_from_last_check, seconds_training))                        
                        
                        # check for improvements
                        if (test_acc > best_test_acc) or (
                            test_acc == best_test_acc and (
                                last_improvement_loss is not None and\
                                last_improvement_loss > train_loss_ema
                            )
                        ):  
                            last_improvement_loss = train_loss_ema
                            best_test_acc = test_acc
                            checks_with_no_improvement = 0
                            saver = tf.train.Saver()
                            save_path = saver.save(sess, MODEL_PATH)                    
                            print("   ** improvement detected: model saved to path ", save_path)
                            model_updated = True
                        else:
                            checks_with_no_improvement += 1                            
                            model_updated = False

                        # --- logging ---                        
                        trainlogger.log_update(
                            train_loss_ema, test_acc, n_train_instances, n_test_instances,
                            elapsed_seconds_from_last_check, batch_size, learning_rates[lr_i], 't' if model_updated else 'f')
                        
                        # --- check for early stopping
                        if checks_with_no_improvement >= early_stopping_checks:
                            if lr_i + 1 < len(learning_rates):
                                lr_i += 1
                                checks_with_no_improvement = 0
                                print("   *** %d checks with no improvements -> using a smaller learning_rate = %f" % (
                                    early_stopping_checks, learning_rates[lr_i]))
                            else:
                                print("   *** %d checks with no improvements -> early stopping :(" % early_stopping_checks)
                                return
                        
                        # --- reset check variables
                        elapsed_seconds_from_last_check = 0                        
            print('====== TIMEOUT ======')

In [36]:
train_batch_size = 120000
train_minibatches = generate_minibatches(train_instances, train_batch_size)
sanity_check_minibatches(train_minibatches)

n_tuples =  4627404
n_batches =  39


In [37]:
test_batch_size = 120000
test_minibatches = generate_minibatches(test_instances, test_batch_size)
sanity_check_minibatches(test_minibatches)

n_tuples =  347053
n_batches =  3


In [None]:
train_network(
    train_minibatches, test_minibatches,
    len(train_instances), len(test_instances), train_batch_size,
    max_seconds_training=3600 * 3,
    min_seconds_to_check_improvement=120,
    early_stopping_checks=3,
    learning_rates=[1e-4, 3.33e-5, 1e-5])