#### **Library imports**

In [5]:
# Library imports
import pyforest
import numpy as np
import pandas as pd
import os
from matplotlib import pyplot as plt
from tqdm import tqdm
from pprint import pprint
from time import sleep
import time

from turtle import forward
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch.utils.data import DataLoader, TensorDataset

#### **Hyperparams and loading data**

In [2]:
train_edges = np.load('data/train_edges.npy')
users = torch.LongTensor(train_edges[:, 0])
items = torch.LongTensor(train_edges[:, 1])
ratings = torch.FloatTensor(train_edges[:, 2])

n_users = 943 
n_items = 1682
n_samples = len(ratings)

#### **Defining collaborative filtering**

In [3]:
class CollaborativeFiltering(Module):
    def __init__(self, n_users, n_items, n_factors):
        super(CollaborativeFiltering, self).__init__()
        self.user_emb = nn.Embedding(n_users, n_factors)
        self.item_emb = nn.Embedding(n_items, n_factors)

    def forward(self, user, item):
        u = self.user_emb(user)
        i = self.item_emb(item)
        dot = (u * i).sum(1)
        return torch.sigmoid(dot)

#### **Code for meta attack**

In [23]:
# start execution
start_time = time.time()

# some hyperparams
lr = 1
T = 10
Delta = 10
n_factors = 64

# store loss results in this list and later convert to dataframe 
results = []

# list of perturbations
perturbations = dict()
perturbations['edges'] = []
perturbations['metagrad'] = []

# print hyperparam config
print('-> Learning rate: ', lr)
print('-> T: ', T)
print('-> Delta: ', Delta)
print('-> Embedding size: ', n_factors)

for delta in range(Delta):
    # reload the users, items and ratings tensors
    users = torch.LongTensor(train_edges[:, 0])
    items = torch.LongTensor(train_edges[:, 1])
    ratings = torch.FloatTensor(train_edges[:, 2])

    # add those perturbations to "ratings"
    for index in perturbations['edges']:
        ratings[index] = 1

    # set requires_grad for ratings, to compute meta gradients
    ratings.requires_grad_()

    # makes code reproducible
    torch.manual_seed(0)

    # define model and loss
    model = CollaborativeFiltering(n_users, n_items, n_factors)
    p1, p2 = model.parameters()
    loss_fn = nn.BCELoss(reduction = 'mean')
    model.train()

    # inner loop training process
    for i in range(T):
        y_hat = model(users, items)
        loss = loss_fn(y_hat, ratings)
        results.append([delta, i, loss.item()])
        
        p1_grad = torch.autograd.grad(loss, p1, create_graph=True)
        p2_grad = torch.autograd.grad(loss, p2, create_graph=True)

        # compute inner parameter gradients
        with torch.no_grad():
            p1_new = p1 - lr * p1_grad[0]
            p2_new = p2 - lr * p2_grad[0]
            p1.copy_(p1_new)
            p2.copy_(p2_new)
    
    # compute meta gradient
    meta_grad = torch.autograd.grad(loss, ratings)[0]

    # select best edge to perturb
    max_meta_grad = -math.inf
    edge_to_add = -1
    for i in range(n_samples):
        if ratings[i] == 0: # search over only negative edges
            if meta_grad[i] > max_meta_grad:
                max_meta_grad = meta_grad[i]
                edge_to_add = i 

    perturbations['edges'].append(edge_to_add)
    perturbations['metagrad'].append(max_meta_grad.item())

# compute execution time
exec_time = int(time.time() - start_time)
exec_time = time.strftime("%Hh %Mm %Ss", time.gmtime(exec_time))
print('-> Execution time: {}'.format(exec_time))

# store results in CSV files
results = pd.DataFrame(results, columns = ['perturbs', 'iters', 'loss'])
results.to_csv('results/losses_Delta={}_T={}_LR={}_Factors={}'.format(Delta, T, lr, n_factors))

perturbations = pd.DataFrame(perturbations)
perturbations.to_csv('results/perturbations_Delta={}_T={}_LR={}_Factors={}'.format(Delta, T, lr, n_factors))

-> Learning rate:  1
-> T:  10
-> Delta:  10
-> Embedding size:  64
-> Execution time: 00h 00m 20s


#### **Checking stored results**

In [28]:
perturbations.head()

Unnamed: 0,edges,metagrad
0,141472,0.000193
1,152311,0.000175
2,173390,0.000174
3,5665,0.000164
4,112886,0.000164


In [29]:
results.head()

Unnamed: 0,perturbs,iters,loss
0,0,0,4.066596
1,0,1,4.065507
2,0,2,4.064427
3,0,3,4.062918
4,0,4,4.061419
