In [1]:
import numpy as np
from tqdm import tqdm, trange
import torch
from torch_geometric.data import Data
import torch
from torch_geometric.nn import GCNConv, GATConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import copy

SEED = 12345
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

VALIDATION_SPLIT = 0.1

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
ratings_train_np = np.load('ratings_train.npy')
namesngenre_np = np.load('namesngenre.npy')
ratings_test_np = np.load('ratings_test.npy')
ratings_train_np.shape, ratings_test_np.shape, namesngenre_np.shape, 

((610, 4980), (610, 4980), (4980, 2))

In [4]:
import re
year_pattern = re.compile(r"\((\d{4})\)")
all_decades = set(int(year_pattern.search(name.item()).group(1)) // 10 * 10 for name, _ in namesngenre_np if year_pattern.search(name.item()))
num_decades = len(all_decades)
all_decades

{1900, 1910, 1920, 1930, 1940, 1950, 1960, 1970, 1980, 1990, 2000, 2010}

In [5]:
sorted_decades = sorted(all_decades)
decade_to_index = {decade: idx for idx, decade in enumerate(sorted_decades)}
# Each edge connects a movie (row index in namesngenre_np) to its release decade
decade_edges = []
for movie_idx, (raw_name, _) in enumerate(namesngenre_np):
    match = year_pattern.search(raw_name.item())
    if match:
        decade = int(match.group(1)) // 10 * 10
        decade_edges.append((movie_idx, decade_to_index[decade]))
decade_edges = torch.tensor(decade_edges, dtype=torch.long).t() if decade_edges else torch.empty((2, 0), dtype=torch.long)
decade_edges

tensor([[   0,    1,    2,  ..., 4977, 4978, 4979],
        [   9,    9,    9,  ...,   10,   11,   11]])

In [6]:
all_genres = set(g for _, genres in namesngenre_np for g in genres.split('|'))
all_genres.discard('(no genres listed)')
all_genres = sorted(list(all_genres))
num_genres = len(all_genres)
all_genres

['Action',
 'Adventure',
 'Animation',
 'Children',
 'Comedy',
 'Crime',
 'Documentary',
 'Drama',
 'Fantasy',
 'Film-Noir',
 'Horror',
 'IMAX',
 'Musical',
 'Mystery',
 'Romance',
 'Sci-Fi',
 'Thriller',
 'War',
 'Western']

In [7]:
genre_to_index = {genre: idx for idx, genre in enumerate(all_genres)}
# Each edge connects a movie to one of its genres without applying offsets
genre_edges = []
for movie_idx, (_, genres_str) in enumerate(namesngenre_np):
    for genre in genres_str.split('|'):
        genre = genre.strip()
        if genre in genre_to_index:
            genre_edges.append((movie_idx, genre_to_index[genre]))
genre_edges = torch.tensor(genre_edges, dtype=torch.long).t() if genre_edges else torch.empty((2, 0), dtype=torch.long)
genre_edges

tensor([[   0,    0,    0,  ..., 4978, 4978, 4979],
        [   1,    2,    3,  ...,    1,    8,    4]])

In [8]:
# genre_map = {genre: i for i, genre in enumerate(all_genres)}
# movie_features = torch.zeros(namesngenre_np.shape[0], len(all_genres))
# for i, (_, genres_str) in enumerate(namesngenre_np):
#     for genre in genres_str.split('|'):
#         if genre in genre_map:
#             movie_features[i, genre_map[genre]] = 1

# movie_features

In [9]:
num_users, num_movies = ratings_train_np.shape
# user_features = torch.rand(num_users, len(all_genres))

In [10]:
user_ids, movie_ids = np.where(~np.isnan(ratings_train_np))
ratings = ratings_train_np[~np.isnan(ratings_train_np)]

test_user_ids, test_movie_ids = np.where(~np.isnan(ratings_test_np))
test_ratings = ratings_test_np[~np.isnan(ratings_test_np)]

edges = np.column_stack((user_ids, movie_ids))
train_edges, val_edges, train_ratings, val_ratings = train_test_split(
    edges, ratings, test_size=VALIDATION_SPLIT, random_state=SEED, shuffle=True
)
train_edges, val_edges = train_edges.T, val_edges.T

test_edges = np.column_stack((test_user_ids, test_movie_ids)).T
ratings = ratings.astype(np.float32)
test_ratings = test_ratings.astype(np.float32)
train_edges.shape, train_ratings.shape, val_edges.shape, val_ratings.shape, test_edges.shape, test_ratings.shape

((2, 28438), (28438,), (2, 3160), (3160,), (2, 31598), (31598,))

In [11]:
def offset_edges(edge_tensor, src_offset=0, dst_offset=0):
    if edge_tensor.numel() == 0:
        return torch.empty((2, 0), dtype=torch.long)
    adjusted = edge_tensor.clone()
    adjusted[0, :] += src_offset
    adjusted[1, :] += dst_offset
    return adjusted

movie_offset = num_users
genre_offset = num_users + num_movies
decade_offset = num_users + num_movies + num_genres

train_edge_index = torch.tensor(train_edges, dtype=torch.long)
val_edge_index = torch.tensor(val_edges, dtype=torch.long)
test_edge_index = torch.tensor(test_edges, dtype=torch.long)

train_edge_index = offset_edges(train_edge_index, 0, movie_offset)
val_edge_index = offset_edges(val_edge_index, 0, movie_offset)
test_edge_index = offset_edges(test_edge_index, 0, movie_offset)

train_actual_ratings = torch.tensor(train_ratings, dtype=torch.float32)
val_actual_ratings = torch.tensor(val_ratings, dtype=torch.float32)
test_actual_ratings = torch.tensor(test_ratings, dtype=torch.float32)

genre_edge_index = offset_edges(genre_edges, movie_offset, genre_offset)
decade_edge_index = offset_edges(decade_edges, movie_offset, decade_offset)

train_edge_index.shape, val_edge_index.shape, test_edge_index.shape

(torch.Size([2, 28438]), torch.Size([2, 3160]), torch.Size([2, 31598]))

In [59]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels_list, concat_outputs=False):
        super(GCN, self).__init__()
        self.concat_outputs = concat_outputs
        self.layers_list = torch.nn.ModuleList()
        prev_channels = in_channels
        for hidden_channels in hidden_channels_list:
            self.layers_list.append(
                GCNConv(prev_channels, hidden_channels),
            )
            prev_channels = hidden_channels
        
    def forward(self, x, edge_index):
        if self.concat_outputs:
            hidden_outputs = []
            for idx, layer in enumerate(self.layers_list):
                is_last = (idx == len(self.layers_list) - 1)
                x = layer(x, edge_index)
                if not is_last:
                    x = F.relu(x)
                hidden_outputs.append(x)
            x = torch.cat(hidden_outputs, dim=1)
        else:
            for idx, layer in enumerate(self.layers_list):
                is_last = (idx == len(self.layers_list) - 1)
                x = layer(x, edge_index)
                if not is_last:
                    x = F.relu(x)
        return x
    
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels_list, concat_outputs=False, heads=8, dropout=0.1):
        super(GAT, self).__init__()
        self.concat_outputs = concat_outputs
        self.layers_list = torch.nn.ModuleList()
        prev_channels = in_channels
        prev_heads = 1
        for idx, hidden_channels in enumerate(hidden_channels_list):
            is_last = (idx == len(hidden_channels_list) - 1)
            self.layers_list.append(
                GATConv(prev_channels * prev_heads, hidden_channels, heads=heads, dropout=dropout, concat=False if is_last else True),
            )
            prev_channels = hidden_channels
            prev_heads = heads
        
    def forward(self, x, edge_index):
        if self.concat_outputs:
            hidden_outputs = []
            for idx, layer in enumerate(self.layers_list):
                is_last = (idx == len(self.layers_list) - 1)
                x = layer(x, edge_index)
                if not is_last:
                    x = F.elu(x)
                hidden_outputs.append(x)
            x = torch.cat(hidden_outputs, dim=1)
        else:
            for idx, layer in enumerate(self.layers_list):
                is_last = (idx == len(self.layers_list) - 1)
                x = layer(x, edge_index)
                if not is_last:
                    x = F.elu(x)
        return x

