In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import torch

from models.light_gcn import LightGCNStack
from utils.light_gcn_utils import bpr_loss, evaluate, build_user_item_interactions, get_positive_negative_ratings, recall_at_k, precision_at_k

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from utils.preprocess import load_dataset

# Load the dataset
dataset = 'goodbooks-10k'
users, items, train_ratings, test_ratings, items_features_tensor, user_features_tensor = load_dataset(dataset)

In [3]:
num_users = users['userid'].nunique()
num_items = items['itemid'].nunique()
print(f"num_users: {num_users}, num_items: {num_items}")

num_users: 2913, num_items: 5700


In [4]:
# Create edge index for bipartite graph for train set
train_user_ids = train_ratings['userid'].values
train_item_ids = train_ratings['itemid'].values + num_users 
train_edge_index = torch.tensor([train_user_ids, train_item_ids], dtype=torch.long)

# Create edge index for bipartite graph for test set
test_user_ids = test_ratings['userid'].values  
test_item_ids = test_ratings['itemid'].values + num_users  
test_edge_index = torch.tensor([test_user_ids, test_item_ids], dtype=torch.long)

  train_edge_index = torch.tensor([train_user_ids, train_item_ids], dtype=torch.long)


In [5]:
train_user_item_dict = build_user_item_interactions(train_ratings)
test_user_item_dict = build_user_item_interactions(test_ratings)

In [6]:
positive_threshold = 5
negative_threshold = 3

In [7]:
train_user_ratings = get_positive_negative_ratings(train_user_item_dict, positive_threshold, negative_threshold)
test_user_ratings = get_positive_negative_ratings(test_user_item_dict, positive_threshold, negative_threshold)

In [8]:
for i, user in enumerate(train_user_ratings):
    train_user_ratings[i] = (user[0], [item + num_users for item in user[1]], [item + num_users for item in user[2]])

for i, user in enumerate(test_user_ratings):
    test_user_ratings[i] = (user[0], [item + num_users for item in user[1]], [item + num_users for item in user[2]])

In [9]:
embedding_dim = 64
num_nodes = num_users + num_items
no_user_features = user_features_tensor.size(1)
no_item_features = items_features_tensor.size(1)

num_layers = 6
num_epochs = 50
learning_rate = 0.0005
k = 10

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

user_features_tensor = user_features_tensor.to(device)
items_features_tensor = items_features_tensor.to(device)
train_edge_index = train_edge_index.to(device)
test_edge_index = test_edge_index.to(device)

