In [None]:
import torch
import pickle
from torch import nn, optim
from torch_geometric.nn import LGConv
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE

import utils
from models import (
    LightGCN,
    LightGCNPlus0,
    LightGCNPlus1,
    LightGCNPlus2,
    LightGCNPlus3
)

In [None]:
plt.style.use('seaborn-v0_8')

In [None]:
DATASET_NAME = 'book_crossing'

In [None]:
with open(f'datasets/{DATASET_NAME}_dataset.bin', 'rb') as f:
    dataset = pickle.load(f)

In [None]:
users_features = dataset['users_features']
items_features = dataset['items_features']
train_edge_index = dataset['train_edge_index']
val_edge_index = dataset['val_edge_index']

In [None]:
K = 20
LAMBDA = 1e-6
BATCH_SIZE = 1024
N_BATCH = int(train_edge_index.shape[1]/BATCH_SIZE)
N_EPOCHS = 30
EMBEDDING_DIMENSION = 64

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_edge_index = train_edge_index.to(device)
val_edge_index = val_edge_index.to(device)
users_features = users_features.to(device)
items_features = items_features.to(device)

In [None]:
num_users = users_features.shape[0]
num_items = items_features.shape[0]

In [None]:
USER_PROJ = torch.rand(users_features.shape[1], EMBEDDING_DIMENSION)
ITEM_PROJ = torch.rand(items_features.shape[1], EMBEDDING_DIMENSION)

In [None]:
EMB_USERS = nn.Embedding(num_embeddings=num_users, embedding_dim=EMBEDDING_DIMENSION)
EMB_ITEMS = nn.Embedding(num_embeddings=num_items, embedding_dim=EMBEDDING_DIMENSION)

nn.init.normal_(EMB_USERS.weight, std=0.01)
nn.init.normal_(EMB_ITEMS.weight, std=0.01);

In [None]:
results_list = []
for i in range(5):
    user_proj = torch.rand(users_features.shape[1], EMBEDDING_DIMENSION)
    item_proj = torch.rand(items_features.shape[1], EMBEDDING_DIMENSION)
    model = LightGCNPlus0(EMB_USERS, EMB_ITEMS, users_features, items_features, user_proj, item_proj).to(device)
    result = utils.training_routine(
        model,
        f'LightGCN+ (solution 0) [{i}]',
        DATASET_NAME,
        train_edge_index,
        val_edge_index,
        N_EPOCHS,
        N_BATCH,
        BATCH_SIZE,
        LAMBDA,
        K
    )
    results_list.append(result)


In [None]:
%matplotlib inline

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

for i, result in enumerate(results_list):
    label = f'LightGCN+ (solution 0) [run {i}]'
    ax[0].plot(result['val_recall'], label=label)
    ax[1].plot(result['val_ndcg'], label=label)
    ax[0].set_title('Validation Recall@20', fontweight='bold')
    ax[1].set_title('Validation NDCG@20', fontweight='bold')
    for i in range(2):
        ax[i].set_xlabel('epoch', fontweight='bold')
        ax[i].set_ylabel('value', fontweight='bold')
        ax[i].legend()
plt.tight_layout()

In [None]:
MODELS = {
    'LightGCN': LightGCN(EMB_USERS, EMB_ITEMS),
    'LightGCN+_scenario_0': LightGCNPlus0(EMB_USERS, EMB_ITEMS, users_features, items_features, USER_PROJ, ITEM_PROJ),
    'LightGCN+_scenario_1': LightGCNPlus1(EMB_USERS, EMB_ITEMS, users_features, items_features, USER_PROJ, ITEM_PROJ),
    'LightGCN+_scenario_2': LightGCNPlus2(EMB_USERS, EMB_ITEMS, users_features, items_features, USER_PROJ, ITEM_PROJ),
    'LightGCN+_scenario_3': LightGCNPlus3(EMB_USERS, EMB_ITEMS, users_features, items_features, USER_PROJ, ITEM_PROJ),
}

In [None]:
MODELS['LightGCN'].load_state_dict(
    torch.load(f'models/final/{DATASET_NAME}_LightGCN.bin', weights_only=False)
    )

In [None]:
MODELS['LightGCN+_scenario_0'].load_state_dict(
    torch.load(f'models/final/{DATASET_NAME}_LightGCN+_scenario_0.bin', weights_only=False)
    )

In [None]:
MODELS['LightGCN+_scenario_1'].load_state_dict(
    torch.load(f'models/final/{DATASET_NAME}_LightGCN+_scenario_1.bin', weights_only=False)
    )

