In [1]:
import scipy.optimize
import random
from math import exp
from math import log
from collections import defaultdict
import gzip
import numpy as np
import time
import sys

def readGz(f):
    for l in gzip.open(f):
        yield eval(l)

data = []
for l in readGz('train.json.gz'):
    data.append(l)
random.shuffle(data)

In [2]:
class KNN(object):

    def __init__(self, numUsers, numItems, lamI = 6e-2, lamJ = 6e-3, learningRate = 0.1):
        self._numUsers = numUsers
        self._numItems = numItems
        self._lamI = lamI
        self._lamJ = lamJ
        self._learningRate = learningRate
        self._users = set()
        self._items = set()
        self._Iu = defaultdict(set)
        
        
    def sigmoid(self, x):
        return 1/(1+np.exp(-x))

    def train(self, trainData, epochs=30, batchSize=500):
        
        # correlation matrix
        self.C =np.random.rand(self._numItems, self._numItems)  
        for l in xrange(self._numItems):
            self.C[l][l] = 0
            for n in xrange(l, self._numItems):
                self.C[l][n] = self.C[n][l]
              
        # change batch_size to min(batch-size, len(train))
        if len(trainData) < batchSize:
            sys.stderr.write("WARNING: Batch size is greater than number of training samples, switching to a batch size of %s\n" % str(len(trainData)))
            batchSize = len(trainData)
                  
        self._trainDict, self._users, self._items = self._dataPretreatment(trainData)
        N = len(trainData) * epochs
        users, pItems, nItems = self._sampling(N)
        itr = 0
        t2 = t0 = time.time()
        while (itr+1)*batchSize < N:
      
            self._mbgd(
                users[itr*batchSize: (itr+1)*batchSize],
                pItems[itr*batchSize: (itr+1)*batchSize],
                nItems[itr*batchSize: (itr+1)*batchSize]
            )
            
            itr += 1
            t2 = time.time()
            sys.stderr.write("\rProcessed %s ( %.3f%% ) in %.1f seconds" %(str(itr*batchSize), 100.0 * float(itr*batchSize)/N, t2 - t0))
            sys.stderr.flush()
        if N > 0:
            sys.stderr.write("\nTotal training time %.2f seconds; %.2f samples per second\n" % (t2 - t0, N*1.0/(t2 - t0)))
            sys.stderr.flush()
            
            
    def _mbgd(self, users, pItems, nItems):
        
        prev = -2**10
        for _ in xrange(30):
            
            gradientC = defaultdict(float)
            obj = 0

            for ind in xrange(len(users)):
                u, i, j = users[ind], pItems[ind], nItems[ind]
                x_ui = sum([self.C[i][l] for l in self._Iu[u] if i != l])
                x_uj = sum([self.C[j][l] for l in self._Iu[u]])
                x_uij =  x_ui - x_uj
                
                for l in self._Iu[u]:
                    if l != i:
                        gradientC[(i,l)] += (1-self.sigmoid(x_uij)) + self._lamI * self.C[i][l]**2
                        gradientC[(l,i)] += (1-self.sigmoid(x_uij)) + self._lamI * self.C[l][i]**2
                    gradientC[(j,l)] += -(1-self.sigmoid(x_uij)) + self._lamJ * self.C[j][l]**2
                    gradientC[(l,j)] += -(1-self.sigmoid(x_uij)) + self._lamJ * self.C[l][j]**2
                    
                    obj -= 2*self._lamI * self.C[i][l]**2 + 2*self._lamJ * self.C[j][l]**2
                    
                obj += log(self.sigmoid(x_uij))
            
            #print 'OBJ: ', obj
            if prev > obj: 
                break
            prev = obj
            
            for a,b in gradientC:
                self.C[a][b] += self._learningRate * gradientC[(a,b)]
            
        #print _, '\n'
        
    def _sampling(self, N):
        
        sys.stderr.write("Generating %s random training samples\n" % str(N))
        userList = list(self._users)
        userIndex = np.random.randint(0, len(self._users), N)
        pItems, nItems = [], []
        cnt = 0
        for index in userIndex:
            u = userList[index]
            i = self._trainDict[u][np.random.randint(len(self._Iu[u]))]
            pItems.append(i)
            j = np.random.randint(self._numItems)
            while j in self._Iu[u]:
                j = np.random.randint(self._numItems)
            nItems.append(j)
            
            cnt += 1
            if not cnt %10000:
                sys.stderr.write("\rGenerated %s" %(str(cnt)))
                sys.stderr.flush()
        return userIndex, pItems, nItems

    def predictionsKNN(self, K, u):
        #slow
        if K >= self._Iu[u]:
            res = np.sum([self.C[:,l] for l in self._Iu[u]], 0)
        else:
            res = []
            for i in xrange(self._numItems):
                res.append(sum(sorted([self.C[i][l] for l in self._Iu[u]], reverse=True)[:K]))
        return res

    def predictionsAll(self, u):
        
        res = np.sum([self.C[:,l] for l in self._Iu[u]], 0)
        return res

    def prediction(self, u, i):
        
        scores = self.predictions(u)
        return scores[i] > sorted(scores)[self._numItem*0.8]

    def _dataPretreatment(self, data):
        dataDict = defaultdict(list)
        items = set()
        for u, i in data:
            self._Iu[u].add(i)
            dataDict[u].append(i)
            items.add(i)
        return dataDict, set(dataDict.keys()), items

