### Imports

In [None]:
# import the pytorch library into environment and check its version
import os
import torch
import numpy as np
print("Using torch", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.1+cu118.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

### Data Preprocessing

In [None]:
from torch_geometric.data import Data, NeighborSampler
from torch_geometric.utils import train_test_split_edges
from torch_geometric.utils import to_networkx
import re

#### Metadata Preprocessing

In [None]:
# get metadata of items
file_path = "metadata.txt"
with open(file_path, 'r') as file:
    metadata = [line.strip() for line in file.readlines() if line.strip()]
len(metadata)

In [None]:
# map org_id to new_id
import csv
file_path = 'item_list.txt'
org_remap_dict = {}
with open(file_path, 'r') as file:
    next(file)
    for line in file:
        org_id, remap_id = line.strip().split()
        org_remap_dict[org_id] = int(remap_id)
len(org_remap_dict)

In [None]:
descriptions = {}
titles = {}
prices = {}
for line in metadata:
    item_id = re.search(r"'asin'\s*:\s*'([^']+)'", line).group(1)
    new_id = org_remap_dict[item_id]
    price = re.search(r"'price': (\d+\.\d+)", line)
    if price:
        price = float(price.group(1))
        prices[new_id] = price
    title = re.search(r"'title': '([^']*)'", line)
    if title:
        title = title.group(1)
        if len(title) > 300:
            title = title[0:300]  
        titles[new_id] = title
    descrip_pos = line.find("'description':")
    title_pos = line.find("'title':")
    if descrip_pos != -1 and title_pos != -1:
        description = line[descrip_pos + len("'description':"):title_pos].strip()
        if len(description) > 300:
            description = description[0:300]
        descriptions[new_id] = description
print(len(descriptions))
print(len(titles))
print(len(prices))

In [None]:
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
embed_descriptions = {}
for key, sentence in descriptions.items():
    if len(sentence) > 200:
        small = sentence[0:200]
    inputs = tokenizer(small, return_tensors="pt")
    outputs = model(**inputs)
    cls_embedding = outputs.last_hidden_state[:, 0, :]
    embed_descriptions[key] = cls_embedding.detach().tolist()

# with open('descriptions.pkl', 'wb') as file:
#     pickle.dump(embed_descriptions, file)

embed_titles = {}
for key, sentence in titles.items():
    if len(sentence) > 200:
        small = sentence[0:200]    
    inputs = tokenizer(sentence, return_tensors="pt")
    outputs = model(**inputs)
    cls_embedding = outputs.last_hidden_state[:, 0, :]
    embed_titles[key] = cls_embedding.detach().tolist()

# with open('titles.pkl', 'wb') as file:
#     pickle.dump(embed_titles, file)

In [None]:
import pickle
file_path = 'descriptions.pkl'
with open(file_path, 'rb') as file:
    data = pickle.load(file)
file_path2 = 'titles.pkl'
with open(file_path2, 'rb') as file2:
    data2 = pickle.load(file2)
embed_descriptions = data
embed_titles = data2

In [None]:
# gets mean value of each feature
from statistics import mean
mean_value = mean(prices.values())
descrip_array = np.array(list(embed_descriptions.values()))
title_array = np.array(list(embed_titles.values()))
mean_descrip = np.mean(descrip_array, axis = 0)
mean_title = np.mean(title_array, axis = 0)
mean_price = np.array([mean_value])

In [None]:
# gets features for each node
mean_price_tensor = torch.tensor(np.tile(mean_price, (91599, 1)))
mean_title_tensor = torch.tensor(np.tile(mean_title, (91599, 1)))
mean_descrip_tensor = torch.tensor(np.tile(mean_descrip, (91599, 1)))
for index, value in embed_titles.items():
    mean_title_tensor[index] = torch.tensor(value)
for index, value in embed_descriptions.items():
    mean_descrip_tensor[index] = torch.tensor(value)

In [None]:
# INITIAL NODE EMBEDDING
X = torch.cat((mean_title_tensor, mean_descrip_tensor, mean_price_tensor), dim=1)
# torch.save(X, 'X.pt')

#### Training Preprocessing

In [None]:
user_train_edge_index = []
item_train_edge_source = []
item_train_edge_target = []
seen_edges = set()

file_path = 'train.txt'
with open(file_path, 'r') as file:
    for line in file:
        values = list(map(int, line.strip().split()))
        user = values[0]
        items = values[1:]
        user_edges = [(user, item) for item in items]
        # for i, item1 in enumerate(items):
        #     for item2 in items[i+1:]:
        #         if item1 < item2:
        #             if (item1,item2) not in seen_edges:
        #                 seen_edges.add((item1,item2))
        #                 item_train_edge_source.append(item1)
        #                 item_train_edge_target.append(item2)
        #         else:
        #             if (item2,item1) not in seen_edges:
        #                 seen_edges.add((item2,item1))
        #                 item_train_edge_source.append(item2)
        #                 item_train_edge_target.append(item1)
        user_train_edge_index.extend(user_edges)


In [None]:
user_train_edge_index_np = np.array(user_train_edge_index, dtype=np.int64).T
user_train_edge_index = torch.tensor(user_train_edge_index_np, dtype=torch.long).contiguous()
torch.save(user_train_edge_index, 'user_train_edge_index.pt')

# item_train_edge_index = torch.tensor([item_train_edge_source, item_train_edge_target], dtype=torch.long)

# torch.save(item_train_edge_index, 'item_train_edges.pt')
item_train_edge_index = torch.load('item_train_edges.pt')

In [None]:
user_train_edge_index = torch.load('user_train_edge_index.pt')
user_train_edge_index.shape

#### Test Preprocessing

In [None]:
user_test_edge_index = []
item_test_edge_source = []
item_test_edge_target = []
seen_edges = set()

file_path = 'test.txt'
with open(file_path, 'r') as file:
    for line in file:
        values = list(map(int, line.strip().split()))
        user = values[0]
        items = values[1:]
        user_edges = [(user, item) for item in items]
        for i, item1 in enumerate(items):
            for item2 in items[i+1:]:
                if item1 < item2:
                    if (item1,item2) not in seen_edges:
                        seen_edges.add((item1,item2))
                        item_test_edge_source.append(item1)
                        item_test_edge_target.append(item2)
                else:
                    if (item2,item1) not in seen_edges:
                        seen_edges.add((item2,item1))
                        item_test_edge_source.append(item2)
                        item_test_edge_target.append(item1)
        user_test_edge_index.extend(user_edges)


In [None]:
# user_test_edge_index_np = np.array(user_test_edge_index, dtype=np.int64).T
# user_test_edge_index = torch.tensor(user_test_edge_index_np, dtype=torch.long).contiguous()

# item_test_edge_index = torch.tensor([item_test_edge_source, item_test_edge_target], dtype=torch.long)

# torch.save(item_test_edge_index, 'item_test_edges.pt')
item_test_edge_index = torch.load('item_test_edges.pt')

#### graph dataset

In [None]:
# split between training message edges, training supervision edges, and validation edges

# item_train_split = Data(edge_index=item_train_edge_index)
# # user_train_split = Data(edge_index = user_train_edge_index)
# item_train_message, item_train_super, item_val = train_test_split_edges(item_train_split, val_ratio = 0.1, test_ratio = 0.1)
# # user_train_message_index, user_train_super_index, user_val_index = train_test_split_edges(user_train_split, val_ratio = 0.1, test_ratio = 0.1)
# item_train_message_index = item_train_message.edge_index
# item_train_super_index = item_train_super.edge_index
# item_val_index = item_val.edge_index

# num_edges = item_train_edge_index.size(1)
# edge_indices = np.arange(num_edges)
# np.random.shuffle(edge_indices)
# split1_size = int(0.1 * num_edges)
# split2_size = int(0.1 * num_edges)
# split3_size = num_edges - split1_size - split2_size
# split1_indices = edge_indices[:split1_size]
# split2_indices = edge_indices[split1_size:split1_size + split2_size]
# split3_indices = edge_indices[split1_size + split2_size:]

# item_train_message_index = item_train_edge_index[:, split1_indices]
# item_train_super_index = item_train_edge_index[:, split2_indices]
# item_val_index = item_train_edge_index[:, split3_indices]

# torch.save(item_train_message_index, 'item_train_message_index.pt')
# torch.save(item_train_super_index, 'item_train_super_index.pt')
# torch.save(item_val_index, 'item_val_index.pt')



In [None]:
# import numpy as np

# item_train_super_index = torch.load('item_train_super_index.pt')
# num_columns = item_train_super_index.shape[1]
# num_indices_to_select = 150000
# indices = np.random.choice(num_columns, size=num_indices_to_select, replace=False)
# item_train_super_index_random = item_train_super_index[:, indices]
# torch.save(item_train_super_index_random,'item_train_super_index_random.pt')

In [None]:
X = torch.load('X.pt')
# item_test_edge_index = torch.load('item_test_edges.pt')
# item_train_edge_index = torch.load('item_train_edges.pt')
# item_train_message_index = torch.load('item_train_message_index.pt')
# item_train_super_index = torch.load('item_train_super_index.pt')
# item_val_index = torch.load('item_val_index.pt')

In [None]:
from torch_geometric.utils import to_undirected
# item_train_message_index = to_undirected(item_train_message_index)
# item_train_super2_index = to_undirected(item_train_super_index)
item_val_message_index = torch.cat((item_train_message_index, item_train_super_index), dim=1)
item_val_2_index = to_undirected(item_val_index)
item_test_message_index = torch.cat((item_val_message_index, item_val_2_index), dim=1)

# user_val_message_index = torch.cat((user_train_message_index, user_train_super_index), dim=1)

# torch.save(item_train_message_index, 'item_train_message_index.pt')

In [None]:
import numpy as np

item_test_edge_index = torch.load('item_test_edges.pt')
num_columns = item_test_edge_index.shape[1]
num_indices_to_select = 1000
indices = np.random.choice(num_columns, size=num_indices_to_select, replace=False)
item_test_edge_index_random = item_test_edge_index[:, indices]

In [None]:
# item_train_data = Data(x = X, edge_index=item_train_message_index, y = None, edge_label_index = item_train_super_index_random).to(device)
# item_val_data = Data(x = X, edge_index=item_val_message_index, y = None, edge_label_index = item_val_index)
item_test_data = Data(x = X, edge_index=item_test_message_index, y = None, edge_label_index = item_test_edge_index_random)


In [None]:
print("Number of the nodes in training, validation and test data are", item_train_data.num_nodes, item_val_data.num_nodes, item_test_data.num_nodes)
print("Number of the edges in training, validation and test data are", item_train_data.num_edges, item_val_data.num_edges, item_test_data.num_edges)
print("Number of features:", item_train_data.num_features)

In [None]:
from torch_geometric.loader import LinkNeighborLoader

neighbor_loader = LinkNeighborLoader(item_train_data, num_neighbors=[2,2,2], edge_label_index = item_train_data.edge_label_index, neg_sampling_ratio = 1.0, batch_size=150, shuffle=True, subgraph_type = 'bidirectional')


In [None]:
sampled_data = next(iter(neighbor_loader))
print(sampled_data)

In [None]:
len(neighbor_loader)

In [None]:
!nvidia-smi

### Model with Attention Diffusion

#### Model Architecture (Edge attributes not included)

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv

class GATModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads = 8, dropout=0.3):
        super(GATModel, self).__init__()
        self.gat1 = GATv2Conv(in_channels, hidden_channels, heads=num_heads)
        self.gat2 = GATv2Conv(hidden_channels * num_heads, hidden_channels, heads=num_heads)
        self.gat3 = GATv2Conv(hidden_channels * num_heads, out_channels, heads=1)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout)
        
    def forward(self, x, edge_index):
        output = self.dropout(x) if self.training else x
        output = self.gat1(output, edge_index)
        output = self.relu(output)
        output = self.dropout(output) if self.training else output
        output = self.gat2(output, edge_index)
        output = self.relu(output)
        output = self.dropout(output) if self.training else output
        output = self.gat3(output, edge_index)        
        return output

