In [1]:
import numpy as np
import pandas as pd
import heapq
import random
import os
import time
import tensorflow as tf
from tqdm import tqdm
from os import path
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA

In [2]:
# use a single GPU because we want to be nice with other people :)
os.environ["CUDA_VISIBLE_DEVICES"]="1"

###  Load pre-trained ResNet50 image embeddings

In [4]:
def load_embeddings_and_ids(dirpath, embedding_file, ids_file):
    embeddings = np.load(path.join(dirpath, embedding_file))
    with open(path.join(dirpath, ids_file)) as f:
        ids = [int(x) for x in f.readlines()]
        id2index = { _id:i for i,_id in enumerate(ids) }    
    assert (embeddings.shape[0] == len(ids))
    return embeddings, ids, id2index

In [5]:
resnet50_embeddings,\
artwork_ids,\
artwork_id2index = load_embeddings_and_ids(
'/mnt/workspace/Ugallery/ResNet50/', 'flatten_1.npy', 'ids')

In [7]:
n_artworks = len(artwork_ids)
n_artworks

13297

###  Load transactions

In [37]:
sales_df = pd.read_csv('./valid_sales.csv')
artworks_df = pd.read_csv('./valid_artworks.csv')

In [38]:
artist_ids = np.full((n_artworks,), -1, dtype=int)
for _artworkId, _artistId in zip(artworks_df.id, artworks_df.artist_id):
    i = artwork_id2index[_artworkId]
    artist_ids[i] = _artistId

In [39]:
artistId2artworkIndexes = dict()
for i, _artistId in enumerate(artist_ids):
    if _artistId == -1:
        continue
    try:
        artistId2artworkIndexes[_artistId].append(i)
    except KeyError:
        artistId2artworkIndexes[_artistId] = [i]

### Collect transactions per user (making sure we hide the last nonfirst purchase basket per user)

In [40]:
class User:
    def __init__(self, uid):
        self._uid = uid
        self.artwork_ids = []
        self.artwork_idxs = []
        self.artwork_idxs_set = set()
        self.timestamps = []
        self.artist_ids_set = set()
        
    def clear(self):
        self.artwork_ids.clear()
        self.artwork_idxs.clear()
        self.artwork_idxs_set.clear()        
        self.artist_ids_set.clear()
        self.timestamps.clear()
        
    def append_transaction(self, artwork_id, timestamp):
        aidx = artwork_id2index[artwork_id]
        self.artwork_ids.append(artwork_id)
        self.artwork_idxs.append(aidx)
        self.artwork_idxs_set.add(aidx)
        self.artist_ids_set.add(artist_ids[aidx])
        self.timestamps.append(timestamp)
    
    def remove_last_nonfirst_purchase_basket(self):
        baskets = self.baskets
        if len(baskets) >= 2:
            last_b = baskets.pop()            
            artwork_ids = self.artwork_ids[:last_b[0]]
            timestamps = self.timestamps[:last_b[0]]
            self.clear()
            for aid, t in zip(artwork_ids, timestamps):
                self.append_transaction(aid, t)
        
    def build_purchase_baskets(self):
        baskets = []
        prev_t = None
        offset = 0
        count = 0
        for i, t in enumerate(self.timestamps):
            if t != prev_t:
                if prev_t is not None:
                    baskets.append((offset, count))
                    offset = i
                count = 1
            else:
                count += 1
            prev_t = t
        baskets.append((offset, count))
        self.baskets = baskets
        
    def sanity_check_purchase_baskets(self):
        ids = self.artwork_ids
        ts = self.timestamps
        baskets = self.baskets        
        n = len(ts)
        assert(len(ids) == len(ts))
        assert(len(baskets) > 0)
        assert (n > 0)
        for b in baskets:
            for j in range(b[0], b[0] + b[1] - 1):
                assert(ts[j] == ts[j+1])
        for i in range(1, len(baskets)):
            b1 = baskets[i-1]
            b2 = baskets[i]
            assert(b1[0] + b1[1] == b2[0])
        assert(baskets[0][0] == 0)
        assert(baskets[-1][0] + baskets[-1][1] == n)

#### create list of users

In [41]:
user_ids = sales_df.customer_id.unique()
user_id2index = { _id:i for i,_id in enumerate(user_ids) }
users = [User(uid) for uid in user_ids]
n_users = len(user_ids)