In [None]:
CONFIG = {
    # General Training Parameters
    'num_epochs': 10000,
    'batch_size': 2**12,
    'learning_rate': 0.001,
    'patience': 200,
    'scale_function': 'sigmoid',  # Options: 'sigmoid', 'clamp' sigmoind works better
    # Model Parameters
    # 'MODEL_TYPE': 'GCN',  # Options: 'GCN', 'GAT'
    
    # 'concat_outputs': False,
    
    # 'num_features': 32,
    
    # 'hidden_dims': (32,) * 2,
    # Best Val Loss: 0.9446, during training was 0.9446
    # Test Loss: 0.9472

    # 'hidden_dims': (32,) * 3,
    # Best Val Loss: 0.9264, during training was 0.9264
    # Test Loss: 0.9252

    # 'hidden_dims': (32,) * 4,
    # Best Val Loss: 0.9220, during training was 0.9220
    # Test Loss: 0.9198

    # 'hidden_dims': (32,) * 5,
    # Best Val Loss: 0.9241, during training was 0.9241
    # Test Loss: 0.9275

    # 'hidden_dims': (32,) * 6,
    # Best Val Loss: 0.9117, during training was 0.9117
    # Test Loss: 0.9166
    
    # 'hidden_dims': (32,) * 7,
    # Best Val Loss: 0.9141, during training was 0.9141
    # Test Loss: 0.9163

    # 'hidden_dims': (32,) * 8,
    # Best Val Loss: 0.9098, during training was 0.9098
    # Test Loss: 0.9146
    
    # 'hidden_dims': (32,) * 9,
    # Best Val Loss: 0.9112, during training was 0.9112
    # Test Loss: 0.9120
    
    # 'hidden_dims': (32,) * 10,
    # Best Val Loss: 0.9050, during training was 0.9050
    # Test Loss: 0.9131
    
    # 'hidden_dims': (32,) * 11,
    # Best Val Loss: 0.9281, during training was 0.9281
    # Test Loss: 0.9350
    
    # 'hidden_dims': (32,) * 12,
    # Best Val Loss: 0.9277, during training was 0.9277
    # Test Loss: 0.9338
    
    # 'hidden_dims': (32,) * 20,
    # Best Val Loss: 0.9258, during training was 0.9258
    # Test Loss: 0.9333
    
    # 'concat_outputs': True,
    
    # 'hidden_dims': (32,) * 10,
    # Best Val Loss: 0.9320, during training was 0.9320
    # Test Loss: 0.9385
    
    # 'hidden_dims': (32,) * 20,
    # Best Val Loss: 0.9306, during training was 0.9306
    # Test Loss: 0.9346

    # Test Num Features:
    # 'hidden_dims': (32,) * 6,
    # 'num_features': 32,
    # Best Val Loss: 0.9258, during training was 0.9258
    # Test Loss: 0.9327
    # 'num_features': 64,
    # Best Val Loss: 0.9185, during training was 0.9185
    # Test Loss: 0.9164
    # 'num_features': 128,
    # Best Val Loss: 0.9260, during training was 0.9260
    # Test Loss: 0.9337
    
    # 'concat_outputs' : False,
    # 'num_features': 64,
    # optimal depth seems to be 10
    # concatenating outputs seems to worsen performance
    
    # testing decreasing hidden dims with depth 10
    # 'hidden_dims': (128, 112, 96, 80, 64, 48, 32, 24, 16, 8),
    # Best Val Loss: 0.9150, during training was 0.9150
    # Test Loss: 0.9206
    
    # testing a wider network with depth 10
    # 'hidden_dims': (64,) * 10,
    # Best Val Loss: 0.9142, during training was 0.9142
    # Test Loss: 0.9152

    # 'hidden_dims': (128,) * 10,
    # Best Val Loss: 0.9104, during training was 0.9104
    # Test Loss: 0.9156
    
    # 'hidden_dims': (256,) * 10,
    # Best Val Loss: 0.9073, during training was 0.9073
    # Test Loss: 0.9115
    
    # 'hidden_dims': (512,) * 10,
    # Not trainable
    
    'MODEL_TYPE': 'GAT',

    # 'concat_outputs' : False,
    
    'num_features': 64,
    # 'hidden_dims': (32,) * 10,
    # Best Val Loss: 0.8985, during training was 0.8985
    # Test Loss: 0.8973
    
    # 'hidden_dims': (16,) * 8,
    # Best Val Loss: 0.8998, during training was 0.8998
    # Test Loss: 0.9065
    
    # 'hidden_dims': (8,) * 8,
    # Best Val Loss: 0.9299, during training was 0.9299
    # Test Loss: 0.9303
    
    # 'concat_outputs': True,
    # 'hidden_dims': (16,) * 8,
    # Not trainable
    #  'hidden_dims': (16,) * 6,
    # Not trainable
    # 'hidden_dims': (16,) * 4,
    # Best Val Loss: 0.9500, during training was 0.9500
    # Test Loss: 0.9521
    
    # concatenating outputs seems to worsen performance for GAT as well
    
    'concat_outputs': False,
    'hidden_dims': (32,) * 12,
    # Best Val Loss: 0.8983, during training was 0.8983
    # Test Loss: 0.9028

        
    # FOR GAT:
    'heads': 8,
    'dropout': 0.1,
}