In [None]:
model = GATModel(in_channels = item_train_data.num_features, hidden_channels=256, out_channels=128).to(device)
model

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.00005)

#### Helper Functions

##### Similarity (inner product)

In [None]:
def compute_similarity(node_embs, edge_index):
    result = (node_embs[edge_index[0], :] * node_embs[edge_index[1], :]).sum(dim=1, keepdim=True)
    return result

##### Negative Sampling (Random vs Hard Negative)

In [None]:
from torch_geometric.utils import negative_sampling
import networkx as nx

## G_item_train = nx.Graph(user_train_edge_index)

# def pos_sample(edges, nodes, num_samples, batch):


def neg_sample(edges, nodes, num_samples, type = "random", batch = None):
    if type == "random":
        neg_edge_index = negative_sampling(edge_index = edges, num_nodes = nodes, num_neg_samples = num_samples)
        return neg_edge_index
    # if type == "neg":
    #     random_nodes = random.sample(nodes, 500)
    #     return random_nodes
    # if type == "hard":
    #     personalized_pagerank_all = []
        
        

##### Loss Functions

In [None]:
loss_fn1 = torch.nn.BCEWithLogitsLoss()
margin = 1.0
loss_fn2 = torch.nn.MarginRankingLoss(margin=margin)

#### Model Training