#### collect and sanity check transactions per user

In [42]:
sorted_sales_df = sales_df.sort_values('order_date')

In [43]:
# clear structures to prevent possible duplicate elements
for user in users:
    user.clear()

# collect transactions per user sorted by timestamp
for uid, aid, t in zip(sorted_sales_df.customer_id,
                       sorted_sales_df.artwork_id,
                       sorted_sales_df.order_date):
    users[user_id2index[uid]].append_transaction(aid,t)
    assert users[user_id2index[uid]]._uid == uid
    
# bin transctions with same timestamps into purchase baskets
for user in users:
    user.build_purchase_baskets()
    user.sanity_check_purchase_baskets()
#     user.remove_last_nonfirst_purchase_basket()
#     user.sanity_check_purchase_baskets()

### Compute minimun cosine distance from each user profile to each item in the dataset
\* using R200 vectors obtained with PCA(200) over ResNet50 embeddings

In [44]:
resnet50_PCA200 = PCA(n_components=200).fit_transform(resnet50_embeddings)

In [45]:
resnet50_PCA200.shape

(13297, 200)

In [46]:
distmat = squareform(pdist(resnet50_PCA200, 'cosine'))

In [47]:
user2artwork_mindist = np.empty((n_users, n_artworks))

In [48]:
for ui in tqdm(range(n_users)):
    for ai in range(n_artworks):
        user2artwork_mindist[ui][ai] = min(distmat[ai][j] for j in users[ui].artwork_idxs)

100%|██████████| 2919/2919 [00:45<00:00, 64.73it/s]


### Generate training data

In [49]:
train_instances = []

In [50]:
test_instances = []

In [51]:
def sanity_check_instance(instance, pos_is_purchased=True):
    profile = instance[0]
    pos = instance[1]
    neg = instance[2]    
    assert neg != pos
    
    if instance[3] == -1:        
        return # fake user
    
    user = users[instance[3]]
    try:
        assert all(i in user.artwork_idxs_set for i in profile)
        if pos_is_purchased is not None:
            if pos_is_purchased:
                assert pos in user.artwork_idxs_set
            else:
                assert pos not in user.artwork_idxs_set
        assert neg not in user.artwork_idxs_set
        assert artist_ids[neg] not in user.artist_ids_set
    except AssertionError:
        print(t)
        print('user._uid = ', user._uid)
        print('user.artwork_idxs = ', user.artwork_idxs)
        raise

In [52]:
def append_instance(container, instance, **kwargs):
    sanity_check_instance(instance, **kwargs)
    container.append(instance)

##### 1) Given a list of purchased items, each purchased item should trivially be ranked higher than any item of non-purchased artists

In [53]:
def sample_artwork_index__notsharingartist(profile_artist_ids):
    while True:
        i = random.randint(0, n_artworks-1)
        if artist_ids[i] not in profile_artist_ids:
            return i
    
def sample_artwork_index__notsharingartist__notinprofile(profile_artist_ids, profile_artwork_idxs):
    while True:
        i = random.randint(0, n_artworks-1)
        if i not in profile_artwork_idxs and artist_ids[i] not in profile_artist_ids:
            return i

In [54]:
def generate_samples__rank_purchased_above_nonpurchased__real_users(n_test_samples=5000, n_reps=2):
    
    # --- train instances
    print('sampling train instances ....')
    for ui, user in enumerate(users):
        u_artwork_idxs = user.artwork_idxs
        u_artist_ids = user.artist_ids_set
        n = len(u_artwork_idxs)        
        
        for k in range(1,n+1):
            for _ in range(n_reps):
                sample_profile = random.sample(u_artwork_idxs, k)
                for pi in sample_profile:
                    ni = sample_artwork_index__notsharingartist(u_artist_ids)
                    append_instance(train_instances, (sample_profile, pi, ni, ui))
        
    # --- test instances
    print('sampling test instances ....')
    for _ in range(n_test_samples):
        ui = random.randint(0,n_users-1)
        user = users[ui]
        k = random.randint(1, len(user.artwork_idxs))
        sample_profile = random.sample(user.artwork_idxs, k)
        pi = random.choice(sample_profile)
        ni = sample_artwork_index__notsharingartist(user.artist_ids_set)
        append_instance(test_instances, (sample_profile, pi, ni, ui))        

