In [1]:
import pandas as pd
import torch
import torch_geometric
import dgl
import numpy as np
from tqdm import tqdm
from load_GODDESS_3d import GODDESSDataset, split_test_val
from run_with_val import run
from model_3d.Spherenet_NMR import SphereNet

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##### 1, Load the GODESS dataset for 3D GNN models

This code is modified from the 2D loading method.

We first randomly split the glycans into training set and test set. 

We then randomly split the half of the validation set from the test set

In [3]:
dataset = GODDESSDataset()
train_data_godess = []
test_data_godess = []

for i, (temp_g, temp_id, temp_g_name) in enumerate(tqdm(dataset)):
    temp_nx_g = dgl.to_networkx(temp_g, node_attrs=['z', 'y', 'Carbon_Hydrogen_mask', 'train_mask',
                                                    'test_mask', 'train_carbon_mask', 'test_carbon_mask',
                                                    'train_hydrogen_mask', 'test_hydrogen_mask', 'pos'])

    temp_g_torch = torch_geometric.utils.from_networkx(temp_nx_g)

    if temp_id == 0:
        train_data_godess.append(temp_g_torch)
    elif temp_id == 1:
        test_data_godess.append(temp_g_torch)

test_data_list, valid_data_list = split_test_val(test_data_godess)

100%|███████████████████████████████████████| 2310/2310 [00:43<00:00, 52.64it/s]
100%|███████████████████████████████████████| 2310/2310 [00:59<00:00, 39.07it/s]


##### 2, Initialize the node embedding size

The number of number of the embedding size, is the input to the model

In [4]:
node_embedding_size = train_data_godess[0].z.shape[1]

##### 3, Initialize and train the SphereNet,

Our implementation is modified from https://github.com/divelab/DIG/tree/dig-stable. 

To apply SphereNet to our tasks, we replaced the global pooling layer, which is needed for predicting the properties of whole molecules, and added a layer that maps the learned embedding of each atom to its NMR chemical shift.

In [5]:
model = SphereNet(energy_and_force=False, in_embed_size = node_embedding_size, cutoff=5.0, num_layers=4,
                  hidden_channels=128, out_channels=1, int_emb_size=64,
                  basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=128,
                  num_spherical=3, num_radial=6, envelope_exponent=5,
                  num_before_skip=1, num_after_skip=2, num_output_layers=2)

In [6]:
run3d = run()

loss_func = torch.nn.L1Loss()

train_loss_list, test_loss_list = run3d.run(device, train_data_godess, valid_data_list, test_data_list, model, loss_func,
                                            epochs=5, batch_size=4, vt_batch_size=4, lr=0.001, lr_decay_factor=0.5, lr_decay_step_size=15)

#Params: 1091910

=====Epoch 1

Training...


100%|█████████████████████████████████████████| 462/462 [00:32<00:00, 14.20it/s]



Evaluating...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 23.82it/s]

use_tensor


Testing...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 24.56it/s]

use_tensor

{'Train': 10.512698264349075, 'Valid': 5.5684648, 'Test': 5.3742557}

=====Epoch 2

Training...



100%|█████████████████████████████████████████| 462/462 [00:31<00:00, 14.53it/s]



Evaluating...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 23.81it/s]

use_tensor


Testing...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 24.59it/s]

use_tensor

{'Train': 2.4258238464206845, 'Valid': 4.0848293, 'Test': 3.9320168}

=====Epoch 3

Training...



100%|█████████████████████████████████████████| 462/462 [00:31<00:00, 14.50it/s]



Evaluating...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 23.53it/s]

use_tensor


Testing...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 24.04it/s]

use_tensor

{'Train': 2.085217180964235, 'Valid': 3.4811795, 'Test': 3.29}

=====Epoch 4

Training...



100%|█████████████████████████████████████████| 462/462 [00:32<00:00, 14.40it/s]



Evaluating...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 23.73it/s]

use_tensor


Testing...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 24.45it/s]

use_tensor

{'Train': 1.8979495300617053, 'Valid': 3.1246784, 'Test': 2.939757}

=====Epoch 5

Training...



100%|█████████████████████████████████████████| 462/462 [00:31<00:00, 14.44it/s]



Evaluating...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 23.57it/s]

use_tensor


Testing...



100%|███████████████████████████████████████████| 58/58 [00:02<00:00, 24.40it/s]

use_tensor

{'Train': 1.8213000776189747, 'Valid': 2.44496, 'Test': 2.2442708}
Best validation RMSE so far: 2.444960117340088
Test MAE when got best validation result: 2.2442708015441895



