In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as opt
# from tqdm import tqdm
from tqdm import tqdm_notebook as tqdm
from heapq import heappush, heappop
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import evaluation
import data_loader
# from model import MUD
import pdb

In [2]:
params = dict()
params['lr'] = 1e-1
params['batch_size'] = 64
params['epoch_limit'] = 20
params['w_decay'] = 0
params['negNum_test'] = 100
params['epsilon'] = 1e-2
params['negNum_train'] = 1
params['l_size'] = 16
params['gpu']= False


# params['lr_cf'] = 1e-1
# params['negNum_train'] = 4
# params["negNum_test"] = 1000
# params['epoch_limit_cf'] = 20
# params['w_decay_cf'] = 0
# params['batch_size'] = 64
# params['gpu'] = False
# params['l_size_cf'] = 16
# params['epsilon'] = 0.01


In [3]:
category = 'Baby'
print('Start loading data...')
train, val, test = data_loader.read_data(category)
item_price = data_loader.get_price(category)
item_related = data_loader.get_related(category)
distribution = data_loader.get_distribution(category)
trainset = data_loader.TransactionData(train, item_related, \
            item_price, distribution)
valset = data_loader.UserTransactionData(val, item_price, \
            trainset.itemNum, trainset.userHist)
testset = data_loader.UserTransactionData(test, item_price, \
            trainset.itemNum, trainset.userHist)
avg_rating = trainset.get_avgRating()
print('Finish loading data. Average rating score of training set: %.2f' %avg_rating)

Start loading data...
Finish loading data. Average rating score of training set: 4.20


In [4]:
class MF(nn.Module):
    """
        - userLen: the number of users
        - itemLen: the number of items
        - params: the parameters dict used for constructing model
            - l_size: latent dimension size
            - gpu: True/False, whether using GPU
            
    """
    def __init__(self, userLen, itemLen, params):
        super(MF, self).__init__()
        self.userNum = userLen
        self.itemNum = itemLen
        self.params = params
        if 'gpu' in params and params['gpu'] == True:
            self.device = 'cuda:0'
        else:
            self.device = 'cpu'

        l_size = params['l_size']
        
        """
            Initialize  global bias,
                        user bias,
                        item bias,
                        user embedding,
                        item embedding
        """
#         self.globalBias = avg_rating.to(self.device)
        
        self.uBias = nn.Embedding(userLen,1).to(self.device)
        self.uBias.weight.data = torch.zeros_like(self.uBias.weight.data)
        
        self.itemBias = nn.Embedding(itemLen,1).to(self.device)
        self.itemBias.weight.data = torch.zeros_like(self.itemBias.weight.data)
        
#         a = math.sqrt(2.0 / params['l_size'])
        self.uEmbed = nn.Embedding(userLen, l_size).to(self.device)
        self.uEmbed.weight.data.uniform_(-1, 1)
        
        self.itemEmbed = nn.Embedding(itemLen, l_size).to(self.device)
        self.itemEmbed.weight.data.uniform_(-1, 1)
    
    def forward(self, users, items):
        uE = self.uEmbed(users)
#         print('uE:',uE.shape)
        uB = self.uBias(users)
#         print('uB:',uB.shape)
        iE = self.itemEmbed(items)
#         print('iE:',iE.shape)
        iB = self.itemBias(items)
#         print('iB:',iB.shape)
#         gB = self.globalBias.expand(users.shape[0],1)
#         print('mul(uE, iE):',(torch.mul(uE, iE).sum(1)).shape)
        score = uB + iB + torch.mul(uE, iE).sum(1).view(-1,1)

#         score = self.globalBias + uB + iB + torch.mul(uE, iE).sum(1).view(-1,1)
#         print((self.globalBias + uB + iB).shape)
#         print(((uE*iE).sum(1)).shape)
#         score = (self.globalBias + uB + iB).view(-1) + (uE*iE).sum(1)
#         score = (uE*iE).sum(1)
#         print(score.shape)
        return score
        

In [5]:
print("Start training the rating model...")
model = MF(userLen = trainset.userNum, itemLen = trainset.itemNum, params = params)
optimizer = opt.SGD(model.parameters(), lr = params['lr'], weight_decay = params['w_decay'])

criterion_bpr = nn.BCELoss()


trainset.set_negN(params['negNum_train'])
trainLoader = DataLoader(trainset, batch_size = params['batch_size'], \
                        shuffle = True, num_workers = 0)
valset.set_negN(params['negNum_test'])
valLoader = DataLoader(valset, batch_size = 1, \
                        shuffle = True, num_workers = 0)
testset.set_negN(params['negNum_test'])
testLoader = DataLoader(testset, batch_size = 1, \
                        shuffle = True, num_workers = 0)

epsilon = params['epsilon']
epoch = 0
error = np.float('inf')

trainErrorList = []
valErrorList = []
valHistory = []
explodeTempreture = 3
convergenceTempreture = 3

Start training the rating model...


In [6]:
print("Starting training BPR model...")
epoch = 0
runningLoss = []
while epoch < params['epoch_limit']:
    epoch += 1
    print("Epoch " + str(epoch) + " training...")
    L = len(trainLoader.dataset)
    pbar = tqdm(total = L)
    for i, batchData in enumerate(trainLoader):
        optimizer.zero_grad()
        # get input
        users = torch.LongTensor(batchData['user']).to(model.device)
#         print(users.shape)
        items = torch.LongTensor(batchData['item']).to(model.device)
#         print(items.shape)
        pOut = model.forward(users,items).view(-1)