In [3]:
length = len(data)
users = set()
items = set()
visited = set()
businessCount = defaultdict(int)

for l in data[:length]:
    user,business = l['userID'],l['businessID']
    users.add(user)  
    items.add(business)
    visited.add((user, business))  #visited pair
    businessCount[business] += 1
    
mostPopular = [(businessCount[x], x) for x in businessCount]
mostPopular.sort()
mostPopular.reverse()

return1 = set()
count = 0
for ic, i in mostPopular:
    count += ic
    return1.add(i)
    if count*1.0/length > 0.571: 
        break

users = list(users)
items = list(items)
Iu = defaultdict(set)
Ui = defaultdict(set)
for l in data[:length]:
    Iu[l['userID']].add(l['businessID'])
    Ui[l['businessID']].add(l['userID'])
    
users = {value:key for key, value in enumerate(users)}
items = {value:key for key, value in enumerate(items)}

In [4]:
train = [(users[l['userID']], items[l['businessID']]) for l in data[:length]]   
bpr = KNN(len(users), len(items), lamI = 4e-2, lamJ = 4e-3, learningRate = 0.1)
bpr.train(train, epochs=2, batchSize=500)

Generating 400000 random training samples
Generated 400000

OBJ:  -714.082881515
OBJ:  -468.384842843
OBJ:  -391.671572317
OBJ:  -351.099722502
OBJ:  -327.150577218
OBJ:  -312.116722227
OBJ:  -302.373538446
OBJ:  -295.976460405
OBJ:  

Processed 500 ( 0.125% ) in 1.6 seconds

-291.80315476
OBJ:  -289.169868832
OBJ:  -287.640893566
OBJ:  -286.927760733
OBJ:  -286.833161225
OBJ:  -287.218336557
13 

OBJ:  -668.410453935
OBJ:  -444.003951036
OBJ:  -359.352568517
OBJ:  -314.944098242

Processed 1000 ( 0.250% ) in 2.9 seconds


OBJ:  -288.838263517
OBJ:  -272.404876317
OBJ:  -261.630551634
OBJ:  -254.404298503
OBJ:  -249.525142659
OBJ:  -246.267206199
OBJ:  -244.169787664
OBJ:  -242.928485129
OBJ:  -242.335302281
OBJ:  -242.244032625
OBJ:  -242.54939527
14 

