In [2]:
%load_ext autoreload
%autoreload 1

In [3]:
%aimport utils, Networks

In [4]:
import numpy as np
import pandas as pd
import heapq
import random
import os
import time
from tqdm import tqdm
from os import path
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA
from math import ceil, floor
from utils import load_embeddings_and_ids, User

In [5]:
# use a single GPU because we want to be nice with other people :)
os.environ["CUDA_VISIBLE_DEVICES"]="1"

###  Load pre-trained ResNet50 image embeddings

In [6]:
resnet50_embeddings,\
artwork_ids,\
artwork_id2index = load_embeddings_and_ids(
'/mnt/workspace/Ugallery/ResNet50/', 'flatten_1.npy', 'ids')

In [7]:
n_artworks = len(artwork_ids)
n_artworks

13297

###  Load transactions

In [8]:
sales_df = pd.read_csv('./valid_sales.csv')
artworks_df = pd.read_csv('./valid_artworks.csv')

In [9]:
artist_ids = np.full((n_artworks,), -1, dtype=int)
for _artworkId, _artistId in zip(artworks_df.id, artworks_df.artist_id):
    i = artwork_id2index[_artworkId]
    artist_ids[i] = _artistId

In [10]:
artistId2artworkIndexes = dict()
for i, _artistId in enumerate(artist_ids):
    if _artistId == -1:
        continue
    try:
        artistId2artworkIndexes[_artistId].append(i)
    except KeyError:
        artistId2artworkIndexes[_artistId] = [i]

### Collect transactions per user (making sure we hide the last nonfirst purchase basket per user)

#### create list of users

In [11]:
user_ids = sales_df.customer_id.unique()
user_id2index = { _id:i for i,_id in enumerate(user_ids) }
users = [User(uid) for uid in user_ids]
n_users = len(user_ids)

#### collect and sanity check transactions per user

In [12]:
sorted_sales_df = sales_df.sort_values('order_date')

In [13]:
# clear structures to prevent possible duplicate elements
for user in users:
    user.clear()

# collect transactions per user sorted by timestamp
for uid, aid, t in zip(sorted_sales_df.customer_id,
                       sorted_sales_df.artwork_id,
                       sorted_sales_df.order_date):
    users[user_id2index[uid]].append_transaction(aid,t,artwork_id2index,artist_ids)
    assert users[user_id2index[uid]]._uid == uid
    
# bin transctions with same timestamps into purchase baskets
for user in users:
    user.build_purchase_baskets()
    user.sanity_check_purchase_baskets()
#     user.remove_last_nonfirst_purchase_basket(artwork_id2index, artist_ids)
#     user.sanity_check_purchase_baskets()

### Compute minimun cosine distance from each user profile to each item in the dataset
\* using R200 vectors obtained with PCA(200) over ResNet50 embeddings

In [14]:
resnet50_PCA200 = PCA(n_components=200).fit_transform(resnet50_embeddings)

In [15]:
resnet50_PCA200.shape

(13297, 200)

In [16]:
distmat = squareform(pdist(resnet50_PCA200, 'cosine'))

In [17]:
user2artwork_mindist = np.empty((n_users, n_artworks))

In [18]:
for ui in tqdm(range(n_users)):
    for ai in range(n_artworks):
        user2artwork_mindist[ui][ai] = min(distmat[ai][j] for j in users[ui].artwork_idxs)

100%|██████████| 2919/2919 [00:45<00:00, 64.36it/s]


### Generate training data

In [19]:
def hash(profile, pi, ni):
    h = 0
    for x in profile:
        h = (h * 127) % 1000000007 + x
    h = (h * 127) % 1000000007 + pi
    h = (h * 127) % 1000000007 + ni
    return h

