In [1]:
%load_ext autoreload
%autoreload 1

In [2]:
# hack based on https://stackoverflow.com/a/33532002
from inspect import getsourcefile
import os.path as path, sys
current_dir = path.dirname(path.abspath(getsourcefile(lambda:0)))
sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)])

In [3]:
%aimport FeatureUtils
%aimport ExperimentUtils
%aimport Networks

In [4]:
import os
import numpy as np
import FeatureUtils as featils
from Classes import Customer, ProfileBase
from ExperimentUtils import sanity_check_purchase_upload_events, recommendations_to_csv,\
        run_personalized_recommendation_experiment
from TransactionsUtils import TransactionsHandler

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [6]:
dirpath = '/mnt/workspace/Ugallery/VBPR/vbpr_resnet_10m/'
vbpr_item_vectors = np.load(dirpath + 'item_vectors.npy')
vbpr_item_biases = np.load(dirpath + 'item_biases.npy')
vbpr_item_index2id,\
vbpr_item_id2index = featils.read_ids_file(dirpath, 'items_ids')
vbpr_user_vectors = np.load(dirpath + 'user_vectors.npy')
vbpr_user_index2id,\
vbpr_user_id2index = featils.read_ids_file(dirpath, 'user_ids')

In [7]:
ids_with_features = set(vbpr_item_index2id)
len(ids_with_features)

13297

In [8]:
vbpr_item_biases.shape, vbpr_item_vectors.shape

((13297,), (13297, 200))

In [9]:
vbpr_user_vectors.shape

(2919, 200)

In [10]:
artworks_dict = TransactionsHandler.artworks_dict

In [11]:
customers_dict = { cid : Customer(cid) for cid in TransactionsHandler.valid_sales_df.customer_id.unique() }

In [12]:
# ---- upload events -----
upload_events = TransactionsHandler.upload_events

# ---- purchase events -----
purchase_session_events = TransactionsHandler.purchase_session_events

# distribute purchases among customers
for pe in purchase_session_events:
    customers_dict[pe.customer_id].append_purchase_session(pe)

# --- join events and sort by timestamp ----
time_events = upload_events + purchase_session_events
time_events.sort(key=lambda x : x.timestamp)

In [13]:
print("len(upload_events) = ", len(upload_events))
print("len(purchase_session_events) = ", len(purchase_session_events))
print("len(time_events) = ", len(time_events))

len(upload_events) =  7742
len(purchase_session_events) =  4897
len(time_events) =  12639


In [14]:
sanity_check_purchase_upload_events(time_events, artworks_dict)

CHECK: event types are correct
CHECK: events ordered by timestamp
CHECK: products are only uploaded once
CHECK: products can only be purchased if present in inventory


In [15]:
REC_SIZE = 20

In [16]:
import tensorflow as tf

In [17]:
class Network:
    def __init__(self):
        
        # --- placeholders
        self._user_vector = tf.placeholder(shape=[200], dtype=tf.float32)
        self._item_vectors = tf.placeholder(shape=[None, 200], dtype=tf.float32)
        self._item_biases = tf.placeholder(shape=[None], dtype=tf.float32)        
        self._candidate_item_indexes = tf.placeholder(shape=[None], dtype=tf.int32)
        
        # ---- candidate item vectors
        self._candidate_item_vectors = tf.gather(self._item_vectors, self._candidate_item_indexes)
        self._candidate_item_biases = tf.gather(self._item_biases, self._candidate_item_indexes)
        
        # ---- match scores
        self._match_scores = tf.reduce_sum(tf.multiply(self._user_vector, self._candidate_item_vectors), 1) +\
                            self._candidate_item_biases
    
    def get_match_scores(self, sess, user_vector, item_vectors, item_biases, candidate_items_indexes):
        return sess.run(
            self._match_scores, feed_dict={
            self._user_vector: user_vector,
            self._item_vectors: item_vectors,
            self._item_biases: item_biases,
            self._candidate_item_indexes: candidate_items_indexes,
        })

In [18]:
class VBPR_Profile(ProfileBase):
    # --- global -----        
    @classmethod
    def global_purchase_session_event_handler(cls, purch_sess):
        pass

    # --- instance ----    
    def __init__(self, artworks_dict, network, sess, user_vector):
        ProfileBase.__init__(self, None, artworks_dict)
        self._network = network
        self._sess = sess
        self._user_vector = user_vector
    def ready(self):
        return len(self.consumed_artworks) > 0
    def handle_artwork_added(self, artwork):
        pass        
    def handle_artwork_removed(self, artwork):
        pass
    def rank_inventory_ids(self, inventory_artworks):
        inventory_indexes = [vbpr_item_id2index[a.id] for a in inventory_artworks]
        match_scores = self._network.get_match_scores(self._sess,
            self._user_vector, vbpr_item_vectors, vbpr_item_biases, inventory_indexes)
        pairs = [(s,i) for s,i in zip(match_scores, inventory_indexes)]
        pairs.sort(reverse=True)
        return [vbpr_item_index2id[p[1]] for p in pairs]

In [19]:
def run_experiment(artworks_dict, customers_dict, time_events, version):
    with tf.Graph().as_default():
        network = Network()
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=0.99,
            allow_growth=True
        )
        config = tf.ConfigProto(gpu_options=gpu_options)
        with tf.Session(config=config) as sess:
            create_profile_func = lambda cid: VBPR_Profile(
                artworks_dict, network, sess, vbpr_user_vectors[vbpr_user_id2index[cid]])
            recommendations = run_personalized_recommendation_experiment(
                artworks_dict, customers_dict, time_events, create_profile_func, rec_size=REC_SIZE)
            recommendations_to_csv(
                recommendations,
                "/mnt/workspace/ugallery_experiment_results/@{}_vbpr-{}".format(REC_SIZE, version))

In [20]:
run_experiment(artworks_dict, customers_dict, time_events,
               version='vbpr_resnet_10m')

---------- starting experiment ------------
500 tests done! elapsed time: 4.09 seconds
1000 tests done! elapsed time: 8.21 seconds
1500 tests done! elapsed time: 12.70 seconds
1978 tests done! elapsed time: 17.19 seconds
** recommendations successfully saved to /mnt/workspace/ugallery_experiment_results/@20_vbpr-vbpr_resnet_10m