OBJ:  -712.313241301
OBJ:  -460.292409492
OBJ:  -375.308153065
OBJ:  -332.102394055


Processed 1500 ( 0.375% ) in 4.3 seconds

OBJ:  -307.085623416
OBJ:  -291.561739827
OBJ:  -281.530469577
OBJ:  -274.918838182
OBJ:  -270.558670303
OBJ:  -267.748375961
OBJ:  -266.04467188
OBJ:  -265.155638438
OBJ:  -264.882312161
OBJ:  -265.085111075
13 

OBJ:  -723.039065288
OBJ:  -473.349994958
OBJ:  -392.169837437


Processed 2000 ( 0.500% ) in 5.9 seconds

OBJ:  -349.022883654
OBJ:  -323.820100127
OBJ:  -308.305900126
OBJ:  -298.379634076
OBJ:  -291.904973525
OBJ:  -287.689676606
OBJ:  -285.024529337
OBJ:  -283.46463751
OBJ:  -282.718443816
OBJ:  -282.58785169
OBJ:  -282.934155909
13 

OBJ:  -675.603812991
OBJ:  -444.027694074
OBJ:  -367.63578217
OBJ:  

Processed 2500 ( 0.625% ) in 7.4 seconds

-328.645657584
OBJ:  -305.788277798
OBJ:  -291.503041594
OBJ:  -282.261772768
OBJ:  -276.20309901
OBJ:  -272.26206341
OBJ:  -269.790826998
OBJ:  -268.375812516
OBJ:  -267.742464913
OBJ:  -267.702514524
OBJ:  -268.123288739
13 

OBJ:  -679.855976447
OBJ:  -443.998090015
OBJ:  -369.106210963
OBJ:  -329.810594399
OBJ:  -306.546003607

Processed 3000 ( 0.750% ) in 8.7 seconds


OBJ:  -291.947333342
OBJ:  -282.460856302
OBJ:  -276.194766555
OBJ:  -272.067360693
OBJ:  -269.422734418
OBJ:  -267.843394195
OBJ:  -267.052725374
OBJ:  -266.861228057
OBJ:  -267.135380404
13 

OBJ:  -636.10988695
OBJ:  -431.825569975
OBJ:  -354.365435601
OBJ:  -315.248139274


Processed 3500 ( 0.875% ) in 10.0 seconds

OBJ:  -292.688755225
OBJ:  -278.631915581
OBJ:  -269.462752852
OBJ:  -263.349623183
OBJ:  -259.268968644
OBJ:  -256.605203842
OBJ:  -254.966794328
OBJ:  -254.093536199
OBJ:  -253.806290982
OBJ:  -253.978105579
13 

OBJ:  -721.153086855
OBJ:  -482.404589257
OBJ:  -400.459566224
OBJ: 

Processed 4000 ( 1.000% ) in 11.6 seconds

 -359.209768733
OBJ:  -335.602703982
OBJ:  -321.157397759
OBJ:  -311.958780088
OBJ:  -306.002194032
OBJ:  -302.17360696
OBJ:  -299.810468247
OBJ:  -298.496424764
OBJ:  -297.957901202
OBJ:  -298.008469566
12 

OBJ:  -689.134417231
OBJ:  -460.6529214
OBJ:  -384.1727043
OBJ:  

Processed 4500 ( 1.125% ) in 13.2 seconds

-345.666852618
OBJ:  -323.457379511
OBJ:  -309.695239132
OBJ:  -300.833715111
OBJ:  -295.041825806
OBJ:  -291.28527677
OBJ:  -288.940179715
OBJ:  -287.610443301
OBJ:  -287.033948892
OBJ:  -287.031184147
OBJ:  -287.475628909
13 

OBJ:  -656.849147214
OBJ:  -436.894460307
OBJ:  -358.047998156
OBJ:  -319.070425529
OBJ:  -296.48863923
OBJ:  -282.461713275
OBJ:  -273.406185411
OBJ:  -267.462042146