In [20]:
def sanity_check_instance(instance, pos_is_purchased=True, not_sharing_artist=True):
    profile, pi, ni, ui = instance
    try:
        assert 0 <= pi < n_artworks
        assert 0 <= ni < n_artworks
        assert pi != ni
        assert distmat[pi][ni] > 0
        if ui == -1: return
        
        assert 0 <= ui < n_users
        user = users[ui]
        assert all(i in user.artwork_idxs_set for i in profile)
        if pos_is_purchased is not None:
            if pos_is_purchased:
                assert pi in user.artwork_idxs_set
            else:
                assert pi not in user.artwork_idxs_set
        assert ni not in user.artwork_idxs_set        
        if not_sharing_artist:
            assert artist_ids[ni] not in user.artist_ids_set
    except AssertionError:
        print('profile = ', profile)
        print('pi = ', pi)
        print('ni = ', ni)
        print('ui = ', ui)
        raise

In [74]:
def append_instance(container, instance, **kwargs):
    global _hash_collisions, _visual_collisions
    h = hash(instance[0], instance[1], instance[2])
    if h in used_hashes:
        _hash_collisions += 1
        return False
    pi, ni, ui = instance[1], instance[2], instance[3]
    if distmat[pi][ni] == 0 or (ui != -1 and user2artwork_mindist[ui][ni] == 0):
        _visual_collisions += 1
        return False
    sanity_check_instance(instance, **kwargs)
    container.append(instance)
    used_hashes.add(h)
    return True

In [76]:
used_hashes = set()
_hash_collisions = 0
_visual_collisions = 0
train_instances = []
test_instances = []

##### 0) Given a profile of a single item, such item should be ranked higher than any other item

In [67]:
def sample_artwork_index__nonidentical(i):
    while True:
        j = random.randint(0, n_artworks-1)
        if distmat[i][j] > 0:
            return j

In [68]:
def generate_samples__rank_single_item_above_anything_else(instances_container, n_samples_per_item):
    for pi in range(n_artworks):
        profile = (pi,)
        n = n_samples_per_item
        while n > 0:
            ni = sample_artwork_index__nonidentical(pi)
            if append_instance(instances_container, (profile, pi, ni, -1)):
                n -= 1

In [77]:
print('sampling train instances ...')
generate_samples__rank_single_item_above_anything_else(train_instances, n_samples_per_item=100)
print('sampling test instances ...')
generate_samples__rank_single_item_above_anything_else(test_instances, n_samples_per_item=4)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
1329700 53188
hash_collisions =  5394
visual_collisions =  0


##### 1) Given a profile of a single item, other items sharing the same artist should be ranked higher than items of different artists as long as ResNet50 agrees

In [63]:
def sample_artwork_index__nonidentical_sharingartist(i):
    aid = artist_ids[i]
    assert aid != -1
    candidate_idxs = artistId2artworkIndexes[aid]
    assert len(candidate_idxs) >= 2
    for _ in range(10): # try at most 10 times
        j = random.choice(candidate_idxs)
        if distmat[i][j] > 0: return j
    return None # failed to find

In [50]:
def sample_artwork_index__notsharingartist_visuallyacceptable(i, pi):    
    aid = artist_ids[i]
    assert aid != -1
    assert artist_ids[pi]  == aid
    for _ in range(10): # try at most 10 times
        ni = random.randint(0, n_artworks-1)
        if artist_ids[ni] != aid and distmat[i][pi] < distmat[i][ni]:
            return ni
    return None

In [54]:
# debug_i = 1500
# print(artist_ids[debug_i])
# debug_pi = sample_artwork_index__nonidentical_sharingartist(debug_i)
# debug_ni = sample_artwork_index__notsharingartist_visuallyacceptable(debug_i, debug_pi)
# artwork_ids[debug_i], artwork_ids[debug_pi], artwork_ids[debug_ni]

