### Testing mulitple simple GNNs

#### Installations

In [None]:
%pip install -r ../../requirements.txt

#### Imports

In [23]:
# general
import pathlib
import os
import sys
from tabulate import tabulate

# for py2neo and utils
from py2neo import Graph, Relationship
parent_path = pathlib.Path(os.getcwd()).parent.absolute()
sys.path.append(str(parent_path))

# for PyG
import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from torch_geometric.nn import to_hetero
import torch.nn.functional as F
from utils_draft.pyg import load_node, load_edge, SequenceEncoder, IdentityEncoder, ListEncoder
from torch.nn import Linear
from torch_geometric.nn import SAGEConv, GATv2Conv, GCNConv, TransformerConv

#### Connecting to the existing neo4j instance

In [4]:
graph = Graph(
    "bolt://localhost:7687",
    auth=("neo4j", "admin"),
)

#### Observing the pre-stored graph

In [5]:
movies_num = graph.nodes.match("Movie").count()
users_num = graph.nodes.match("User").count()
RATES = Relationship.type("RATES")
ratings_num = graph.relationships.match(r_type=RATES).count()
table = [
    ["Movies", movies_num],
    ["Users", users_num],
    ["Ratings", ratings_num]
]
print(tabulate(table, headers=["Type", "Count"], tablefmt="github"))

| Type    |   Count |
|---------|---------|
| Movies  |    9460 |
| Users   |     610 |
| Ratings |   96150 |


#### Building the PyG graph

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

movies_x, movies_mapping = load_node(
    graph=graph,
    query="""
        MATCH (m: Movie)
        return m.movieId as movieId, m.title as title, m.genres as genres, m.year as year
    """,
    index_col="movieId",
    encoders={
        "title": SequenceEncoder(),
        "genres": ListEncoder(sep="|"),
        "year": SequenceEncoder(),
    }
)

users_x, users_mapping = load_node(
    graph=graph,
    query="""
        MATCH (u:User)-[r:RATES]-(m:Movie)
        return u.userId as userId, u.username as username, avg(r.rating) as avg_rating, count(r) as ratings;
    """,
    index_col="userId",
    encoders={
        # "avg_rating": IdentityEncoder(dtype=torch.float16),
        # "ratings": IdentityEncoder(dtype=torch.int64),
        "username": SequenceEncoder(),
    }
)

edge_index, edge_label = load_edge(
    graph=graph,
    query="""
        MATCH (u:User)-[r:RATES]-(m:Movie)
        return u.userId as userId, r.rating as rating, r.datetime as datetime, m.movieId as movieId;
    """,
    src_index_col="userId",
    src_mapping=users_mapping,
    dst_index_col="movieId",
    dst_mapping=movies_mapping,
    encoders={
        "rating": IdentityEncoder(dtype=torch.long),
        # "datetime": IdentityEncoder(dtype=torch.long),
    }
)

items: dict_items([('title', <utils.pyg.SequenceEncoder object at 0x00000226309C5240>), ('genres', <utils.pyg.ListEncoder object at 0x000002262FA2B580>), ('year', <utils.pyg.SequenceEncoder object at 0x00000226309D0CA0>)])


Batches:   0%|          | 0/296 [00:00<?, ?it/s]

Batches:   0%|          | 0/296 [00:00<?, ?it/s]

items: dict_items([('username', <utils.pyg.SequenceEncoder object at 0x00000226309C5240>)])


Batches:   0%|          | 0/20 [00:00<?, ?it/s]

#### Building the dataset

In [8]:
data = HeteroData()
data["user"].x = users_x
data["movie"].x = movies_x
data["user", "reviews", "movie"].edge_index = edge_index
data["user", "reviews", "movie"].edge_label = edge_label
data.to(device, non_blocking=True)
data = ToUndirected()(data)
del data["movie", "rev_reviews", "user"].edge_label

transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[("user", "reviews", "movie")],
    rev_edge_types=[("movie", "rev_reviews", "user")],
)
train_data, val_data, test_data = transform(data)

#### Building the model

In [24]:
layers = {
    "SAGE": SAGEConv,
    "GAT": GATv2Conv,
    "GCN": GCNConv,
    "Transformer": TransformerConv,
}