In [55]:
all_artwork_indexes = list(range(n_artworks))

In [56]:
def generate_samples__rank_purchased_above_nonpurchased__fake_users(
        instances_container, n_samples=10000, profile_size=1):
    for _ in range(n_samples):
        profile = random.sample(all_artwork_indexes, profile_size)
        profile_idxs_set = set(profile)
        profile_artist_ids_set = set(artist_ids[i] for i in profile if artist_ids[i] != -1)
        pi = random.choice(profile)
        ni = sample_artwork_index__notsharingartist__notinprofile(profile_artist_ids_set, profile_idxs_set)
        append_instance(instances_container, (profile, pi, ni, -1))

In [57]:
generate_samples__rank_purchased_above_nonpurchased__real_users(n_test_samples=10000, n_reps=3)
len(train_instances), len(test_instances)

sampling train instances ....
sampling test instances ....


(187680, 10000)

In [58]:
for x in range(1,6):
    generate_samples__rank_purchased_above_nonpurchased__fake_users(
        train_instances, n_samples=35000, profile_size=x)
    generate_samples__rank_purchased_above_nonpurchased__fake_users(
        test_instances, n_samples=3000, profile_size=x)
len(train_instances), len(test_instances)

(362680, 25000)

##### 2) Given a list of purchased items, any non-purchased item sharing the same artist with a purchased item should be ranked higher than any item of a non-purchased artist as long as ResNet50 doesn't disagree by much

In [59]:
def sample_artwork_index__nonpurchased_sharingartist(artist_id, artwork_idxs_set):
    candidate_idxs = artistId2artworkIndexes[artist_id]
    for _ in range(10): # try at most 10 times
        i = random.choice(candidate_idxs)
        if i not in artwork_idxs_set:
            return i
    return None # failed to find

In [60]:
def reject_user_positive_negative_triplet(ui, pi, ni, threshold=0.55):
    dp = user2artwork_mindist[ui][pi]
    dn = user2artwork_mindist[ui][ni]
    assert dp + dn > 0
    return dp / (dp + dn) > threshold

In [61]:
def sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, threshold):
    while True:
        ni = sample_artwork_index__notsharingartist(users[ui].artist_ids_set)
        if not reject_user_positive_negative_triplet(ui, pi, ni, threshold=threshold):
            return ni

In [62]:
def generate_samples__rank_purchased_artist_above_nonpurchased_artist(instances_container, n_samples):
    while n_samples > 0:
        ui = random.randint(0,n_users-1)
        user = users[ui]
        k = random.randint(1, len(user.artwork_idxs))
        sample_profile = random.sample(user.artwork_idxs, k)
        for _ in range(3):
            aid = artist_ids[random.choice(sample_profile)]
            assert aid != -1
            pi = sample_artwork_index__nonpurchased_sharingartist(aid, user.artwork_idxs_set)
            if pi is None:
                continue
            ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.55)
            append_instance(instances_container, (sample_profile, pi, ni, ui), pos_is_purchased=False)
            n_samples -= 1
            break

In [63]:
print('sampling train instances ...')
generate_samples__rank_purchased_artist_above_nonpurchased_artist(train_instances, n_samples=150000)
print('sampling test instances ...')
generate_samples__rank_purchased_artist_above_nonpurchased_artist(test_instances, n_samples=10000)
len(train_instances), len(test_instances)

sampling train instances ...
sampling test instances ...


(512680, 35000)

##### 3) Given all previous purchases, rank each  item of the next purchase basket higher than any item of non-purchased artists as long as ResNet50 doesn't disagree by much

In [64]:
def generate_samples__given_past_rank_next(n_neg_per_pos=10):
    for ui, user in enumerate(users):
        past_items = []
        n = len(user.baskets)
        for i in range(n-1):
            cur_b = user.baskets[i]
            for j in range(cur_b[0], cur_b[0] + cur_b[1]):
                past_items.append(user.artwork_idxs[j])
            next_b  = user.baskets[i+1]
            profile = past_items.copy()
            for p in range(next_b[0], next_b[0] + next_b[1]):
                
                pi = user.artwork_idxs[p]
                
                # train instances
                for _ in range(n_neg_per_pos):
                    ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.54)
                    append_instance(train_instances, (profile, pi, ni, ui))

                # test instance
                ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.54)
                append_instance(test_instances, (profile, pi, ni, ui))

