In [1]:
import numpy as np 
import pandas as pd 
import torch 

In [2]:
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from torch.optim import Adam
from torch.nn.parameter import Parameter

In [4]:
N, D_in, H = 64, 20, 10
safety_margin_size = 3
#total_item_num = 3

In [5]:
def euclidean_distance(a, b):
    return (a-b).pow(2).sum(1).sqrt()

In [6]:
def generator(user):
    item = torch.randn(1, D_in)
    return item

In [7]:
value_memory = torch.randn(H, D_in, requires_grad=True)
key_memory = torch.randn(H, D_in, requires_grad=True)

In [27]:
x = torch.randn(500, D_in)
y = torch.randn(500, D_in)

In [28]:
class Discriminator(torch.nn.Module):
    def __init__(self, D_in, H, N, key_memory, value_memory):
        super(Discriminator, self).__init__()
        self.key_memory = Parameter(torch.Tensor(H, D_in))
        self.value_memory = Parameter(torch.Tensor(H, D_in))
        self.key_memory.data = key_memory
        self.value_memory.data = value_memory
        #self.value_memory.requires_grad=True
        self.fc_erase = nn.Linear(D_in, D_in)
        self.fc_update = nn.Linear(D_in, D_in)
        self.D_in = D_in
        self.N = N
        self.H = H
    def forward(self, user, item):
        output_list = []
        e_I = torch.ones(self.H, self.D_in)
        
        for i in range(self.N):
            #Memory Adderssing
            Attention_weight = torch.empty(self.H)

            for j in range(self.H):
                Attention_weight[j] = -euclidean_distance(user[i].unsqueeze(0), self.key_memory[j,:])


            #select value memory by attention

            Attention_weight = Attention_weight.softmax(0)
            s = Attention_weight.matmul(self.value_memory)

            output = euclidean_distance(s, item[i].unsqueeze(0))
            output_list.append(output)
            
            #update value memory by item vector
            e_t = self.fc_erase(item[i].unsqueeze(0)).sigmoid()
            self.value_memory.data = self.value_memory * (e_I - Attention_weight.unsqueeze(0).t().matmul(e_t))

            a_t = self.fc_update(item[i].unsqueeze(0)).tanh()
            self.value_memory.data = self.value_memory + Attention_weight.unsqueeze(0).t().matmul(a_t)
        
        return output_list

In [None]:
class generator(torch.nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(4, 24)
        self.fc2 = nn.Linear(24, 36)
        self.fc3 = nn.Linear(36, 1) 
    def forward(self, user, item):
        output_list = []
        e_I = torch.ones(self.H, self.D_in)
        
        for i in range(self.N):
            #Memory Adderssing
            Attention_weight = torch.empty(self.H)

            for j in range(self.H):
                Attention_weight[j] = -euclidean_distance(user[i].unsqueeze(0), self.key_memory[j,:])


            #select value memory by attention

            Attention_weight = Attention_weight.softmax(0)
            s = Attention_weight.matmul(self.value_memory)

            output = euclidean_distance(s, item[i].unsqueeze(0))
            output_list.append(output)
            
            #update value memory by item vector
            e_t = self.fc_erase(item[i].unsqueeze(0)).sigmoid()
            self.value_memory.data = self.value_memory * (e_I - Attention_weight.unsqueeze(0).t().matmul(e_t))

            a_t = self.fc_update(item[i].unsqueeze(0)).tanh()
            self.value_memory.data = self.value_memory + Attention_weight.unsqueeze(0).t().matmul(a_t)
        
        return output_list

In [29]:
model = Discriminator(20, 10, 1, key_memory, value_memory)

In [30]:
def train_per_datapoint(user, item):
    ps_rslt = model(user.unsqueeze(0),item.unsqueeze(0))[0]
    neg_rslt = model(user.unsqueeze(0), generator(user))[0]
    return ps_rslt, neg_rslt

In [31]:
# use normal hinge_loss before implement of generator
def weighted_hinge_loss(ps_rslt, neg_rslt, safety_margin_size):
    return F.relu(safety_margin_size + ps_rslt - neg_rslt)

In [32]:
total_loss = []
optimizer = Adam(model.parameters(), lr= 0.01)
for epoch in range(50):
    losses = torch.empty(5)

    for i in range(5):
        ps_rslt, neg_rslt = train_per_datapoint(x[i],y[i])
        losses[i] = weighted_hinge_loss(ps_rslt, neg_rslt, 3)

    mini_batch_loss = torch.sum(losses)
    total_loss.append(float(mini_batch_loss.data))

    model.zero_grad()
    mini_batch_loss.backward()
    optimizer.step()

In [33]:
total_loss

[14.893839836120605,
 12.60797119140625,
 13.931467056274414,
 13.747971534729004,
 10.971602439880371,
 15.503976821899414,
 11.689532279968262,
 14.63906192779541,
 12.685401916503906,
 13.13146686553955,
 14.879408836364746,
 15.143271446228027,
 11.764472961425781,
 14.215774536132812,
 15.724404335021973,
 13.589129447937012,
 14.119624137878418,
 13.775588035583496,
 12.62105655670166,
 14.591506958007812,
 12.706850051879883,
 11.98979663848877,
 14.428289413452148,
 15.281702995300293,
 14.493289947509766,
 18.07111930847168,
 13.574992179870605,
 15.360875129699707,
 14.06322193145752,
 13.372254371643066,
 13.445112228393555,
 14.595959663391113,
 16.539819717407227,
 12.95510482788086,
 14.256826400756836,
 14.360271453857422,
 13.74833869934082,
 14.389134407043457,
 11.892964363098145,
 12.730216979980469,
 13.607656478881836,
 13.281892776489258,
 12.306159973144531,
 13.009481430053711,
 13.639678955078125,
 13.911109924316406,
 13.194938659667969,
 12.223461151123047,
 