In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [3]:
class embedding_mlp(nn.Module):
    def __init__(self,n_user_features,n_item_features,user_df,item_df,dim=128):
        super(embedding_mlp,self).__init__()
        self.user_features = nn.Embedding(n_user_features,dim,max_norm=1)
        self.item_features = nn.Embedding(n_item_features,dim,max_norm=1)

        self.user_df = user_df
        self.item_df = item_df


        total_neighbours = user_df.shape[1]+item_df.shape[1]

        self.dense1 = self.dense_layer(dim*total_neighbours,dim*total_neighbours//2)
        self.dense2 = self.dense_layer(dim*total_neighbours//2,dim)
        self.dense3 = self.dense_layer(dim,1)
        self.sigmoid = nn.Sigmoid()

    def dense_layer(self,in_features,out_features):
        return nn.Sequential(
            nn.Linear(in_features,out_features),
            nn.Tanh()
        )

    def forward(self,u,i,isTrain=True):
        user_ids = torch.LongTensor(self.user_df.loc[u].values)
        item_ids = torch.LongTensor(self.item_df.loc[i].values)
        user_features = self.user_features(user_ids)
        item_features = self.item_features(item_ids)

        uv = torch.cat([user_features,item_features],dim=1)

        uv = reshape(len(u),-1)

        uv = self.dense1(uv)

        uv = self.dense2(uv)

        uv = self.dense3(uv)

        if isTrain:
            uv = F.DropOut(uv)
        uv = torch.squeeze(uv)
        logit = self.sigmoid(uv)
        return logit