In [55]:
def generate_samples__rank_single_item_artist_above_other_artists(instances_container, n_samples_per_item):
    for i in range(n_artworks):
        aid = artist_ids[i]
        if aid == -1 or len(artistId2artworkIndexes[aid]) < 2:
            continue
        profile = (i,)
        for _ in range(n_samples_per_item):
            for __ in range(5):
                pi = sample_artwork_index__nonidentical_sharingartist(i)
                if pi is None: continue                
                ni = sample_artwork_index__notsharingartist_visuallyacceptable(i, pi)
                if ni is None: continue
                if append_instance(instances_container, (profile, pi, ni, -1)):
                    break

In [78]:
print('sampling train instances ...')
generate_samples__rank_single_item_artist_above_other_artists(train_instances, n_samples_per_item=100)
print('sampling test instances ...')
generate_samples__rank_single_item_artist_above_other_artists(test_instances, n_samples_per_item=4)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
2088328 83534
hash_collisions =  12306
visual_collisions =  0


##### 2) Given a list of purchased items, each purchased item should trivially be ranked higher than any non-purchased item

In [70]:
def sample_artwork_index__nonpurchased(purchased_artwork_idxs):
    while True:
        i = random.randint(0, n_artworks-1)
        if i not in purchased_artwork_idxs:
            return i

In [71]:
def generate_samples__rank_purchased_above_nonpurchased(instances_container, n_samples_per_user):    
    for ui, user in enumerate(users):
        profile = user.artwork_idxs
        profile_set = user.artwork_idxs_set
        for _ in range(n_samples_per_user):
            for __ in range(5):
                pi = random.choice(profile)
                ni = sample_artwork_index__nonpurchased(profile_set)
                if append_instance(instances_container, (profile, pi, ni, ui), not_sharing_artist=False):
                    break

In [79]:
print('sampling train instances ...')
generate_samples__rank_purchased_above_nonpurchased(train_instances, n_samples_per_user=600)
print('sampling test instances ...')
generate_samples__rank_purchased_above_nonpurchased(test_instances, n_samples_per_user=10)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
3839717 112721
hash_collisions =  66610
visual_collisions =  0


##### 3) Given a user, any non-purchased item sharing the same artist with a purchased item should be ranked higher than any item of a non-purchased artist as long as ResNet50 agrees

In [82]:
def sample_artwork_index__nonpurchased_sharingartist(artist_id, artwork_idxs_set):
    candidate_idxs = artistId2artworkIndexes[artist_id]
    for _ in range(10): # try at most 10 times
        i = random.choice(candidate_idxs)
        if i not in artwork_idxs_set:
            return i
    return None # failed to find

In [83]:
def sample_artwork_index__notsharingartist_visuallyacceptable_largeprofile(ui, pi):
    artist_ids_set = users[ui].artist_ids_set
    for _ in range(10): # try at most 10 times
        ni = random.randint(0, n_artworks-1)
        if artist_ids[ni] not in artist_ids_set and (
            user2artwork_mindist[ui][pi] < user2artwork_mindist[ui][ni]):
            return ni
    return None

In [84]:
def generate_samples__rank_purchased_artist_above_nonpurchased_artist(instances_container, n_samples_per_user=500):
    for ui, user in enumerate(users):
        profile = user.artwork_idxs
        profile_set = user.artwork_idxs_set
        for _ in range(n_samples_per_user):
            for __ in range(5):
                aid = artist_ids[random.choice(profile)]
                assert aid != -1
                pi = sample_artwork_index__nonpurchased_sharingartist(aid, profile_set)
                if pi is None: continue
                ni = sample_artwork_index__notsharingartist_visuallyacceptable_largeprofile(ui, pi)
                if ni is None: continue
                if append_instance(instances_container, (profile, pi, ni, ui), pos_is_purchased=False):
                    break

In [85]:
print('sampling train instances ...')
generate_samples__rank_purchased_artist_above_nonpurchased_artist(train_instances, n_samples_per_user=400)
print('sampling test instances ...')
generate_samples__rank_purchased_artist_above_nonpurchased_artist(test_instances, n_samples_per_user=10)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
4987538 141413
hash_collisions =  85437
visual_collisions =  0


