Import libraries

In [324]:
import os
import re
cwd = os.getcwd()
if re.search("protein-reconstruction.+", cwd):
    os.chdir("..")
print(os.getcwd())

C:\Users\attor\Desktop\Python\protein-reconstruction


In [325]:
import networkx as nx
import torch
import torch_geometric as tg
from torch_geometric.utils import negative_sampling
import torchinfo
import torch_geometric.utils.convert as tgc
import numpy as np
from typing import final
import plotly.graph_objects as go
import random
from models.layers import GATConvBlock, SAGEConvBlock, GCN2ConvBlock, GCNConvBlock
from models.pretraining.encoders import SimpleGCNEncoder, ResGCN2ConvEncoder, RevSAGEConvEncoder, RevGATConvEncoder, \
    ResGCN2ConvEncoderV2, RevGCNEncoder
from models.pretraining.gae import GAEv2
from models.pretraining.vgae import VGAEv2, VGEncoder
from models.classification.classifiers import ProtMotionNet
from torch_geometric.loader import DataLoader
from preprocessing.constants import PRETRAIN_CLEANED_TRAIN, PRETRAIN_CLEANED_VAL
from preprocessing.dataset import load_dataset


Define graph and plot it

In [326]:
_POSITION_ATTRIBUTE: final = "pos"
_X_MIN: final = 0
_X_MAX: final = 2
_Y_MIN: final = 0
_Y_MAX: final = 2
g = nx.Graph()
g.add_node(0, x=[1., 0., 1.2, 1.1, 0.2, 0.1])
g.add_node(1, x=[0., 1., 0, 1.2, 1.1, 0.2])
g.add_node(2, x=[0.4, 1., 0.1, 0.2, 0.7, 0.3])
g.add_node(3, x=[1., 1.2, 0.9, 0.9, 0.5, 0.4])
g.add_node(4, x=[1., 1.3, 0.4, 0.3, 1.8, 0.45])
g.add_edge(1, 0, edge_weight=1.0)
g.add_edge(1, 2, edge_weight=2.)
g.add_edge(2, 0, edge_weight=1.)
g.add_edge(3, 2, edge_weight=1.)
g.add_edge(4, 2, edge_weight=1.)

print(g.nodes(data=True))
print(g.edges(data=True))
print("ciao")

[(0, {'x': [1.0, 0.0, 1.2, 1.1, 0.2, 0.1]}), (1, {'x': [0.0, 1.0, 0, 1.2, 1.1, 0.2]}), (2, {'x': [0.4, 1.0, 0.1, 0.2, 0.7, 0.3]}), (3, {'x': [1.0, 1.2, 0.9, 0.9, 0.5, 0.4]}), (4, {'x': [1.0, 1.3, 0.4, 0.3, 1.8, 0.45]})]
[(0, 1, {'edge_weight': 1.0}), (0, 2, {'edge_weight': 1.0}), (1, 2, {'edge_weight': 2.0}), (2, 3, {'edge_weight': 1.0}), (2, 4, {'edge_weight': 1.0})]
ciao


Plot graph

In [327]:
if len(nx.get_node_attributes(g, "pos",)) == 0:
    pos = {i: (random.gauss(_X_MIN, _X_MAX), random.gauss(_Y_MIN, _Y_MAX)) for i in g.nodes}
    nx.set_node_attributes(g, pos, "pos")

edge_x = []
edge_y = []
for edge in g.edges():
    x0, y0 = g.nodes[edge[0]]['pos']
    x1, y1 = g.nodes[edge[1]]['pos']
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color='#888'),
    hoverinfo='none',
    mode='lines')

node_x = []
node_y = []
for node in g.nodes():
    x, y = g.nodes[node]['pos']
    node_x.append(x)
    node_y.append(y)

node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        showscale=True,
        # colorscale options
        #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
        #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
        #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
        colorscale='YlGnBu',
        reversescale=True,
        color=[],
        size=10,
        colorbar=dict(
            thickness=15,
            title='Node Connections',
            xanchor='left',
            titleside='right'
        ),
        line_width=2))

node_adjacencies = []
node_text = []
for node, adjacencies in enumerate(g.adjacency()):
    node_adjacencies.append(len(adjacencies[1]))
    node_text.append(f'node {node}, # of connections: ' + str(len(adjacencies[1])))

node_trace.marker.color = node_adjacencies
node_trace.text = node_text

# noinspection PyTypeChecker
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title='<br>Network graph made with Python',
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    annotations=[ dict(
                        text="Python code: <a href='https://plotly.com/ipython-notebooks/network-graphs/'> https://plotly.com/ipython-notebooks/network-graphs/</a>",
                        showarrow=False,
                        xref="paper", yref="paper",
                        x=0.005, y=-0.002 ) ],
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.show()

Convert graph in PyTorch geometric

In [328]:
pyg = tgc.from_networkx(g)
print(pyg)
print(pyg.edge_weight)

Data(x=[5, 6], edge_index=[2, 10], pos=[5, 2], edge_weight=[10])
tensor([1., 1., 1., 2., 1., 2., 1., 1., 1., 1.])


Instantiate layers and test their serialization

In [329]:
gat0 = GATConvBlock(6, 3, heads=2, edge_dim=1)
gat1 = GATConvBlock(6, 3, heads=1, edge_dim=1, dropout=0.3, concat=True)

sage0 = SAGEConvBlock(6, 3, project=True)
sage1 = SAGEConvBlock(6, 3, project=False)

gcn0 = GCNConvBlock(6, 3, normalize=False)
gcn1 = GCNConvBlock(6, 3, normalize=True)