Processed 5000 ( 1.250% ) in 14.5 seconds

OBJ:  -263.578744733
OBJ:  -261.123536734
OBJ:  -259.695152368
OBJ:  -259.028187473
OBJ:  -258.940889137
OBJ:  -259.305132192
13 

OBJ:  -648.330328012
OBJ:  -447.943548082
OBJ:  

Processed 5500 ( 1.375% ) in 15.9 seconds

-377.616797902
OBJ:  -341.43071857
OBJ:  -320.109722158
OBJ:  -306.758280252
OBJ:  -298.104109243
OBJ:  -292.419528672
OBJ:  -288.719383146
OBJ:  -286.405584267
OBJ:  -285.095989777
OBJ:  -284.536068196
OBJ:  -284.550408552
12 

OBJ:  -682.525201412
OBJ:  -457.434229095
OBJ:  -379.837614635
OBJ:  -341.542889221
OBJ:  

Processed 6000 ( 1.500% ) in 17.3 seconds

-319.797219683
OBJ:  -306.525956568
OBJ:  -298.084552856
OBJ:  -292.632076818
OBJ:  -289.149659553
OBJ:  -287.030943588
OBJ:  -285.893913561
OBJ:  -285.486732696
OBJ:  -285.637187005
12 

OBJ:  -619.981393004
OBJ:  -421.752233064
OBJ:  -351.465108934
OBJ:  

Processed 6500 ( 1.625% ) in 18.7 seconds

-314.795438272
OBJ:  -293.410676174
OBJ:  -280.089810512
OBJ:  -271.465292808
OBJ:  -265.78732121
OBJ:  -262.066610743
OBJ:  -259.706068042
OBJ:  -258.326219777
OBJ:  -257.675620918
OBJ:  -257.58183356
OBJ:  -257.923156426
13 

OBJ:  -608.489867225
OBJ:  -419.4590677
OBJ:  -350.934642978
OBJ:  -315.167039898


Processed 7000 ( 1.750% ) in 20.1 seconds

OBJ:  -294.106932796
OBJ:  -280.960287504
OBJ:  -272.489526082
OBJ:  -266.970136208
OBJ:  -263.413820927
OBJ:  -261.220705829
OBJ:  -260.00884038
OBJ:  -259.525081813
OBJ:  -259.595922023
12 

OBJ:  -670.149726596
OBJ:  -449.660417842
OBJ:  -378.713485975
OBJ:  -343.791736111
OBJ:  -323.977541944
OBJ:  -311.833168939
OBJ:  -304.109466519
OBJ:  -299.15308709
OBJ:  -296.036267666
OBJ:  

Processed 7500 ( 1.875% ) in 21.7 seconds

-294.200653085
OBJ:  -293.291587332
OBJ:  -293.073886031
OBJ:  -293.38602292
12 

OBJ:  -640.249018972
OBJ:  -441.125420842
OBJ:  -370.572530703

Processed 8000 ( 2.000% ) in 23.1 seconds


OBJ:  -333.62628836
OBJ:  -311.924860789
OBJ:  -298.360828389
OBJ:  -289.588143737
OBJ:  -283.842864286
OBJ:  -280.117342855
OBJ:  -277.799217494
OBJ:  -276.497464022
OBJ:  -275.952102117
OBJ:  -275.98440671
12 

OBJ:  -662.542422597
OBJ:  -447.314482591
OBJ:  -378.024263551
OBJ:  -342.23809585

Processed 8500 ( 2.125% ) in 24.7 seconds


OBJ:  -321.612801937
OBJ:  -308.981690683
OBJ:  -300.993905208
OBJ:  -295.909447033
OBJ:  -292.746801519
OBJ:  -290.915176529
OBJ:  -290.040089458
OBJ:  -289.874031007
OBJ:  -290.247703286
12 

