In [1]:
import numpy as np 
import pandas as pd 

import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.parameter import Parameter
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

from math import log

In [2]:
#some parameters of this model
N, D_in, H = 64, 20, 10
safety_margin_size = 3
Max_sampleNum = 50
total_item_num = 500

In [3]:
def euclidean_distance(a, b):
    return (a-b).pow(2).sum(1).sqrt()

In [4]:
value_memory = torch.randn(H, D_in, requires_grad=True)
key_memory = torch.randn(H, D_in, requires_grad=True)

In [5]:
x = torch.randn(500, D_in)
y = torch.randn(500, D_in)

In [6]:
class CandidateDataset(Dataset):
    def __init__(self, x, y):
        self.len = x.shape[0]
        self.x_data = torch.as_tensor(x,dtype=torch.float)
        self.y_data = torch.as_tensor(y,dtype=torch.float)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

In [7]:
my_dataset = CandidateDataset(x, y)
train_loader = DataLoader(dataset = my_dataset, batch_size = 32, shuffle=True)

In [34]:
class generator(nn.Module):
    def __init__(self, D_in):
        super(generator, self).__init__()

        self.fc1 = nn.Linear(2*D_in, 32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 1)  # Prob of Left

    def forward(self, x, y):
        cat_xy = torch.cat([x,y],1)
        hidden1 = F.relu(self.fc1(cat_xy))
        hidden2 = F.relu(self.fc2(hidden1))
        output = torch.sigmoid(self.fc3(hidden2))
        return output

In [9]:
class Discriminator(torch.nn.Module):
    def __init__(self, D_in, H, N, key_memory, value_memory):
        super(Discriminator, self).__init__()
        self.key_memory = Parameter(torch.Tensor(H, D_in))
        self.value_memory = Parameter(torch.Tensor(H, D_in))
        self.key_memory.data = key_memory
        self.value_memory.data = value_memory
        #self.value_memory.requires_grad=True
        self.fc_erase = nn.Linear(D_in, D_in)
        self.fc_update = nn.Linear(D_in, D_in)
        self.D_in = D_in
        self.N = N
        self.H = H
    def forward(self, user, item):
        output_list = []
        e_I = torch.ones(self.H, self.D_in)
        
        for i in range(self.N):
            #Memory Adderssing
            Attention_weight = torch.empty(self.H)

            for j in range(self.H):
                Attention_weight[j] = -euclidean_distance(user[i].unsqueeze(0), self.key_memory[j,:])


            #select value memory by attention

            Attention_weight = Attention_weight.softmax(0)
            s = Attention_weight.matmul(self.value_memory)

            output = euclidean_distance(s, item[i].unsqueeze(0))
            output_list.append(output)
            
            #update value memory by item vector
            e_t = self.fc_erase(item[i].unsqueeze(0)).sigmoid()
            self.value_memory.data = self.value_memory * (e_I - Attention_weight.unsqueeze(0).t().matmul(e_t))

            a_t = self.fc_update(item[i].unsqueeze(0)).tanh()
            self.value_memory.data = self.value_memory + Attention_weight.unsqueeze(0).t().matmul(a_t)
        
        return output_list

In [10]:
def ngtv_spl_by_unifm_dist(user):
    ngtv_item = torch.randn(1, D_in)
    return ngtv_item

In [62]:
def ngtv_spl_by_generator(user):
    ngtv_item_id = list(torch.multinomial(prob_ngtv_item, 1, replacement=True))[0]
    ngtv_item = dict_item_id2vec[candidate_ngtv_item[ngtv_item_id].item()]
    return ngtv_item_id, ngtv_item

In [63]:
def G_pretrain_per_datapoint(user, generator, model):
    ngtv_item_id, ngtv_item = generator(user)
    neg_rslt = model(user.unsqueeze(0), ngtv_item)[0]
    return -neg_rslt, ngtv_item_id

