In [6]:
import numpy as np
import tensorflow as tf
import random
from collections import defaultdict
import tqdm

def load_data(data_path):
    '''
    user_id, item_id, item_id,...
    '''
    user_ratings = defaultdict(set)
    max_u_id = -1
    max_i_id = -1
    with open(data_path, 'r') as f:
        f.readline()
        for idx, line in enumerate(f):
            u, i, _, _ = line.split(",")
            u = int(u)
            i = int(i)
            user_ratings[u].add(i)
            max_u_id = max(u, max_u_id)
            max_i_id = max(i, max_i_id)
            if idx == 1000:
                break
    return max_u_id, max_i_id, user_ratings
    

data_path = "./ml-20m/ratings.csv"
user_count, item_count, user_ratings = load_data(data_path)


def generate_test(user_ratings):
    user_test = dict()
    for u, i_list in user_ratings.items():
        user_test[u] = random.sample(i_list, 1)[0]
    return user_test

user_ratings_test = generate_test(user_ratings)

In [2]:
def generate_train_batch(user_ratings, user_ratings_test, item_count, batch_size=512):
    t = []
    for _ in range(batch_size):
        u = random.sample(user_ratings.keys(), 1)[0]
        i = random.sample(user_ratings[u], 1)[0]
        while i == user_ratings_test[u]:
            i = random.sample(user_ratings[u], 1)[0]
        
        j = random.randint(1, item_count)
        while j in user_ratings[u]:
            j = random.randint(1, item_count)
        t.append([u, i, j])
    return np.asarray(t)

def generate_test_batch(user_ratings, user_ratings_test, item_count):
    '''
    for an user u and an item i rated by u, 
    generate pairs (u,i,j) for all item j which u has't rated
    it's convinent for computing AUC score for u
    '''
    for u in user_ratings.keys():
        t = []
        i = user_ratings_test[u]
        for j in range(1, item_count+1):
            if not (j in user_ratings[u]):
                t.append([u, i, j])
        yield np.asarray(t)

In [3]:
def weight_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def bias_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

In [4]:
def bpr(user_count, item_count, hidden_dim, batch_size=512):
    
    u = tf.placeholder(tf.int32, [None])
    i = tf.placeholder(tf.int32, [None])
    j = tf.placeholder(tf.int32, [None])

    user_emb_w = weight_variable([user_count+1, hidden_dim])
    item_emb_w = weight_variable([item_count+1, hidden_dim])
    item_b = bias_variable([item_count+1, 1])
        
        
    u_emb = tf.nn.embedding_lookup(user_emb_w, u)
        
    i_emb = tf.nn.embedding_lookup(item_emb_w, i)
    i_b = tf.nn.embedding_lookup(item_b, i)
        
    j_emb = tf.nn.embedding_lookup(item_emb_w, j)
    j_b = tf.nn.embedding_lookup(item_b, j)
    
    # MF 
    x = i_b - j_b + tf.reduce_sum(tf.matmul(u_emb, tf.transpose((i_emb - j_emb))), 1, keep_dims=True)
    
    auc_per_user = tf.reduce_mean(tf.cast(x > 0,"float"))
    
    l2_norm = tf.add_n([
            tf.reduce_sum(tf.norm(u_emb)), 
            tf.reduce_sum(tf.norm(i_emb)),
            tf.reduce_sum(tf.norm(j_emb))
        ])
    
    regu_rate = 0.0001
    loss = - tf.reduce_mean(tf.log(tf.sigmoid(x))) + regu_rate * l2_norm
    
    train_op = tf.train.AdamOptimizer(0.01).minimize(loss)
    return u, i, j, auc_per_user, loss, train_op