OBJ:  -683.038644143
OBJ:  -441.455028062
OBJ:  -374.319394056
OBJ:  -341.552511087
OBJ:  -322.718341967
OBJ:  -311.178345753


Processed 9000 ( 2.250% ) in 26.2 seconds

OBJ:  -303.908088434
OBJ:  -299.330761617
OBJ:  -296.548098889
OBJ:  -295.014672749
OBJ:  -294.382508114
OBJ:  -294.420556584
11 

OBJ:  -658.935688588
OBJ:  -437.049172267
OBJ:  -377.217769834
OBJ:  -346.583787174
OBJ:  -328.514722646
OBJ:  -317.246127044
OBJ: 

Processed 9500 ( 2.375% ) in 27.7 seconds

 -310.031066361
OBJ:  -305.405468727
OBJ:  -302.525087789
OBJ:  -300.872363269
OBJ:  -300.114116478
OBJ:  -300.027181553
OBJ:  -300.457116376
12 

OBJ:  -632.319938098
OBJ:  -424.550749169
OBJ:  -360.3903743


Processed 10000 ( 2.500% ) in 28.9 seconds

OBJ:  -328.581472884
OBJ:  -310.331286882
OBJ:  -299.160746623
OBJ:  -292.121825077
OBJ:  -287.68241006
OBJ:  -284.973875632
OBJ:  -283.470833848
OBJ:  -282.83892045
OBJ:  -282.856416784
11 

OBJ:  -654.733004041
OBJ:  -444.053617419
OBJ:  -375.363602268
OBJ:  -338.954435006
OBJ: 

Processed 10500 ( 2.625% ) in 30.3 seconds

 -317.809935237
OBJ:  -304.769564391
OBJ:  -296.449041884
OBJ:  -291.082276122
OBJ:  -287.674762945
OBJ:  -285.628798787
OBJ:  -284.565516669
OBJ:  -284.234160863
OBJ:  -284.462805483
12 

OBJ:  -654.143299715
OBJ:  -434.07459793
OBJ:  -365.776596479
OBJ:  -331.584284397


Processed 11000 ( 2.750% ) in 31.6 seconds

OBJ:  -311.889373691
OBJ:  -299.69749757
OBJ:  -291.876682458
OBJ:  -286.814110734
OBJ:  -283.596439734
OBJ:  -281.669802534
OBJ:  -280.680330362
OBJ:  -280.392384694
OBJ:  -280.643685362
12 

OBJ:  -638.484728405
OBJ:  -431.133705362
OBJ:  -361.682237759
OBJ: 

Processed 11500 ( 2.875% ) in 32.8 seconds

 -326.396672092
OBJ:  -305.625698014
OBJ:  -292.59030868
OBJ:  -284.131784242
OBJ:  -278.583429423
OBJ:  -274.989213632
OBJ:  -272.765109736
OBJ:  -271.536228115
OBJ:  -271.052123078
OBJ:  -271.13991235
12 

OBJ:  -662.739626435
OBJ:  -435.863079796
OBJ:  -372.626009143
OBJ: 

Processed 12000 ( 3.000% ) in 34.2 seconds

 -340.526146336
OBJ:  -322.013833476
OBJ:  -310.624143519
OBJ:  -303.393432648
OBJ:  -298.788343576
OBJ:  -295.94085817
OBJ:  -294.325095694
OBJ:  -293.605227011
OBJ:  -293.557595149
OBJ:  -294.028075765
12 

OBJ:  -625.410200664
OBJ:  -426.472287651
OBJ:  -364.215216055
OBJ: 

Processed 12500 ( 3.125% ) in 35.4 seconds

 -332.119043614
OBJ:  -313.483599991
OBJ:  -302.081524792
OBJ:  -294.926818108
OBJ:  -290.44171393
OBJ:  -287.73035238
OBJ:  -286.252389523
OBJ:  -285.664922151
OBJ:  -285.741036417
11 

