In [1]:
from admet_prediction.datamodules.featurizers import OGBFeaturizer
from admet_prediction.datamodules.transforms import RandomWalkGenerator
from admet_prediction.encoders.gps.encoders import NodeEncoder, EdgeEncoder

from rdkit import Chem

featurizer = OGBFeaturizer()
rw = RandomWalkGenerator(
    ksteps=[1,17],
    space_dim=0
)
node = NodeEncoder(
    dim_in=1,
    dim_posenc=20,
    dim_emb=384,
    ksteps=[1,17],
    expand_x=False,
    batch_norm=False,
)
edge = EdgeEncoder(
    dim_emb=384,
    batch_norm=False
)


smi = 'c1ccc(CCC)c(CC)c1'
data = featurizer(smi)
data = rw(data)

#data.x = node(data.x, data.rwse)
#data.edge_attr = edge(data.edge_attr)

  warn(f"Failed to load image Python extension: {e}")


In [2]:
from admet_prediction.encoders.gps.gnn import GNNEncoder
import torch
import copy

In [16]:
encoder = GNNEncoder(
    model='gatv2',
    d_model=128,
    nhead=8,
    dropout=0.1,
    layer_norm=False,
    num_layer=2,
)

In [17]:
before_state = copy.deepcopy(encoder.state_dict())
optimizer = torch.optim.SGD(encoder.parameters(), lr=0.1)
STEPS = 10
# check unused parameters
for _ in range(STEPS):
    optimizer.zero_grad()
    output = encoder(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, batch=data.batch)
    (output[0]**2).sum().backward()  # dummy loss
    optimizer.step()
after_state = encoder.state_dict()
unused_param_names = []
used_params = []
for state_key in before_state.keys():
    if (before_state[state_key] == after_state[state_key]).all():
        # params which is not updated during training
        unused_param_names.append(state_key)
    else:
        used_params.append(state_key)
assert len(unused_param_names) == 0, f"Unused parameters: {unused_param_names}"

In [18]:
print(encoder)

GNNEncoder(
  (atom): AtomEncoder(
    (atom_embedding_list): ModuleList(
      (0): Embedding(119, 128)
      (1): Embedding(19, 128)
      (2): Embedding(7, 128)
      (3): Embedding(5, 128)
      (4-5): 2 x Embedding(12, 128)
      (6): Embedding(10, 128)
      (7-8): 2 x Embedding(6, 128)
      (9-10): 2 x Embedding(3, 128)
    )
  )
  (bond): EdgeEncoder(
    (encoder): BondEncoder(
      (bond_embedding_list): ModuleList(
        (0): Embedding(5, 128)
        (1): Embedding(7, 128)
        (2): Embedding(3, 128)
      )
    )
  )
  (model): GATv2(128, 128, num_layers=2)
)


In [19]:
print(used_params1)

['atom.atom_embedding_list.0.weight', 'atom.atom_embedding_list.1.weight', 'atom.atom_embedding_list.2.weight', 'atom.atom_embedding_list.3.weight', 'atom.atom_embedding_list.4.weight', 'atom.atom_embedding_list.5.weight', 'atom.atom_embedding_list.6.weight', 'atom.atom_embedding_list.7.weight', 'atom.atom_embedding_list.8.weight', 'atom.atom_embedding_list.9.weight', 'atom.atom_embedding_list.10.weight', 'bond.encoder.bond_embedding_list.0.weight', 'bond.encoder.bond_embedding_list.1.weight', 'bond.encoder.bond_embedding_list.2.weight', 'model.convs.0.att', 'model.convs.0.bias', 'model.convs.0.lin_l.weight', 'model.convs.0.lin_l.bias', 'model.convs.0.lin_r.weight', 'model.convs.0.lin_r.bias', 'model.convs.0.lin_edge.weight', 'model.convs.1.att', 'model.convs.1.bias', 'model.convs.1.lin_l.weight', 'model.convs.1.lin_l.bias', 'model.convs.1.lin_r.weight', 'model.convs.1.lin_r.bias', 'model.convs.1.lin_edge.weight', 'model.norms.0.weight', 'model.norms.0.bias']


In [20]:
print(used_params)

