### Testing mulitple simple GNNs

#### Installations

In [None]:
%pip install -r ../../requirements.txt

#### Imports

In [18]:
# general
import pathlib
import os
import sys
from tabulate import tabulate

# for py2neo and utils
from py2neo import Graph, Relationship
parent_path = pathlib.Path(os.getcwd()).parent.absolute()
sys.path.append(str(parent_path))

# for PyG
import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from torch_geometric.nn import to_hetero
import torch.nn.functional as F
from utils_draft.pyg import load_node, load_edge, SequenceEncoder, IdentityEncoder, ListEncoder
from torch.nn import Linear
from torch_geometric.nn import SAGEConv

#### Connecting to the existing neo4j instance

In [4]:
graph = Graph(
    "bolt://localhost:11005",
    auth=("neo4j", "admin"),
)

#### Observing the pre-stored graph

In [13]:
movies_num = graph.nodes.match("Movie").count()
users_num = graph.nodes.match("User").count()
RATES = Relationship.type("RATES")
ratings_num = graph.relationships.match(r_type=RATES).count()
table = [
    ["Movies", movies_num],
    ["Users", users_num],
    ["Ratings", ratings_num]
]
print(tabulate(table, headers=["Type", "Count"], tablefmt="github"))

| Type    |   Count |
|---------|---------|
| Movies  |   58528 |
| Users   |    7801 |
| Ratings | 1028167 |


#### Building the PyG graph

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

movies_x, movies_mapping = load_node(
    graph=graph,
    query="""
        MATCH (m: Movie)
        return m.movieId as movieId, m.title as title, m.genres as genres, m.year as year
    """,
    index_col="movieId",
    encoders={
        "title": SequenceEncoder(),
        "genres": ListEncoder(sep="|"),
        "year": SequenceEncoder(),
    }
)

users_x, users_mapping = load_node(
    graph=graph,
    query="""
        MATCH (u:User)-[r:RATES]-(m:Movie)
        return u.userId as userId, u.username as username, avg(r.rating) as avg_rating, count(r) as ratings;
    """,
    index_col="userId",
    encoders={
        # "avg_rating": IdentityEncoder(dtype=torch.float16),
        # "ratings": IdentityEncoder(dtype=torch.int64),
        "username": SequenceEncoder(),
    }
)

edge_index, edge_label = load_edge(
    graph=graph,
    query="""
        MATCH (u:User)-[r:RATES]-(m:Movie)
        return u.userId as userId, r.rating as rating, r.datetime as datetime, m.movieId as movieId;
    """,
    src_index_col="userId",
    src_mapping=users_mapping,
    dst_index_col="movieId",
    dst_mapping=movies_mapping,
    encoders={
        "rating": IdentityEncoder(dtype=torch.long),
        # "datetime": IdentityEncoder(dtype=torch.long),
    }
)

items: dict_items([('title', <utils.pyg.SequenceEncoder object at 0x000002D6692ED450>), ('genres', <utils.pyg.ListEncoder object at 0x000002D679DA9870>), ('year', <utils.pyg.SequenceEncoder object at 0x000002D6108A2BC0>)])


Batches:   0%|          | 0/1829 [00:00<?, ?it/s]

Batches:   0%|          | 0/1829 [00:00<?, ?it/s]

items: dict_items([('username', <utils.pyg.SequenceEncoder object at 0x000002D6692ED990>)])


Batches:   0%|          | 0/244 [00:00<?, ?it/s]

#### Building the dataset

In [17]:
data = HeteroData()
data["user"].x = users_x
data["movie"].x = movies_x
data["user", "reviews", "movie"].edge_index = edge_index
data["user", "reviews", "movie"].edge_label = edge_label
data.to(device, non_blocking=True)
data = ToUndirected()(data)
del data["movie", "rev_reviews", "user"].edge_label

transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[("user", "reviews", "movie")],
    rev_edge_types=[("movie", "rev_reviews", "user")],
)
train_data, val_data, test_data = transform(data)

#### Building the model

In [40]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels)
        self.conv4 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).relu()
        x = self.conv4(x, edge_index)
        return x
    
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)
        self.lin3 = Linear(hidden_channels, hidden_channels)
        self.lin4 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z).relu()
        z = self.lin3(z).relu()
        z = self.lin4(z)
        return z.view(-1)

In [41]:
class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

#### Default Loss function

#### Training

In [42]:
weight = torch.bincount(train_data['user', 'movie'].edge_label)
weight = weight.max() / weight

def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