model = LightGCNStack(num_nodes, no_user_features, no_item_features, embedding_dim, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
embeddings = model(user_features_tensor, items_features_tensor, train_edge_index)
recall = recall_at_k(train_user_ratings, embeddings, k=k, device=device)
precision = precision_at_k(train_user_ratings, embeddings, k=k, device=device)

print("Base recall:", recall)
print("Base precision:", precision)

Base recall: 0.09321948188924047
Base precision: 0.22871322226160937


In [12]:
calc_metrics_every = 1

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0
    pbar = tqdm(train_user_ratings, desc=f'Epoch {epoch+1}/{num_epochs}')
    embeddings = model(user_features_tensor, items_features_tensor, train_edge_index)

    for user_id, pos_items, neg_items in pbar:
        no_sample = min(len(pos_items), len(neg_items))
        users = torch.tensor([user_id] * no_sample, dtype=torch.long).to(device)
        pos_samples = random.sample(pos_items, no_sample)
        pos_samples = torch.tensor(pos_samples, dtype=torch.long).to(device)
        neg_samples = random.sample(neg_items, no_sample)
        neg_samples = torch.tensor(neg_samples, dtype=torch.long).to(device)
        
        loss = bpr_loss(embeddings, users, pos_samples, neg_samples)
        total_loss += loss
        num_batches += 1
        avg_loss = total_loss.item() / num_batches

        pbar.set_postfix({'Avg Loss': f'{avg_loss:.4f}'})

    total_loss.backward()
    optimizer.step()
    
    if (epoch + 1) % calc_metrics_every == 0:
        recall = recall_at_k(train_user_ratings, embeddings, k=k, device=device)
        precision = precision_at_k(train_user_ratings, embeddings, k=k, device=device)
        avg_loss = total_loss / len(train_user_ratings)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Recall@{k}: {recall:.4f}, Precision@{k}: {precision:.4f}')
    else:
        avg_loss = total_loss / len(train_user_ratings)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')
    

Epoch 1/50: 100%|██████████| 2830/2830 [00:05<00:00, 516.58it/s, Avg Loss=0.6936]


Epoch 1/50, Loss: 0.6936, Recall@10: 0.0932, Precision@10: 0.2287


Epoch 2/50: 100%|██████████| 2830/2830 [00:06<00:00, 442.55it/s, Avg Loss=0.6920]


Epoch 2/50, Loss: 0.6920, Recall@10: 0.0948, Precision@10: 0.2327


Epoch 3/50: 100%|██████████| 2830/2830 [00:05<00:00, 564.84it/s, Avg Loss=0.6915]


Epoch 3/50, Loss: 0.6915, Recall@10: 0.0957, Precision@10: 0.2349


Epoch 4/50: 100%|██████████| 2830/2830 [00:05<00:00, 524.38it/s, Avg Loss=0.6906]


Epoch 4/50, Loss: 0.6906, Recall@10: 0.0972, Precision@10: 0.2384


Epoch 5/50: 100%|██████████| 2830/2830 [00:05<00:00, 472.83it/s, Avg Loss=0.6895]


Epoch 5/50, Loss: 0.6895, Recall@10: 0.0985, Precision@10: 0.2417


Epoch 6/50: 100%|██████████| 2830/2830 [00:05<00:00, 538.66it/s, Avg Loss=0.6886]


Epoch 6/50, Loss: 0.6886, Recall@10: 0.0999, Precision@10: 0.2452


Epoch 7/50: 100%|██████████| 2830/2830 [00:05<00:00, 526.90it/s, Avg Loss=0.6874]


Epoch 7/50, Loss: 0.6874, Recall@10: 0.1006, Precision@10: 0.2469


Epoch 8/50: 100%|██████████| 2830/2830 [00:05<00:00, 545.30it/s, Avg Loss=0.6870]


Epoch 8/50, Loss: 0.6870, Recall@10: 0.1019, Precision@10: 0.2499


Epoch 9/50: 100%|██████████| 2830/2830 [00:05<00:00, 536.58it/s, Avg Loss=0.6851]


Epoch 9/50, Loss: 0.6851, Recall@10: 0.1029, Precision@10: 0.2525


Epoch 10/50: 100%|██████████| 2830/2830 [00:04<00:00, 567.89it/s, Avg Loss=0.6842]


Epoch 10/50, Loss: 0.6842, Recall@10: 0.1039, Precision@10: 0.2550


Epoch 11/50: 100%|██████████| 2830/2830 [00:05<00:00, 509.63it/s, Avg Loss=0.6833]


Epoch 11/50, Loss: 0.6833, Recall@10: 0.1045, Precision@10: 0.2563


Epoch 12/50: 100%|██████████| 2830/2830 [00:05<00:00, 559.33it/s, Avg Loss=0.6824]


Epoch 12/50, Loss: 0.6824, Recall@10: 0.1050, Precision@10: 0.2577


Epoch 13/50: 100%|██████████| 2830/2830 [00:04<00:00, 582.43it/s, Avg Loss=0.6806]


Epoch 13/50, Loss: 0.6806, Recall@10: 0.1052, Precision@10: 0.2582


Epoch 14/50: 100%|██████████| 2830/2830 [00:04<00:00, 568.35it/s, Avg Loss=0.6803]


Epoch 14/50, Loss: 0.6803, Recall@10: 0.1056, Precision@10: 0.2591


Epoch 15/50: 100%|██████████| 2830/2830 [00:04<00:00, 572.97it/s, Avg Loss=0.6787]


Epoch 15/50, Loss: 0.6787, Recall@10: 0.1054, Precision@10: 0.2587


Epoch 16/50: 100%|██████████| 2830/2830 [00:04<00:00, 577.45it/s, Avg Loss=0.6772]


Epoch 16/50, Loss: 0.6772, Recall@10: 0.1056, Precision@10: 0.2592


Epoch 17/50: 100%|██████████| 2830/2830 [00:05<00:00, 479.73it/s, Avg Loss=0.6776]


Epoch 17/50, Loss: 0.6776, Recall@10: 0.1062, Precision@10: 0.2605


Epoch 18/50: 100%|██████████| 2830/2830 [00:05<00:00, 541.55it/s, Avg Loss=0.6760]


Epoch 18/50, Loss: 0.6760, Recall@10: 0.1064, Precision@10: 0.2610


Epoch 19/50: 100%|██████████| 2830/2830 [00:05<00:00, 550.53it/s, Avg Loss=0.6759]


Epoch 19/50, Loss: 0.6759, Recall@10: 0.1064, Precision@10: 0.2610


Epoch 20/50: 100%|██████████| 2830/2830 [00:04<00:00, 591.11it/s, Avg Loss=0.6749]


Epoch 20/50, Loss: 0.6749, Recall@10: 0.1067, Precision@10: 0.2618


Epoch 21/50: 100%|██████████| 2830/2830 [00:06<00:00, 461.63it/s, Avg Loss=0.6731]


Epoch 21/50, Loss: 0.6731, Recall@10: 0.1068, Precision@10: 0.2620


Epoch 22/50: 100%|██████████| 2830/2830 [00:05<00:00, 542.31it/s, Avg Loss=0.6729]


Epoch 22/50, Loss: 0.6729, Recall@10: 0.1071, Precision@10: 0.2628


Epoch 23/50: 100%|██████████| 2830/2830 [00:06<00:00, 459.86it/s, Avg Loss=0.6736]


Epoch 23/50, Loss: 0.6736, Recall@10: 0.1073, Precision@10: 0.2634


Epoch 24/50: 100%|██████████| 2830/2830 [00:05<00:00, 474.49it/s, Avg Loss=0.6725]


Epoch 24/50, Loss: 0.6725, Recall@10: 0.1071, Precision@10: 0.2629


Epoch 25/50: 100%|██████████| 2830/2830 [00:05<00:00, 514.22it/s, Avg Loss=0.6718]


In [None]:
total_loss = 0
num_batches = 0
pbar = tqdm(test_user_ratings)

embeddings = model(user_features_tensor, items_features_tensor, test_edge_index)

for user_id, pos_items, neg_items in pbar:
    no_sample = min(len(pos_items), len(neg_items))
    users = torch.tensor([user_id] * no_sample, dtype=torch.long).to(device)
    pos_samples = random.sample(pos_items, no_sample)
    pos_samples = torch.tensor(pos_samples, dtype=torch.long).to(device)
    neg_samples = random.sample(neg_items, no_sample)
    neg_samples = torch.tensor(neg_samples, dtype=torch.long).to(device)
    loss = bpr_loss(embeddings, users, pos_samples, neg_samples)
    total_loss += loss
    num_batches += 1
    avg_loss = total_loss / num_batches

    # Update progress bar with average loss
    pbar.set_postfix({'Avg Loss': f'{avg_loss:.4f}'})
    
recall = recall_at_k(train_user_ratings, embeddings, k=k, device=device)
precision = precision_at_k(train_user_ratings, embeddings, k=k, device=device)
avg_loss = total_loss / len(test_user_ratings)
print(f'Test Loss: {avg_loss:.4f}, Test Recall@{k}: {recall:.4f}, Test Precision@{k}: {precision:.4f}')

100%|██████████| 2653/2653 [00:08<00:00, 326.42it/s, Avg Loss=0.6775]


Test Loss: 0.6775, Test Recall@10: 0.1136, Test Precision@10: 0.2788