In [None]:
def train(model, loader, optimizer, loss_fn):
    loss = 0
    model.train()
    for data in loader:
        node_embs = model(data.x, data.edge_index)
        # neg_edge_index = neg_sample(data.edge_index, data.num_nodes, data.edge_label_index.shape[1])
        # edges = torch.cat((data.edge_label_index, neg_edge_index), dim = 1)
        # edge_labels = torch.cat((torch.ones(data.edge_label_index.shape[1], 1), torch.zeros(data.edge_label_index.shape[1], 1)), dim = 0)
        # similarity = compute_similarity(node_embs, edges)
        # loss = loss_fn(similarity, edge_labels)
        similarity = compute_similarity(node_embs, data.edge_label_index)
        loss = loss_fn(similarity.to(device), data.edge_label.view(-1, 1).to(device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return loss

#### Model Testing

In [None]:
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def test(model, data):
    model.eval()
    out = model(data.x, data.edge_index)  # use `edge_index` to perform message passing
    out = compute_similarity(out, data.edge_label_index).view(-1).sigmoid()  # use `edge_label_index` to compute the loss
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

#### Model and Plots

In [None]:
epochs = 7

for epoch in range(1, epochs + 1):
    loss = train(model, neighbor_loader, optimizer, loss_fn1)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
torch.save(model.state_dict(), 'saved_model.pth')

In [None]:
model = GATModel(in_channels = item_train_data.num_features, hidden_channels=256, out_channels=128)
state_dict = torch.load('saved_model.pth')
model.load_state_dict(state_dict)
model.to(device)

In [None]:
from torch_geometric.loader import LinkNeighborLoader

neighbor_loader = LinkNeighborLoader(item_test_data, num_neighbors=[1,1,1], edge_label_index = item_test_data.edge_label_index, neg_sampling_ratio = 1.0, batch_size=50, shuffle=True, subgraph_type = 'induced')
datas = next(iter(neighbor_loader))
print(datas)

In [None]:
datas = next(iter(neighbor_loader))
datas.to(device)
test_auc = test(model, datas)
test_auc

In [None]:
# Plotting function
def plot_curves(curves):
    epochs = range(1, len(curves["train"]) + 1)

    plt.figure(figsize=(10, 5))

    # Plot training loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, curves["train"], label='Training Loss')
    plt.title('Training Loss over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.xticks(epochs)
    plt.legend()

    # Plot validation and test metrics
    plt.subplot(1, 2, 2)
    plt.plot(epochs, curves["valid"], label='Validation Metric', color='orange')
    plt.plot(epochs, curves["test"], label='Test Metric', color='green')
    plt.title('Validation and Test Metrics over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Metric')
    plt.xticks(epochs)
    plt.legend()

    plt.tight_layout()
    plt.show()

# curves
train_curve = []
valid_curve = []
test_curve = []

# Running MODEL
epochs = 10

best_val_auc = final_test_auc = 0
for epoch in range(1, epochs + 1):
    loss = train(model, item_train_data, optimizer, loss_fn1)
    valid_auc = test(model, item_val_data)
    test_auc = test(model, item_test_data)
    if valid_auc > best_val_auc:
        best_val_auc = valid_auc
        final_test_auc = test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {valid_auc:.4f}, Test: {test_auc:.4f}')
    train_curve.append(loss)
    valid_curve.append(valid_auc)
    test_curve.append(test_auc)
    
curves = {"train": train_curve, "valid": valid_curve, "test": test_curve}
print('Best Validation Metric: {}'.format(best_val_auc))
print('Test Metric: {}'.format(final_test_auc))

# plot
plot_curves(curves)

### Model with Collaborative Filtering

#### Message Passing for User Embedding

In [None]:
import random
def randUser(numUsers):
    random_numbers = [random.randint(0, numUsers - 1) for _ in range(100)]
    tensor_users = torch.tensor(random_numbers)
    return tensor_users

In [None]:
user_batch = randUser(52642)

In [None]:
item_train_edge_index = torch.load('item_train_message_index.pt')
X = torch.load('X.pt')

In [None]:
source_nodes = []
target_nodes = []

In [None]:
from collections import defaultdict
import random
import pickle
file_path1 = 'user_train_message.pkl'
file_path2 = 'user_train_super.pkl'
file_path3 = 'user_train_neg.pkl'
user_train_message = defaultdict(list)
user_train_supervision = defaultdict(list)
user_train_neg = defaultdict(list)

file_path = 'train.txt'
with open(file_path, 'r') as file:
    for line in file:
        values = list(map(int, line.strip().split()))
        user = values[0]
        items = values[1:]
        size_first_list = int(len(items) * 0.8)
        size_second_list = len(items) - size_first_list
        message_list = random.sample(items, size_first_list)
        super_list = list(set(items) - set(message_list))
        user_train_message[user] = message_list
        user_train_supervision[user] = super_list
        user_train_neg[user] = random.sample(range(91599), size_second_list)
with open(file_path1, 'wb') as file:
    pickle.dump(user_train_message, file)
with open(file_path2, 'wb') as file:
    pickle.dump(user_train_supervision, file)
with open(file_path3, 'wb') as file:
    pickle.dump(user_train_neg, file)

In [None]:
len(user_train_message)

In [None]:
from torch_geometric.loader import NeighborLoader
item_train_data = Data(x = X, edge_index=item_train_edge_index, y = None).to(device)
loader = NeighborLoader(
    item_train_data,
    num_neighbors=[8] * 3,
    batch_size=128
)

In [None]:
print(len(loader))
sampled_data = next(iter(loader))
print(sampled_data)
# print(sampled_data.n_id)
# print(sampled_data.input_id)
print(sampled_data.n_id[sampled_data.input_id])

In [None]:
# from torch_geometric.nn import knn_graph
# edge_index_2nd_order = knn_graph(item_data.x, 10, loop=False)

In [None]:
model = GATModel(in_channels = X.shape[1], hidden_channels=256, out_channels=128).to(device)
state_dict = torch.load('saved_model.pth')
model.load_state_dict(state_dict)

In [None]:
final_embeddings = torch.zeros((91599, 128))

In [None]:
from collections import defaultdict
with torch.no_grad():
    for batch_idx, data in enumerate(loader):
        item_embed = model(data.x, data.edge_index)
        maps = defaultdict(int)
        for i in range(len(data.n_id)):
            maps[data.n_id[i]] = i
        map_back = []
        for i in range(len(data.input_id)):
            map_back.append(maps[data.input_id])
        spec_embed = item_embed[map_back]
        final_embeddings[data.input_id] = spec_embed.to('cpu')
        del item_embed
        print(batch_idx)

In [None]:
torch.save(final_embeddings, 'final_embeddings.pt')

In [None]:
item_embeddings = torch.load('final_embeddings.pt')

In [None]:
import pickle
source_nodes = []
target_nodes = []
file_path1 = 'user_train_message.pkl'
with open(file_path1, 'rb') as file:
    loaded_dict = pickle.load(file)
for user in loaded_dict:
    items = loaded_dict[user]
    for item in items:
        source_nodes.append(user)
        target_nodes.append(item)

In [None]:
import torch
user_train_message_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)
num_users = 52643
num_items = 91599
adj = torch.zeros((num_users, num_items))
adj[user_train_message_index[0], user_train_message_index[1]] = 1

In [None]:
adj.to(device)

In [None]:
user_degree = adj.sum(dim = 1, keepdim = True)
item_degree = adj.sum(dim = 0, keepdim = True)
user_degree = torch.where(user_degree == 0, torch.tensor(1), user_degree)
item_degree = torch.where(item_degree == 0, torch.tensor(1), item_degree)
adj_norm = torch.divide(torch.divide(adj, torch.sqrt(user_degree)),torch.sqrt(item_degree))
# adj_norm_3 = torch.matmul(torch.matmul(adj_norm, adj_norm.t()),adj_norm)
# adj_norm_5 = torch.matmul(torch.matmul(adj_norm_3, adj_nom.t()),adj_norm)

In [None]:
torch.save(adj_norm, 'adj_norm.pt')

In [None]:
user_embeddings = torch.matmul(adj_norm, item_embeddings)
# user_embeddings_3 = torch.matmul(adj_norm_3, item_embeddings)
# user_embeddings_5 = torch.matmul(adj_norm_5, item_embeddings)
# final_user_embedding = torch.cat((user_embeddings, user_embeddings_3), dim = 1)
torch.save(user_embeddings, 'final_user_embedding.pt')

#### Model Architecture 

In [None]:
import pickle
file_path1 = 'user_train_super.pkl'
with open(file_path1, 'rb') as file:
    train_pos = pickle.load(file)
file_path2 = 'user_train_neg.pkl'
with open(file_path2, 'rb') as file:
    train_neg = pickle.load(file)
user_embeddings = torch.load('final_user_embedding.pt')
item_embeddings = torch.load('final_embeddings.pt')

In [None]:
import torch
class LinkPredict(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels = 256, out_channels = 1):
        super(LinkPredict, self).__init__()
        self.linear1 = torch.nn.Linear(in_channels, hidden_channels)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x


In [None]:
model = LinkPredict(in_channels = user_embeddings.shape[1] * 2).to(device)
model

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0025)

##### Loss Function (Bayesian Personalized Ranking)

In [None]:
from torch.nn.modules.loss import _Loss
class BPRLoss(_Loss):
    def __init__(self, lambda_reg: float = 0.0001, **kwargs):
        super().__init__(None, None, "sum", **kwargs)
        self.lambda_reg = lambda_reg
    def forward(self, positives, negatives, parameters, num_users):
        log_prob = torch.nn.functional.logsigmoid(positives - negatives).mean()
        # reg = 0
        # flat_parameters = torch.cat([p.view(-1) for p in parameters])
        # if self.lambda_reg != 0:
        #     reg = self.lambda_reg * flat_parameters.norm(p = 2).pow(2)
        #     reg = reg/positives.size(0)
        return -log_prob # + reg/num_users
loss_fn = BPRLoss()

#### Model Training

In [None]:
def train(model, users, optimizer, user_embedding, item_embedding, train_pos, train_neg, loss_fn):
    loss = 0
    model.train()
    optimizer.zero_grad()
    num_users = len(users)
    for user in users:
        pos_embed = item_embeddings[train_pos[user]]
        neg_embed = item_embeddings[train_neg[user]]
        user_embed = user_embedding[user, :]
        pos_embeds = torch.cat((user_embed.expand(pos_embed.size(0), -1), pos_embed), dim = 1)
        neg_embeds = torch.cat((user_embed.expand(neg_embed.size(0), -1), neg_embed), dim = 1)       
        pos_values = model(pos_embeds.to(device))
        neg_values = model(neg_embeds.to(device))

        loss += -(torch.nn.functional.logsigmoid(pos_values - neg_values).mean())
        # loss += loss_fn(pos_values, neg_values, list(model.parameters()), num_users).to(device)
    loss.backward()
    optimizer.step()
    return loss

#### Model Testing

In [None]:
# from sklearn.metrics import roc_auc_score

# @torch.no_grad()
# def test(model, data):
#     model.eval()
#     out = model(data.x, data.edge_index)  # use `edge_index` to perform message passing
#     out = compute_similarity(out, data.edge_label_index).view(-1).sigmoid()  # use `edge_label_index` to compute the loss
#     return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

#### Model Run

In [None]:
# Running MODEL
epochs = 10
for epoch in range(1, epochs + 1):
    loss = train(model, list(range(52643)), optimizer, user_embeddings, item_embeddings, train_pos, train_neg, loss_fn)
    # valid_auc = test(model, val_data)
    # test_auc = test(model, test_data)
    # if valid_auc > best_val_auc:
    #     best_val_auc = valid_auc
    #     final_test_auc = test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    # train_curve.append(loss)
    # valid_curve.append(valid_auc)
    # test_curve.append(test_auc)
    
# curves = {"train": train_curve, "valid": valid_curve, "test": test_curve}
# print('Best Validation Metric: {}'.format(best_val_auc)
# print('Test Metric: {}'.format(final_test_auc)

# # plot
# plot_curves(curves)