In [5]:
with tf.Session() as session:
    u, i, j, auc, loss, train_op = bpr(user_count, item_count, 20)
    session.run(tf.global_variables_initializer())
    for epoch in range(10):
        _batch_loss = 0
        for index in tqdm.tqdm(range(5000)): 
            uij = generate_train_batch(user_ratings, user_ratings_test, item_count)
            _loss, _ = session.run([loss, train_op], feed_dict={u:uij[:,0], i:uij[:,1], j:uij[:,2]})
            _batch_loss += _loss
                   
        print("epoch: ", epoch, ", loss: ", _batch_loss / (index+1))

        user_count = 0
        _auc_sum = 0.0

        #each batch will return only one user's auc
        for t_uij in tqdm.tqdm(generate_test_batch(user_ratings, user_ratings_test, item_count)):
            _auc, _test_loss = session.run([auc, loss],feed_dict={u:t_uij[:,0], i:t_uij[:,1], j:t_uij[:,2]})
            user_count += 1
            _auc_sum += _auc
        print("test_loss: ", _test_loss, ", test_auc: ", _auc_sum/user_count)

100%|██████████| 5000/5000 [01:11<00:00, 71.32it/s]
0it [00:00, ?it/s]

epoch:  0 , loss:  0.00123147875081


11it [00:18,  1.72s/it]
  0%|          | 6/5000 [00:00<01:29, 55.87it/s]

test_loss:  0.00582718 , test_auc:  1.0


100%|██████████| 5000/5000 [01:22<00:00, 60.96it/s]
0it [00:00, ?it/s]

epoch:  1 , loss:  0.00077263317887


11it [00:19,  1.81s/it]
  0%|          | 7/5000 [00:00<01:19, 62.71it/s]

test_loss:  0.00573505 , test_auc:  0.727272727273


100%|██████████| 5000/5000 [01:21<00:00, 61.60it/s]
0it [00:00, ?it/s]

epoch:  2 , loss:  0.000748916212423


11it [00:20,  1.83s/it]
  0%|          | 6/5000 [00:00<01:36, 51.65it/s]

test_loss:  0.00532791 , test_auc:  0.909090909091


100%|██████████| 5000/5000 [01:22<00:00, 60.67it/s]
0it [00:00, ?it/s]

epoch:  3 , loss:  0.000755578781117


11it [00:20,  1.78s/it]
  0%|          | 6/5000 [00:00<01:31, 54.67it/s]

test_loss:  0.00570867 , test_auc:  0.909090909091


100%|██████████| 5000/5000 [01:23<00:00, 60.20it/s]
0it [00:00, ?it/s]

epoch:  4 , loss:  0.000756793994037


11it [00:19,  1.70s/it]
  0%|          | 6/5000 [00:00<01:27, 56.83it/s]

test_loss:  0.00540385 , test_auc:  0.909090909091


100%|██████████| 5000/5000 [01:22<00:00, 60.46it/s]
0it [00:00, ?it/s]

epoch:  5 , loss:  0.000758263415296


11it [00:19,  1.79s/it]
  0%|          | 6/5000 [00:00<01:37, 51.41it/s]

test_loss:  0.00570118 , test_auc:  0.727272727273


100%|██████████| 5000/5000 [01:21<00:00, 61.02it/s]
0it [00:00, ?it/s]

epoch:  6 , loss:  0.000758073269262


11it [00:18,  1.69s/it]
  0%|          | 6/5000 [00:00<01:36, 51.67it/s]

test_loss:  0.00548648 , test_auc:  0.909090909091


100%|██████████| 5000/5000 [01:20<00:00, 62.08it/s]
0it [00:00, ?it/s]

epoch:  7 , loss:  0.000757788489305


11it [00:19,  1.69s/it]
  0%|          | 6/5000 [00:00<01:29, 55.99it/s]

test_loss:  0.00522403 , test_auc:  0.818181818182


100%|██████████| 5000/5000 [01:22<00:00, 57.56it/s]
0it [00:00, ?it/s]

epoch:  8 , loss:  0.000758971285739


11it [00:19,  1.73s/it]
  0%|          | 6/5000 [00:00<01:25, 58.57it/s]

test_loss:  0.00515164 , test_auc:  0.818181818182


100%|██████████| 5000/5000 [01:20<00:00, 62.41it/s]
0it [00:00, ?it/s]

epoch:  9 , loss:  0.000757699703961


11it [00:18,  1.73s/it]

test_loss:  0.00554141 , test_auc:  0.909090909091



