## Imports

In [15]:
import pathlib
import os
import os.path as osp
import sys
import argparse
parent_path = pathlib.Path(os.getcwd()).parent.absolute()
sys.path.append(str(parent_path))
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.nn import Linear
import torch_geometric.transforms as T
from torch_geometric.datasets import MovieLens
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.loader import NeighborLoader, LinkNeighborLoader

from utils.Neo4jMovieLensMetaData import Neo4jMovieLensMetaData
# from utils.gnn_simple import Model
from utils.train_test import train_test
from utils.visualize import plot_loss, plot_train, plot_val, plot_test, plot_results

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### Read the corresponsing csv, store the dataset to the DB, preprocess it, and get it as a pytorch graph object

In [17]:
path = osp.join(osp.dirname(osp.abspath('')), '../../data/MovieLensNeo4j')
dataset = Neo4jMovieLensMetaData(
    path,
    model_name='all-MiniLM-L6-v2',
    database_url="bolt://localhost:7687",
    database_username="neo4j",
    database_password="admin",
    force_pre_process=True,
    force_db_restore=False,
    text_features=["title"],
    list_features=[],
    fastRP_features=[],
    numeric_features=[],
)
data = dataset[0].to(device)

Processing...


Movies have features...
Encoding title...


Batches: 100%|██████████| 565/565 [00:15<00:00, 36.70it/s]


[torch.Size([18062, 64])]


Done!


#### Preprocess the dataset

In [18]:
# Add user node features for message passing:
data['user'].x = torch.eye(data['user'].num_nodes, device=device)
del data['user'].num_nodes

# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:
data = T.ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label  # Remove "reverse" label.

# Perform a link-level split into training, validation, and test edges:
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.7,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'movie')],
    rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)

In [19]:
train_data['user', 'movie'].edge_label

tensor([4, 4, 4,  ..., 3, 4, 4])

In [21]:
# train_loader = LinkNeighborLoader(
#     data=train_data,
#     # Sample ALL neighbors for each node and each edge type for 2 iterations:
#     num_neighbors=[-1],
#     # Use a batch size of 128 for sampling training nodes of type "paper":
#     batch_size=1024,
#     edge_label_index = (("user", "movie"), None),
#     edge_label = train_data['user', 'movie'].edge_label,
# )
# train_batch = next(iter(train_loader))


edge_label_index = train_data["user", "rates", "movie"].edge_label_index
edge_label = train_data["user", "rates", "movie"].edge_label

train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[15, 5],
    neg_sampling_ratio=2.0,
    edge_label_index=(("user", "rates", "movie"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

train_batch = next(iter(train_loader))

In [23]:
# test_loader = LinkNeighborLoader(
#     data=test_data,
#     # Sample ALL neighbors for each node and each edge type for 2 iterations:
#     num_neighbors=[-1],
#     # Use a batch size of 128 for sampling training nodes of type "paper":
#     batch_size=1024,
#     edge_label_index = (("user", "movie"), None),
#     edge_label = test_data['user', 'movie'].edge_label,
# )
# test_batch = next(iter(test_loader))

test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[15],
    neg_sampling_ratio=2.0,
    edge_label_index=(("user", "rates", "movie"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

test_batch = next(iter(test_loader))

#### Define and train-test the model

In [24]:
import torch
from torch import Tensor
from torch.nn import Linear, LazyLinear, Sequential, BatchNorm1d, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GATv2Conv, GCNConv, TransformerConv, GraphConv, GINConv, GINEConv, to_hetero, HeteroLinear, HeteroConv
from torch_geometric.nn.models import GIN, GraphSAGE
from torch_geometric.nn.aggr import MultiAggregation
from typing import Union
from torch_geometric.typing import Adj, OptPairTensor, Size


class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        global z_dict_test_lala, edge_label_index_test_lala
        row, col = edge_label_index
        z_dict_test_lala = z_dict
        edge_label_index_test_lala = edge_label_index
        movie = z_dict['movie'][col]
        user = z_dict['user'][row]
        z = torch.cat([user, movie], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

In [25]:
model = Model(hidden_channels=16).to(device)

In [26]:
with torch.no_grad():
    model.encoder(train_batch.x_dict, train_batch.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.012)

weight = None
# weight = torch.bincount(train_data['user', 'movie'].edge_label)
# weight = weight.max() / weight

def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

In [27]:
def train(batch, log=False):
    pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = batch['user', 'movie'].edge_label.float()
    loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(batch, log=False):
    pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = batch['user', 'movie'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)

In [28]:
def train_epoch():
    model.train()
    total_loss = batch_index = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        batch = batch.to(device, 'edge_index')
        loss = train(batch)
        # print(f'Batch {batch_index:03d}, Train Loss: {loss:.4f}')
        total_loss += loss
        # train_rmse = test(batch)
        # test_rmse = test()
        # losses.append((loss, train_rmse, test_rmse))
        # losses.append((loss, train_rmse, test_rmse))
        # print(f'Batch {batch_index:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, Test: {test_rmse:.4f}')
        batch_index += 1
    return total_loss / batch_index if batch_index else None

In [29]:
def test_epoch():
    model.eval()
    total_loss = batch_index = 0
    for batch in tqdm(test_loader):
        batch = batch.to(device, 'edge_index')
        loss = test(batch)
        # print(f'Batch {batch_index:03d}, Test Loss: {loss:.4f}')
        total_loss += loss
        # train_rmse = test(batch)
        # test_rmse = test()
        # losses.append((loss, train_rmse, test_rmse))
        # losses.append((loss, train_rmse, test_rmse))
        # print(f'Batch {batch_index:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, Test: {test_rmse:.4f}')
        batch_index += 1
    return total_loss / batch_index if batch_index else None

In [30]:
epochs = 2
losses = []
epoch_index = 1
for epoch in range(1, epochs+1):
        train_loss = train_epoch()
        test_loss = test_epoch()
                # train_rmse = test(batch)
                # test_rmse = test()
                # losses.append((loss, train_rmse, test_rmse))
                # losses.append((loss, train_rmse, test_rmse))
        print(f'Batch {epoch_index:03d}, Loss: {train_loss:.4f}, Test: {test_loss:.4f}')
        epoch_index += 1

100%|██████████| 4158/4158 [23:16<00:00,  2.98it/s]
100%|██████████| 4158/4158 [06:45<00:00, 10.27it/s]


Batch 001, Loss: 7.1764, Test: 2.6784


100%|██████████| 4158/4158 [23:03<00:00,  3.01it/s]
100%|██████████| 4158/4158 [06:39<00:00, 10.42it/s]

Batch 002, Loss: 7.1766, Test: 2.6784