OBJ:  -643.777484195
OBJ:  -442.134016525
OBJ: 

Processed 13000 ( 3.250% ) in 36.8 seconds

 -377.497767769
OBJ:  -343.472592501
OBJ:  -323.596378636
OBJ:  -311.399164351
OBJ:  -303.684463013
OBJ:  -298.775462358
OBJ:  -295.728356549
OBJ:  -293.975111081
OBJ:  -293.155225093
OBJ:  -293.030071501
OBJ:  -293.436186918
12 

OBJ:  -615.382839225
OBJ:  -437.102037592
OBJ:  

Processed 13500 ( 3.375% ) in 38.2 seconds

-377.362105514
OBJ:  -345.964023346
OBJ:  -327.461075101
OBJ:  -316.067209297
OBJ:  -308.903574613
OBJ:  -304.415735403
OBJ:  -301.711175719
OBJ:  -300.248657481
OBJ:  -299.683981564
OBJ:  -299.789236337
11 

OBJ:  -624.592465757
OBJ:  -430.092584515
OBJ:  -367.38683844

Processed 14000 ( 3.500% ) in 39.6 seconds


OBJ:  -335.214218603
OBJ:  -316.554649162
OBJ:  -305.102715126
OBJ:  -297.883644615
OBJ:  -293.337884291
OBJ:  -290.578231351
OBJ:  -289.066058432
OBJ:  -288.457274583
OBJ:  -288.52295264
11 

OBJ:  -591.71779846
OBJ:  -410.189077789
OBJ:  -345.064611012
OBJ:  -312.913395687
OBJ:  -294.234848588
OBJ:  -282.640421571
OBJ:  -275.220908091
OBJ:  -270.449194493
OBJ:  -267.451788455
OBJ:  -265.6963933
OBJ:  -264.842174055
OBJ:  -264.662009975


Processed 14500 ( 3.625% ) in 40.9 seconds

OBJ:  -264.999505267
12 

OBJ:  -650.600180725
OBJ:  -439.593121861
OBJ:  -382.224602157
OBJ:  -352.501145637
OBJ:  -335.329443722
OBJ:  -324.888753185
OBJ:  -318.415978623
OBJ:  -314.453131896

Processed 15000 ( 3.750% ) in 42.2 seconds


OBJ:  -312.169669173
OBJ:  -311.062240578
OBJ:  -310.810004797
OBJ:  -311.199847187
11 

OBJ:  -614.270973289
OBJ:  -418.590966246
OBJ:  -352.849263502
OBJ:  -319.60467794
OBJ:  -300.553994756
OBJ:  -288.98863518
OBJ:  -281.809528761
OBJ:  -277.402592354
OBJ:  -274.855598065


Processed 15500 ( 3.875% ) in 43.4 seconds

OBJ:  -273.620278901
OBJ:  -273.35149133
OBJ:  -273.824681421
11 

OBJ:  -645.488818926
OBJ:  -439.12196638
OBJ:  

Processed 16000 ( 4.000% ) in 45.0 seconds

-376.334196156
OBJ:  -343.99454809
OBJ:  -325.460117489
OBJ:  -314.21612577
OBJ:  -307.235798465
OBJ:  -302.937288498
OBJ:  -300.421903885
OBJ:  -299.146168057
OBJ:  -298.764704489
OBJ:  -299.048945002
11 

OBJ:  -583.102216988
OBJ:  -417.399898507
OBJ:  -360.208749641

Processed 16500 ( 4.125% ) in 46.2 seconds


OBJ:  -330.541760803
OBJ:  -313.23976153
OBJ:  -302.591102408
OBJ:  -295.897913139
OBJ:  -291.732819429
OBJ:  -289.27741063
OBJ:  -288.032821914
OBJ:  -287.679944072
OBJ:  -288.006698198
11 