##### 4) Given all previous purchases, rank each  item of the next purchase basket higher than any item of non-purchased artists

In [86]:
def generate_samples__given_past_rank_next(instances_container, n_samples_per_user=600):
    for ui, user in enumerate(users):
        n = len(user.baskets)
        if n <= 1:
            continue
        past_items = []
        n_samples_per_basket = ceil(n_samples_per_user / (n-1))
        for i in range(n-1):
            cur_b = user.baskets[i]
            for j in range(cur_b[0], cur_b[0] + cur_b[1]):
                past_items.append(user.artwork_idxs[j])
            next_b  = user.baskets[i+1]
            profile = past_items.copy()
            for _ in range(n_samples_per_basket):
                for __ in range(5):
                    pi = user.artwork_idxs[random.randint(next_b[0], next_b[0] + next_b[1] - 1)]
                    ni = sample_artwork_index__notsharingartist_visuallyacceptable_largeprofile(ui, pi)
                    if ni is None: continue
                    if append_instance(instances_container, (profile, pi, ni, ui)):
                        break

In [87]:
print('sampling train instances ...')
generate_samples__given_past_rank_next(train_instances, n_samples_per_user=1000)
print('sampling test instances ...')
generate_samples__given_past_rank_next(test_instances, n_samples_per_user=60)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
5715956 185475
hash_collisions =  118379
visual_collisions =  0


##### 5) Given only the present purchase basket, hide one and rank it higher than any item of non-purchased artists

In [89]:
def generate_samples__given_present_hide_rank_one(instances_container, n_samples_per_user=600):
    for ui, user in enumerate(users):
        n = sum(1 if b[1] >=2 else 0 for b in user.baskets)
        if n == 0:
            continue
        n_samples_per_basket = ceil(n_samples_per_user / n)
        for b in user.baskets:
            if b[1] < 2:
                continue
            bs = b[0]
            be = b[0] + b[1]
            n_samples_per_item = ceil(n_samples_per_basket / b[1])
            for i in range(bs, be):
                profile = [user.artwork_idxs[j] for j in range(bs, be) if j != i]
                assert len(profile) == be - bs - 1
                assert len(profile) > 0                
                pi = user.artwork_idxs[i]
                for _ in range(n_samples_per_item):
                    for __ in range(5):
                        ni = sample_artwork_index__notsharingartist_visuallyacceptable_largeprofile(ui, pi)
                        if ni is None: continue
                        if append_instance(instances_container, (profile, pi, ni, ui)):
                            break

In [90]:
print('sampling train instances ...')
generate_samples__given_present_hide_rank_one(train_instances, n_samples_per_user=1500)
print('sampling test instances ...')
generate_samples__given_present_hide_rank_one(test_instances, n_samples_per_user=20)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
6685275 199094
hash_collisions =  158381
visual_collisions =  0


##### 6) Given the past and the present, hide one and rank it higher than any item of non-purchased artists

In [91]:
def generate_samples__given_past_present_hide_rank_one(instances_container, n_samples_per_user=600):
    for ui, user in enumerate(users):
        if (len(user.baskets) < 2):
            continue
        u_baskets = user.baskets
        u_artwork_idxs = user.artwork_idxs
        n_baskets = len(u_baskets)
        purchased = []
        n_samples_per_basket = ceil(n_samples_per_user / (n_baskets-1))
        for i in range(n_baskets):
            b = u_baskets[i]
            purchased.extend(u_artwork_idxs[j] for j in range(b[0], b[0] + b[1]))
            if i == 0:
                continue
            assert len(purchased) == b[0] + b[1]
            jmax = b[0] + (b[1] if b[1] >= 2 else 0)
            assert jmax > 0
            n_samples_per_item = ceil(n_samples_per_basket / jmax)
            for j in range(jmax):
                profile = [x for k,x in enumerate(purchased) if k != j]                
                pi = u_artwork_idxs[j]
                for _ in range(n_samples_per_item):
                    for __ in range(5):
                        ni = sample_artwork_index__notsharingartist_visuallyacceptable_largeprofile(ui, pi)
                        if ni is None: continue
                        if append_instance(instances_container, (profile, pi, ni, ui)):
                            break