gcn20 = GCN2ConvBlock(6, 0.6)
gcn21 = GCN2ConvBlock(6, 0.5)

print(gat0)
print(gat1)
print(sage0)
print(sage1)
print(gcn0)
print(gcn1)
print(gcn20)
print(gcn21)

print(gat0(pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print(gat1(pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print(sage0(pyg.x, edge_index=pyg.edge_index))
print(sage1(pyg.x, edge_index=pyg.edge_index))
print(gcn0(pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print(gcn1(pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print(gcn20(pyg.x, x0=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print(gcn21(pyg.x, x0=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))

GATConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GATv2Conv(6, 3, heads=2)
)
GATConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GATv2Conv(6, 3, heads=1)
)
SAGEConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): SAGEConv(6, 3, aggr=mean)
)
SAGEConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): SAGEConv(6, 3, aggr=mean)
)
GCNConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCNConv(6, 3)
)
GCNConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCNConv(6, 3)
)
GCN2ConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCN2Conv(6, alpha=0.6, beta=0.6931471805599453)
)
GCN2ConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCN2Conv(6, alpha=0.5, beta=0.6931471805599453)
)
tensor([[ 0.3742,  0.1815,  0.1761],
        [ 0.3373,  0.1259,  0

In [330]:
state_dict = gat0.state_dict()
gat01 = GATConvBlock.from_constructor_params(gat0.serialize_constructor_params())
gat01.load_state_dict(state_dict)

state_dict = sage0.state_dict()
sage01 = SAGEConvBlock.from_constructor_params(sage0.serialize_constructor_params())
sage01.load_state_dict(state_dict)

state_dict = gcn0.state_dict()
gcn01 = GCNConvBlock.from_constructor_params(gcn0.serialize_constructor_params())
gcn01.load_state_dict(state_dict)

state_dict = gcn20.state_dict()
gcn201 = GCN2ConvBlock.from_constructor_params(gcn20.serialize_constructor_params())
gcn201.load_state_dict(state_dict)

print(gat0)
print(gat01)
print(sage0)
print(sage01)
print(gcn0)
print(gcn01)
print(gcn20)
print(gcn201)

print(gat0(pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print(gat01(pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print(sage0(pyg.x, edge_index=pyg.edge_index))
print(sage01(pyg.x, edge_index=pyg.edge_index))
print(gcn0(pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print(gcn01(pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print(gcn20(pyg.x, x0=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print(gcn201(pyg.x, x0=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))

GATConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GATv2Conv(6, 3, heads=2)
)
GATConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GATv2Conv(6, 3, heads=2)
)
SAGEConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): SAGEConv(6, 3, aggr=mean)
)
SAGEConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): SAGEConv(6, 3, aggr=mean)
)
GCNConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCNConv(6, 3)
)
GCNConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCNConv(6, 3)
)
GCN2ConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCN2Conv(6, alpha=0.6, beta=0.6931471805599453)
)
GCN2ConvBlock(
  (norm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
  (conv): GCN2Conv(6, alpha=0.6, beta=0.6931471805599453)
)
tensor([[ 0.3742,  0.1815,  0.1761],
        [ 0.3373,  0.1259,  0

Instantiate encoders

In [331]:
gat_enc = RevGATConvEncoder(
    in_channels=6,
    hidden_channels=4,
    out_channels=3,
    num_convs=3,
    dropout=0.0,
    version="v2",
    edge_dim=1,
    heads=8,
    num_groups=2,
    concat=False,
    normalize_hidden=True
)
print("Reversible residual GAT Encoder")
print(gat_enc)
print(torchinfo.summary(gat_enc))
print(gat_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


sage_enc = RevSAGEConvEncoder(
    in_channels=6,
    hidden_channels=4,
    out_channels=3,
    num_convs=4,
    dropout=0.0,
    project=True,
    root_weight=True,
    num_groups=2,
    aggr='mean',
    normalize_hidden=True
)
print("Reversible residual SAGE Encoder")
print(sage_enc)
print(torchinfo.summary(sage_enc))
print(sage_enc(pyg.x, pyg.edge_index))
print("--------------------------------------------------------------------\n\n\n")


gcn_enc = SimpleGCNEncoder(
    in_channels=6,
    hidden_channels=5,
    out_channels=3,
    conv_dims=[5, 5, 4, 4],
    dropout=0.0,
    improved=True
)
print("Simple GCN Encoder")
print(gcn_enc)
print(torchinfo.summary(gcn_enc))
print(gcn_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


gcn2_enc = ResGCN2ConvEncoder(
    in_channels=6,
    hidden_channels=5,
    out_channels=3,
    alpha=0.3,
    num_convs=4,
    dropout=0.0
)
print("ResGCN2 Encoder")
print(gcn2_enc)
print(torchinfo.summary(gcn2_enc))
print(gcn2_enc(pyg.x, pyg.edge_index))
print("--------------------------------------------------------------------\n\n\n")


gcn22_enc = ResGCN2ConvEncoderV2(
    in_channels=6,
    hidden_channels=5,
    out_channels=3,
    alpha=0.3,
    num_convs=4,
    dropout=0.0
)
print("ResGCN2 EncoderV2")
print(gcn22_enc)
print(torchinfo.summary(gcn22_enc))
print(gcn22_enc(pyg.x, pyg.edge_index))
print("--------------------------------------------------------------------\n\n\n")


gcn_rev_enc = RevGCNEncoder(
    in_channels=6,
    hidden_channels=4,
    out_channels=3,
    num_convs=4,
    dropout=0.0,
    improved=True
)
print("RevGCN Encoder")
print(gcn_rev_enc)
print(torchinfo.summary(gcn_rev_enc))
print(gcn_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")

Reversible residual GAT Encoder
RevGATConvEncoder(
  (lin1): Linear(in_features=6, out_features=4, bias=True)
  (lin2): Linear(in_features=4, out_features=3, bias=True)
  (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
  (convs): ModuleList(
    (0): GroupAddRev(GATConvBlock(
      (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (conv): GATv2Conv(2, 2, heads=8)
    ), num_groups=2)
    (1): GroupAddRev(GATConvBlock(
      (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (conv): GATv2Conv(2, 2, heads=8)
    ), num_groups=2)
    (2): GroupAddRev(GATConvBlock(
      (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (conv): GATv2Conv(2, 2, heads=8)
    ), num_groups=2)
  )
)
Layer (type:depth-idx)                                  Param #
RevGATConvEncoder                                       --
├─Linear: 1-1                                           28
├─Linear: 1-2                                           15
├─LayerNorm: 1-3

Test encoders serialization

In [332]:
print("Reversible residual GAT Encoder")
constr_params = gat_enc.serialize_constructor_params()
state_dict = gat_enc.state_dict()
print(constr_params)
gat_enc2 = RevGATConvEncoder.from_constructor_params(constr_params)
gat_enc2.load_state_dict(state_dict)
print(gat_enc2)
print(torchinfo.summary(gat_enc2))
print("\n\nOriginal: ")
print(gat_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("\n\nDeserialized: ")
print(gat_enc2(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("Reversible residual SAGE Encoder")
constr_params = sage_enc.serialize_constructor_params()
state_dict = sage_enc.state_dict()
print(constr_params)
sage_enc2 = RevSAGEConvEncoder.from_constructor_params(constr_params)
sage_enc2.load_state_dict(state_dict)
print(sage_enc2)
print(torchinfo.summary(sage_enc2))
print("\n\nOriginal: ")
print(sage_enc(pyg.x, pyg.edge_index))
print("\n\nDeserialized: ")
print(sage_enc2(pyg.x, pyg.edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Simple GCN Encoder")
constr_params = gcn_enc.serialize_constructor_params()
state_dict = gcn_enc.state_dict()
print(constr_params)
gcn_enc2 = SimpleGCNEncoder.from_constructor_params(constr_params)
gcn_enc2.load_state_dict(state_dict)
print(gcn_enc2)
print(torchinfo.summary(gcn_enc2))
print("\n\nOriginal: ")
print(gcn_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("\n\nDeserialized: ")
print(gcn_enc2(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("Residual GCN2 Encoder")
constr_params = gcn2_enc.serialize_constructor_params()
state_dict = gcn2_enc.state_dict()
print(constr_params)
gcn2_enc2 = ResGCN2ConvEncoder.from_constructor_params(constr_params)
gcn2_enc2.load_state_dict(state_dict)
print(gcn2_enc2)
print(torchinfo.summary(gcn2_enc2))
print("\n\nOriginal: ")
print(gcn2_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("\n\nDeserialized: ")
print(gcn2_enc2(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("Residual GCN2 EncoderV2")
constr_params = gcn22_enc.serialize_constructor_params()
state_dict = gcn22_enc.state_dict()
print(constr_params)
gcn22_enc2 = ResGCN2ConvEncoderV2.from_constructor_params(constr_params)
gcn22_enc2.load_state_dict(state_dict)
print(gcn22_enc2)
print(torchinfo.summary(gcn22_enc2))
print("\n\nOriginal: ")
print(gcn22_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("\n\nDeserialized: ")
print(gcn22_enc2(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("RevGCN Encoder")
constr_params = gcn_rev_enc.serialize_constructor_params()
state_dict = gcn_rev_enc.state_dict()
print(constr_params)
gcn_rev_enc2 = RevGCNEncoder.from_constructor_params(constr_params)
gcn_rev_enc2.load_state_dict(state_dict)
print(gcn_rev_enc2)
print(torchinfo.summary(gcn_rev_enc2))
print("\n\nOriginal: ")
print(gcn_rev_enc(pyg.x, pyg.edge_index, pyg.edge_weight))
print("\n\nDeserialized: ")
print(gcn_rev_enc2(pyg.x, pyg.edge_index, pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")

Reversible residual GAT Encoder
{'in_channels': 6, 'hidden_channels': 4, 'out_channels': 3, 'num_convs': 3, 'dropout': 0.0, 'version': 'v2', 'edge_dim': 1, 'heads': 8, 'concat': False, 'num_groups': 2, 'normalize_hidden': True}
RevGATConvEncoder(
  (lin1): Linear(in_features=6, out_features=4, bias=True)
  (lin2): Linear(in_features=4, out_features=3, bias=True)
  (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
  (convs): ModuleList(
    (0): GroupAddRev(GATConvBlock(
      (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (conv): GATv2Conv(2, 2, heads=8)
    ), num_groups=2)
    (1): GroupAddRev(GATConvBlock(
      (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (conv): GATv2Conv(2, 2, heads=8)
    ), num_groups=2)
    (2): GroupAddRev(GATConvBlock(
      (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (conv): GATv2Conv(2, 2, heads=8)
    ), num_groups=2)
  )
)
Layer (type:depth-idx)                                  Param

Instantiate and test GAEv2

In [333]:
print("Reversible residual GAT GAE")
gae = GAEv2(encoder=gat_enc)
print(gae)
print(torchinfo.summary(gae))
print("Reconstruction forward()")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Reconstruction forward_all()")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Latent space encoding")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(z)
print("Reconstruction decode()")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(gae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Reversible residual SAGE GAE")
gae = GAEv2(encoder=sage_enc)
print(gae)
print(torchinfo.summary(gae))
print("Reconstruction forward()")
print(gae(x=pyg.x, edge_index=pyg.edge_index))
print("Reconstruction forward_all()")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index))
print("Latent space encoding")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index)
print(z)
print("Reconstruction decode()")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(gae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Simple GCN GAE")
gae = GAEv2(encoder=gcn_enc)
print(gae)
print(torchinfo.summary(gae))
print("Reconstruction forward()")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all()")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Reconstruction decode()")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(gae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Residual GCN2 GAE")
gae = GAEv2(encoder=gcn2_enc)
print(gae)
print(torchinfo.summary(gae))
print("Reconstruction forward()")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all()")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Reconstruction decode()")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(gae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Rev GCN GAE")
gae = GAEv2(encoder=gcn_rev_enc)
print(gae)
print(torchinfo.summary(gae))
print("Reconstruction forward()")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all()")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Reconstruction decode()")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(gae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")

Reversible residual GAT GAE
GAEv2(
  (encoder): RevGATConvEncoder(
    (lin1): Linear(in_features=6, out_features=4, bias=True)
    (lin2): Linear(in_features=4, out_features=3, bias=True)
    (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (convs): ModuleList(
      (0): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
      (1): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
      (2): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
    )
  )
  (decoder): InnerProductDecoder()
)
Layer (type:depth-idx)                                       Param #
GAEv2                                                        --
├─RevGATConvEncoder: 1-1           

Test serialization for GAE

In [334]:
print("Reversible residual GAT GAE")
gae = GAEv2(encoder=gat_enc)
constr_params = gae.serialize_constructor_params()
state_dict = gae.state_dict()
print("Constructor params: ")
print(constr_params)
gae2 = GAEv2.from_constructor_params(constr_params, RevGATConvEncoder)
gae2.load_state_dict(state_dict)
print(gae2)
print(torchinfo.summary(gae2))
print("Reconstruction forward() original")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Reconstruction forward() deserialized")
print(gae2(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Reconstruction forward_all() deserialized")
print(gae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Latent space encoding original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(z)
print("Reconstruction decode() original")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized")
print(gae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
print(gae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized")
print(gae2.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Reversible residual SAGE GAE")
gae = GAEv2(encoder=sage_enc)
constr_params = gae.serialize_constructor_params()
state_dict = gae.state_dict()
print("Constructor params: ")
print(constr_params)
gae2 = GAEv2.from_constructor_params(constr_params, RevSAGEConvEncoder)
gae2.load_state_dict(state_dict)
print(gae2)
print(torchinfo.summary(gae2))
print("Reconstruction forward() original")
print(gae(x=pyg.x, edge_index=pyg.edge_index))
print("Reconstruction forward() deserialized")
print(gae2(x=pyg.x, edge_index=pyg.edge_index))
print("Reconstruction forward_all() original")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index))
print("Reconstruction forward_all() deserialized")
print(gae2.forward_all(x=pyg.x, edge_index=pyg.edge_index))
print("Latent space encoding original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index)
print(z)
print("Latent space encoding deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index)
print(z)
print("Reconstruction decode() original")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized")
print(gae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
print(gae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized")
print(gae2.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Simple GCN GAE")
gae = GAEv2(encoder=gcn_enc)
constr_params = gae.serialize_constructor_params()
state_dict = gae.state_dict()
print("Constructor params: ")
print(constr_params)
gae2 = GAEv2.from_constructor_params(constr_params, SimpleGCNEncoder)
gae2.load_state_dict(state_dict)
print(gae2)
print(torchinfo.summary(gae2))
print("Reconstruction forward() original")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward() deserialized")
print(gae2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() deserialized")
print(gae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Reconstruction decode() original")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized")
print(gae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
print(gae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized")
print(gae2.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Residual GCN2 GAE")
gae = GAEv2(encoder=gcn2_enc)
constr_params = gae.serialize_constructor_params()
state_dict = gae.state_dict()
print("Constructor params: ")
print(constr_params)
gae2 = GAEv2.from_constructor_params(constr_params, ResGCN2ConvEncoder)
gae2.load_state_dict(state_dict)
print(gae2)
print(torchinfo.summary(gae2))
print("Reconstruction forward() original")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward() deserialized")
print(gae2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() deserialized")
print(gae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Reconstruction decode() original")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized")
print(gae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
print(gae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized")
print(gae2.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Rev GCN GAE")
gae = GAEv2(encoder=gcn_rev_enc)
constr_params = gae.serialize_constructor_params()
state_dict = gae.state_dict()
print("Constructor params: ")
print(constr_params)
gae2 = GAEv2.from_constructor_params(constr_params, RevGCNEncoder)
gae2.load_state_dict(state_dict)
print(gae2)
print(torchinfo.summary(gae2))
print("Reconstruction forward() original")
print(gae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward() deserialized")
print(gae2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(gae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() deserialized")
print(gae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Reconstruction decode() original")
print(gae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized")
print(gae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
print(gae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized")
print(gae2.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test original")
z = gae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized")
z = gae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(gae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")

Reversible residual GAT GAE
Constructor params: 
{'encoder': {'state_dict': OrderedDict([('lin1.weight', tensor([[ 0.0982,  0.2437, -0.4080, -0.1093, -0.3962,  0.1662],
        [ 0.0552,  0.0723, -0.2609,  0.3955,  0.0599,  0.2873],
        [-0.2237, -0.0228, -0.1366,  0.2851,  0.0460,  0.0314],
        [ 0.3006, -0.1165, -0.1177,  0.1174, -0.3019, -0.2847]])), ('lin1.bias', tensor([ 0.1078,  0.0161,  0.2881, -0.1122])), ('lin2.weight', tensor([[-0.2132,  0.1202,  0.4042, -0.3074],
        [-0.1146, -0.1745, -0.4527,  0.1278],
        [-0.0321, -0.4272, -0.3950, -0.1019]])), ('lin2.bias', tensor([-0.4414, -0.4191,  0.3793])), ('norm.weight', tensor([1., 1., 1., 1.])), ('norm.bias', tensor([0., 0., 0., 0.])), ('convs.0.convs.0.norm.weight', tensor([1., 1.])), ('convs.0.convs.0.norm.bias', tensor([0., 0.])), ('convs.0.convs.0.conv.att', tensor([[[ 0.0029, -0.4917],
         [-0.0739, -0.4109],
         [ 0.6850, -0.1777],
         [ 0.7178,  0.7318],
         [-0.0716,  0.2772],
        

Instantiate VGAEv2 and test it

In [335]:
print("Reversible residual GAT VGAE")
vgae_enc = VGEncoder(encoder_mu=gat_enc)
vgae = VGAEv2(encoder=vgae_enc)
print(vgae)
print(torchinfo.summary(vgae))
print("Reconstruction forward()")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Reconstruction forward_all()")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Latent space encoding")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(z)
print("Mu")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(mu)
print("log(std)")
print(logstd)
print("Reconstruction decode()")
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(vgae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Reversible residual SAGE VGAE")
vgae_enc = VGEncoder(encoder_mu=sage_enc)
vgae = VGAEv2(encoder=vgae_enc)
print(vgae)
print(torchinfo.summary(vgae))
print("Reconstruction forward()")
print(vgae(x=pyg.x, edge_index=pyg.edge_index))
print("Reconstruction forward_all()")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index))
print("Latent space encoding")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index)
print(z)
print("Mu")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index)
print(mu)
print("log(std)")
print(logstd)
print("Reconstruction decode()")
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(vgae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Simple GCN VGAE")
vgae_enc = VGEncoder(encoder_mu=gcn_enc)
vgae = VGAEv2(encoder=vgae_enc)
print(vgae)
print(torchinfo.summary(vgae))
print("Reconstruction forward()")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all()")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Mu")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std)")
print(logstd)
print("Reconstruction decode()")
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(vgae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Residual GCN2 VGAE")
vgae_enc = VGEncoder(encoder_mu=gcn2_enc)
vgae = VGAEv2(encoder=vgae_enc)
print(vgae)
print(torchinfo.summary(vgae))
print("Reconstruction forward()")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all()")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Mu")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std)")
print(logstd)
print("Reconstruction decode()")
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(vgae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Rev GCN VGAE")
vgae_enc = VGEncoder(encoder_mu=gcn_rev_enc)
vgae = VGAEv2(encoder=vgae_enc)
print(vgae)
print(torchinfo.summary(vgae))
print("Reconstruction forward()")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all()")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Mu")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std)")
print(logstd)
print("Reconstruction decode()")
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction loss")
print(vgae.recon_loss(z, pyg.edge_index))
print("AUC and precision metric test")
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")

Reversible residual GAT VGAE
VGAEv2(
  (encoder): VGEncoder(
    (_encoder_mu): RevGATConvEncoder(
      (lin1): Linear(in_features=6, out_features=4, bias=True)
      (lin2): Linear(in_features=4, out_features=3, bias=True)
      (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
      (convs): ModuleList(
        (0): GroupAddRev(GATConvBlock(
          (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
          (conv): GATv2Conv(2, 2, heads=8)
        ), num_groups=2)
        (1): GroupAddRev(GATConvBlock(
          (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
          (conv): GATv2Conv(2, 2, heads=8)
        ), num_groups=2)
        (2): GroupAddRev(GATConvBlock(
          (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
          (conv): GATv2Conv(2, 2, heads=8)
        ), num_groups=2)
      )
    )
    (_encoder_logstd): RevGATConvEncoder(
      (lin1): Linear(in_features=6, out_features=4, bias=True)
      (lin2): Linear(in_features

Test VGAE serialization

In [336]:
print("Reversible residual GAT VGAE")
vgae_enc = VGEncoder(shared_encoder=gat_enc, encoder_mu=GATConvBlock(in_channels=3, out_channels=3, heads=2, edge_dim=1), encoder_logstd=GATConvBlock(out_channels=3, in_channels=3, heads=3, edge_dim=1))
vgae = VGAEv2(encoder=vgae_enc)
print("Constructor params: ")
constr_params = vgae.serialize_constructor_params()
state_dict = vgae.state_dict()
print(constr_params)
vgae2 = VGAEv2.from_constructor_params(constr_params, VGEncoder, encoder_mu_constructor=GATConvBlock, shared_encoder_constructor=RevGATConvEncoder, encoder_logstd_constructor=GATConvBlock)
vgae2.load_state_dict(state_dict)
print(vgae2)
print(torchinfo.summary(vgae2))
print("forward() original")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("forward() deserialized (should be ok if they are different because of the randomization)")
print(vgae2(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Reconstruction forward_all() deserialized (should be ok if they are different because of the randomization)")
print(vgae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("Latent space encoding original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(z)
print("Mu original")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(mu)
print("log(std) original")
print(logstd)
print("Mu deserialized (should be equal to original)")
mu, logstd = vgae2.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(mu)
print("log(std) deserialized (should be equal to original)")
print(logstd)
print("Reconstruction decode() original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(vgae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(vgae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(vgae2.recon_loss(z, pyg.edge_index))
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print("AUC and precision metric test original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight)
print(vgae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Reversible residual SAGE VGAE")
vgae_enc = VGEncoder(encoder_mu=sage_enc)
vgae = VGAEv2(encoder=vgae_enc)
print("Constructor params: ")
constr_params = vgae.serialize_constructor_params()
state_dict = vgae.state_dict()
print(constr_params)
vgae2 = VGAEv2.from_constructor_params(constr_params, VGEncoder, RevSAGEConvEncoder)
vgae2.load_state_dict(state_dict)
print(vgae2)
print(torchinfo.summary(vgae2))
print("forward() original")
print(vgae(x=pyg.x, edge_index=pyg.edge_index))
print("forward() deserialized (should be ok if they are different because of the randomization)")
print(vgae2(x=pyg.x, edge_index=pyg.edge_index))
print("Reconstruction forward_all() original")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index))
print("Reconstruction forward_all() deserialized (should be ok if they are different because of the randomization)")
print(vgae2.forward_all(x=pyg.x, edge_index=pyg.edge_index))
print("Latent space encoding original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index)
print(z)
print("Latent space encoding deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index)
print(z)
print("Mu original")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index)
print(mu)
print("log(std) original")
print(logstd)
print("Mu deserialized (should be equal to original)")
mu, logstd = vgae2.encoder(x=pyg.x, edge_index=pyg.edge_index)
print(mu)
print("log(std) deserialized (should be equal to original)")
print(logstd)
print("Reconstruction decode() original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index)
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index)
print(vgae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index)
print(vgae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index)
print(vgae2.recon_loss(z, pyg.edge_index))
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print("AUC and precision metric test original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index)
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index)
print(vgae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Simple GCN VGAE")
vgae_enc = VGEncoder(encoder_mu=gcn_enc)
vgae = VGAEv2(encoder=vgae_enc)
print("Constructor params: ")
constr_params = vgae.serialize_constructor_params()
state_dict = vgae.state_dict()
print(constr_params)
vgae2 = VGAEv2.from_constructor_params(constr_params, VGEncoder, SimpleGCNEncoder)
vgae2.load_state_dict(state_dict)
print(vgae2)
print(torchinfo.summary(vgae2))
print("forward() original")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("forward() deserialized (should be ok if they are different because of the randomization)")
print(vgae2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() deserialized (should be ok if they are different because of the randomization)")
print(vgae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Mu original")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std) original")
print(logstd)
print("Mu deserialized (should be equal to original)")
mu, logstd = vgae2.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std) deserialized (should be equal to original)")
print(logstd)
print("Reconstruction decode() original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.recon_loss(z, pyg.edge_index))
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print("AUC and precision metric test original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Residual GCN2 VGAE")
vgae_enc = VGEncoder(encoder_mu=gcn2_enc)
vgae = VGAEv2(encoder=vgae_enc)
print("Constructor params: ")
constr_params = vgae.serialize_constructor_params()
state_dict = vgae.state_dict()
print(constr_params)
vgae2 = VGAEv2.from_constructor_params(constr_params, VGEncoder, ResGCN2ConvEncoder)
vgae2.load_state_dict(state_dict)
print(vgae2)
print(torchinfo.summary(vgae2))
print("forward() original")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("forward() deserialized (should be ok if they are different because of the randomization)")
print(vgae2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() deserialized (should be ok if they are different because of the randomization)")
print(vgae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Mu original")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std) original")
print(logstd)
print("Mu deserialized (should be equal to original)")
mu, logstd = vgae2.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std) deserialized (should be equal to original)")
print(logstd)
print("Reconstruction decode() original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.recon_loss(z, pyg.edge_index))
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print("AUC and precision metric test original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("Rev GCN VGAE")
vgae_enc = VGEncoder(encoder_mu=gcn_rev_enc)
vgae = VGAEv2(encoder=vgae_enc)
print("Constructor params: ")
constr_params = vgae.serialize_constructor_params()
state_dict = vgae.state_dict()
print(constr_params)
vgae2 = VGAEv2.from_constructor_params(constr_params, VGEncoder, RevGCNEncoder)
vgae2.load_state_dict(state_dict)
print(vgae2)
print(torchinfo.summary(vgae2))
print("forward() original")
print(vgae(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("forward() deserialized (should be ok if they are different because of the randomization)")
print(vgae2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() original")
print(vgae.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Reconstruction forward_all() deserialized (should be ok if they are different because of the randomization)")
print(vgae2.forward_all(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("Latent space encoding original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Latent space encoding deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(z)
print("Mu original")
mu, logstd = vgae.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std) original")
print(logstd)
print("Mu deserialized (should be equal to original)")
mu, logstd = vgae2.encoder(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(mu)
print("log(std) deserialized (should be equal to original)")
print(logstd)
print("Reconstruction decode() original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.decode(z, pyg.edge_index))
print("Reconstruction decode() deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.decode(z, pyg.edge_index))
print("Reconstruction loss original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.recon_loss(z, pyg.edge_index))
print("Reconstruction loss deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.recon_loss(z, pyg.edge_index))
neg_edge_index = negative_sampling(pyg.edge_index, z.size(0))
print("AUC and precision metric test original")
z = vgae.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("AUC and precision metric test deserialized (should be ok if they are different because of the randomization)")
z = vgae2.encode(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight)
print(vgae2.test(z, pyg.edge_index, neg_edge_index=neg_edge_index))
print("--------------------------------------------------------------------\n\n\n")

Reversible residual GAT VGAE
Constructor params: 
{'encoder': {'state_dict': OrderedDict([('_encoder_mu.norm.weight', tensor([1., 1., 1.])), ('_encoder_mu.norm.bias', tensor([0., 0., 0.])), ('_encoder_mu.conv.att', tensor([[[-0.7339, -0.3626, -0.7060],
         [-0.1553,  1.0130, -1.0443]]])), ('_encoder_mu.conv.bias', tensor([0., 0., 0.])), ('_encoder_mu.conv.lin_l.weight', tensor([[ 0.6304, -0.2264,  0.6272],
        [ 0.5438,  0.7960,  0.2522],
        [ 0.4468,  0.1234, -0.3910],
        [-0.3899, -0.7737,  0.4645],
        [ 0.2371,  0.3295,  0.5477],
        [ 0.0060, -0.2315,  0.1780]])), ('_encoder_mu.conv.lin_l.bias', tensor([ 0.0296,  0.2697, -0.3883,  0.3916, -0.4433, -0.0455])), ('_encoder_mu.conv.lin_r.weight', tensor([[ 0.4510, -0.5296, -0.3083],
        [ 0.6905,  0.2204,  0.3145],
        [ 0.5668,  0.1087, -0.4290],
        [-0.4100,  0.3755,  0.5446],
        [-0.8026,  0.2371, -0.3495],
        [ 0.2396, -0.6471, -0.5171]])), ('_encoder_mu.conv.lin_r.bias', tensor([-

Instantiate classifier and test it

In [337]:
print("RevGAT encoder protnet")
protnet = ProtMotionNet(
    encoder=gat_enc,
    encoder_out_channels=gat_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.3,
    readout='mean_pool'
)
print(protnet)
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("RevSAGE encoder protnet")
protnet = ProtMotionNet(
    encoder=sage_enc,
    encoder_out_channels=sage_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.3,
    readout='add_pool'
)
print(protnet)
print(protnet(x=pyg.x, edge_index=pyg.edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("SimpleGCN encoder protnet")
protnet = ProtMotionNet(
    encoder=gcn_enc,
    encoder_out_channels=gcn_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.3,
    readout='max_pool'
)
print(protnet)
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("ResGCN2 encoder protnet LSTM aggregation")
protnet = ProtMotionNet(
    encoder=gcn2_enc,
    encoder_out_channels=gcn2_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.0,
    readout='lstm'
)
print(protnet)
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("ResGCN2 encoder protnet softmax aggregation")
protnet = ProtMotionNet(
    encoder=gcn2_enc,
    encoder_out_channels=gcn2_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.0,
    readout='softmax'
)
print(protnet)
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("RevGCN encoder protnet")
protnet = ProtMotionNet(
    encoder=gcn_rev_enc,
    encoder_out_channels=gcn_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.3,
    readout='max_pool'
)
print(protnet)
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")

RevGAT encoder protnet
ProtMotionNet(
  (_encoder): RevGATConvEncoder(
    (lin1): Linear(in_features=6, out_features=4, bias=True)
    (lin2): Linear(in_features=4, out_features=3, bias=True)
    (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (convs): ModuleList(
      (0): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
      (1): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
      (2): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
    )
  )
  (_readout_aggregation): MeanAggregation()
  (_dense_layers): ModuleList(
    (0): Linear(3, 3, bias=True)
    (1): Linear(3, 3, bias=True)
    (2): Linear(3, 2, bias=True)
    (3): Linear(2, 2, bias=

Test classifier serialization

In [338]:
print("RevGAT encoder protnet")
protnet = ProtMotionNet(
    encoder=gat_enc,
    encoder_out_channels=gat_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.3,
    readout='mean_pool'
)
constr_params = protnet.serialize_constructor_params()
state_dict = protnet.state_dict()
protnet2 = ProtMotionNet.from_constructor_params(constr_params, RevGATConvEncoder)
protnet2.load_state_dict(state_dict)
print(protnet2)
print("forward() original")
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("forward() deserialized")
print(protnet2(x=pyg.x, edge_index=pyg.edge_index, edge_attr=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("RevSAGE encoder protnet")
protnet = ProtMotionNet(
    encoder=sage_enc,
    encoder_out_channels=sage_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.3,
    readout='add_pool'
)
constr_params = protnet.serialize_constructor_params()
state_dict = protnet.state_dict()
protnet2 = ProtMotionNet.from_constructor_params(constr_params, RevSAGEConvEncoder)
protnet2.load_state_dict(state_dict)
print(protnet2)
print("forward() original")
print(protnet(x=pyg.x, edge_index=pyg.edge_index))
print("forward() deserialized")
print(protnet2(x=pyg.x, edge_index=pyg.edge_index))
print("--------------------------------------------------------------------\n\n\n")


print("SimpleGCN encoder protnet")
protnet = ProtMotionNet(
    encoder=gcn_enc,
    encoder_out_channels=gcn_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.0,
    readout='max_pool'
)
constr_params = protnet.serialize_constructor_params()
state_dict = protnet.state_dict()
protnet2 = ProtMotionNet.from_constructor_params(constr_params, SimpleGCNEncoder)
protnet2.load_state_dict(state_dict)
print(protnet2)
print("forward() original")
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("forward() deserialized")
print(protnet2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("ResGCN2 encoder protnet LSTM aggregation")
protnet = ProtMotionNet(
    encoder=gcn2_enc,
    encoder_out_channels=gcn2_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.0,
    readout='lstm'
)
constr_params = protnet.serialize_constructor_params()
state_dict = protnet.state_dict()
protnet2 = ProtMotionNet.from_constructor_params(constr_params, ResGCN2ConvEncoder)
protnet2.load_state_dict(state_dict)
print(protnet2)
print("forward() original")
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("forward() deserialized")
print(protnet2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("ResGCN2 encoder protnet softmax aggregation")
protnet = ProtMotionNet(
    encoder=gcn2_enc,
    encoder_out_channels=gcn2_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.0,
    readout='softmax'
)
constr_params = protnet.serialize_constructor_params()
state_dict = protnet.state_dict()
protnet2 = ProtMotionNet.from_constructor_params(constr_params, ResGCN2ConvEncoder)
protnet2.load_state_dict(state_dict)
print(protnet2)
print("forward() original")
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("forward() deserialized")
print(protnet2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")


print("RevGCN encoder protnet")
protnet = ProtMotionNet(
    encoder=gcn_rev_enc,
    encoder_out_channels=gcn_enc.out_channels,
    dense_units=[3, 3, 2, 2],
    dense_activations=['gelu', 'relu', 'sigmoid', 'softmax'],
    dropout=0.0,
    readout='max_pool'
)
constr_params = protnet.serialize_constructor_params()
state_dict = protnet.state_dict()
protnet2 = ProtMotionNet.from_constructor_params(constr_params, RevGCNEncoder)
protnet2.load_state_dict(state_dict)
print(protnet2)
print("forward() original")
print(protnet(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("forward() deserialized")
print(protnet2(x=pyg.x, edge_index=pyg.edge_index, edge_weight=pyg.edge_weight))
print("--------------------------------------------------------------------\n\n\n")

RevGAT encoder protnet
ProtMotionNet(
  (_encoder): RevGATConvEncoder(
    (lin1): Linear(in_features=6, out_features=4, bias=True)
    (lin2): Linear(in_features=4, out_features=3, bias=True)
    (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (convs): ModuleList(
      (0): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
      (1): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
      (2): GroupAddRev(GATConvBlock(
        (norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
        (conv): GATv2Conv(2, 2, heads=8)
      ), num_groups=2)
    )
  )
  (_readout_aggregation): MeanAggregation()
  (_dense_layers): ModuleList(
    (0): Linear(3, 3, bias=True)
    (1): Linear(3, 3, bias=True)
    (2): Linear(3, 2, bias=True)
    (3): Linear(2, 2, bias=

In [339]:

from torch_geometric.nn import Linear
from models.layers import SerializableModule
from torch.nn import ModuleList, LayerNorm
from torch import Tensor
from typing import Any, List, Optional
from torch_geometric.nn.models import GroupAddRev
import torch_geometric.nn as nn
import torch.nn.functional as F


class RevResWrapper(SerializableModule):
    def serialize_constructor_params(self, *args, **kwargs) -> dict:
        pass

    def __init__(self, rev_res_module: torch.nn.Module):
        super().__init__()
        self.rev_res_module_list_wrapper = ModuleList()

        # This is the only fucking way to make it work, we are not sure why,
        # but for some reason torch requires that the calls to rev res modules to be in a for loop, otherwise weird memory errors will be thrown on forward() call
        for i in range(0, 1):
            self.rev_res_module_list_wrapper.append(rev_res_module)

    def forward(self, *args, **kwargs):
        for i  in range(0, 1):
            output = None
            for rev_res_module in self.rev_res_module_list_wrapper:
                output = rev_res_module(*args, **kwargs)
            return output


class Giggino(torch.nn.Module):
    def __init__(self):
        super(Giggino, self).__init__()
        self.split_dim = -1
        self.num_groups = 2
        #self.convs = ModuleList([GATConvBlock(5, 5, heads=5), GATConvBlock(5, 5, heads=5)])
        self.conv3 = GATConvBlock(10, 10, heads=5)
        self.conv = VGEncoder(shared_encoder=RevGATConvEncoder(10, 10, 10, 2), encoder_mu=GATConvBlock(10, 10, heads=2), encoder_logstd=GATConvBlock(10, 10, heads=2))
        self.conv2 = RevGATConvEncoder(10, 10, 10, 2)
        self.conv5 = RevGATConvEncoder(10, 10, 10, 2)
        self.conv3 = RevSAGEConvEncoder(10, 10, 10, 2)

        self.convs = ModuleList()
        for i in range(0, 2):
            self.convs.append(RevGATConvEncoder(10, 10, 10, 2))

        self.conv4 = RevResWrapper(RevGATConvEncoder(10, 10, 10, 2))
        self.ciao = nn.Linear(10, 10)
        self.ciao4 = GCN2ConvBlock(10)

    def forward(self, x, edge_index):
        ciao = x
        ciao3 = ciao
        # il prblema è creare due variabili
        ciao = self.conv4(ciao, edge_index)
        #yield ciao
        a = self.ciao(ciao) + ciao
        b = self.ciao4(ciao, x0=a, edge_index=edge_index)
        #yield self.conv5(ciao, edge_index)
        return a, b


BATCH_SIZE = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ds_train = load_dataset(PRETRAIN_CLEANED_TRAIN, dataset_type="pretrain")
ds_val = load_dataset(PRETRAIN_CLEANED_VAL, dataset_type="pretrain")

dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=True)
count = 0
g = Giggino()
g.to(device)
for el in iter(dl_train):
    el.to(device)
    print(g.forward(el.x, el.edge_index))
    if count > 5:
        break
    count += 1
print(g)


FileNotFoundError: [Errno 2] No such file or directory: 'data\\cleaned\\pretraining\\train\\params\\params.json'