In [None]:
class GNNEncoder1(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, layer_name="SAGE"):
        super().__init__()
        layer = layers.get(layer_name) or SAGEConv
        self.conv1 = layer(in_channels, hidden_channels)
        self.conv2 = layer(hidden_channels, hidden_channels)
        self.conv3 = layer(hidden_channels, hidden_channels)
        self.conv4 = layer(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).relu()
        x = self.conv4(x, edge_index)
        return x

class GNNEncoder2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, layer_name="SAGE", num_layers=4):
        super().__init__()
        layer = layers.get(layer_name) or SAGEConv
        self.convs = torch.nn.ModuleList()
        self.convs.append(layer(in_channels, hidden_channels))
        for _ in range(num_layers-2):
            self.convs.append(layer(hidden_channels, hidden_channels))
        self.convs.append(layer(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for i in range(len(self.convs)-1):
            x = self.convs[i](x, edge_index).relu()
        x = self.convs[-1](x, edge_index)
        return x

class GCNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(-1, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).relu()
        x = self.conv4(x, edge_index)
        return x
    
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)
        self.lin3 = Linear(hidden_channels, hidden_channels)
        self.lin4 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z).relu()
        z = self.lin3(z).relu()
        z = self.lin4(z)
        return z.view(-1)

In [26]:
class Model(torch.nn.Module):
    def __init__(self, in_channels=(-1, -1), hidden_channels=32, out_channels=32, layer_name="SAGE"):
        super().__init__()
        if layer_name == "GCN":
            self.encoder = GCNEncoder(
                hidden_channels=hidden_channels,
                out_channels=out_channels
            )
        else:
            self.encoder = GNNEncoder1(
                in_channels=in_channels,
                hidden_channels=hidden_channels,
                out_channels=out_channels,
                layer_name=layer_name,    
            )
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

#### Default Loss function

#### Training

In [27]:
weight = torch.bincount(train_data['user', 'movie'].edge_label)
weight = weight.max() / weight

def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

In [28]:
class ModelBuilderTrainerTester():
    def __init__(self, train_data, val_data, test_data, in_channels=(-1, -1), hidden_channels=32, out_channels=32, layer_name="SAGE"):
        model = Model(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=out_channels,
            layer_name=layer_name
        ).to(device)
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.layer_name = layer_name
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        weight = torch.bincount(train_data['user', 'movie'].edge_label)
        self.weight = weight.max() / weight


    def train(self):
        self.model.train()
        self.optimizer.zero_grad()
        pred = self.model(
            self.train_data.collect('x'),
            self.train_data.edge_index_dict,
            self.train_data['user', 'movie'].edge_label_index
        )
        target = self.train_data['user', 'movie'].edge_label
        loss = weighted_mse_loss(pred, target, weight)
        loss.backward()
        self.optimizer.step()
        return float(loss)

    @torch.no_grad()
    def test(self, data):
        self.model.eval()
        pred = self.model(
            data.collect('x'),
            data.edge_index_dict,
            data['user', 'movie'].edge_label_index,
        )
        pred = pred.clamp(min=0, max=5)
        target = data['user', 'movie'].edge_label.float()
        rmse = F.mse_loss(pred, target).sqrt()
        # print("Predictions:", pred)
        # print("Target:", target)
        return float(rmse)
       
    def train_test(self, epochs):
        # Due to lazy initialization, we need to run one model step so the number
        # of parameters can be inferred:
        with torch.no_grad():
            self.model.encoder(self.train_data.collect('x'), self.train_data.edge_index_dict)
        
        for epoch in range(1, epochs):
            loss = self.train()
            train_rmse = self.test(self.train_data)
            val_rmse = self.test(self.val_data)
            test_rmse = self.test(self.test_data)
            if not epoch%10:
                print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
                    f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')
        print("hidden_channels:", self.hidden_channels)
        print("layers:", self.layer_name)
        print("Final epoch:")
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
            f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')

#### Some tests

In [21]:
models = [
    {
        "layer_name": "SAGE",
        "hidden_channels": 64,
    },
    {
        "layer_name": "GAT",
        "hidden_channels": 64,
    },
]

