In [2]:
from dig.threedgraph.dataset import QM93D
from dig.threedgraph.method import SphereNet, ComENet, DimeNetPP, ProNet, SchNet
from dig.threedgraph.evaluation import ThreeDEvaluator
from dig.threedgraph.method import run
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from tqdm import tqdm
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_sparse import SparseTensor
import utils
import tts

name_of_model = 'shere_100_df'

device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")

dataset = utils.clean_dataset(pd.read_csv("dataset/md_dataset_full_v3.csv"))

X_train, X_val, X_test, y_train, y_val, y_test = tts.create_tts_from_df(dataset)

model = SphereNet(energy_and_force=True, cutoff=5.0, num_layers=4,
                  hidden_channels=128, out_channels=1, int_emb_size=64,
                  basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256,
                  num_spherical=3, num_radial=6, envelope_exponent=5,
                  num_before_skip=1, num_after_skip=2, num_output_layers=3)
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

run3d = run()
run3d.run(device, X_train, X_val, X_test, model, loss_func, evaluation,
          epochs=100, batch_size=32, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=15,
          save_dir=f"models/{name_of_model}", log_dir=f"logs/{name_of_model}")

Shape of old dataset: (34926, 56)
Processing...


100%|██████████| 56/56 [00:02<00:00, 23.72it/s]


Shape of new dataset: (34924, 56)


100%|██████████| 34924/34924 [00:11<00:00, 3148.03it/s]


#Params: 1890118

=====Epoch 1

Training...


Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk|
100%|██████████| 874/874 [04:37<00:00,  3.15it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.50it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.46it/s]



{'Train': 1.4125659120369285, 'Validation': 1.1079771518707275, 'Test': 1.1373628377914429}
Saving checkpoint...

=====Epoch 2

Training...


100%|██████████| 874/874 [04:01<00:00,  3.61it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.27it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.45it/s]



{'Train': 0.6938622261416721, 'Validation': 0.8240274786949158, 'Test': 0.8422085046768188}
Saving checkpoint...

=====Epoch 3

Training...


100%|██████████| 874/874 [04:02<00:00,  3.61it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.51it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.35it/s]


{'Train': 0.626109454533601, 'Validation': 1.227743148803711, 'Test': 1.221901297569275}

=====Epoch 4

Training...



100%|██████████| 874/874 [04:02<00:00,  3.60it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.30it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 13.93it/s]



{'Train': 0.5805333257403472, 'Validation': 0.7970637679100037, 'Test': 0.7865322232246399}
Saving checkpoint...

=====Epoch 5

Training...


100%|██████████| 874/874 [04:02<00:00,  3.60it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.48it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.35it/s]


{'Train': 0.5532679745002524, 'Validation': 0.859997034072876, 'Test': 0.876189112663269}

=====Epoch 6

Training...



100%|██████████| 874/874 [04:02<00:00,  3.60it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.51it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.58it/s]


{'Train': 0.5199756837818661, 'Validation': 0.9877182245254517, 'Test': 1.0003917217254639}

=====Epoch 7

Training...



100%|██████████| 874/874 [04:02<00:00,  3.60it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.56it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.61it/s]



{'Train': 0.5032599494713668, 'Validation': 0.48538732528686523, 'Test': 0.5098602771759033}
Saving checkpoint...

=====Epoch 8

Training...


100%|██████████| 874/874 [04:02<00:00,  3.61it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.56it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.60it/s]



{'Train': 0.4831688929204264, 'Validation': 0.4413105249404907, 'Test': 0.4707673192024231}
Saving checkpoint...

=====Epoch 9

Training...


100%|██████████| 874/874 [04:02<00:00,  3.61it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.40it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.23it/s]


{'Train': 0.46431559099212666, 'Validation': 0.4811950922012329, 'Test': 0.4884572923183441}

=====Epoch 10

Training...



100%|██████████| 874/874 [04:02<00:00,  3.60it/s]



Evaluating...



100%|██████████| 175/175 [00:12<00:00, 14.40it/s]



Testing...



100%|██████████| 44/44 [00:03<00:00, 14.53it/s]


{'Train': 0.4524567779196209, 'Validation': 0.45575737953186035, 'Test': 0.4859897792339325}

=====Epoch 11

Training...



 16%|█▌        | 136/874 [00:37<03:25,  3.59it/s]


RuntimeError: The expanded size of the tensor (1180) must match the existing size (1181) at non-singleton dimension 0.  Target sizes: [1180, 1].  Tensor sizes: [1181, 1]