In [None]:
MODELS['LightGCN+_scenario_2'].load_state_dict(
    torch.load(f'models/final/{DATASET_NAME}_LightGCN+_scenario_2.bin', weights_only=False)
    )

In [None]:
MODELS['LightGCN+_scenario_3'].load_state_dict(
    torch.load(f'models/final/{DATASET_NAME}_LightGCN+_scenario_3.bin', weights_only=False)
    )

In [None]:
embeddings_2d = {}

for model_name, model in MODELS.items():
    print(model_name)
    emb_users, emb_items = model.forward(val_edge_index)
    emb_users = emb_users.cpu().detach().numpy()
    emb_items = emb_items.cpu().detach().numpy()
    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    emb_users_2d = tsne.fit_transform(emb_users)
    emb_items_2d = tsne.fit_transform(emb_items)
    embeddings_2d[model_name] = (emb_users_2d, emb_items_2d)



In [None]:
model_names = list(MODELS.keys())

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(14, 6))

for i in range(5):
    x = embeddings_2d[model_names[i]][0][:, 0]
    y = embeddings_2d[model_names[i]][0][:, 1]
    axs[0, i].scatter(x, y, s=0.5, alpha=0.7)
    axs[0, i].set_title(model_names[i].replace('_', ' '), fontweight='bold', fontsize=12)
    for axis in ['x', 'y']:
        axs[0, i].tick_params(axis=axis, labelsize=10)
    if i == 0:
        axs[0, i].set_ylabel('user embeddings', fontweight='bold', fontsize=12)

for j in range(5):
    x = embeddings_2d[model_names[j]][1][:, 0]
    y = embeddings_2d[model_names[j]][1][:, 1]
    axs[1, j].scatter(x, y, s=0.5, alpha=0.7)
    for axis in ['x', 'y']:
        axs[1, j].tick_params(axis=axis, labelsize=10)
    if j == 0:
        axs[1, j].set_ylabel('item embeddings', fontweight='bold', fontsize=12)

plt.tight_layout()

In [None]:
projections_list = []

In [None]:
projections_list.append((
    MODELS['LightGCN+_scenario_0'].users_features_proj.cpu().detach().numpy(),
    MODELS['LightGCN+_scenario_0'].items_features_proj.cpu().detach().numpy()
    ))


In [None]:
model = MODELS['LightGCN+_scenario_1']
projections_list.append((
    (model.users_features_proj * nn.functional.sigmoid(model.alpha_users)).cpu().detach().numpy(),
    (model.items_features_proj * nn.functional.sigmoid(model.alpha_items)).cpu().detach().numpy()
    ))

In [None]:
model = MODELS['LightGCN+_scenario_2']
projections_list.append((
    (model.users_features_proj * nn.functional.sigmoid(model.users_coefs_vector)).cpu().detach().numpy(),
    (model.items_features_proj * nn.functional.sigmoid(model.items_coefs_vector)).cpu().detach().numpy()
    ))

In [None]:
model = MODELS['LightGCN+_scenario_3']
projections_list.append((
    model.user_proj(model.users_features).cpu().detach().numpy(),
    model.item_proj(model.items_features).cpu().detach().numpy()
    ))

In [None]:
plt.style.use('default')

In [None]:
user_min, user_max = float('inf'), float('-inf')
item_min, item_max = float('inf'), float('-inf')

for user_proj, item_proj in projections_list:
    user_min, user_max = min(user_min, user_proj.min()), max(user_max, user_proj.max())
    item_min, item_max = min(item_min, item_proj.min()), max(item_max, item_proj.max())

In [None]:
fig, axs = plt.subplots(2, 4, figsize=(14, 6))

for i in range(4):
    im1 = axs[0, i].imshow(projections_list[i][0].T, aspect='auto', cmap='cool')
    axs[0, i].set_title(model_names[i+1].replace('_', ' '), fontweight='bold', fontsize=12)
    if i == 0:
        axs[0, i].set_ylabel('user projection matrix', fontweight='bold', fontsize=12)
    fig.colorbar(im1, ax=axs[0, i], fraction=0.046, pad=0.04)

for j in range(4):
    im1 = axs[1, j].imshow(projections_list[j][1].T, aspect='auto', cmap='cool')
    if j == 0:
        axs[1, j].set_ylabel('item projection matrix', fontweight='bold', fontsize=12)
    fig.colorbar(im1, ax=axs[1, j], fraction=0.046, pad=0.04)
    
plt.grid(False)
plt.tight_layout()

In [None]:
plt.imshow(projections_list[0][0].T[:, :1000], aspect='auto', cmap='viridis')