In [84]:
print(CONFIG)

x = torch.rand(num_users + num_movies + num_genres + num_decades, CONFIG['num_features'])

train_edge_index = train_edge_index.to(device)
val_edge_index = val_edge_index.to(device)
test_edge_index = test_edge_index.to(device)
val_actual_ratings = val_actual_ratings.to(device)
test_actual_ratings = test_actual_ratings.to(device)

genre_edge_index = genre_edge_index.to(device)
decade_edge_index = decade_edge_index.to(device)
graph_data = Data(x=x, edge_index=torch.concat([train_edge_index, genre_edge_index, decade_edge_index], axis=1), y=train_actual_ratings)
graph_data = graph_data.to(device)

scale_functions = {
    'sigmoid': lambda x: torch.sigmoid(x) * 4.5 + .5,
    'clamp': lambda x: x.clamp(.5, 5.0),
}
scale_function = scale_functions[CONFIG['scale_function']]

if CONFIG['MODEL_TYPE'] == 'GCN':
    model = GCN(in_channels=x.shape[1], hidden_channels_list=CONFIG['hidden_dims'], concat_outputs=CONFIG['concat_outputs'])
elif CONFIG['MODEL_TYPE'] == 'GAT':
    model = GAT(in_channels=x.shape[1], hidden_channels_list=CONFIG['hidden_dims'], concat_outputs=CONFIG['concat_outputs'], heads=CONFIG['heads'], dropout=CONFIG['dropout'])


