In [50]:
from collections import defaultdict
import csv
import scipy
import scipy.optimize
import random
import numpy as np
import time

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable

import pandas as pd 
from IPython.display import display

In [2]:
class MF(nn.Module):
    itr = 0
    
    def __init__(self, n_user, n_item, k=1, c_vector=1.0, c_bias=1.0, writer=None):
        super(MF, self).__init__()
        self.writer = writer
        self.k = k
        self.n_user = n_user
        self.n_item = n_item
        self.c_bias = c_bias
        self.c_vector = c_vector
        
        # gammas (users and items)
        self.user = nn.Embedding(n_user, k)
        self.item = nn.Embedding(n_item, k)
        
        # alpha and betas (users and items)
        self.bias_user = nn.Embedding(n_user, 1)
        self.bias_item = nn.Embedding(n_item, 1)
        self.bias = nn.Parameter(torch.ones(1))
    
    def forward(self, train_x):
        user_id = train_x[:, 0]
        item_id = train_x[:, 1]
        vector_user = self.user(user_id)
        vector_item = self.item(item_id)
        
        # Pull out biases
        bias_user = self.bias_user(user_id).squeeze()
        bias_item = self.bias_item(item_id).squeeze()
        biases = (self.bias + bias_user + bias_item)
        
        ui_interaction = torch.sum(vector_user * vector_item, dim=1)
        
        # Add bias prediction to the interaction prediction
        prediction = ui_interaction + biases
        return prediction
    
    def loss(self, prediction, target):
        loss_mse = F.mse_loss(prediction, target.squeeze())
        
        # Add new regularization to the biases
        prior_bias_user =  l2_regularize(self.bias_user.weight) * self.c_bias
        prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias
        
        prior_user =  l2_regularize(self.user.weight) * self.c_vector
        prior_item = l2_regularize(self.item.weight) * self.c_vector
        total = loss_mse + prior_user + prior_item + prior_bias_user + prior_bias_item
        for name, var in locals().items():
            if type(var) is torch.Tensor and var.nelement() == 1 and self.writer is not None:
                self.writer.add_scalar(name, var, self.itr)
        return total

## user_id and place_id --> user_idx, place_idx

We need to change user_id and place_id into their index in the user-place interaction matrix.

### Reformat Interactions based on (user_idx, place_idx, rating, time)

In [70]:
data = pd.read_csv("../datasets/google_local/reviews.csv")
display(data.info())
display(data.head())

n_user = len(data['gPlusUserId'].unique())
n_place = len(data['gPlusPlaceId'].unique())

print(n_user,n_place)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11453845 entries, 0 to 11453844
Data columns (total 4 columns):
gPlusPlaceId      object
gPlusUserId       object
rating            float64
unixReviewTime    object
dtypes: float64(1), object(3)
memory usage: 349.5+ MB


None

Unnamed: 0,gPlusPlaceId,gPlusUserId,rating,unixReviewTime
0,108103314380004200232,100000010817154263736,3.0,1372686659
1,102194128241608748649,100000013500285534661,5.0,1342870724
2,101409858828175402384,100000021336848867366,5.0,1390653513
3,101477177500158511502,100000021336848867366,5.0,1389187706
4,106994170641063333085,100000021336848867366,4.0,1390486279


5054567 3116785


In [71]:
data = pd.read_csv("../datasets/google_local/reviews_reformatted.csv")
display(data.info())
display(data.head())

n_user = len(data['gPlusUserId'].unique())
n_place = len(data['gPlusPlaceId'].unique())

print(n_user,n_place)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11453845 entries, 0 to 11453844
Data columns (total 4 columns):
gPlusPlaceId      int64
gPlusUserId       int64
rating            float64
unixReviewTime    object
dtypes: float64(1), int64(2), object(1)
memory usage: 349.5+ MB


None

Unnamed: 0,gPlusPlaceId,gPlusUserId,rating,unixReviewTime
0,1368311,0,3.0,1372686659
1,370282,1,5.0,1342870724
2,237940,2,5.0,1390653513
3,249417,2,5.0,1389187706
4,1181533,2,4.0,1390486279


5054567 3116785