In [92]:
print('sampling train instances ...')
generate_samples__given_past_present_hide_rank_one(train_instances, n_samples_per_user=1000)
print('sampling test instances ...')
generate_samples__given_past_present_hide_rank_one(test_instances, n_samples_per_user=20)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', _visual_collisions)

sampling train instances ...
sampling test instances ...
7445804 251137
hash_collisions =  187506
visual_collisions =  0


#### sort train and test instances by profile size

In [93]:
train_instances.sort(key=lambda x: len(x[0]))
test_instances.sort(key=lambda x: len(x[0]))

### Train Model

In [94]:
def generate_minibatches(tuples, batch_size):
    n_tuples = len(tuples)
    n_batches = ceil(n_tuples / batch_size)
    
    print('n_tuples = ', n_tuples)
    print('n_batches = ', n_batches)
    
    profile_indexes_batches = [None] * n_batches
    profile_size_batches = [None] * n_batches
    positive_index_batches = [None] * n_batches
    negative_index_batches = [None] * n_batches
    
    for i in range(n_batches):
        jmin = i * batch_size
        jmax = min(jmin + batch_size, n_tuples)
        actual_batch_size = jmax - jmin
        
        profile_maxlen = max(len(tuples[j][0]) for j in range(jmin, jmax))
        profile_indexes_batch = np.full((actual_batch_size, profile_maxlen), 0, dtype=int)
        profile_size_batch = np.empty((actual_batch_size,))
        positive_index_batch = np.empty((actual_batch_size,), dtype=int)
        negative_index_batch = np.empty((actual_batch_size,), dtype=int)
        
        for j in range(actual_batch_size):            
            # profile indexes
            for k,v in enumerate(tuples[jmin+j][0]):
                profile_indexes_batch[j][k] = v
            # profile size
            profile_size_batch[j] = len(tuples[jmin+j][0])        
            # positive index
            positive_index_batch[j] = tuples[jmin+j][1]
            # negative index
            negative_index_batch[j] = tuples[jmin+j][2]
            
        profile_indexes_batches[i] = profile_indexes_batch
        profile_size_batches[i] = profile_size_batch
        positive_index_batches[i] = positive_index_batch
        negative_index_batches[i] = negative_index_batch
        
    return dict(
        profile_indexes_batches = profile_indexes_batches,
        profile_size_batches    = profile_size_batches,
        positive_index_batches  = positive_index_batches,
        negative_index_batches  = negative_index_batches,
        n_batches               = n_batches,
    )

In [95]:
def sanity_check_minibatches(minibatches):
    profile_indexes_batches = minibatches['profile_indexes_batches']
    profile_size_batches = minibatches['profile_size_batches']
    positive_index_batches = minibatches['positive_index_batches']
    negative_index_batches = minibatches['negative_index_batches']
    n_batches = minibatches['n_batches']
    assert n_batches == len(profile_indexes_batches)
    assert n_batches == len(profile_size_batches)
    assert n_batches == len(positive_index_batches)
    assert n_batches == len(negative_index_batches)
    assert n_batches > 0
    
    for profile_indexes, profile_size, positive_index, negative_index in zip(
        profile_indexes_batches,
        profile_size_batches,
        positive_index_batches,
        negative_index_batches
    ):
        n = profile_size.shape[0]
        assert n == profile_indexes.shape[0]
        assert n == positive_index.shape[0]
        assert n == negative_index.shape[0]
        
        for i in range(n):
            assert positive_index[i] != negative_index[i]
            psz = int(profile_size[i])
            m = profile_indexes[i].shape[0]
            assert psz <= m
            for j in range(psz, m):
                assert profile_indexes[i][j] == 0

