In [119]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.linalg import orthogonal_procrustes
from sklearn.metrics import mean_squared_error
from helpers_pmf import *
from helpers_similarity import *
from helpers_optimization import *

In [120]:
mu = 0
sigma_u = 1
sigma_v = 1
d_dim = 2
n_users = 3
n_movies = 4

In [123]:
def generate_U_V_X(mu, sigma_u, sigma_v, d_dim, n_users, n_movies):
    
    U = np.random.normal(mu, sigma_u, size = (d_dim,n_users))
    V = np.random.normal(mu, sigma_v, size = (d_dim,n_movies))
    X = np.matmul(U.T,V) 
    
    return U, V, X 

Calculate the difference between each pair of elements in the 2D array X. Then, generate the probability matrix by applying the logistic function $$P(x) = \frac{e^{x}}{1 + e^{x}}$$ element-wise to the difference matrix, where x is the difference between two elements :

In [124]:
def generate_P_BT_Luce(X):
    
    diff = np.subtract.outer(X, X) 
    diff_3D= np.array([diff[i, :, i, :] for i in range(n_users)])
    prob_matrix = np.exp(diff_3D) / (1 + np.exp(diff_3D))
    
    return prob_matrix, diff_3D


Generate pairwise comparison data $Y_{ijk} = \pm 1$ for each user and item. The output Y is a $3D$ tensor with the shape $(n\_users, n\_items, n\_items)$. Each entry $Y[i,j,k]$ corresponds to whether user i prefers item j over item k. For example, line 0 of Y corresponds to the pairwise comparisons of user 0 with all items: $Y[0,0,:]$ represents whether user 0 prefers item 0 over all other items $(item\_0, item\_1, item\_2, etc..)$, and so on.

In [125]:
def pairwise_comparisons(prob_matrix):
    
    Y = np.random.binomial(n=1, p=prob_matrix, size=prob_matrix.shape) 
    Y = np.where(Y == 0, -1, Y)
    
    return Y

In [126]:
U, V, X = generate_U_V_X(mu, sigma_u, sigma_v, d_dim, n_users, n_movies)

In [132]:
prob_matrix, diff = generate_P_BT_Luce(X)
Y = pairwise_comparisons(prob_matrix)

print('X :\n ',X)
print('P :\n ',P)
print('Y :\n ',Y)

X :
  [[ 0.87277673 -1.2750381  -0.90293949  0.36307287]
 [-0.45016557 -0.52031421  0.1595789   0.25205107]
 [ 2.9848343   0.33565863 -1.86747763 -0.50975766]]
P :
  [[[0.5        0.92598551 0.76025225 0.98740195]
  [0.07401449 0.5        0.20221063 0.86234886]
  [0.23974775 0.79778937 0.5        0.9611145 ]
  [0.01259805 0.13765114 0.0388855  0.5       ]]

 [[0.5        0.85102733 0.70249005 0.88594198]
  [0.14897267 0.5        0.29245347 0.576217  ]
  [0.29750995 0.70754653 0.5        0.76687707]
  [0.11405802 0.423783   0.23312293 0.5       ]]

 [[0.5        0.63339127 0.5194305  0.97214697]
  [0.36660873 0.5        0.38484537 0.95283422]
  [0.4805695  0.61515463 0.5        0.96996231]
  [0.02785303 0.04716578 0.03003769 0.5       ]]]
Y :
  [[[ 1  1  1 -1]
  [-1  1 -1 -1]
  [-1  1  1 -1]
  [-1  1  1  1]]

 [[ 1  1 -1 -1]
  [ 1 -1 -1 -1]
  [ 1 -1  1 -1]
  [ 1 -1 -1  1]]

 [[ 1  1  1  1]
  [-1 -1  1  1]
  [-1 -1  1 -1]
  [-1  1  1 -1]]]