['atom.atom_embedding_list.0.weight', 'atom.atom_embedding_list.1.weight', 'atom.atom_embedding_list.2.weight', 'atom.atom_embedding_list.3.weight', 'atom.atom_embedding_list.4.weight', 'atom.atom_embedding_list.5.weight', 'atom.atom_embedding_list.6.weight', 'atom.atom_embedding_list.7.weight', 'atom.atom_embedding_list.8.weight', 'atom.atom_embedding_list.9.weight', 'atom.atom_embedding_list.10.weight', 'bond.encoder.bond_embedding_list.0.weight', 'bond.encoder.bond_embedding_list.1.weight', 'bond.encoder.bond_embedding_list.2.weight', 'model.convs.0.att', 'model.convs.0.bias', 'model.convs.0.lin_l.weight', 'model.convs.0.lin_l.bias', 'model.convs.0.lin_r.weight', 'model.convs.0.lin_r.bias', 'model.convs.0.lin_edge.weight', 'model.convs.1.att', 'model.convs.1.bias', 'model.convs.1.lin_l.weight', 'model.convs.1.lin_l.bias', 'model.convs.1.lin_r.weight', 'model.convs.1.lin_r.bias', 'model.convs.1.lin_edge.weight']


In [22]:
set(used_params1).difference(set(used_params))

{'model.norms.0.bias', 'model.norms.0.weight'}

In [31]:
from admet_prediction.encoders.gps.gnn import GATv2
import torch.nn as nn

d_model = 128
nhead=8
dropout=0.1
num_layer=2
gat = GATv2(
    in_channels = d_model,
    hidden_channels = d_model,
    heads = nhead,
    dropout = dropout,
    num_layers = num_layer,
    act = nn.SiLU(),
    norm = 'batchnorm',
    edge_dim = d_model,
)

In [35]:
for i, param in gat.named_parameters():
    
    if i == 'norms.0.module.weight':
        print(param)
    elif i=='norms.0.module.bias':
        print(param)

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [65]:
['MessageNorm', 'HeteroBatchNorm', 'GraphNorm', 'BatchNorm', 'LayerNorm', 
'HeteroLayerNorm', 'GraphSizeNorm', 'DiffGroupNorm', 'MeanSubtractionNorm', 
'PairNorm', 'InstanceNorm']
d_model=32
gat = GATv2(
    in_channels = d_model,
    hidden_channels = d_model,
    heads = nhead,
    dropout = dropout,
    num_layers = num_layer,
    act = nn.SiLU(),
    norm = 'LayerNorm',
    edge_dim = d_model,
)

In [62]:
a = []
for i, param in gat.named_parameters():
    if i.split('.')[0] == 'norms':
        print(param)

        a.append(param)

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       requires_grad=True)


In [54]:
'MessageNorm', 'HeteroBatchNorm', 'GraphNorm', 'BatchNorm', 'LayerNorm', 
'HeteroLayerNorm', 'GraphSizeNorm', 'DiffGroupNorm', 'MeanSubtractionNorm', 
'PairNorm', 'InstanceNorm'

In [55]:
norms

[torch_geometric.nn.norm.batch_norm.BatchNorm,
 torch_geometric.nn.norm.batch_norm.HeteroBatchNorm,
 torch_geometric.nn.norm.instance_norm.InstanceNorm,
 torch_geometric.nn.norm.layer_norm.LayerNorm,
 torch_geometric.nn.norm.layer_norm.HeteroLayerNorm,
 torch_geometric.nn.norm.graph_norm.GraphNorm,
 torch_geometric.nn.norm.graph_size_norm.GraphSizeNorm,
 torch_geometric.nn.norm.pair_norm.PairNorm,
 torch_geometric.nn.norm.mean_subtraction_norm.MeanSubtractionNorm,
 torch_geometric.nn.norm.msg_norm.MessageNorm,
 torch_geometric.nn.norm.diff_group_norm.DiffGroupNorm]

In [46]:
for i, param in gat.named_parameters():
    print(i)

convs.0.att
convs.0.bias
convs.0.lin_l.weight
convs.0.lin_l.bias
convs.0.lin_r.weight
convs.0.lin_r.bias
convs.0.lin_edge.weight
convs.1.att
convs.1.bias
convs.1.lin_l.weight
convs.1.lin_l.bias
convs.1.lin_r.weight
convs.1.lin_r.bias
convs.1.lin_edge.weight
norms.0.weight
norms.0.bias
