### Import Packages

In [1]:
import pandas as pd
import numpy as np
from scipy.sparse import load_npz
from scipy import sparse
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn.functional as F

### Read data

In [2]:
df_val = pd.read_csv('df_val.csv')

In [3]:
R = load_npz('R_train.npz')

In [4]:
n_users,n_movies = R.shape

### Constructing accuracy function

In [5]:
one_idxs=[]
minus_one_idxs=[]
hold_out=[]
hold_out_minus=[]

In [6]:
for i in range(n_users):
    one_idxs.append(np.where(R[[i],:].toarray()[0] == 1)[0])
    minus_one_idxs.append(np.where(R[[i],:].toarray()[0] == -1)[0])
    hold_out.append(df_val.query(f"userId=={i} & rating==1").movieId.values)
    hold_out_minus.append(df_val.query(f"userId=={i} & rating==-1").movieId.values)

In [7]:
def accuracy_func(model,k=10):
    B = model.linear.weight.detach().to('cpu').numpy()
    x = np.float32(R.toarray())
    S = np.matmul(x,B)
    accuracy = []
    for i in range(n_users):
        output = S[i]
        np.put(output,one_idxs[i],-np.inf)
        np.put(output,minus_one_idxs[i],-np.inf)
        c = len(np.intersect1d(np.argsort(output)[::-1][:k],hold_out[i]))
        nc = len(np.intersect1d(np.argsort(output)[::-1][:k],hold_out_minus[i]))
#         acc = np.max([0,(c-nc)/(np.min([k,len(hold_out[i])+1]))]) ## Recal@K
        acc = np.max([0,(c-nc)/k]) ## HR@K
        accuracy.append(acc)
    return np.mean(accuracy)

### Preparing dataset for pytorch model

In [8]:
class MovieDataset(Dataset):
    def __init__(self,utility_matrix):
        self.utility  = utility_matrix
        
    def __len__(self):
        return self.utility.shape[0]
    def __getitem__(self,idx):
        
        user_vector = self.utility[[idx],:].toarray()[0]
        user_vector = torch.tensor(user_vector,dtype=torch.float32)
        
        return user_vector

In [9]:
train_dataset = MovieDataset(R)

In [10]:
BATCH_SIZE = 512
train_loader= DataLoader(dataset=train_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         )

### Pytorch model

In [11]:
class EASE(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
                              
        self.linear = nn.Linear(input_dim,input_dim,bias=False)
        nn.init.xavier_normal_(self.linear.weight)
        
        self.linear.weight.detach().fill_diagonal_(0) 
        
        
    def forward(self,x):
        x_rec = self.linear(x)
        
        return x_rec

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
model = EASE(n_movies).to(device)

In [14]:
optimizer = optim.Adam(model.parameters(),lr=0.001,weight_decay=1/5500)

In [15]:
loss_func = nn.MSELoss()

### Training Loop

In [16]:
epochs = 10
for epoch in range(epochs):
    model.train()
    train_losses = []
    for i,x in enumerate(train_loader):
        x = x.to(device).to(torch.float32)
        x_rec = model(x)

        cost = loss_func(x_rec,x)
        optimizer.zero_grad()
        ### for zero constrain on diagonal, we simply zero out gradients for diagonal parameters on each gradient update
        grads = torch.autograd.grad(cost,list(model.parameters())[0],retain_graph=True) 
        grads[0].fill_diagonal_(0)
        list(model.parameters())[0].backward(gradient=grads[0])
        optimizer.step()

        train_losses.append(cost.item())
    model.eval()
    acc = accuracy_func(model)
    

    print(f"Epoch {epoch + 1},train loss: {torch.tensor(train_losses).mean():.4f}, val accuracy:{acc:.4f}")


Epoch 1,train loss: 0.0522, val accuracy:0.0962
Epoch 2,train loss: 0.0329, val accuracy:0.1940
Epoch 3,train loss: 0.0273, val accuracy:0.2600
Epoch 4,train loss: 0.0255, val accuracy:0.2963
Epoch 5,train loss: 0.0249, val accuracy:0.3148
Epoch 6,train loss: 0.0248, val accuracy:0.3225
Epoch 7,train loss: 0.0247, val accuracy:0.3271
Epoch 8,train loss: 0.0247, val accuracy:0.3281
Epoch 9,train loss: 0.0249, val accuracy:0.3277
Epoch 10,train loss: 0.0248, val accuracy:0.3293