In [96]:
MODEL_PATH = '/mnt/workspace/pamessina_models/ugallery/youtube_like/v5_usermodelbigger/'

In [98]:
import tensorflow as tf
from Networks import ContentBasedLearn2RankNetwork_Train, TrainLogger

In [99]:
def train_network(train_minibatches, test_minibatches,
                  n_train_instances, n_test_instances, batch_size,
                  max_seconds_training=3600,
                  min_seconds_to_check_improvement=60,
                  early_stopping_checks=4,
                  learning_rates=[1e-3]):
    
    n_train_batches = train_minibatches['n_batches']
    n_test_batches = test_minibatches['n_batches']
    
    print('learning_rates = ', learning_rates)
    
    with tf.Graph().as_default():
        network = ContentBasedLearn2RankNetwork_Train(user_model_mode='BIGGER')
        with tf.Session() as sess:
            try:
                saver = tf.train.Saver()            
                saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
                print('model successfully restored from checkpoint!')
            except ValueError:
                print('no checkpoint found: initializing variables with random values')
                os.makedirs(MODEL_PATH, exist_ok=True)
                sess.run(tf.global_variables_initializer())            
            trainlogger = TrainLogger(MODEL_PATH + 'train_logs.csv')

            # ========= BEFORE TRAINING ============
            
            initial_test_acc = 0.
            for profile_indexes, profile_size, positive_index, negative_index in zip(
                test_minibatches['profile_indexes_batches'],
                test_minibatches['profile_size_batches'],
                test_minibatches['positive_index_batches'],
                test_minibatches['negative_index_batches']
            ):
                minibatch_test_acc = network.get_test_accuracy(
                    sess, resnet50_embeddings, profile_indexes, profile_size, positive_index, negative_index)
                initial_test_acc += minibatch_test_acc
            initial_test_acc /= n_test_instances

            print("Before training: test_accuracy = %f" % initial_test_acc)
            
            best_test_acc = initial_test_acc
            seconds_training = 0
            elapsed_seconds_from_last_check = 0
            checks_with_no_improvement = 0
            last_improvement_loss = None
            
            # ========= TRAINING ============
            
            print ('Starting training ...')
            n_lr = len(learning_rates)
            lr_i = 0
            train_loss_ema = 0. # exponential moving average
            
            while seconds_training < max_seconds_training:
                
                for train_i, (profile_indexes, profile_size, positive_index, negative_index) in enumerate(zip(
                    train_minibatches['profile_indexes_batches'],
                    train_minibatches['profile_size_batches'],
                    train_minibatches['positive_index_batches'],
                    train_minibatches['negative_index_batches']
                )):
                    # optimize and get traing loss
                    start_t = time.time()
                    _, minibatch_train_loss = network.optimize_and_get_train_loss(
                        sess, learning_rates[lr_i], resnet50_embeddings, profile_indexes,
                        profile_size, positive_index, negative_index)
                    delta_t = time.time() - start_t
                    
                    # update train loss exponential moving average
                    train_loss_ema = 0.999 * train_loss_ema + 0.001 * minibatch_train_loss
                    
                    # update time tracking variables
                    seconds_training += delta_t
                    elapsed_seconds_from_last_check += delta_t
                    
                    # check for improvements using test set if it's time to do so
                    if elapsed_seconds_from_last_check >= min_seconds_to_check_improvement:
                        
                        # --- testing
                        test_acc = 0.
                        for _profile_indexes, _profile_size, _positive_index, _negative_index in zip(
                            test_minibatches['profile_indexes_batches'],
                            test_minibatches['profile_size_batches'],
                            test_minibatches['positive_index_batches'],
                            test_minibatches['negative_index_batches']
                        ):
                            minibatch_test_acc = network.get_test_accuracy(
                                sess, resnet50_embeddings, _profile_indexes,
                                _profile_size, _positive_index, _negative_index)                            
                            test_acc += minibatch_test_acc
                        test_acc /= n_test_instances
                    
                        print(("train_i=%d, train_loss = %.12f, test_accuracy = %.5f,"
                               " check_secs = %.2f, total_secs = %.2f") % (
                                train_i, train_loss_ema, test_acc, elapsed_seconds_from_last_check, seconds_training))                        
                        
                        # check for improvements
                        if (test_acc > best_test_acc) or (
                            test_acc == best_test_acc and (
                                last_improvement_loss is not None and\
                                last_improvement_loss > train_loss_ema
                            )
                        ):  
                            last_improvement_loss = train_loss_ema
                            best_test_acc = test_acc
                            checks_with_no_improvement = 0
                            saver = tf.train.Saver()
                            save_path = saver.save(sess, MODEL_PATH)                    
                            print("   ** improvement detected: model saved to path ", save_path)
                            model_updated = True
                        else:
                            checks_with_no_improvement += 1                            
                            model_updated = False

                        # --- logging ---                        
                        trainlogger.log_update(
                            train_loss_ema, test_acc, n_train_instances, n_test_instances,
                            elapsed_seconds_from_last_check, batch_size, learning_rates[lr_i], 't' if model_updated else 'f')
                        
                        # --- check for early stopping
                        if checks_with_no_improvement >= early_stopping_checks:
                            if lr_i + 1 < len(learning_rates):
                                lr_i += 1
                                checks_with_no_improvement = 0
                                print("   *** %d checks with no improvements -> using a smaller learning_rate = %f" % (
                                    early_stopping_checks, learning_rates[lr_i]))
                            else:
                                print("   *** %d checks with no improvements -> early stopping :(" % early_stopping_checks)
                                return
                        
                        # --- reset check variables
                        elapsed_seconds_from_last_check = 0