OBJ:  -646.507073447
OBJ:  -436.095223518
OBJ:  -372.964173899
OBJ:  -342.322017814
OBJ:  

Processed 17000 ( 4.250% ) in 47.3 seconds

-325.100320477
OBJ:  -314.821156117
OBJ:  -308.616611583
OBJ:  -305.014238765
OBJ:  -303.184286769
OBJ:  -302.629533054
OBJ:  -303.039856788
10 

OBJ:  -605.082834752
OBJ:  -425.562727896
OBJ:  -367.28795076
OBJ: 

Processed 17500 ( 4.375% ) in 48.5 seconds

 -336.763472266
OBJ:  -318.993809915
OBJ:  -308.087407581
OBJ:  -301.24709576
OBJ:  -296.985625377
OBJ:  -294.45019594
OBJ:  -293.121688701
OBJ:  -292.668281577
OBJ:  -292.869190197
11 

OBJ:  -610.290444406
OBJ:  -423.170134095
OBJ:  -360.046228528
OBJ:  -327.710965622
OBJ:  -308.869855328
OBJ:  -297.248946321
OBJ:  -289.891739317
OBJ:  -285.231825275
OBJ:  -282.37334863
OBJ:  

Processed 18000 ( 4.500% ) in 49.9 seconds

-280.772786487
OBJ:  -280.083952017
OBJ:  -280.077241743
OBJ:  -280.595153385
12 

OBJ:  -597.072820419
OBJ:  -426.22799837
OBJ:  -369.371545007
OBJ: 

Processed 18500 ( 4.625% ) in 51.2 seconds

 -339.685710065
OBJ:  -322.678726406
OBJ:  -312.415402916
OBJ:  -306.094983945
OBJ:  -302.257105587
OBJ:  -300.074773864
OBJ:  -299.048165889
OBJ:  -298.859235653
OBJ:  -299.297063924
11 

OBJ:  -656.399708032
OBJ:  -435.147345861
OBJ:  -376.379396893
OBJ:  -347.771718398

Processed 19000 ( 4.750% ) in 52.6 seconds


OBJ:  -331.398323495
OBJ:  -321.479330678
OBJ:  -315.357289177
OBJ:  -311.641310109
OBJ:  -309.537692451
OBJ:  -308.563598029
OBJ:  -308.410353965
OBJ:  -308.872623558
11 



KeyboardInterrupt: 

In [None]:
predictions = open("predictions_Visit.txt", 'w')
cnt = 0
for l in open("pairs_Visit.txt"):
    cnt += 1
    if l.startswith("userID"):
    #header
        predictions.write(l)
        continue
    u,i = l.strip().split('-')
    if u not in users:  #no history of user, use popularity
        if i in return1:
            predictions.write(u + '-' + i + ",1\n")
        else:
            predictions.write(u + '-' + i + ",0\n")
        continue
        
    if i not in items:   #no one visited before
        predictions.write(u + '-' + i + ",0\n")
        continue
        
    #scores = bpr.predictionsKNN(10, users[u])
    scores = bpr.predictionsAll(users[u])
    #thrhd = sorted(scores)[len(items)*7/10]
    score = scores[items[i]]
    
    num = len([1 for x in scores if x > score])

    #if score >= thrhd:
    if num < len(scores)*30/100:
        predictions.write(u + '-' + i + ",1\n")
    else:
        predictions.write(u + '-' + i + ",0\n")
        
    sys.stderr.write("\rProcessed %s" %(str(cnt)))
    sys.stderr.flush()
predictions.close()

In [20]:
cnt = 0
for l in open("predictions_Visit.txt", 'r'):
    if l.startswith("userID"):
    #header
        continue
    _, y = l.strip().split(',')
    if y == '1':
        cnt += 1
cnt

18644