In [35]:
for model in models:
    hidden_channels = model.get("hidden_channels")
    layer_name = model.get("layer_name")
    print(layer_name, "|", hidden_channels, "hidden_channels")
    my_model = ModelBuilderTrainerTester(
        layer_name=layer_name,
        train_data=train_data,
        val_data=val_data,
        test_data=test_data,
        hidden_channels=hidden_channels,
    )
    my_model.train_test(epochs=1000)
    print("\n---------------------------------------------------------")

SAGE | 64 hidden_channels
Epoch: 010, Loss: 6.7702, Train: 1.0875, Val: 1.1037, Test: 1.0978
Epoch: 020, Loss: 6.9893, Train: 1.1583, Val: 1.1814, Test: 1.1729
Epoch: 030, Loss: 6.0863, Train: 1.3461, Val: 1.3711, Test: 1.3619
Epoch: 040, Loss: 5.8485, Train: 1.4209, Val: 1.4449, Test: 1.4367
Epoch: 050, Loss: 5.7502, Train: 1.4013, Val: 1.4246, Test: 1.4184
Epoch: 060, Loss: 5.5393, Train: 1.2708, Val: 1.2925, Test: 1.2896
Epoch: 070, Loss: 5.2944, Train: 1.3040, Val: 1.3250, Test: 1.3252
Epoch: 080, Loss: 5.0280, Train: 1.2662, Val: 1.2830, Test: 1.2886
Epoch: 090, Loss: 4.6796, Train: 1.2882, Val: 1.3028, Test: 1.3148
Epoch: 100, Loss: 4.4098, Train: 1.1301, Val: 1.1498, Test: 1.1591
Epoch: 110, Loss: 4.0095, Train: 1.2827, Val: 1.3205, Test: 1.3266
Epoch: 120, Loss: 3.8531, Train: 1.2115, Val: 1.2597, Test: 1.2648
Epoch: 130, Loss: 3.6745, Train: 1.2019, Val: 1.2621, Test: 1.2672
Epoch: 140, Loss: 3.4091, Train: 1.1702, Val: 1.2434, Test: 1.2432
Epoch: 150, Loss: 4.4678, Train: 1.3

In [29]:
layer_name = "Transformer"
in_channels = -1
hidden_channels = 64
print(layer_name, "|", hidden_channels, "hidden_channels")
my_model = ModelBuilderTrainerTester(
    layer_name=layer_name,
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    in_channels=-1,
    hidden_channels=64,
)
my_model.train_test(epochs=1000)

Transformer | 64 hidden_channels
Epoch: 010, Loss: 15.9526, Train: 3.1762, Val: 3.1707, Test: 3.1727
Epoch: 020, Loss: 6.6684, Train: 1.1358, Val: 1.1307, Test: 1.1330
Epoch: 030, Loss: 6.1654, Train: 1.3282, Val: 1.3329, Test: 1.3337
Epoch: 040, Loss: 6.2509, Train: 1.5258, Val: 1.5287, Test: 1.5298
Epoch: 050, Loss: 6.1423, Train: 1.3833, Val: 1.3861, Test: 1.3873
Epoch: 060, Loss: 6.1302, Train: 1.3290, Val: 1.3322, Test: 1.3333
Epoch: 070, Loss: 6.0765, Train: 1.3520, Val: 1.3562, Test: 1.3570
Epoch: 080, Loss: 5.9007, Train: 1.2967, Val: 1.3042, Test: 1.3038
Epoch: 090, Loss: 5.3109, Train: 1.1792, Val: 1.1922, Test: 1.1790
Epoch: 100, Loss: 4.9238, Train: 1.3878, Val: 1.3940, Test: 1.3828
Epoch: 110, Loss: 4.6709, Train: 1.2180, Val: 1.2282, Test: 1.2132
Epoch: 120, Loss: 4.3941, Train: 1.2368, Val: 1.2458, Test: 1.2340
Epoch: 130, Loss: 4.2831, Train: 1.3016, Val: 1.3078, Test: 1.2959
Epoch: 140, Loss: 3.9700, Train: 1.2541, Val: 1.2640, Test: 1.2522
Epoch: 150, Loss: 3.8447, Tr