#         print(pOut.shape)


        negItems = torch.LongTensor(batchData['negItem']).reshape(-1).to(model.device)
#         print(negItems.shape)        
        nusers = users.view(-1,1) 
        nusers = nusers.expand(nusers.shape[0], params['negNum_train']).reshape(-1)
#         print(nusers.shape)
        nOut = model.forward(nusers, negItems).view(-1)
#         nOut = nOut.reshape(-1,params["negNum_train"])
#         print(nOut.shape)

        totalOut = torch.cat((pOut,nOut))
#         print (totalOut.shape)
        
        target = torch.FloatTensor([1]*len(pOut)+[0]*len(nOut))
#         print(target.shape)
#         break
        
        m = nn.Sigmoid()
        loss = criterion_bpr(m(totalOut),target)
        runningLoss.append(loss.item())
#         print(loss)
        loss.backward()
        optimizer.step()
        if (i+1) >= 50:
            pbar.set_postfix({'loss' : '{0:1.5f}'.format(np.mean(np.array(runningLoss[-50:])))})

        pbar.update(users.shape[0])
    pbar.close()

 

    #validation
    print("Epoch " + str(epoch) + " validating...")
    with torch.no_grad():
        L = len(valLoader.dataset)
        pbar = tqdm(total = L)
#         model.eval()
        scoreDict = dict()
        for i, batchData in enumerate(valLoader):
            if i > 1000:
                  break
            user = torch.LongTensor(batchData['user'])#.to(model.device)
            posItems = torch.LongTensor(batchData['posItem'])#.to(model.device)
            negItems = torch.LongTensor(batchData['negItem'])#.to(model.device)
            budget = torch.FloatTensor(batchData['budget'])#.to(model.device)
            posPrices = torch.FloatTensor(batchData['posPrice'])#.to(model.device)
            negPrices = torch.FloatTensor(batchData['negPrice'])#.to(model.device)

            items = torch.cat((posItems, negItems),1).view(-1)
            prices = torch.cat((posPrices, negPrices),1).view(-1)
            users = user.expand(items.shape[0])

            out = model.forward(users,items)
            scoreHeap = list()
            for j in range(out.shape[0]):
                gt = False
                if j < posItems.shape[1]:
                    gt = True
#                 if prices[j] > budget:
#                     heappush(scoreHeap, (100, (0 + items[j].cpu().numpy(), gt)))
#                 else:
                heappush(scoreHeap, (1 - out[j].cpu().numpy(), (0 + items[j].cpu().numpy(), gt)))
            scores = list()
            candidate = len(scoreHeap)
            for k in range(candidate):
                scores.append(heappop(scoreHeap))
            pbar.update(1)
            scoreDict[user[0]] = (scores, posItems.shape[1])
        pbar.close()

    valHistory.append(evaluation.ranking_performance(scoreDict,10))
#     valError = 1 - valHistory[-1]["avg_ndcg"][0]
#     valErrorList.append(valError)
#     improvement = np.abs(error - valError)
#     error = valError
#     if improvement < epsilon:
#         print("stop early")
#         break



Starting training BPR model...
Epoch 1 training...


torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([24])
torch.Size([24])

Epoch 1 validating...



	Precision@: {1:0.01998001998; 5: 0.0141858141858; 10: 0.0143856143856}
	Recall@: {1:0.0138187570006; 5: 0.0498795216328; 10: 0.100132082924}
	NDCG@: {1:0.01998001998; 5: 0.0382657248515; 10: 0.0553858964535}
Epoch 2 training...


torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size

KeyboardInterrupt: 

In [None]:
# test
print("starting test...")
L = len(testLoader.dataset)
pbar = tqdm(total = L)
scoreDict = dict()
with torch.no_grad(): 
    for i, batchData in enumerate(testLoader):
#         if i>1000:
#             break
        user = torch.LongTensor(batchData['user']).to(model.device)
        posItems = torch.LongTensor(batchData['posItem']).to(model.device)
        negItems = torch.LongTensor(batchData['negItem']).to(model.device)
        budget = torch.FloatTensor(batchData['budget']).to(model.device)
        posPrices = torch.FloatTensor(batchData['posPrice']).to(model.device)
        negPrices = torch.FloatTensor(batchData['negPrice']).to(model.device)

        items = torch.cat((posItems, negItems),1).view(-1)
        prices = torch.cat((posPrices, negPrices),1).view(-1)
        users = user.expand(items.shape[0])

        out = model.forward(users,items)
        scoreHeap = list()
        for j in range(out.shape[0]):
            gt = False
            if j < posItems.shape[1]:
                gt = True
#             if prices[j] > budget:
#                 heappush(scoreHeap, (1000, (0 + items[j].cpu().numpy(), gt)))
#             else:
#                 heappush(scoreHeap, (1 - out[j].cpu().numpy(), (0 + items[j].cpu().numpy(), gt)))
            heappush(scoreHeap, (1 - out[j].cpu().numpy(), (0 + items[j].cpu().numpy(), gt)))
        scores = list()
        candidate = len(scoreHeap)
        for k in range(candidate):
            scores.append(heappop(scoreHeap))
        pbar.update(1)
        scoreDict[user[0]] = (scores, posItems.shape[1])
pbar.close()
testResult = evaluation.ranking_performance(scoreDict,10)

In [None]:
len(scoreDict)

In [None]:
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward()

In [None]:
input

In [None]:
target

In [None]:
criterion_MUD(m(input),target)