# **GCN-based Heterogeneous Graph Recommendation System**

In [1]:
!pip install -q torch torch-geometric

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.3/1.3 MB[0m [31m54.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m33.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [16]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import ast
from collections import defaultdict
import zipfile
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import torch
from torch_geometric.data import HeteroData
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import HeteroConv, SAGEConv, RGCNConv, Linear, GraphConv
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    ndcg_score,
)
from typing import Dict, Any, List
from tqdm.auto import tqdm

sns.set_palette(sns.color_palette('CMRmap'))

path = 'graph_data'

with zipfile.ZipFile('graph_data.zip', 'r') as zf:
    zf.extractall(path)


## **Data Preparation**

Load all node features and edge lists for heterogeneous graph construction

In [3]:
df_connections_user_stats = pd.read_csv(f'{path}/user_stats.csv')
df_connections_movie_features = pd.read_csv(f'{path}/movie_features.csv')
df_connections_movie_genre = pd.read_csv(f'{path}/movie_genre_edges.csv')
df_connections_movie_actor = pd.read_csv(f'{path}/movie_actor_edges.csv')
df_connections_movie_director = pd.read_csv(f'{path}/movie_director_edges.csv')
df_connections_user_movie = pd.read_csv(f'{path}/user_movie_edges.csv')
df_connections_actors = pd.read_csv(f'{path}/actors.csv')
df_connections_directors = pd.read_csv(f'{path}/directors.csv')
df_connections_genres = pd.read_csv(f'{path}/genres.csv')


In [4]:
print(f'Movies: {df_connections_movie_features.shape}')
print(f'Users: {df_connections_user_stats.shape}')
print(f'Directors: {df_connections_directors.shape}')
print(f'Actors: {df_connections_actors.shape}')
print(f'Genres: {df_connections_genres.shape}')
print(f'Movie-genre links: {df_connections_movie_genre.shape}')
print(f'Movie-actor links: {df_connections_movie_actor.shape}')
print(f'Movie-director links: {df_connections_movie_director.shape}')
print(f'User-movie links: {df_connections_user_movie.shape}')


Movies: (16409, 4)
Users: (75680, 4)
Directors: (1741, 3)
Actors: (2517, 4)
Genres: (20, 2)
Movie-genre links: (38491, 2)
Movie-actor links: (48106, 4)
Movie-director links: (10570, 2)
User-movie links: (2505904, 3)


Create mapping dictionaries: original IDs -> consecutive indices for all node types

In [5]:
movie_id_to_idx = {movie_id: idx for idx, movie_id in enumerate(df_connections_movie_features['movie_id'].unique())}

user_id_to_idx = {user_id: idx for idx, user_id in enumerate(df_connections_user_stats['user_id'].unique())}

director_id_to_idx = {director_id: idx for idx, director_id in
                      enumerate(df_connections_directors['director_id'].unique())}

actor_id_to_idx = {actor_id: idx for idx, actor_id in enumerate(df_connections_actors['actor_id'].unique())}

genre_id_to_idx = {genre_id: idx for idx, genre_id in enumerate(df_connections_genres['genre_id'].unique())}

print(f'Number of movies: {len(movie_id_to_idx)}')
print(f'Number of users: {len(user_id_to_idx)}')
print(f'Number of directors: {len(director_id_to_idx)}')
print(f'Number of actors: {len(actor_id_to_idx)}')
print(f'Number of genres: {len(genre_id_to_idx)}')


Number of movies: 16409
Number of users: 75680
Number of directors: 1741
Number of actors: 2517
Number of genres: 20


## **Subgraph Construction Pipeline**

This function builds a heterogeneous graph subgraph centered around selected users, preparing data for GraphSAGE training.

### **Main Parameters**
- **`user_ids`**: Target users to build the subgraph around (central nodes)
- **`num_negatives`**: Number of negative samples for training (200 by default) - affects training data balance
- **DataFrames**: Feature tables and edge lists for all entity types

### **Pipeline Stages**

#### **1. Data Filtering & Propagation**
- **Step 1**: Filter user-related data based on `user_ids`
- **Step 2**: Propagate filtering to connected movies via user-movie edges
- **Step 3**: Further propagate to related entities (actors, directors, genres)
- **Result**: A connected subgraph containing only relevant entities

#### **2. Feature Engineering & Normalization**
- **Numerical features**: Standardized using `StandardScaler` for:
  - Movie metrics (`vote_average`, `vote_count`, `popularity`)
  - User statistics (`num_ratings`, `avg_rating`, `activity_days`)
  - Actor/director metrics (`total_films`, `avg_order`)
- **Categorical encoding**:
  - Gender (directors/actors): One-hot encoded (3 categories: 0,1,2)
  - Genres: One-hot encoded per unique genre ID

#### **3. ID Remapping & Graph Initialization**
- Creates sequential index mappings for each node type
- Initializes `HeteroData` object with:
  - `num_nodes` defined for each entity type
  - Maintains original ID to index mapping for reference

#### **4. Node Feature Assignment**
- For each entity type, extracts normalized features
- Maps features to corresponding node indices

#### **5. Edge Construction**
- Builds five relationship types with optional weights:
  1. `('user', 'rates', 'movie')` - with rating weights
  2. `('movie', 'has_genre', 'genre')`
  3. `('movie', 'has_director', 'director')`
  4. `('movie', 'has_actor', 'actor')` - with role weight

#### **6. Bidirectional Edge Addition**
- Creates reverse edges for message passing in both directions
- Reverse edges follow pattern: `('genre', 'rev_has_genre', 'movie')`

### **Output Structure**
Returns a tuple containing:
1. **`data`**: PyG HeteroData object
2. **`mappings`**: ID-to-index dictionaries for all entity types
3. **`df_user_movie_sub`**: Filtered user-movie interactions for training

In [6]:
def build_subgraph_for_users(
        user_ids,
        df_user_movie,
        df_movie_features,
        df_user_stats,
        df_movie_genre,
        df_movie_actor,
        df_movie_director,
        df_actors,
        df_directors,
        df_genres,
        num_negatives=200
):
    selected_users = set(user_ids)
    df_user_stats_sub = df_user_stats[df_user_stats['user_id'].isin(selected_users)].copy()

    df_user_movie_sub = df_user_movie[df_user_movie['user_id'].isin(selected_users)].copy()
    selected_movie_ids = set(df_user_movie_sub['movie_id'].unique())

    df_movie_features_sub = df_movie_features[df_movie_features['movie_id'].isin(selected_movie_ids)].copy()
    df_movie_genre_sub = df_movie_genre[df_movie_genre['movie_id'].isin(selected_movie_ids)].copy()
    df_movie_actor_sub = df_movie_actor[df_movie_actor['movie_id'].isin(selected_movie_ids)].copy()
    df_movie_director_sub = df_movie_director[df_movie_director['movie_id'].isin(selected_movie_ids)].copy()

    selected_genre_ids = set(df_movie_genre_sub['genre_id'].unique())
    selected_actor_ids = set(df_movie_actor_sub['actor_id'].unique())
    selected_director_ids = set(df_movie_director_sub['director_id'].unique())

    df_genres_sub = df_genres[df_genres['genre_id'].isin(selected_genre_ids)].copy()
    df_actors_sub = df_actors[df_actors['actor_id'].isin(selected_actor_ids)].copy()
    df_directors_sub = df_directors[df_directors['director_id'].isin(selected_director_ids)].copy()

    print(f'\nSizes of filtered data:')
    print(f'Users: {df_user_stats_sub.shape[0]}')
    print(f'Movies: {df_movie_features_sub.shape[0]}')
    print(f'User-movie links: {df_user_movie_sub.shape[0]}')
    print(f'Actors: {df_actors_sub.shape[0]}')
    print(f'Directors: {df_directors_sub.shape[0]}')
    print(f'Genres: {df_genres_sub.shape[0]}')

    scaler = StandardScaler()

    movie_num_cols = ['vote_average', 'vote_count', 'popularity']
    if len(df_movie_features_sub) > 0:
        df_movie_features_sub[movie_num_cols] = scaler.fit_transform(
            df_movie_features_sub[movie_num_cols].fillna(0)
        )

    user_num_cols = ['num_ratings', 'avg_rating', 'activity_days']
    if len(df_user_stats_sub) > 0:
        df_user_stats_sub[user_num_cols] = scaler.fit_transform(
            df_user_stats_sub[user_num_cols].fillna(0)
        )

    if len(df_directors_sub) > 0:
        df_directors_sub['total_films'] = scaler.fit_transform(
            df_directors_sub[['total_films']].fillna(0)
        )

    actor_num_cols = ['total_films', 'avg_order']
    if len(df_actors_sub) > 0:
        df_actors_sub[actor_num_cols] = scaler.fit_transform(
            df_actors_sub[actor_num_cols].fillna(0)
        )

    gender_encoder = OneHotEncoder(sparse_output=False, categories=[[0, 1, 2]], handle_unknown='ignore')
    if len(df_directors_sub) > 0:
        gender_dir_encoded = gender_encoder.fit_transform(df_directors_sub[['gender']])
        gender_dir_df = pd.DataFrame(gender_dir_encoded,
                                     columns=['gender_0', 'gender_1', 'gender_2'])
        df_directors_sub = pd.concat([df_directors_sub.reset_index(drop=True),
                                      gender_dir_df], axis=1)
        df_directors_sub = df_directors_sub.drop('gender', axis=1)

    if len(df_actors_sub) > 0:
        gender_act_encoded = gender_encoder.transform(df_actors_sub[['gender']])
        gender_act_df = pd.DataFrame(gender_act_encoded,
                                     columns=['gender_0', 'gender_1', 'gender_2'])
        df_actors_sub = pd.concat([df_actors_sub.reset_index(drop=True),
                                   gender_act_df], axis=1)
        df_actors_sub = df_actors_sub.drop('gender', axis=1)

    if len(df_genres_sub) > 0:
        genre_encoder = OneHotEncoder(sparse_output=False)
        genre_encoded = genre_encoder.fit_transform(df_genres_sub[['genre_id']])
        genre_cols = [f'genre_{int(i)}' for i in genre_encoder.categories_[0]]
        genre_df = pd.DataFrame(genre_encoded, columns=genre_cols)
        df_genres_sub = pd.concat([df_genres_sub.reset_index(drop=True),
                                   genre_df], axis=1)
        df_genres_sub = df_genres_sub.drop('genre_name', axis=1)

    mappings = {}
    mappings['user'] = {user_id: idx for idx, user_id in enumerate(df_user_stats_sub['user_id'].unique())}
    mappings['movie'] = {movie_id: idx for idx, movie_id in enumerate(df_movie_features_sub['movie_id'].unique())}
    mappings['actor'] = {actor_id: idx for idx, actor_id in enumerate(df_actors_sub['actor_id'].unique())}
    mappings['director'] = {director_id: idx for idx, director_id in
                            enumerate(df_directors_sub['director_id'].unique())}
    mappings['genre'] = {genre_id: idx for idx, genre_id in enumerate(df_genres_sub['genre_id'].unique())}

    data = HeteroData()
    data['user'].num_nodes = len(mappings['user'])
    data['movie'].num_nodes = len(mappings['movie'])
    data['actor'].num_nodes = len(mappings['actor'])
    data['director'].num_nodes = len(mappings['director'])
    data['genre'].num_nodes = len(mappings['genre'])

    def add_node_features(df, node_type, id_col, mapping, data_obj):
        if len(df) == 0:
            return

        features_list = []
        feature_cols = [col for col in df.columns if col != id_col]
        sorted_ids = sorted(mapping.keys(), key=lambda x: mapping[x])
        for node_id in sorted_ids:
            row = df[df[id_col] == node_id]
            if len(row) > 0:
                features = row[feature_cols].values[0].astype(np.float32)
                features_list.append(features)
            else:
                features_list.append(np.zeros(len(feature_cols), dtype=np.float32))
        data_obj[node_type].x = torch.tensor(np.array(features_list), dtype=torch.float32)

    add_node_features(df_user_stats_sub, 'user', 'user_id', mappings['user'], data)
    add_node_features(df_movie_features_sub, 'movie', 'movie_id', mappings['movie'], data)
    add_node_features(df_actors_sub, 'actor', 'actor_id', mappings['actor'], data)
    add_node_features(df_directors_sub, 'director', 'director_id', mappings['director'], data)
    add_node_features(df_genres_sub, 'genre', 'genre_id', mappings['genre'], data)

    def add_edges(df, source_type, target_type, source_col, target_col,
                  edge_type_name, data_obj, mappings_dict, weight_col=None):
        if len(df) == 0:
            return

        edges = []
        weights = [] if weight_col is not None else None
        for _, row in df.iterrows():
            source_idx = mappings_dict[source_type].get(row[source_col])
            target_idx = mappings_dict[target_type].get(row[target_col])
            if source_idx is not None and target_idx is not None:
                edges.append([source_idx, target_idx])
                if weight_col is not None:
                    weights.append(row[weight_col])

        if edges:
            edge_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()
            data_obj[source_type, edge_type_name, target_type].edge_index = edge_tensor
            if weights is not None:
                data_obj[source_type, edge_type_name, target_type].edge_weight = torch.tensor(
                    weights, dtype=torch.float32
                )

    add_edges(df_user_movie_sub, 'user', 'movie', 'user_id', 'movie_id',
              'rates', data, mappings, weight_col='rating')
    add_edges(df_movie_genre_sub, 'movie', 'genre', 'movie_id', 'genre_id',
              'has_genre', data, mappings)
    add_edges(df_movie_director_sub, 'movie', 'director', 'movie_id', 'director_id',
              'has_director', data, mappings)
    add_edges(df_movie_actor_sub, 'movie', 'actor', 'movie_id', 'actor_id',
              'has_actor', data, mappings, weight_col='weight')

    edge_types_to_reverse = [
        ('movie', 'has_genre', 'genre'),
        ('movie', 'has_director', 'director'),
        ('movie', 'has_actor', 'actor'),
        ('user', 'rates', 'movie')
    ]

    for src, rel, dst in edge_types_to_reverse:
        if hasattr(data[src, rel, dst], 'edge_index'):
            reverse_edges = data[src, rel, dst].edge_index[[1, 0]]
            rev_rel = f'rev_{rel}'
            data[dst, rev_rel, src].edge_index = reverse_edges
            if hasattr(data[src, rel, dst], 'edge_weight'):
                data[dst, rev_rel, src].edge_weight = data[src, rel, dst].edge_weight

    print(f'\nConstructed graph with the following structure:')
    print(f'Nodes: {data.node_types}')
    print(f'Edges: {data.edge_types}')
    print(f'\nChecking node features:')

    for node_type in data.node_types:
        if hasattr(data[node_type], 'x'):
            print(f'  {node_type}: {data[node_type].x.shape}')

    print(f'\nChecking edges:')
    for edge_type in data.edge_types:
        if hasattr(data[edge_type], 'edge_index'):
            print(f'  {edge_type}: {data[edge_type].edge_index.shape[1]} edges')
    return data, mappings, df_user_movie_sub


## **Active User Selection**

This function identifies the most active users by rating count, ensuring the subgraph contains users with sufficient interaction data for reliable embeddings. Returns the top N user IDs sorted by activity level.

In [7]:
def get_most_active_users(df_user_movie, df_user_stats, n_users=1000):
    user_activity = df_user_movie.groupby('user_id').size().reset_index(name='rating_count')
    user_activity = user_activity.sort_values('rating_count', ascending=False)
    active_user_ids = user_activity.head(n_users)['user_id'].tolist()

    print(f'Selected {len(active_user_ids)} most active users')
    print(f'Minimum number of ratings: {user_activity.head(n_users)['rating_count'].min()}')
    print(f'Maximum number of ratings: {user_activity.head(n_users)['rating_count'].max()}')
    print(f'Average number of ratings: {user_activity.head(n_users)['rating_count'].mean():.2f}')

    return active_user_ids


## **Stratified Train/Val/Test Split**

Performs user-level stratified splitting while maintaining the positive (rating ≥4) vs negative interaction ratio for each user. This ensures consistent label distribution across all splits and prevents data leakage between users.

In [8]:
def split_data_stratified(df_user_movie, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-9
    train_dfs, val_dfs, test_dfs = [], [], []

    for user_id, group in df_user_movie.groupby('user_id'):
        positive = group[group['rating'] >= 4].copy()
        negative = group[group['rating'] < 4].copy()

        if len(positive) > 0:
            pos_train, pos_temp = train_test_split(
                positive, test_size=(val_ratio + test_ratio) / (val_ratio + test_ratio + train_ratio),
                random_state=random_state
            )
            pos_val, pos_test = train_test_split(
                pos_temp, test_size=test_ratio / (val_ratio + test_ratio),
                random_state=random_state
            )
        else:
            pos_train, pos_val, pos_test = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

        if len(negative) > 0:
            neg_train, neg_temp = train_test_split(
                negative, test_size=(val_ratio + test_ratio) / (val_ratio + test_ratio + train_ratio),
                random_state=random_state
            )
            neg_val, neg_test = train_test_split(
                neg_temp, test_size=test_ratio / (val_ratio + test_ratio),
                random_state=random_state
            )
        else:
            neg_train, neg_val, neg_test = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

        train_dfs.append(pd.concat([pos_train, neg_train]))
        val_dfs.append(pd.concat([pos_val, neg_val]))
        test_dfs.append(pd.concat([pos_test, neg_test]))

    train_df = pd.concat(train_dfs, ignore_index=True)
    val_df = pd.concat(val_dfs, ignore_index=True)
    test_df = pd.concat(test_dfs, ignore_index=True)

    print(f'Train: {len(train_df)} links ({len(train_df) / len(df_user_movie) * 100:.1f}%)')
    print(f'Val: {len(val_df)} links ({len(val_df) / len(df_user_movie) * 100:.1f}%)')
    print(f'Test: {len(test_df)} links ({len(test_df) / len(df_user_movie) * 100:.1f}%)')

    return train_df, val_df, test_df


## **Graph Split Mask Assignment**

Maps the stratified user-movie splits back to the graph structure by creating binary masks for each edge in the `('user', 'rates', 'movie')` relation.

In [9]:
def add_split_masks_to_graph(graph, train_df, val_df, test_df, mappings):
    train_set = set(zip(train_df['user_id'], train_df['movie_id']))
    val_set = set(zip(val_df['user_id'], val_df['movie_id']))
    test_set = set(zip(test_df['user_id'], test_df['movie_id']))

    if ('user', 'rates', 'movie') in graph.edge_types:
        edge_index = graph['user', 'rates', 'movie'].edge_index
        num_edges = edge_index.shape[1]
        train_mask = torch.zeros(num_edges, dtype=torch.bool)
        val_mask = torch.zeros(num_edges, dtype=torch.bool)
        test_mask = torch.zeros(num_edges, dtype=torch.bool)

        for i in range(num_edges):
            user_idx = edge_index[0, i].item()
            movie_idx = edge_index[1, i].item()
            user_id = None
            movie_id = None

            for uid, uidx in mappings['user'].items():
                if uidx == user_idx:
                    user_id = uid
                    break

            for mid, midx in mappings['movie'].items():
                if midx == movie_idx:
                    movie_id = mid
                    break

            if user_id is not None and movie_id is not None:
                edge_tuple = (user_id, movie_id)

                if edge_tuple in train_set:
                    train_mask[i] = True
                elif edge_tuple in val_set:
                    val_mask[i] = True
                elif edge_tuple in test_set:
                    test_mask[i] = True

        graph['user', 'rates', 'movie'].train_mask = train_mask
        graph['user', 'rates', 'movie'].val_mask = val_mask
        graph['user', 'rates', 'movie'].test_mask = test_mask

        print(f'Masks added:')
        print(f'Train: {train_mask.sum().item()} edges')
        print(f'Val: {val_mask.sum().item()} edges')
        print(f'Test: {test_mask.sum().item()} edges')

    return graph


## **Optimized Split Mask Assignment**

Creates reverse ID mappings for efficient lookup and applies the same train/val/test masks to both directional and reverse user-movie edges, ensuring consistency during bidirectional message passing in GraphSAGE.

In [10]:
def add_split_masks_fast(graph, train_df, val_df, test_df, mappings):
    reverse_mappings = {}
    for node_type, mapping in mappings.items():
        reverse_mappings[node_type] = {v: k for k, v in mapping.items()}

    train_set = set(zip(train_df['user_id'], train_df['movie_id']))
    val_set = set(zip(val_df['user_id'], val_df['movie_id']))
    test_set = set(zip(test_df['user_id'], test_df['movie_id']))
    if ('user', 'rates', 'movie') in graph.edge_types:
        edge_index = graph['user', 'rates', 'movie'].edge_index
        num_edges = edge_index.shape[1]
        train_mask = torch.zeros(num_edges, dtype=torch.bool)
        val_mask = torch.zeros(num_edges, dtype=torch.bool)
        test_mask = torch.zeros(num_edges, dtype=torch.bool)

        for i in range(num_edges):
            user_idx = edge_index[0, i].item()
            movie_idx = edge_index[1, i].item()
            user_id = reverse_mappings['user'].get(user_idx)
            movie_id = reverse_mappings['movie'].get(movie_idx)

            if user_id is not None and movie_id is not None:
                edge_tuple = (user_id, movie_id)
                if edge_tuple in train_set:
                    train_mask[i] = True
                elif edge_tuple in val_set:
                    val_mask[i] = True
                elif edge_tuple in test_set:
                    test_mask[i] = True

        graph['user', 'rates', 'movie'].train_mask = train_mask
        graph['user', 'rates', 'movie'].val_mask = val_mask
        graph['user', 'rates', 'movie'].test_mask = test_mask

        if ('movie', 'rev_rates', 'user') in graph.edge_types:
            graph['movie', 'rev_rates', 'user'].train_mask = train_mask
            graph['movie', 'rev_rates', 'user'].val_mask = val_mask
            graph['movie', 'rev_rates', 'user'].test_mask = test_mask

    return graph


## **Graph Preparation**

Performs a complete pipeline: splits user-movie interactions into train/val/test sets, assigns corresponding masks to graph edges, and precomputes negative sampling candidates for each user by excluding their rated movies.

In [11]:
def prepare_graph_with_splits(subgraph, mappings, user_movie_edges, num_negatives=200):
    train_df, val_df, test_df = split_data_stratified(
        user_movie_edges,
        train_ratio=0.7,
        val_ratio=0.15,
        test_ratio=0.15
    )

    reverse_mappings = {}

    for node_type, mapping in mappings.items():
        reverse_mappings[node_type] = {v: k for k, v in mapping.items()}

    graph = add_split_masks_fast(subgraph, train_df, val_df, test_df, mappings)
    user_train_movies = {}

    for _, row in train_df.iterrows():
        user_id = row['user_id']
        movie_id = row['movie_id']

        if user_id not in user_train_movies:
            user_train_movies[user_id] = set()
        user_train_movies[user_id].add(movie_id)

    all_movie_ids = set(mappings['movie'].keys())
    negative_candidates = {}

    for user_id, rated_movies in user_train_movies.items():
        candidate_movies = all_movie_ids - rated_movies
        negative_candidates[user_id] = list(candidate_movies)

    splits = {
        'train_df': train_df,
        'val_df': val_df,
        'test_df': test_df,
        'user_train_movies': user_train_movies,
        'negative_candidates': negative_candidates,
        'all_movie_ids': all_movie_ids,
        'reverse_mappings': reverse_mappings
    }

    print(f'   Graph contains {graph.num_edges} edges')
    print(f'   {len(negative_candidates)} users available for negative sampling')

    return graph, splits


## **Simple Train/Val/Test Split**

Performs a user-level random split of interactions without stratifying by rating. Ensures each user has at least one interaction in the training set.

In [12]:
def split_data_simple(df_user_movie, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
    np.random.seed(random_state)
    train_dfs, val_dfs, test_dfs = [], [], []

    for user_id, group in df_user_movie.groupby('user_id'):
        n = len(group)
        if n == 0:
            continue

        n_test = max(1, int(n * test_ratio))
        n_val = max(1, int(n * val_ratio))
        n_train = n - n_val - n_test

        if n_train <= 0:
            n_train = 1
            n_val = min(n - 1, n_val)
            n_test = n - n_train - n_val

        indices = np.random.permutation(n)

        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train + n_val]
        test_idx = indices[n_train + n_val:]

        train_dfs.append(group.iloc[train_idx])
        val_dfs.append(group.iloc[val_idx])
        test_dfs.append(group.iloc[test_idx])

    train_df = pd.concat(train_dfs, ignore_index=True)
    val_df = pd.concat(val_dfs, ignore_index=True)
    test_df = pd.concat(test_dfs, ignore_index=True)

    print(f'Train: {len(train_df)} links ({len(train_df) / len(df_user_movie) * 100:.1f}%)')
    print(f'Val: {len(val_df)} links ({len(val_df) / len(df_user_movie) * 100:.1f}%)')
    print(f'Test: {len(test_df)} links ({len(test_df) / len(df_user_movie) * 100:.1f}%)')

    return train_df, val_df, test_df


In [13]:
active_user_ids = get_most_active_users(
    df_connections_user_movie,
    df_connections_user_stats,
    n_users=5000
)

subgraph, mappings, user_movie_edges = build_subgraph_for_users(
    user_ids=active_user_ids,
    df_user_movie=df_connections_user_movie,
    df_movie_features=df_connections_movie_features,
    df_user_stats=df_connections_user_stats,
    df_movie_genre=df_connections_movie_genre,
    df_movie_actor=df_connections_movie_actor,
    df_movie_director=df_connections_movie_director,
    df_actors=df_connections_actors,
    df_directors=df_connections_directors,
    df_genres=df_connections_genres,
    num_negatives=500
)

train_df, val_df, test_df = split_data_simple(
    user_movie_edges,
    train_ratio=0.7,
    val_ratio=0.15,
    test_ratio=0.15
)

reverse_mappings = {}
for node_type, mapping in mappings.items():
    reverse_mappings[node_type] = {v: k for k, v in mapping.items()}

train_set = set(zip(train_df['user_id'], train_df['movie_id']))
val_set = set(zip(val_df['user_id'], val_df['movie_id']))
test_set = set(zip(test_df['user_id'], test_df['movie_id']))

if ('user', 'rates', 'movie') in subgraph.edge_types:
    edge_index = subgraph['user', 'rates', 'movie'].edge_index
    num_edges = edge_index.shape[1]
    train_mask = torch.zeros(num_edges, dtype=torch.bool)
    val_mask = torch.zeros(num_edges, dtype=torch.bool)
    test_mask = torch.zeros(num_edges, dtype=torch.bool)

    for i in range(num_edges):
        user_idx = edge_index[0, i].item()
        movie_idx = edge_index[1, i].item()
        user_id = reverse_mappings['user'].get(user_idx)
        movie_id = reverse_mappings['movie'].get(movie_idx)

        if user_id is not None and movie_id is not None:
            edge_tuple = (user_id, movie_id)
            if edge_tuple in train_set:
                train_mask[i] = True
            elif edge_tuple in val_set:
                val_mask[i] = True
            elif edge_tuple in test_set:
                test_mask[i] = True
    subgraph['user', 'rates', 'movie'].train_mask = train_mask
    subgraph['user', 'rates', 'movie'].val_mask = val_mask
    subgraph['user', 'rates', 'movie'].test_mask = test_mask

    if ('movie', 'rev_rates', 'user') in subgraph.edge_types:
        subgraph['movie', 'rev_rates', 'user'].train_mask = train_mask
        subgraph['movie', 'rev_rates', 'user'].val_mask = val_mask
        subgraph['movie', 'rev_rates', 'user'].test_mask = test_mask

user_train_movies = {}

for _, row in train_df.iterrows():
    user_id = row['user_id']
    movie_id = row['movie_id']

    if user_id not in user_train_movies:
        user_train_movies[user_id] = set()
    user_train_movies[user_id].add(movie_id)

all_movie_ids = set(mappings['movie'].keys())
negative_candidates = {}

for user_id in active_user_ids:
    if user_id in user_train_movies:
        rated_movies = user_train_movies[user_id]
        candidate_movies = all_movie_ids - rated_movies
        negative_candidates[user_id] = list(candidate_movies)
    else:
        negative_candidates[user_id] = list(all_movie_ids)

splits_info = {
    'train_df': train_df,
    'val_df': val_df,
    'test_df': test_df,
    'user_train_movies': user_train_movies,
    'negative_candidates': negative_candidates,
    'all_movie_ids': all_movie_ids,
    'reverse_mappings': reverse_mappings,
    'active_user_ids': active_user_ids
}

torch.save({
    'graph': subgraph,
    'splits': splits_info,
    'mappings': mappings
}, 'prepared_graph.pt')

print(f'Saved to file: \'prepared_graph.pt\'')


Selected 5000 most active users
Minimum number of ratings: 52
Maximum number of ratings: 56
Average number of ratings: 54.28

Sizes of filtered data:
Users: 5000
Movies: 8579
User-movie links: 271376
Actors: 2512
Directors: 1600
Genres: 20

Constructed graph with the following structure:
Nodes: ['user', 'movie', 'actor', 'director', 'genre']
Edges: [('user', 'rates', 'movie'), ('movie', 'has_genre', 'genre'), ('movie', 'has_director', 'director'), ('movie', 'has_actor', 'actor'), ('genre', 'rev_has_genre', 'movie'), ('director', 'rev_has_director', 'movie'), ('actor', 'rev_has_actor', 'movie'), ('movie', 'rev_rates', 'user')]

Checking node features:
  user: torch.Size([5000, 3])
  movie: torch.Size([8579, 3])
  actor: torch.Size([2512, 5])
  director: torch.Size([1600, 4])
  genre: torch.Size([20, 20])

Checking edges:
  ('user', 'rates', 'movie'): 271376 edges
  ('movie', 'has_genre', 'genre'): 21323 edges
  ('movie', 'has_director', 'director'): 6605 edges
  ('movie', 'has_actor', '

## **Recommendation Dataset**

A dataset class that stores negative candidates as pre-mapped indices for faster batch generation.

In [17]:
class RecDataset(Dataset):
    def __init__(self, df, negative_candidates, mappings, num_negatives=10, rng_seed: int = None):
        self.df = df
        self.negative_candidates = negative_candidates
        self.mappings = mappings
        self.num_negatives = num_negatives

        self.rng = np.random.default_rng(rng_seed)

        self.data = []
        for _, row in df.iterrows():
            user_idx = mappings['user'][row['user_id']]
            movie_idx = mappings['movie'][row['movie_id']]
            self.data.append({
                'user_idx': int(user_idx),
                'movie_idx': int(movie_idx),
                'rating': float(row['rating'])
            })

        self.neg_candidates_idx = {}
        for user_id, movie_ids in negative_candidates.items():
            user_idx = mappings['user'][user_id]
            self.neg_candidates_idx[int(user_idx)] = [
                int(mappings['movie'][mid]) for mid in movie_ids
            ]

        self._all_movie_indices = list(mappings['movie'].values())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        user_idx = item['user_idx']

        if user_idx in self.neg_candidates_idx and len(self.neg_candidates_idx[user_idx]) > 0:
            candidates = self.neg_candidates_idx[user_idx]
            if len(candidates) >= self.num_negatives:
                neg_indices = self.rng.choice(candidates, self.num_negatives, replace=False)
            else:
                neg_indices = self.rng.choice(candidates, self.num_negatives, replace=True)
        else:
            if len(self._all_movie_indices) >= self.num_negatives:
                neg_indices = self.rng.choice(self._all_movie_indices, self.num_negatives, replace=False)
            else:
                neg_indices = self.rng.choice(self._all_movie_indices, self.num_negatives, replace=True)

        return {
            'user_idx': torch.tensor(user_idx, dtype=torch.long),
            'pos_movie_idx': torch.tensor(item['movie_idx'], dtype=torch.long),
            'neg_movie_indices': torch.tensor(neg_indices, dtype=torch.long),
            'rating': torch.tensor(item['rating'], dtype=torch.float)
        }


## **Heterogeneous GraphConv Model**

A multi-layer heterogeneous graph neural network that replaces SAGEConv with standard GraphConv layers. The architecture supports variable number of layers, with separate linear projections for each node type and final user/movie embedding refinement layers.

In [20]:
class HeteroGraphConv(nn.Module):
    def __init__(self,
                 user_feat_dim,
                 movie_feat_dim,
                 actor_feat_dim,
                 director_feat_dim,
                 genre_feat_dim,
                 hidden_channels=64,
                 out_channels=32,
                 num_layers=2):
        super().__init__()

        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers

        self.user_lin = Linear(user_feat_dim, hidden_channels)
        self.movie_lin = Linear(movie_feat_dim, hidden_channels)
        self.actor_lin = Linear(actor_feat_dim, hidden_channels)
        self.director_lin = Linear(director_feat_dim, hidden_channels)
        self.genre_lin = Linear(genre_feat_dim, hidden_channels)

        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_channels = hidden_channels
            out_channels_layer = out_channels if i == num_layers - 1 else hidden_channels

            conv = HeteroConv({
                ('user', 'rates', 'movie'): GraphConv(in_channels, out_channels_layer),
                ('movie', 'rev_rates', 'user'): GraphConv(in_channels, out_channels_layer),

                ('movie', 'has_genre', 'genre'): GraphConv(in_channels, out_channels_layer),
                ('genre', 'rev_has_genre', 'movie'): GraphConv(in_channels, out_channels_layer),

                ('movie', 'has_director', 'director'): GraphConv(in_channels, out_channels_layer),
                ('director', 'rev_has_director', 'movie'): GraphConv(in_channels, out_channels_layer),

                ('movie', 'has_actor', 'actor'): GraphConv(in_channels, out_channels_layer),
                ('actor', 'rev_has_actor', 'movie'): GraphConv(in_channels, out_channels_layer),
            }, aggr='sum')

            self.convs.append(conv)

        self.user_final = Linear(out_channels, out_channels)
        self.movie_final = Linear(out_channels, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        for module in self.modules():
            if isinstance(module, (nn.Linear, GraphConv)):
                module.reset_parameters()

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            'user': self.user_lin(x_dict['user']),
            'movie': self.movie_lin(x_dict['movie']),
            'actor': self.actor_lin(x_dict['actor']),
            'director': self.director_lin(x_dict['director']),
            'genre': self.genre_lin(x_dict['genre']),
        }

        for i, conv in enumerate(self.convs):
            x_dict = conv(x_dict, edge_index_dict)

            x_dict = {key: F.relu(x) for key, x in x_dict.items()}

        user_emb = self.user_final(x_dict['user'])
        movie_emb = self.movie_final(x_dict['movie'])

        return user_emb, movie_emb


## **Training Pipeline**

### **Core Components**

**Model Architecture:**
- **HeteroGraphConv**: Multi-layer heterogeneous GNN with GraphConv layers
- **Feature Projection**: Separate linear layers for each node type (user, movie, actor, director, genre)
- **Mask-Aware Embeddings**: Uses train/val edge masks during message passing
- **Final Embeddings**: Additional linear layers refine user and movie embeddings


**Training:**
- **Loss Function**: BCEWithLogitsLoss comparing positive vs negative movie scores

**Evaluation:**
- **Comprehensive Metrics**:
  - Classification: Accuracy, Precision, Recall, F1, AUC-ROC
  - Ranking: Precision@K, Recall@K, F1@K, NDCG@K, MAP@K (for K=5,10,20)

In [21]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import ndcg_score

class RecDataset(Dataset):
    def __init__(self, df, negative_candidates, mappings, num_negatives=10, rng_seed=None):
        self.df = df
        self.negative_candidates = negative_candidates
        self.mappings = mappings
        self.num_negatives = num_negatives
        self.rng = np.random.default_rng(rng_seed)
        self.data = []

        for _, row in df.iterrows():
            user_idx = mappings['user'][row['user_id']]
            movie_idx = mappings['movie'][row['movie_id']]
            self.data.append({
                'user_idx': int(user_idx),
                'movie_idx': int(movie_idx),
                'rating': float(row['rating'])
            })

        self.neg_candidates_idx = {}
        for user_id, movie_ids in negative_candidates.items():
            user_idx = mappings['user'][user_id]
            self.neg_candidates_idx[int(user_idx)] = [
                int(mappings['movie'][mid]) for mid in movie_ids
            ]

        self._all_movie_indices = list(mappings['movie'].values())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        user_idx = item['user_idx']

        if user_idx in self.neg_candidates_idx and len(self.neg_candidates_idx[user_idx]) > 0:
            candidates = self.neg_candidates_idx[user_idx]
            if len(candidates) >= self.num_negatives:
                neg_indices = self.rng.choice(candidates, self.num_negatives, replace=False)
            else:
                neg_indices = self.rng.choice(candidates, self.num_negatives, replace=True)
        else:
            if len(self._all_movie_indices) >= self.num_negatives:
                neg_indices = self.rng.choice(self._all_movie_indices, self.num_negatives, replace=False)
            else:
                neg_indices = self.rng.choice(self._all_movie_indices, self.num_negatives, replace=True)

        return {
            'user_idx': torch.tensor(user_idx, dtype=torch.long),
            'pos_movie_idx': torch.tensor(item['movie_idx'], dtype=torch.long),
            'neg_movie_indices': torch.tensor(neg_indices, dtype=torch.long),
            'rating': torch.tensor(item['rating'], dtype=torch.float)
        }

def dict_collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    collated = {}
    if len(batch) == 0:
        return collated
    keys = batch[0].keys()
    for k in keys:
        collated[k] = torch.stack([d[k] for d in batch], dim=0)
    return collated

def create_dataloaders(splits, mappings, batch_size=32, num_negatives=10, num_workers=2, pin_memory: bool = True):
    train_dataset = RecDataset(splits['train_df'], splits['negative_candidates'], mappings, num_negatives)
    val_dataset = RecDataset(splits['val_df'], splits['negative_candidates'], mappings, num_negatives)
    test_dataset = RecDataset(splits['test_df'], splits['negative_candidates'], mappings, num_negatives)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=dict_collate_fn, num_workers=num_workers, pin_memory=pin_memory
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=dict_collate_fn, num_workers=num_workers, pin_memory=pin_memory
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=dict_collate_fn, num_workers=num_workers, pin_memory=pin_memory
    )

    return train_loader, val_loader, test_loader

def _model_device(model: torch.nn.Module) -> torch.device:
    try:
        return next(model.parameters()).device
    except StopIteration:
        return torch.device('cpu')

def get_embeddings_with_masks(model, graph, use_train_mask=True):
    device = _model_device(model)
    x_dict_on_device = {}
    for k, v in graph.x_dict.items():
        x_dict_on_device[k] = v.to(device)
    edge_index_dict = {}
    for edge_type in graph.edge_types:
        data_edge = graph[edge_type]
        if use_train_mask and hasattr(data_edge, 'train_mask'):
            edge_index = data_edge.edge_index[:, data_edge.train_mask]
            edge_index_dict[edge_type] = edge_index.to(device)
        else:
            edge_index_dict[edge_type] = data_edge.edge_index.to(device)
    user_emb, movie_emb = model(x_dict_on_device, edge_index_dict)
    user_emb = user_emb.to(device)
    movie_emb = movie_emb.to(device)
    return user_emb, movie_emb

def batch_loss(model, graph, batch, use_train_mask=True):
    user_emb, movie_emb = get_embeddings_with_masks(model, graph, use_train_mask)
    user_emb_batch = user_emb[batch['user_idx']]
    pos_movie_emb = movie_emb[batch['pos_movie_idx']]
    neg_movie_emb = movie_emb[batch['neg_movie_indices']]

    pos_scores = torch.sum(user_emb_batch * pos_movie_emb, dim=1)
    neg_scores = torch.sum(user_emb_batch.unsqueeze(1) * neg_movie_emb, dim=2)

    pos_targets = torch.ones_like(pos_scores)
    neg_targets = torch.zeros_like(neg_scores)

    all_scores = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
    all_targets = torch.cat([pos_targets.unsqueeze(1), neg_targets], dim=1)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    return loss_fn(all_scores, all_targets)

def simple_train_epoch(model, graph, train_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    num_batches = 0
    for batch in train_loader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        optimizer.zero_grad()
        loss = batch_loss(model, graph, batch, use_train_mask=True)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    return total_loss / max(1, num_batches)

def validate(model, graph, val_loader, device):
    model.eval()
    total_loss = 0.0
    num_batches = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            loss = batch_loss(model, graph, batch, use_train_mask=False)
            total_loss += loss.item()
            num_batches += 1
    return total_loss / max(1, num_batches)

def compute_ranking_metrics(
        model,
        graph,
        splits,
        mappings,
        k_list: List[int] = [5, 10, 20],
        max_users: int = 200,
        max_negatives_per_user: int = 100,
        positive_threshold: float = 4.0
):
    model.eval()
    device = _model_device(model)
    user_emb, movie_emb = get_embeddings_with_masks(model, graph, use_train_mask=False)
    val_users = splits['val_df']['user_id'].unique()
    if len(val_users) > max_users:
        val_users = np.random.choice(val_users, max_users, replace=False)
    metrics_accum = {f'precision@{k}': [] for k in k_list}
    metrics_accum.update({f'recall@{k}': [] for k in k_list})
    metrics_accum.update({f'f1@{k}': [] for k in k_list})
    metrics_accum.update({f'ndcg@{k}': [] for k in k_list})
    metrics_accum.update({f'map@{k}': [] for k in k_list})
    all_true = []
    all_pred = []
    all_pred_proba = []
    with torch.no_grad():
        for user_id in val_users:
            if user_id not in mappings['user']:
                continue
            user_idx = mappings['user'][user_id]
            user_val_rows = splits['val_df'][splits['val_df']['user_id'] == user_id]
            user_val_positives = user_val_rows[user_val_rows['rating'] >= positive_threshold]['movie_id'].tolist()
            if len(user_val_positives) == 0:
                continue
            if user_id in splits['negative_candidates']:
                neg_cands = splits['negative_candidates'][user_id]
                if len(neg_cands) > max_negatives_per_user:
                    neg_cands = list(np.random.choice(neg_cands, max_negatives_per_user, replace=False))
            else:
                neg_cands = [mid for mid in splits['all_movie_ids'] if mid not in splits['user_train_movies'].get(user_id, set())]
                if len(neg_cands) > max_negatives_per_user:
                    neg_cands = list(np.random.choice(neg_cands, max_negatives_per_user, replace=False))
            if len(neg_cands) == 0:
                continue
            candidate_movies = list(neg_cands) + list(user_val_positives)
            candidate_movies = [mid for mid in candidate_movies if mid in mappings['movie']]
            if len(candidate_movies) == 0:
                continue
            candidate_indices = [mappings['movie'][mid] for mid in candidate_movies]
            user_emb_single = user_emb[user_idx:user_idx + 1]
            candidate_emb = movie_emb[candidate_indices]
            scores = torch.matmul(user_emb_single, candidate_emb.T).squeeze(0)
            probs = torch.sigmoid(scores).cpu().numpy()
            relevance = np.array([1 if mid in user_val_positives else 0 for mid in candidate_movies], dtype=int)
            binary_preds = (probs > 0.5).astype(int)
            all_true.extend(relevance.tolist())
            all_pred.extend(binary_preds.tolist())
            all_pred_proba.extend(probs.tolist())
            sorted_idx = np.argsort(-scores.cpu().numpy())
            sorted_relevance = relevance[sorted_idx]
            sorted_scores_for_ndcg = scores.cpu().numpy()[sorted_idx]
            for k in k_list:
                if len(sorted_relevance) < 1:
                    continue
                topk = sorted_relevance[:k]
                precision_at_k = topk.sum() / k
                recall_at_k = topk.sum() / max(1, relevance.sum())
                if precision_at_k + recall_at_k > 0:
                    f1_at_k = 2 * precision_at_k * recall_at_k / (precision_at_k + recall_at_k)
                else:
                    f1_at_k = 0.0
                metrics_accum[f'precision@{k}'].append(precision_at_k)
                metrics_accum[f'recall@{k}'].append(recall_at_k)
                metrics_accum[f'f1@{k}'].append(f1_at_k)
                try:
                    ndcg_at_k = ndcg_score([relevance], [scores.cpu().numpy()], k=k)
                except Exception:
                    ndcg_at_k = 0.0
                metrics_accum[f'ndcg@{k}'].append(ndcg_at_k)
                map_at_k = 0.0
                num_relevant = 0
                for i in range(min(k, len(sorted_relevance))):
                    if sorted_relevance[i] == 1:
                        num_relevant += 1
                        map_at_k += num_relevant / (i + 1)
                denom = min(relevance.sum(), k) if relevance.sum() > 0 else 1
                map_at_k = map_at_k / denom
                metrics_accum[f'map@{k}'].append(map_at_k)
    avg_metrics = {}
    for k in k_list:
        for mname in ['precision', 'recall', 'f1', 'ndcg', 'map']:
            key = f'{mname}@{k}'
            vals = metrics_accum.get(key, [])
            avg_metrics[key] = float(np.mean(vals)) if len(vals) > 0 else 0.0
    if len(all_true) > 0:
        avg_metrics['accuracy'] = float(accuracy_score(all_true, all_pred))
        avg_metrics['precision'] = float(precision_score(all_true, all_pred, zero_division=0))
        avg_metrics['recall'] = float(recall_score(all_true, all_pred, zero_division=0))
        avg_metrics['f1'] = float(f1_score(all_true, all_pred, zero_division=0))
        try:
            avg_metrics['auc_roc'] = float(roc_auc_score(all_true, all_pred_proba))
        except Exception:
            avg_metrics['auc_roc'] = 0.0
    else:
        avg_metrics.update({
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1': 0.0,
            'auc_roc': 0.0
        })
    return avg_metrics

def evaluate(model, graph, splits, mappings, k_list=[5, 10, 20], max_users=500, max_negatives_per_user=100):
    metrics = compute_ranking_metrics(model, graph, splits, mappings, k_list=k_list, max_users=max_users, max_negatives_per_user=max_negatives_per_user)
    print('\nClassification metrics:')
    print(f'Accuracy:  {metrics["accuracy"]:.4f}')
    print(f'Precision: {metrics["precision"]:.4f}')
    print(f'Recall:    {metrics["recall"]:.4f}')
    print(f'F1-score:  {metrics["f1"]:.4f}')
    print(f'AUC-ROC:   {metrics["auc_roc"]:.4f}')
    print('\nRanking metrics:')
    for k in k_list:
        print(f"@k={k}: Precision: {metrics[f'precision@{k}']:.4f}, Recall: {metrics[f'recall@{k}']:.4f}")
        print(f"F1: {metrics[f'f1@{k}']:.4f}, NDCG: {metrics[f'ndcg@{k}']:.4f}, MAP: {metrics[f'map@{k}']:.4f}")
    return metrics

def train(prepared_path='prepared_graph.pt', epochs=5, batch_size=32, eval_every=2, k_list=[5, 10, 20], hidden_channels=64, out_channels=32):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    use_pin_memory = device.type == 'cuda'
    print(f'Device: {device}, pin_memory={use_pin_memory}')
    loaded_data = torch.load(prepared_path, map_location='cpu', weights_only=False)
    graph = loaded_data['graph']
    splits = loaded_data['splits']
    mappings = loaded_data['mappings']
    user_feat_dim = graph['user'].x.shape[1]
    movie_feat_dim = graph['movie'].x.shape[1]
    actor_feat_dim = graph['actor'].x.shape[1]
    director_feat_dim = graph['director'].x.shape[1]
    genre_feat_dim = graph['genre'].x.shape[1]
    model = HeteroGraphConv(
        user_feat_dim=user_feat_dim,
        movie_feat_dim=movie_feat_dim,
        actor_feat_dim=actor_feat_dim,
        director_feat_dim=director_feat_dim,
        genre_feat_dim=genre_feat_dim,
        hidden_channels=hidden_channels,
        out_channels=out_channels,
        num_layers=2
    )
    train_loader, val_loader, test_loader = create_dataloaders(
        splits, mappings, batch_size=batch_size, num_negatives=10,
        pin_memory=use_pin_memory
    )
    model = model.to(device)
    try:
        graph = graph.to(device)
    except Exception:
        print('graph.to(device) not available')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    history = {'train_loss': [], 'val_loss': [], 'metrics': []}
    for epoch in tqdm(range(epochs)):
        train_loss = simple_train_epoch(model, graph, train_loader, optimizer, device)
        val_loss = validate(model, graph, val_loader, device)
        if (epoch + 1) % eval_every == 0 or epoch == epochs - 1:
            print(f'\nEvaluating metrics at epoch {epoch + 1}...')
            metrics = evaluate(
                model,
                graph,
                splits,
                mappings,
                k_list=k_list,
                max_users=500,
                max_negatives_per_user=200)
            history['metrics'].append(metrics)
        else:
            metrics = None
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        if metrics:
            print(f'Epoch {epoch + 1}/{epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
            print(f"Sample: Precision@10: {metrics.get('precision@10', 0):.4f}, NDCG@10: {metrics.get('ndcg@10', 0):.4f}")
        else:
            print(f'Epoch {epoch + 1}/{epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    print('\nFinal evaluation on TEST set...')
    final_metrics = compute_ranking_metrics(
        model, graph,
        {**splits, 'val_df': splits['test_df']},
        mappings,
        k_list=k_list,
        max_users=min(500, len(splits['test_df']['user_id'].unique())),
        max_negatives_per_user=200
    )
    print('Final test metrics (subset):')
    print(f"  Precision@10: {final_metrics.get('precision@10', 0):.4f}")
    print(f"  Recall@10:    {final_metrics.get('recall@10', 0):.4f}")
    print(f"  NDCG@10:      {final_metrics.get('ndcg@10', 0):.4f}")
    print(f"  AUC-ROC:      {final_metrics.get('auc_roc', 0):.4f}")
    return model, history, final_metrics

model, history, final_metrics = train(
    prepared_path='prepared_graph.pt',
    epochs=5,
    batch_size=32,
    eval_every=2,
    k_list=[5, 10, 20],
    hidden_channels=32,
    out_channels=16
)


Device: cuda, pin_memory=True


 20%|██        | 1/5 [03:46<15:05, 226.35s/it]

Epoch 1/5: Train Loss: 75.0007, Val Loss: 1.8791

Evaluating metrics at epoch 2...


 40%|████      | 2/5 [07:55<11:59, 239.77s/it]


Classification metrics:
Accuracy:  0.9772
Precision: 0.0000
Recall:    0.0000
F1-score:  0.0000
AUC-ROC:   0.4950

Ranking metrics:
@k=5: Precision: 0.0004, Recall: 0.0003
F1: 0.0004, NDCG: 0.0253, MAP: 0.0004
@k=10: Precision: 0.0004, Recall: 0.0006
F1: 0.0005, NDCG: 0.0361, MAP: 0.0004
@k=20: Precision: 0.0003, Recall: 0.0009
F1: 0.0005, NDCG: 0.0559, MAP: 0.0004
Epoch 2/5: Train Loss: 0.4054, Val Loss: 0.3393
Sample: Precision@10: 0.0004, NDCG@10: 0.0361


 60%|██████    | 3/5 [11:03<07:12, 216.07s/it]

Epoch 3/5: Train Loss: 0.3055, Val Loss: 0.3180

Evaluating metrics at epoch 4...


 80%|████████  | 4/5 [14:14<03:26, 206.39s/it]


Classification metrics:
Accuracy:  0.9770
Precision: 0.0000
Recall:    0.0000
F1-score:  0.0000
AUC-ROC:   0.5035

Ranking metrics:
@k=5: Precision: 0.0004, Recall: 0.0003
F1: 0.0003, NDCG: 0.0254, MAP: 0.0001
@k=10: Precision: 0.0002, Recall: 0.0003
F1: 0.0002, NDCG: 0.0361, MAP: 0.0001
@k=20: Precision: 0.0003, Recall: 0.0010
F1: 0.0005, NDCG: 0.0560, MAP: 0.0001
Epoch 4/5: Train Loss: 0.3048, Val Loss: 0.3157
Sample: Precision@10: 0.0002, NDCG@10: 0.0361

Evaluating metrics at epoch 5...


100%|██████████| 5/5 [17:25<00:00, 209.17s/it]


Classification metrics:
Accuracy:  0.9758
Precision: 0.0000
Recall:    0.0000
F1-score:  0.0000
AUC-ROC:   0.5024

Ranking metrics:
@k=5: Precision: 0.0016, Recall: 0.0013
F1: 0.0014, NDCG: 0.0263, MAP: 0.0007
@k=10: Precision: 0.0010, Recall: 0.0017
F1: 0.0013, NDCG: 0.0368, MAP: 0.0006
@k=20: Precision: 0.0013, Recall: 0.0050
F1: 0.0021, NDCG: 0.0570, MAP: 0.0008
Epoch 5/5: Train Loss: 0.3048, Val Loss: 0.3146
Sample: Precision@10: 0.0010, NDCG@10: 0.0368

Final evaluation on TEST set...





Final test metrics (subset):
  Precision@10: 0.0014
  Recall@10:    0.0024
  NDCG@10:      0.0362
  AUC-ROC:      0.5003


In [22]:
df = (pd.Series(final_metrics)
        .rename_axis('metric_at_k')
        .reset_index(name='value'))
df.sort_values('metric_at_k', ascending=False)

Unnamed: 0,metric_at_k,value
1,recall@5,0.001603
11,recall@20,0.003292
6,recall@10,0.002405
17,recall,0.0
0,precision@5,0.002041
10,precision@20,0.00102
5,precision@10,0.001429
16,precision,0.0
3,ndcg@5,0.025598
13,ndcg@20,0.056112