In [19]:
# predict with users' thrhd
predictions = open("predictions_Visit.txt", 'w')
cnt = 0
for l in open("pairs_Visit.txt"):
    cnt += 1
    if l.startswith("userID"):
    #header
        predictions.write(l)
        continue
    u,i = l.strip().split('-')
    if u not in users:  #no history of user, use popularity
        if i in return1:
            predictions.write(u + '-' + i + ",1\n")
        else:
            predictions.write(u + '-' + i + ",0\n")
        continue
        
    if i not in items:   #no one visited before
        predictions.write(u + '-' + i + ",0\n")
        continue
        
    scores = pred(users[u])
    thrhd = min([scores[items[it]] for it in Iu[u]])
    
    if scores[items[i]] >= thrhd:
        predictions.write(u + '-' + i + ",1\n")
    else:
        predictions.write(u + '-' + i + ",0\n")
        
    sys.stderr.write("\rProcessed %s" %(str(cnt)))
    sys.stderr.flush()
predictions.close()

Processed 40001

In [None]:
from theano_bpr import BPR
import theano
import theano.tensor as T
from theano import function, config, shared, sandbox  

In [None]:
length = len(data)*10/11
users = set()
items = set()
visited = set()
businessCount = defaultdict(int)

for l in data[:length]:
    user,business = l['userID'],l['businessID']
    users.add(user)  
    items.add(business)
    visited.add((user, business))  #visited pair
    businessCount[business] += 1
    
mostPopular = [(businessCount[x], x) for x in businessCount]
mostPopular.sort()
mostPopular.reverse()

return1 = set()
count = 0
for ic, i in mostPopular:
    count += ic
    return1.add(i)
    if count*1.0/length > 0.571: 
        break

users = list(users)
items = list(items)
Iu = defaultdict(set)
Ui = defaultdict(set)
for l in data[:length]:
    Iu[l['userID']].add(l['businessID'])
    Ui[l['businessID']].add(l['userID'])
    
unvisited = set()
while len(unvisited) < len(data)-length:    
    user = users[random.randint(0,len(users)-1)]
    item = items[random.randint(0,len(items)-1)]
    if item not in Iu[user]:
        unvisited.add((user, item))
    
users = {value:key for key, value in enumerate(users)}
items = {value:key for key, value in enumerate(items)}

In [None]:
from theano_bpr import BPR
train = [(users[l['userID']], items[l['businessID']]) for l in data[:length]]   
for U in [1e-4,1e-3,0.1]:
    for I in [1e-4,1e-3,0.1]:
        for J in [1e-4,1e-3,0.1]:
            for bias in [0.1,1,3]:
                bpr = BPR(5, len(users), len(items), lambda_u = U, lambda_i = I, lambda_j = J, lambda_bias = bias)
                bpr.train(train, epochs=150, batch_size=1000)
                
                print U,I,J,bias
                cnt = 0
                for l in data[length:]:
                    u,i = l['userID'], l['businessID']
                    if u not in users:  #no history of user, use popularity
                        if i in return1:
                            cnt += 1
                        continue

                    if i not in items:   #no one visited before
                        continue

                    scores = bpr.predictions(users[u])
                    thrhd = min([scores[items[it]] for it in Iu[u]])

                    if scores[items[i]] >= thrhd:
                        cnt += 1
                ac = cnt*1.0/(len(data)-length)
                print "acr posi: " + str(ac),
                cnt = 0
                for u,i in unvisited:
                    
                    if u not in users:  #no history of user, use popularity
                        if i not in return1:
                            cnt += 1
                        continue

                    if i not in items:   #no one visited before
                        cnt += 1
                        continue

                    scores = bpr.predictions(users[u])
                    thrhd = min([scores[items[it]] for it in Iu[u]])

                    if scores[items[i]] < thrhd:
                        cnt += 1
                print "acr nega: " + str(cnt*1.0/len(unvisited)),
                print 'avr: ' + str((ac+cnt*1.0/len(unvisited))/2.0)