# Demonstrate Combination of neo4j & PyTorch

In [None]:
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from torch import Tensor
import torch
import numpy as np
import sys
import networkx as nx
sys.path.append('../scripts/')
from utils_neo4j import load_node, conn, get_db_query_result_as_networkx

In [None]:
# Important information:
# record_table
# record_target_column

In [None]:
# get all types of nodes inside the graph database
def get_node_types():
    query = "MATCH (n) RETURN DISTINCT labels(n) as node_type"
    return conn.query(query)

print("🌎 Node types:")
[print("-", x["node_type"][0]) for x in get_node_types()]
print()

def get_relationship_types():
    query = "MATCH ()-[r]->() RETURN DISTINCT type(r) as relationship_type"
    return conn.query(query)

print("🌐 Node types:")
[print("-", x["relationship_type"]) for x in get_relationship_types()]
print()

In [None]:
G = get_db_query_result_as_networkx()

In [None]:
# show graph
nx.draw(G, with_labels=True)

In [None]:
# Relabel all nodes inside the networkx graph to monotonically increasing integers starting at 0
G = nx.convert_node_labels_to_integers(G, first_label=0, ordering="sorted")

In [None]:
G.nodes

In [None]:
record_table = "entity"
record_target_column = "feature_1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_single_graph(idx=0):
    data_ = HeteroData()

    G = get_db_query_result_as_networkx()
    for nodetype in G.nodes:
        data_[nodetype].x = torch.from_numpy(np.array([G.nodes[nodetype][idx][record_target_column]]))

    data_[record_table].label = torch.tensor([G.nodes[record_table][idx][record_target_column]])

    for n1, n2, e in G.edges:
        data_[list(G.nodes[n1]["labels"])[0],
              G.get_edge_data(n1, n2)["type"],
              list(G.nodes[n2]["labels"])[0]].edge_index = edge_index_
    data_ = ToUndirected()(data_)
    data_.to(device, non_blocking=True)
    return data_

In [None]:
class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, mode="train"):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.mode = mode
        if mode == "train":
            self.inds = train.index.tolist()
        elif mode == "val":
            self.inds = validate.index.tolist()
        elif mode == "test":
            self.inds = test.index.tolist()
        # print(self.inds)

    @property
    def raw_file_names(self):
        return [
            "some_file_1",
        ]

    @property
    def processed_file_names(self):
        return [
            "data_1.pt",
        ]

    def download(self):
        pass

    def process(self):
        pass

    def len(self):
        return len(self.inds)

    def get(self, idx):
        data = get_single_graph(self.inds[idx])
        return data

In [None]:
dset_train = MyOwnDataset(root="", mode="train")
dset_val = MyOwnDataset(root="", mode="val")
dset_test = MyOwnDataset(root="", mode="test")

In [None]:
g = dset_train.get(0)
g = T.AddSelfLoops()(g)

In [None]:
from torch_geometric.nn import BatchNorm, GraphConv
from torch_geometric.nn import MultiAggregation
from torch_geometric.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Batch
from torch_geometric.nn import to_hetero


batch_size = 32
hidden_size = 27
num_classes = 2
learn_rate = 0.01
aggr = "max"

class GraphLevelGNNRes(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GraphConv(-1, 27)
        self.conv2 = GraphConv(-1, 27)
        self.conv3 = GraphConv(-1, 27)
        self.batchnorm1 = BatchNorm(27)
        self.pool = MultiAggregation(aggrs=["mean", "min", "max"], mode="cat")
        self.lin = Linear(hidden_size * 3, num_classes)

    def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor:
        h1 = self.conv1(x, edge_index)
        h1 = self.batchnorm1(h1)
        h1 = F.leaky_relu(h1)
        h1 = x + h1
        h1 = F.dropout(h1, p=0.1, training=self.training)

        h2 = self.conv2(h1, edge_index)
        h2 = self.batchnorm1(h2)
        h2 = F.leaky_relu(h2)
        h2 = h1 + h2
        h2 = F.dropout(h2, p=0.1, training=self.training)

        h3 = self.conv3(h2, edge_index)
        h3 = self.batchnorm1(h3)
        h3 = F.leaky_relu(h3)
        h3 = h2 + h3
        h3 = F.dropout(h3, p=0.1, training=self.training)

        h3 = self.pool(h3, batch)
        h3 = self.lin(h3)
        return h3


g1 = get_single_graph(1)
g2 = get_single_graph(2)

btch = Batch.from_data_list([g1, g2])

metadata = g1.metadata()

model = GraphLevelGNNRes()
model = to_hetero(model, metadata, aggr=aggr, debug=False)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

out = model(btch.x_dict, btch.edge_index_dict, btch.batch_dict)