In [66]:
generate_samples__given_past_rank_next(n_neg_per_pos=35)
len(train_instances), len(test_instances)

(605920, 40328)

##### 4) Given only the present purchase basket, hide one and rank it higher than any item of non-purchased artists as long as ResNet50 doesn't disagree by much

In [67]:
def generate_samples__given_present_hide_rank_one(n_neg_per_pos=10):
    for ui, user in enumerate(users):
        for b in user.baskets:
            if b[1] < 2:
                continue
            bs = b[0]
            be = b[0] + b[1]
            for i in range(bs, be):            
                profile = [user.artwork_idxs[j] for j in range(bs, be) if j != i]
                assert len(profile) == be - bs - 1
                assert len(profile) > 0
                
                pi = user.artwork_idxs[i]
                
                # train instances
                for _ in range(n_neg_per_pos):
                    ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.55)
                    append_instance(train_instances, (profile, pi, ni, ui))
                    
                # test instance
                ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.55)
                append_instance(test_instances, (profile, pi, ni, ui))

In [68]:
generate_samples__given_present_hide_rank_one(n_neg_per_pos=35)
len(train_instances), len(test_instances)

(694190, 42850)

##### 5) Given the past and the present, hide one and rank it higher than any item of non-purchased artists as long as ResNet50 doesn't disagree by much

In [69]:
def generate_samples__given_past_present_hide_rank_one(n_neg_per_pos=3):
    for ui, user in enumerate(users):
        if (len(user.baskets) < 2):
            continue
        u_baskets = user.baskets
        u_artwork_idxs = user.artwork_idxs
        n_baskets = len(u_baskets)
        purchased = []
        for i in range(n_baskets):
            b = u_baskets[i]
            purchased.extend(u_artwork_idxs[j] for j in range(b[0], b[0] + b[1]))
            if i == 0:
                continue
            assert len(purchased) == b[0] + b[1]
            jmax = b[0] + (b[1] if b[1] >= 2 else 0)
            for j in range(jmax):
                
                profile = [x for k,x in enumerate(purchased) if k != j]
                
                pi = u_artwork_idxs[j]
                
                # train instances
                for _ in range(n_neg_per_pos):
                    ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.55)
                    append_instance(train_instances, (profile, pi, ni, ui))
                    
                # test instance
                ni = sample_artwork_index__notsharingartist_tripletacceptable(ui, pi, 0.55)
                append_instance(test_instances, (profile, pi, ni, ui))

In [70]:
generate_samples__given_past_present_hide_rank_one(n_neg_per_pos=3)
len(train_instances), len(test_instances)

(816197, 83519)

#### sort train and test instances by profile size

In [71]:
train_instances.sort(key=lambda x: len(x[0]))
test_instances.sort(key=lambda x: len(x[0]))

### Build Tensorflow Network Graph