model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
criterion = torch.nn.MSELoss()
best_val_loss = float('inf')
best_model_state = None
val_loss_not_improved_count = 0
early_stopping_patience = CONFIG['patience']

print("Starting training...")
for epoch in range(CONFIG['num_epochs']):
    model.train()
    optimizer.zero_grad()
    embeddings = model(graph_data.x, graph_data.edge_index)
    user_embeds = embeddings[train_edge_index[0]]
    movie_embeds = embeddings[train_edge_index[1]]
    logits = (user_embeds * movie_embeds).sum(dim=1)
    if CONFIG['scale_function'] == 'sigmoid':
        predictions = scale_function(logits)
    else:
        predictions = logits
    loss = criterion(predictions, graph_data.y)
    loss.backward()
    optimizer.step()
    
    model.eval()
    with torch.no_grad():
        eval_embeddings = model(graph_data.x, graph_data.edge_index)
        val_user_embeds = eval_embeddings[val_edge_index[0]]
        val_movie_embeds = eval_embeddings[val_edge_index[1]]
        val_logits = (val_user_embeds * val_movie_embeds).sum(dim=1)
        val_predictions = scale_function(val_logits)
        val_loss = criterion(val_predictions, val_actual_ratings)
        if epoch % 20 == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss**.5:.4f}, Val Loss: {val_loss**.5:.4f}')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            val_loss_not_improved_count = 0
        else:
            val_loss_not_improved_count += 1
            if val_loss_not_improved_count >= early_stopping_patience:
                print(f'Early stopping at epoch {epoch}')
                break
print("Training finished.")

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    model = model.to(device)

model.eval()
with torch.no_grad():
    embeddings = model(graph_data.x, graph_data.edge_index)
    
    val_user_embeds = embeddings[val_edge_index[0]]
    val_movie_embeds = embeddings[val_edge_index[1]]
    val_logits = (val_user_embeds * val_movie_embeds).sum(dim=1)
    val_predictions = scale_function(val_logits)
    val_loss = criterion(val_predictions, val_actual_ratings)
    print(f'Best Val Loss: {val_loss**.5:.4f}, during training was {best_val_loss**.5:.4f}')
    
    test_user_embeds = embeddings[test_edge_index[0]]
    test_movie_embeds = embeddings[test_edge_index[1]]
    test_logits = (test_user_embeds * test_movie_embeds).sum(dim=1)
    test_predictions = scale_function(test_logits)
    test_loss = criterion(test_predictions, test_actual_ratings)
    print(f'Test Loss: {test_loss**.5:.4f}')

{'num_epochs': 10000, 'batch_size': 4096, 'learning_rate': 0.001, 'patience': 200, 'scale_function': 'sigmoid', 'MODEL_TYPE': 'GAT', 'num_features': 64, 'concat_outputs': False, 'hidden_dims': (32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32), 'heads': 8, 'dropout': 0.1}
Starting training...
Epoch: 000, Loss: 1.2541, Val Loss: 1.8023
Epoch: 020, Loss: 1.0314, Val Loss: 1.0214
Epoch: 040, Loss: 1.0173, Val Loss: 1.0167
Epoch: 060, Loss: 1.0042, Val Loss: 0.9926
Epoch: 080, Loss: 0.9951, Val Loss: 0.9853
Epoch: 100, Loss: 0.9908, Val Loss: 0.9799
Epoch: 120, Loss: 0.9875, Val Loss: 0.9718
Epoch: 140, Loss: 0.9709, Val Loss: 0.9667
Epoch: 160, Loss: 0.9647, Val Loss: 0.9699
Epoch: 180, Loss: 0.9598, Val Loss: 0.9641
Epoch: 200, Loss: 0.9679, Val Loss: 0.9494
Epoch: 220, Loss: 0.9500, Val Loss: 0.9465
Epoch: 240, Loss: 0.9477, Val Loss: 0.9434
Epoch: 260, Loss: 0.9455, Val Loss: 0.9561
Epoch: 280, Loss: 0.9501, Val Loss: 0.9419
Epoch: 300, Loss: 0.9377, Val Loss: 0.9476
Epoch: 320, Loss: 0.