In [100]:
train_batch_size = 2048
train_minibatches = generate_minibatches(train_instances, train_batch_size)
sanity_check_minibatches(train_minibatches)

n_tuples =  7445804
n_batches =  3636


In [101]:
test_batch_size = 2048
test_minibatches = generate_minibatches(test_instances, test_batch_size)
sanity_check_minibatches(test_minibatches)

n_tuples =  251137
n_batches =  123


In [102]:
train_network(
    train_minibatches, test_minibatches,
    len(train_instances), len(test_instances), train_batch_size,
    max_seconds_training=3600 * 4,
    min_seconds_to_check_improvement=180,
    early_stopping_checks=4,
    learning_rates=[1e-4, 3.33e-5, 1e-5, 3.33e-6, 1e-6])

learning_rates =  [0.0001, 3.33e-05, 1e-05, 3.33e-06, 1e-06]
INFO:tensorflow:Restoring parameters from /mnt/workspace/pamessina_models/ugallery/youtube_like/v5_usermodelbigger/
model successfully restored from checkpoint!
Before training: test_accuracy = 0.998129
Starting training ...
train_i=1915, train_loss = 0.005370066157, test_accuracy = 0.99640, check_secs = 180.01, total_secs = 180.01
train_i=138, train_loss = 0.031004245594, test_accuracy = 0.99750, check_secs = 180.07, total_secs = 360.08
train_i=2057, train_loss = 0.007360850512, test_accuracy = 0.99739, check_secs = 180.09, total_secs = 540.17
train_i=278, train_loss = 0.007656113307, test_accuracy = 0.99795, check_secs = 180.00, total_secs = 720.18
   *** 4 checks with no improvements -> using a smaller learning_rate = 0.000033
train_i=2197, train_loss = 0.002553315991, test_accuracy = 0.99845, check_secs = 180.03, total_secs = 900.21
   ** improvement detected: model saved to path  /mnt/workspace/pamessina_models/ugallery/