In [72]:
class Network:
    def __init__(self, learning_rate=1e-4):
        
        print('Network::__init__: learning_rate = ', learning_rate)
        
        # --- placeholders
        self._pretrained_embeddings = tf.placeholder(shape=[None, 2048], dtype=tf.float32,
                                                     name='pretrained_embeddings')            
        self._profile_item_indexes = tf.placeholder(shape=[None,None], dtype=tf.int32,
                                                    name='profile_item_indexes')
        self._profile_sizes = tf.placeholder(shape=[None], dtype=tf.float32,
                                                   name='profile_sizes')        
        self._positive_item_index = tf.placeholder(shape=[None], dtype=tf.int32,
                                                   name='positive_item_index')
        self._negative_item_index = tf.placeholder(shape=[None], dtype=tf.int32,
                                                   name='negative_item_index')
            
        # ---- user profile vector
        
        # profile item embeddings average
        tmp = tf.gather(self._pretrained_embeddings, self._profile_item_indexes)
        self._profile_item_embeddings = self.trainable_item_embedding(tmp)
        self._profile_masks = tf.expand_dims(tf.sequence_mask(self._profile_sizes, dtype=tf.float32), -1)
        self._masked_profile_item_embeddings = tf.multiply(self._profile_item_embeddings, self._profile_masks)        
        self._profile_items_average =\
            tf.reduce_sum(self._masked_profile_item_embeddings, axis=1) /\
            tf.reshape(self._profile_sizes, [-1, 1])
            
        # user hidden layer
        self._user_hidden = tf.layers.dense(
            inputs=self._profile_items_average,
            units=128,
            activation=tf.nn.selu,
            name='user_hidden'
        )
        
        # user final vector
        self._user_vector = tf.layers.dense(
            inputs=self._user_hidden,
            units=128,
            activation=tf.nn.selu,
            name='user_vector'
        )
        
        # ---- positive item vector
        tmp = tf.gather(self._pretrained_embeddings, self._positive_item_index)
        self._positive_item_vector = self.trainable_item_embedding(tmp)
        
        # ---- negative item vector
        tmp = tf.gather(self._pretrained_embeddings, self._negative_item_index)
        self._negative_item_vector = self.trainable_item_embedding(tmp)
        
        # --- train loss
        dot_pos = tf.reduce_sum(tf.multiply(self._user_vector, self._positive_item_vector), 1)
        dot_neg = tf.reduce_sum(tf.multiply(self._user_vector, self._negative_item_vector), 1)
        dot_delta = dot_pos - dot_neg
        ones = tf.fill(tf.shape(self._user_vector)[:1], 1.0)
        loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dot_delta, labels=ones)
        loss = tf.reduce_mean(loss, name='train_loss')
        self._train_loss = loss
        
        # --- test accuracy
        accuracy = tf.reduce_sum(tf.cast(dot_delta > .0, tf.float32), name = 'test_accuracy')
        self._test_accuracy = accuracy
        
        # --- optimizer
        self._optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self._train_loss)
        
    @staticmethod
    def trainable_item_embedding(X):
        with tf.variable_scope("trainable_item_embedding", reuse=tf.AUTO_REUSE):
            fc1 = tf.layers.dense( # None -> 256
                inputs=X,
                units=256,
                activation=tf.nn.selu,
                name='fc1'
            )
            fc2 = tf.layers.dense( # 256 -> 128
                inputs=fc1,
                units=128,
                activation=tf.nn.selu,
                name='fc2'
            )
            return fc2
    
    def optimize_and_get_train_loss(self, sess, pretrained_embeddings, profile_item_indexes, profile_sizes,
             positive_item_index, negative_item_index):
        return sess.run([
            self._optimizer,
            self._train_loss,
        ], feed_dict={
            self._pretrained_embeddings: pretrained_embeddings,
            self._profile_item_indexes: profile_item_indexes,
            self._profile_sizes: profile_sizes,
            self._positive_item_index: positive_item_index,
            self._negative_item_index: negative_item_index,
        })
    
    def get_train_loss(self, sess, pretrained_embeddings, profile_item_indexes, profile_sizes,
             positive_item_index, negative_item_index):
        return sess.run(
            self._train_loss, feed_dict={
            self._pretrained_embeddings: pretrained_embeddings,
            self._profile_item_indexes: profile_item_indexes,
            self._profile_sizes: profile_sizes,
            self._positive_item_index: positive_item_index,
            self._negative_item_index: negative_item_index,
        })
    
    def get_test_accuracy(self, sess, pretrained_embeddings, profile_item_indexes, profile_sizes,
             positive_item_index, negative_item_index):
        return sess.run(
            self._test_accuracy, feed_dict={
            self._pretrained_embeddings: pretrained_embeddings,
            self._profile_item_indexes: profile_item_indexes,
            self._profile_sizes: profile_sizes,
            self._positive_item_index: positive_item_index,
            self._negative_item_index: negative_item_index,
        })

In [30]:
# DEBUGGING 
# with tf.Graph().as_default():
#     network = Network()
#     with tf.Session() as sess:
#         sess.run(tf.global_variables_initializer())
#         tmp_debug = sess.run([
#             network._profile_item_embeddings,
#             network._profile_masks,
#             network._masked_profile_item_embeddings,
#             network._profile_items_average,
#         ], feed_dict={
#             network._pretrained_embeddings: resnet50_embeddings,
#             network._profile_item_indexes: [[0, 1], [2, 3]],
#             network._profile_sizes: [1, 2],
#         })