In [52]:
class ModelBuilderTrainerTester():
    def __init__(self, train_data, val_data, test_data, hidden_channels=32):
        model = Model(hidden_channels=hidden_channels).to(device)
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        weight = torch.bincount(train_data['user', 'movie'].edge_label)
        self.weight = weight.max() / weight


    def train(self):
        self.model.train()
        self.optimizer.zero_grad()
        pred = self.model(
            self.train_data.collect('x'),
            self.train_data.edge_index_dict,
            self.train_data['user', 'movie'].edge_label_index
        )
        target = self.train_data['user', 'movie'].edge_label
        loss = weighted_mse_loss(pred, target, weight)
        loss.backward()
        self.optimizer.step()
        return float(loss)

    @torch.no_grad()
    def test(self, data):
        self.model.eval()
        pred = self.model(
            data.collect('x'),
            data.edge_index_dict,
            data['user', 'movie'].edge_label_index,
        )
        pred = pred.clamp(min=0, max=5)
        target = data['user', 'movie'].edge_label.float()
        rmse = F.mse_loss(pred, target).sqrt()
        return float(rmse)
    
    def train_test(self, epochs):
        # Due to lazy initialization, we need to run one model step so the number
        # of parameters can be inferred:
        with torch.no_grad():
            self.model.encoder(self.train_data.collect('x'), self.train_data.edge_index_dict)
        
        for epoch in range(1, epochs):
            loss = self.train()
            train_rmse = self.test(self.train_data)
            val_rmse = self.test(self.val_data)
            test_rmse = self.test(self.test_data)
            # if not epoch%100:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
                f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')
            # else:
                # print(epoch, end=" ")

In [53]:
my_model = ModelBuilderTrainerTester(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    hidden_channels=32
)

In [54]:
my_model.train_test(epochs=100)

Epoch: 001, Loss: 17.9087, Train: 3.3317, Val: 3.3321, Test: 3.3337
Epoch: 002, Loss: 16.9785, Train: 3.1763, Val: 3.1769, Test: 3.1784
Epoch: 003, Loss: 15.4536, Train: 2.6599, Val: 2.6614, Test: 2.6627
Epoch: 004, Loss: 11.1610, Train: 1.2090, Val: 1.2096, Test: 1.2109
Epoch: 005, Loss: 6.5565, Train: 1.1178, Val: 1.1136, Test: 1.1152
Epoch: 006, Loss: 8.5515, Train: 1.7520, Val: 1.7542, Test: 1.7554
Epoch: 007, Loss: 6.6612, Train: 2.1790, Val: 2.1808, Test: 2.1821
Epoch: 008, Loss: 8.2629, Train: 2.2224, Val: 2.2242, Test: 2.2255
Epoch: 009, Loss: 8.4783, Train: 2.0221, Val: 2.0240, Test: 2.0252
Epoch: 010, Loss: 7.5580, Train: 1.5897, Val: 1.5915, Test: 1.5927
Epoch: 011, Loss: 6.3201, Train: 1.1421, Val: 1.1408, Test: 1.1424
Epoch: 012, Loss: 6.9470, Train: 1.1117, Val: 1.1093, Test: 1.1111
Epoch: 013, Loss: 7.4129, Train: 1.3300, Val: 1.3308, Test: 1.3322
Epoch: 014, Loss: 6.2193, Train: 1.6713, Val: 1.6730, Test: 1.6743
Epoch: 015, Loss: 6.4557, Train: 1.8490, Val: 1.8507, Test

In [55]:
my_model.train_test(epochs=1000)


Epoch: 001, Loss: 5.3262, Train: 1.3472, Val: 1.3461, Test: 1.3489
Epoch: 002, Loss: 5.3191, Train: 1.3385, Val: 1.3375, Test: 1.3402
Epoch: 003, Loss: 5.3120, Train: 1.3320, Val: 1.3310, Test: 1.3337
Epoch: 004, Loss: 5.3053, Train: 1.3445, Val: 1.3436, Test: 1.3462
Epoch: 005, Loss: 5.2984, Train: 1.3274, Val: 1.3264, Test: 1.3290
Epoch: 006, Loss: 5.2906, Train: 1.3409, Val: 1.3400, Test: 1.3425
Epoch: 007, Loss: 5.2820, Train: 1.3314, Val: 1.3304, Test: 1.3330
Epoch: 008, Loss: 5.2730, Train: 1.3367, Val: 1.3357, Test: 1.3383
Epoch: 009, Loss: 5.2640, Train: 1.3302, Val: 1.3292, Test: 1.3317
Epoch: 010, Loss: 5.2544, Train: 1.3404, Val: 1.3394, Test: 1.3419
Epoch: 011, Loss: 5.2447, Train: 1.3088, Val: 1.3078, Test: 1.3102
Epoch: 012, Loss: 5.2376, Train: 1.4075, Val: 1.4067, Test: 1.4091
Epoch: 013, Loss: 5.2610, Train: 1.1687, Val: 1.1675, Test: 1.1695
Epoch: 014, Loss: 5.5114, Train: 1.7336, Val: 1.7329, Test: 1.7354
Epoch: 015, Loss: 5.9720, Train: 1.4048, Val: 1.4042, Test: 1.