In [3]:
import sys
sys.path.append('/p/home/jusers/kotobi2/juwels/hida_project/')

In [4]:
import os.path as osp
import numpy as np
import torch
from torch_geometric.loader import DataLoader

from attribution_gnn1.QM9_SpecData import QM9_SpecData
from attribution_gnn1.split import save_split

In [5]:
root = '/p/home/jusers/kotobi2/juwels/data_qm9/all_graph_data/qm9_spec_10k.pt'
qm9_spec = QM9_SpecData(root=root,
             raw_dir='/p/home/jusers/kotobi2/juwels/data_qm9/raw/',
             spectra=[])

In [6]:
idxs = save_split(
    path='/p/home/jusers/kotobi2/juwels/hida_project/data/split_files2/qm9_split_10k.npz',
    ndata=len(qm9_spec),
    ntrain=8000,
    nval=2000,
    ntest=0,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

In [7]:
#train, val and test data
train_qm9 = [qm9_spec[i] for i in idxs['train']]
val_qm9 = [qm9_spec[i] for i in idxs['val']]
#test_qm9 = qm9_spec[idxs['test']]

In [8]:
# data loaders 
train_loader = DataLoader(train_qm9, batch_size=100, shuffle=True)
val_loader = DataLoader(val_qm9, batch_size=100, shuffle=True)
#test_loader = DataLoader(test_qm9, batch_size=100)

# testing the GraphNet

In [9]:
from src.GraphNet.modules import GraphNetwork
from src.GraphNet.SpectraGraphNet import SpectraGraphNet

In [10]:
node_model_params = {"feat_in": 39, "feat_hidd": 64, "feat_out": 50}
edge_model_params = {"feat_in": 124, "feat_hidd": 64, "feat_out": 50}
global_model_params = {"feat_in": 120, "feat_hidd": 64, "feat_out": 50}

In [11]:
graphnet = GraphNetwork(node_model_params, 
                        edge_model_params,
                        global_model_params)

In [12]:
graphnet

GraphNetwork(
  (gatencoder): GATEncoder(
    (gats): ModuleList(
      (0): GATv2Conv(11, 64, heads=3)
      (1): ReLU(inplace=True)
      (2): GATv2Conv(192, 64, heads=3)
      (3): ReLU(inplace=True)
      (4): GATv2Conv(192, 64, heads=3)
      (5): ReLU(inplace=True)
      (6): GATv2Conv(192, 20, heads=1)
    )
  )
  (node_model): NodeModel(
    (mlp): Sequential(
      (0): Linear(in_features=39, out_features=64, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=64, out_features=50, bias=True)
      (3): ReLU(inplace=True)
      (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
    )
  )
  (edge_model): EdgeModel(
    (mlp): Sequential(
      (0): Linear(in_features=124, out_features=64, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=64, out_features=50, bias=True)
      (3): ReLU(inplace=True)
      (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
    )
  )
  (global_model): GlobalModel(
    (mlp): Sequential(
      (0

In [13]:
batch_test = next(iter(train_loader))

In [14]:
graphnet(batch_test)

DataBatch(x=[1586, 50], edge_index=[2, 3144], edge_attr=[3144, 50], pos=[1586, 3], z=[1586], spectrum=[10000], idx=[100], batch=[1586], ptr=[101], u=[100, 50])

In [15]:
batch_test.u

tensor([[-0.2609,  1.1374, -0.5603,  ..., -0.5603, -0.3861, -0.5603],
        [-0.2253,  1.1703, -0.5640,  ..., -0.5640, -0.3790, -0.5640],
        [-0.2379,  1.0912, -0.5694,  ..., -0.5694, -0.4212, -0.5694],
        ...,
        [-0.1369,  1.1702, -0.5700,  ..., -0.5700, -0.4880, -0.5700],
        [-0.2483,  1.0855, -0.5661,  ..., -0.5661, -0.4136, -0.5661],
        [-0.2671,  1.1663, -0.5594,  ..., -0.5594, -0.4140, -0.5594]],
       grad_fn=<NativeLayerNormBackward>)

In [16]:
all_params = {"graphnet1": {"node_model_params": node_model_params,
            "edge_model_params": edge_model_params,
            "global_model_params": global_model_params}}

In [17]:
spectragraphnet = SpectraGraphNet(all_params=all_params, n_layers=1)

In [18]:
spectragraphnet

SpectraGraphNet(
  (graphnets): ModuleList(
    (0): GraphNetwork(
      (gatencoder): GATEncoder(
        (gats): ModuleList(
          (0): GATv2Conv(11, 64, heads=3)
          (1): ReLU(inplace=True)
          (2): GATv2Conv(192, 64, heads=3)
          (3): ReLU(inplace=True)
          (4): GATv2Conv(192, 64, heads=3)
          (5): ReLU(inplace=True)
          (6): GATv2Conv(192, 20, heads=1)
        )
      )
      (node_model): NodeModel(
        (mlp): Sequential(
          (0): Linear(in_features=39, out_features=64, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=64, out_features=50, bias=True)
          (3): ReLU(inplace=True)
          (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        )
      )
      (edge_model): EdgeModel(
        (mlp): Sequential(
          (0): Linear(in_features=124, out_features=64, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=64, out_features=50, bias=True)
          (3)

# Testing jumping knowledge 

In [46]:
import os
import os.path as osp
import numpy as np

from torch_geometric.nn import GATv2Conv
import torch
from torch.nn import LSTM, Linear

from attribution_gnn1.QM9_SpecData import QM9_SpecData
from attribution_gnn1.split import save_split

In [65]:
gat_layers = [
    GATv2Conv(11, 64),
    GATv2Conv(64, 64),
     GATv2Conv(64, 64),
]

lstm = LSTM(64, 64, num_layers=3, bidirectional=True, batch_first=True)
att = Linear(128, 1)

In [66]:
gat_layers = torch.nn.ModuleList(gat_layers)

In [8]:
root = '/p/home/jusers/kotobi2/juwels/data_qm9/qm9_spec_10k.pt'
qm9_spec = QM9_SpecData(root=root,
             raw_dir='/p/home/jusers/kotobi2/juwels/data_qm9/raw/',
             spectra=[])#broadened_spectra_stk)

if not osp.exists(root):
    torch.save(qm9_spec, root)

In [9]:
idxs = save_split(
    path='/p/home/jusers/kotobi2/juwels/hida_project/data/split_files/qm9_split_10k.npz',
    ndata=len(qm9_spec),
    ntrain=8000,
    nval=1500,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

In [10]:
train_qm9 = qm9_spec[idxs['train']]
val_qm9 = qm9_spec[idxs['val']]
test_qm9 = qm9_spec[idxs['test']]

In [102]:
molecule_graph = train_qm9[0]
x, edge_index = molecule_graph.x, molecule_graph.edge_index
batch_seg = torch.tensor(np.repeat(0, x.shape[0]))

In [103]:
xs = []
for layer in gat_layers:
    x = layer(x, edge_index).flatten(1)
    xs.append(x.unsqueeze(-1))

In [96]:
xs = torch.cat(xs, dim=-1).transpose(1,2)

alpha, _ = lstm(xs)

alpha = att(alpha).squeeze(-1)

alpha = torch.softmax(alpha, dim=-1)

h = (xs * alpha.unsqueeze(-1)).sum(1)