Network::__init__: learning_rate =  0.0001


### Training Network

In [78]:
def generate_minibatches(tuples, batch_size):
    n_tuples = len(tuples)
    n_batches = (n_tuples // batch_size) + int(n_tuples % batch_size > 0)
    
    print('n_tuples = ', n_tuples)
    print('n_batches = ', n_batches)
    
    profile_indexes_batches = [None] * n_batches
    profile_size_batches = [None] * n_batches
    positive_index_batches = [None] * n_batches
    negative_index_batches = [None] * n_batches
    
    for i in range(n_batches):
        jmin = i * batch_size
        jmax = min(jmin + batch_size, n_tuples)
        actual_batch_size = jmax - jmin
        
        profile_maxlen = max(len(tuples[j][0]) for j in range(jmin, jmax))
        profile_indexes_batch = np.full((actual_batch_size, profile_maxlen), 0, dtype=int)
        profile_size_batch = np.empty((actual_batch_size,))
        positive_index_batch = np.empty((actual_batch_size,), dtype=int)
        negative_index_batch = np.empty((actual_batch_size,), dtype=int)
        
        for j in range(actual_batch_size):            
            # profile indexes
            for k,v in enumerate(tuples[jmin+j][0]):
                profile_indexes_batch[j][k] = v
            # profile size
            profile_size_batch[j] = len(tuples[jmin+j][0])        
            # positive index
            positive_index_batch[j] = tuples[jmin+j][1]
            # negative index
            negative_index_batch[j] = tuples[jmin+j][2]
            
        profile_indexes_batches[i] = profile_indexes_batch
        profile_size_batches[i] = profile_size_batch
        positive_index_batches[i] = positive_index_batch
        negative_index_batches[i] = negative_index_batch
        
    return dict(
        profile_indexes_batches = profile_indexes_batches,
        profile_size_batches    = profile_size_batches,
        positive_index_batches  = positive_index_batches,
        negative_index_batches  = negative_index_batches,
        n_batches               = n_batches,
    )

In [79]:
def sanity_check_minibatches(minibatches):
    for profile_indexes, profile_size, positive_index, negative_index in zip(
        minibatches['profile_indexes_batches'],
        minibatches['profile_size_batches'],
        minibatches['positive_index_batches'],
        minibatches['negative_index_batches']
    ):
        n = profile_size.shape[0]
        assert n == profile_indexes.shape[0]
        assert n == positive_index.shape[0]
        assert n == negative_index.shape[0]
        
        for i in range(n):
            assert positive_index[i] != negative_index[i]
            psz = int(profile_size[i])
            m = profile_indexes[i].shape[0]
            assert psz <= m
            for j in range(psz, m):
                assert profile_indexes[i][j] == 0

In [80]:
MODEL_PATH = '/mnt/workspace/pamessina_models/ugallery/youtube_like/v1/'

In [76]:
def train_network(train_instances, test_instances, batch_size=64, max_epochs=60,
                  learning_rate=1e-4, early_stopping_epochs=4, session_config=None):
    
    train_minibatches = generate_minibatches(train_instances, batch_size)
    test_minibatches = generate_minibatches(test_instances, batch_size)    
    sanity_check_minibatches(train_minibatches)
    sanity_check_minibatches(test_minibatches)
    n_train_batches = train_minibatches['n_batches']
    n_test_batches = test_minibatches['n_batches']
    n_test_instances = len(test_instances)
    
    with tf.Graph().as_default():
        network = Network(learning_rate=learning_rate)
        with tf.Session(config=session_config) as sess:
            try:
                saver = tf.train.Saver()            
                saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
                print('model successfully restored from checkpoint!')
            except ValueError:
                print('no checkpoint found: initializing variables with random values')
                os.makedirs(MODEL_PATH, exist_ok=True)
                sess.run(tf.global_variables_initializer())

            # ========= BEFORE TRAINING ============
            
            initial_test_acc = 0.
            for profile_indexes, profile_size, positive_index, negative_index in zip(
                test_minibatches['profile_indexes_batches'],
                test_minibatches['profile_size_batches'],
                test_minibatches['positive_index_batches'],
                test_minibatches['negative_index_batches']
            ):
                minibatch_test_acc = network.get_test_accuracy(
                    sess, resnet50_embeddings, profile_indexes, profile_size, positive_index, negative_index)
                initial_test_acc += minibatch_test_acc
            initial_test_acc = (initial_test_acc / n_test_instances) * 100.

            print("Before training: test_accuracy = %f%%" % initial_test_acc)
            
            best_test_acc = initial_test_acc
            last_improvement_epoch = -1
            last_improvement_epoch_train_loss = None

            # ========= TRAINING ============
            
            print ('Starting training ...')

            for epoch in range(max_epochs):
                
                start_time = time.time()

                # --- training
                epoch_train_loss = 0.
                for profile_indexes, profile_size, positive_index, negative_index in zip(
                    train_minibatches['profile_indexes_batches'],
                    train_minibatches['profile_size_batches'],
                    train_minibatches['positive_index_batches'],
                    train_minibatches['negative_index_batches']
                ):
                    _, minibatch_train_loss = network.optimize_and_get_train_loss(
                        sess, resnet50_embeddings, profile_indexes, profile_size, positive_index, negative_index)                
                    epoch_train_loss += minibatch_train_loss
                epoch_train_loss /= n_train_batches

                # --- testing
                epoch_test_acc = 0.
                for profile_indexes, profile_size, positive_index, negative_index in zip(
                    test_minibatches['profile_indexes_batches'],
                    test_minibatches['profile_size_batches'],
                    test_minibatches['positive_index_batches'],
                    test_minibatches['negative_index_batches']
                ):
                    minibatch_test_acc = network.get_test_accuracy(
                        sess, resnet50_embeddings, profile_indexes, profile_size, positive_index, negative_index)
                    epoch_test_acc += minibatch_test_acc
                epoch_test_acc = (epoch_test_acc / n_test_instances) * 100.
                
                elapsed_seconds = time.time() - start_time
                
                # --- check for improvements and update best model if necessary
                print("epoch %d: train_loss = %f, test_accuracy = %f%%, elapsed_seconds = %f" % (
                        epoch, epoch_train_loss, epoch_test_acc, elapsed_seconds))                
                if (epoch_test_acc > best_test_acc) or (
                    epoch_test_acc == best_test_acc and (
                        last_improvement_epoch_train_loss is not None and\
                        epoch_train_loss < last_improvement_epoch_train_loss
                    )
                ):
                    saver = tf.train.Saver()
                    save_path = saver.save(sess, MODEL_PATH)
                    best_test_acc = epoch_test_acc
                    last_improvement_epoch = epoch
                    last_improvement_epoch_train_loss = epoch_train_loss
                    print("   ** improvement detected: model saved to path ", save_path)
                else:                    
                    if (epoch - last_improvement_epoch >= early_stopping_epochs):
                        print("   *** %d epochs with no improvements -> early stopping :(" % early_stopping_epochs)
                        return

In [84]:
train_network(train_instances, test_instances, batch_size=1500, max_epochs=60, learning_rate=1e-6, early_stopping_epochs=5)

n_tuples =  816197
n_batches =  545
n_tuples =  83519
n_batches =  56
Network::__init__: learning_rate =  1e-06
INFO:tensorflow:Restoring parameters from /mnt/workspace/pamessina_models/ugallery/youtube_like/v1/
model successfully restored from checkpoint!
Before training: test_accuracy = 99.261246%
Starting training ...
epoch 0: train_loss = 0.000027, test_accuracy = 99.263641%, elapsed_seconds = 75.409202
   ** improvement detected: model saved to path  /mnt/workspace/pamessina_models/ugallery/youtube_like/v1/
epoch 1: train_loss = 0.000026, test_accuracy = 99.267233%, elapsed_seconds = 75.264133
   ** improvement detected: model saved to path  /mnt/workspace/pamessina_models/ugallery/youtube_like/v1/
epoch 2: train_loss = 0.000025, test_accuracy = 99.264838%, elapsed_seconds = 75.024105
epoch 3: train_loss = 0.000024, test_accuracy = 99.262443%, elapsed_seconds = 75.121038
epoch 4: train_loss = 0.000023, test_accuracy = 99.263641%, elapsed_seconds = 75.021706
epoch 5: train_loss = 0