In [27]:
def D_pretrain_per_datapoint(user, item, generator):
    ps_rslt = model(user.unsqueeze(0),item.unsqueeze(0))[0]
    for i in range(Max_sampleNum):
        weight = log((total_item_num-1)//(i+1) + 1)
        neg_rslt = model(user.unsqueeze(0), generator(user))[0]
        loss = weighted_hinge_loss(ps_rslt, neg_rslt, safety_margin_size, weight)
        if(loss > 0):
            break;
    return loss

In [13]:
def weighted_hinge_loss(ps_rslt, neg_rslt, safety_margin_size, weight):
    return weight * F.relu(safety_margin_size + ps_rslt - neg_rslt)

In [69]:
# pre-training of discriminator
model = Discriminator(20, 10, 1, key_memory, value_memory)
total_loss = []
optimizer = Adam(model.parameters(), lr= 0.01)
for epoch in range(50):
    epoch_loss = 0.
    for batch_id, (x, y) in enumerate(train_loader):
        losses = torch.empty(x.shape[0])
        for i in range(x.shape[0]):
            losses[i] = D_pretrain_per_datapoint(x[i],y[i], ngtv_spl_by_unifm_dist)

        mini_batch_loss = torch.sum(losses)
        epoch_loss += float(mini_batch_loss.data)
        
        model.zero_grad()
        mini_batch_loss.backward()
        optimizer.step()
    print('epoch_' + str(epoch) + ':  loss=' + str(epoch_loss))
    total_loss.append(epoch_loss)    

epoch_0:  loss=9312.409759521484
epoch_1:  loss=9527.794250488281
epoch_2:  loss=9212.543060302734


KeyboardInterrupt: 

In [75]:
dict_user_id2vec = {}
dict_item_id2vec = {}
candidate_ngtv_item = torch.randn(100)
prob_ngtv_item = torch.randn(100)

for i in candidate_ngtv_item:
    dict_item_id2vec[i.item()] = torch.randn(1,20)

In [77]:
#pre-training of generator

gener = generator(20)
total_reward = []
total_loss = []
optimizer_G = Adam(gener.parameters(), lr= 0.01)
for epoch in range(50):
    epoch_reward = 0.
    epoch_loss = 0.
    for batch_id, (x, y) in enumerate(train_loader):
        reward = torch.empty(x.shape[0])
        mini_batch_loss = torch.empty(x.shape[0])
        
        for data_point in range(x.shape[0]):
            #state = x[i]
            #action_pool = []
            #reward_pool = []
            candidate_ngtv_item = torch.randn(100)
            prob_ngtv_item = torch.randn(100)
            for i in candidate_ngtv_item:
                dict_item_id2vec[i.item()] = torch.randn(1,20)
            
            for i in range(candidate_ngtv_item.shape[0]):
                prob_ngtv_item[i] = gener(x[data_point].unsqueeze(0), dict_item_id2vec[candidate_ngtv_item[i].item()])
            prob_ngtv_item = torch.softmax(prob_ngtv_item,0)
            reward[data_point], ngtv_item_id = G_pretrain_per_datapoint(x[data_point], ngtv_spl_by_generator, model)
            mini_batch_loss[data_point] = (-reward[data_point]) * torch.log(prob_ngtv_item[ngtv_item_id])
        rewards = torch.sum(reward)
        mini_batch_losses = torch.sum(mini_batch_loss)
        epoch_loss = epoch + float(mini_batch_losses.data)
        epoch_reward = epoch + float(rewards.data)
        
        gener.zero_grad()
        mini_batch_losses.backward()
        optimizer_G.step()
    
    print('epoch_' + str(epoch) + ':  loss=' + str(epoch_loss) + '  reward=  ' + str(epoch_reward))
    total_loss.append(epoch_loss)    

epoch_0:  loss=-422.4139099121094  reward=  -91.8388442993164
epoch_1:  loss=-456.72161865234375  reward=  -98.9530258178711
epoch_2:  loss=-460.8658752441406  reward=  -97.71232604980469
epoch_3:  loss=-449.4365234375  reward=  -96.67015075683594
epoch_4:  loss=-427.617919921875  reward=  -90.51976013183594
epoch_5:  loss=-444.9988708496094  reward=  -95.82398986816406
epoch_6:  loss=-379.0965270996094  reward=  -77.06893920898438
epoch_7:  loss=-435.7945861816406  reward=  -88.72418212890625
epoch_8:  loss=-393.2669677734375  reward=  -80.82926940917969
epoch_9:  loss=-440.8939514160156  reward=  -89.49713897705078
epoch_10:  loss=-410.2823181152344  reward=  -83.31441497802734
epoch_11:  loss=-446.33758544921875  reward=  -89.46939849853516
epoch_12:  loss=-450.8667297363281  reward=  -88.67179870605469
epoch_13:  loss=-398.5673828125  reward=  -78.1609878540039
epoch_14:  loss=-432.98089599609375  reward=  -83.81454467773438


KeyboardInterrupt: 