In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('/detanet-md/')
from detanet_model.detanet_pbc import *
from detanet_model.metrics import *

from torch_geometric.loader import DataLoader
from e3nn import o3
import os
from torch_geometric.data import Data
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

  impl_abstract(
  impl_abstract(


In [2]:
model=DetaNet(num_features=128,
                 act='swish',
                 maxl=3,
                 num_block=3,
                 radial_type='trainable_bessel',
                 num_radial=128,
                 attention_head=8,
                 cutoff_lower=0.0,
                 cutoff_upper=3.5,
                 max_num_neighbors=120,
                 strategy="brute",
                 check_errors=True,
                 box_vecs=None,
                 dropout=0.0,
                 use_cutoff=False,
                 max_atomic_number=35,
                 atom_ref=None,
                 scale=1.0,
                 scalar_outsize=1,
                 irreps_out='1o',
                 summation=True,
                 norm=False,
                 out_type='dipole',
                 grad_type=None,
                 device=torch.device('cuda'))


In [3]:
class Trainer:
    def __init__(self,model,train_loader,val_loader=None,loss_function=l2loss,device=torch.device('cuda'),
                 optimizer='Adam_amsgrad',lr=5e-4,weight_decay=0):
        self.opt_type=optimizer
        self.device=device
        self.model=model
        self.train_data=train_loader
        self.val_data=val_loader
        self.device=device
        self.opts={'AdamW':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'AdamW_amsgrad':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adam':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'Adam_amsgrad':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adadelta':torch.optim.Adadelta(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'RMSprop':torch.optim.RMSprop(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'SGD':torch.optim.SGD(self.model.parameters(),lr=lr,weight_decay=weight_decay)
        }
        self.optimizer=self.opts[self.opt_type]
        self.loss_function=loss_function
        self.step=-1
    def train(self,num_epoch,targ,stop_loss=1e-8,element=None,q9=None,loss_area=[100000,1e-8,1e8],
              val_per_train=10,view_data=False,print_per_epoch=10,data_scale=1):
        self.model.train()
        len_train=len(self.train_data)
        epoch=num_epoch
        for i in range(epoch):
            val_datas=iter(self.val_data)
            for j,batch in enumerate(self.train_data):
                self.step=self.step+1
                torch.cuda.empty_cache()
                self.optimizer.zero_grad()
                out = self.model(pos=torch.tensor(batch.pos,device=device,dtype=torch.float32), z=batch.z.to(self.device),
                                     batch=batch.batch.to(self.device), box=batch.box.to(self.device),
                                )
                target = batch[targ].to(self.device)*data_scale
                loss = self.loss_function(out.reshape(target.shape),target)
                if self.step<loss_area[0]:
                    loss.backward()
                elif self.step>loss_area[0] and loss_area[1]<loss.item()<loss_area[2]:
                    loss.backward()
                self.optimizer.step()
                if (self.step%val_per_train==0) and (self.val_data is not None):
                    val_batch = next(val_datas)
                    val_target=val_batch[targ].to(self.device).reshape(-1)*data_scale

                    val_out = self.model(pos=torch.tensor(val_batch.pos,device=device,dtype=torch.float32), z=val_batch.z.to(self.device),
                                             batch=val_batch.batch.to(self.device),
                                             box=val_batch.box.to(self.device),
                                        ).reshape(val_target.shape)
                    val_loss = self.loss_function(val_out, val_target).item()
                    val_mae=l1loss(val_out, val_target).item()
                    val_R2=R2(val_out,val_target).item()
                    if self.step % print_per_epoch==0:
                        print('Epoch[{}/{}],loss:{:.8f},val_loss:{:.8f},val_mae:{:.8f},val_R2:{:.8f}'
                              .format(self.step,num_epoch*len_train,loss.item(),val_loss,val_mae,val_R2))


                        if view_data:
                            print('valout:{:.8f},valtarget:{:.8f}'.format(val_out.flatten()[0].item()
                                                                                   , val_target.flatten()[0].item()))
                    assert (loss > stop_loss) or (val_loss > stop_loss),'Training and prediction Loss is less' \
                                                                        ' than cut-off Loss, so training stops'
                elif (self.step % print_per_epoch == 0) and (self.step%val_per_train!=0):
                    print('Epoch[{}/{}],loss:{:.8f}'.format(self.step,num_epoch*len_train, loss.item()))
                if self.step%100==0:
                    self.save_param(path+str(self.step)+'.pth')
    def load_state_and_optimizer(self,state_path=None,optimizer_path=None):
        if state_path is not None:
            state_dict=torch.load(state_path)
            self.model.load_state_dict(state_dict)
        if optimizer_path is not None:
            self.optimizer=torch.load(optimizer_path)

    def save_param(self,path):
        torch.save(self.model.state_dict(),path)

    def save_model(self,path):
        torch.save(self.model, path)

    def save_opt(self,path):
        torch.save(self.optimizer,path)

    def params(self):
        return self.model.state_dict()


    def load_state_and_optimizer(self,state_path=None,optimizer_path=None):
        if state_path is not None:
            state_dict=torch.load(state_path)
            self.model.load_state_dict(state_dict)
        if optimizer_path is not None:
            self.optimizer=torch.load(optimizer_path)

    def save_param(self,path):
        torch.save(self.model.state_dict(),path)

    def save_model(self,path):
        torch.save(self.model, path)

    def save_opt(self,path):
        torch.save(self.optimizer,path)

    def params(self):
        return self.model.state_dict()

In [5]:
datasets=torch.load('/H9C8NO2.pt')

  datasets1=torch.load('/home/huwei22/jishengjiao/detanet/cp2k_data/paracetamol_dipole/H9C8NO2_force_energy_dipole_35495_warp.pt')


Data(pos=[80, 3], z=[80], force=[80, 3], atomization_energy=-312.33984375, dipole=[1, 3], box=[1, 3, 3])
20000


In [6]:
train_datasets=[]
val_datasets1=[]
for i in range(len(datasets)):
    zerr=0
    targ='dipole'
    data=datasets[i]
    data_=Data(pos=data.pos,z=data.z,box=data.box,targ=data[targ])
    if zerr==0:
        if i%(10*1)==0:
            val_datasets1.append(data_)
        else:
            train_datasets.append(data_)
    else:
        print('error',data.smile)
print(len(train_datasets))
print(train_datasets[0])
print(len(val_datasets1))
print(val_datasets1[0])

18000
Data(pos=[80, 3], z=[80], box=[1, 3, 3], targ=[1, 3])
2000
Data(pos=[80, 3], z=[80], box=[1, 3, 3], targ=[1, 3])


In [7]:
val_datasets=[]
test_datasets = []
for i in range(len(val_datasets1)):
    d = val_datasets1[i]
    if i % (2*1) == 0:
        val_datasets.append(d)
    else:
        test_datasets.append(d)
print(len(val_datasets))
print(len(test_datasets))

1000
1000


In [8]:
bathes=8
trainloader=DataLoader(train_datasets,batch_size=bathes,shuffle=True)
valloader=DataLoader(val_datasets,batch_size=bathes,shuffle=True)

In [9]:
device=torch.device('cuda')
dtype=torch.float32
model=model.to(dtype)
model=model.to(device)

In [10]:
path = '/dipole/dipole'

In [11]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-3,weight_decay=0,optimizer='AdamW')

In [12]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1000])

Epoch[0/22500],loss:15.64441967,val_loss:10.98642921,val_mae:2.68701172,val_R2:-3.44252872
Epoch[10/22500],loss:4.78204536
Epoch[20/22500],loss:2.33159089
Epoch[30/22500],loss:3.00386381
Epoch[40/22500],loss:1.77478731
Epoch[50/22500],loss:1.47266960,val_loss:3.32507467,val_mae:1.52013493,val_R2:-0.09204590
Epoch[60/22500],loss:1.12271309
Epoch[70/22500],loss:1.12787187
Epoch[80/22500],loss:1.78789651
Epoch[90/22500],loss:1.40306127
Epoch[100/22500],loss:1.80816364,val_loss:0.96379203,val_mae:0.79195583,val_R2:0.75518978
Epoch[110/22500],loss:1.49523580
Epoch[120/22500],loss:1.20040131
Epoch[130/22500],loss:1.04037642
Epoch[140/22500],loss:1.96573114
Epoch[150/22500],loss:0.48985982,val_loss:0.94197750,val_mae:0.81768340,val_R2:0.45332527
Epoch[160/22500],loss:0.46416086
Epoch[170/22500],loss:0.98008966
Epoch[180/22500],loss:1.15993071
Epoch[190/22500],loss:1.23218083
Epoch[200/22500],loss:0.88605654,val_loss:1.04364586,val_mae:0.82574815,val_R2:0.78907239
Epoch[210/22500],loss:0.59421

In [14]:
trainer.save_param('/dipole/dipole1.pth')

In [15]:
model.load_state_dict(torch.load('/dipole/dipole1.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/dipole_20000/dipole1.pth'))


<All keys matched successfully>

In [17]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=4e-4,weight_decay=0,optimizer='AdamW')

In [19]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,100])

Epoch[700/22500],loss:0.00336081,val_loss:0.00701603,val_mae:0.06320903,val_R2:0.99669975
Epoch[710/22500],loss:0.00493425
Epoch[720/22500],loss:0.00512856
Epoch[730/22500],loss:0.00785509
Epoch[740/22500],loss:0.00450343
Epoch[750/22500],loss:0.00521910,val_loss:0.00918136,val_mae:0.06559668,val_R2:0.99765962
Epoch[760/22500],loss:0.00503974
Epoch[770/22500],loss:0.00614424
Epoch[780/22500],loss:0.00726833
Epoch[790/22500],loss:0.00755586
Epoch[800/22500],loss:0.00488436,val_loss:0.01038789,val_mae:0.08364391,val_R2:0.99576449
Epoch[810/22500],loss:0.00505720
Epoch[820/22500],loss:0.00257850
Epoch[830/22500],loss:0.00830008
Epoch[840/22500],loss:0.00359837
Epoch[850/22500],loss:0.00499033,val_loss:0.00543663,val_mae:0.05109635,val_R2:0.99765176
Epoch[860/22500],loss:0.00770784
Epoch[870/22500],loss:0.00449685
Epoch[880/22500],loss:0.00454321
Epoch[890/22500],loss:0.00661914
Epoch[900/22500],loss:0.00483799,val_loss:0.01703405,val_mae:0.09123532,val_R2:0.99044073
Epoch[910/22500],loss:

In [20]:
trainer.save_param('/dipole/dipole2.pth')

In [21]:
model.load_state_dict(torch.load('/dipole/dipole2.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/dipole_20000/dipole2.pth'))


<All keys matched successfully>

In [23]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=2e-4,weight_decay=0,optimizer='AdamW')

In [24]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,10])

Epoch[0/22500],loss:0.00370792,val_loss:0.00184948,val_mae:0.03608812,val_R2:0.99930936
Epoch[10/22500],loss:0.00181620
Epoch[20/22500],loss:0.00183627
Epoch[30/22500],loss:0.00602491
Epoch[40/22500],loss:0.00161142
Epoch[50/22500],loss:0.00059322,val_loss:0.00521785,val_mae:0.03860682,val_R2:0.99862784
Epoch[60/22500],loss:0.00134095
Epoch[70/22500],loss:0.00145398
Epoch[80/22500],loss:0.00221953
Epoch[90/22500],loss:0.00198979
Epoch[100/22500],loss:0.00157993,val_loss:0.00172121,val_mae:0.03387316,val_R2:0.99950314
Epoch[110/22500],loss:0.00175256
Epoch[120/22500],loss:0.00158701
Epoch[130/22500],loss:0.00105410
Epoch[140/22500],loss:0.00116432
Epoch[150/22500],loss:0.00211244,val_loss:0.00265290,val_mae:0.03920975,val_R2:0.99904996
Epoch[160/22500],loss:0.00132579
Epoch[170/22500],loss:0.00358164
Epoch[180/22500],loss:0.00160592
Epoch[190/22500],loss:0.00115087
Epoch[200/22500],loss:0.00388620,val_loss:0.00115360,val_mae:0.02821822,val_R2:0.99963868
Epoch[210/22500],loss:0.00138606


In [25]:
trainer.save_param('/dipole/dipole3.pth')

In [26]:
model.load_state_dict(torch.load('/dipole/dipole3.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/dipole_20000/dipole3.pth'))


<All keys matched successfully>

In [28]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-4,weight_decay=0,optimizer='AdamW')

In [29]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,10])

Epoch[0/22500],loss:0.00095683,val_loss:0.00181543,val_mae:0.03372373,val_R2:0.99928153
Epoch[10/22500],loss:0.00059237
Epoch[20/22500],loss:0.00036598
Epoch[30/22500],loss:0.00111374
Epoch[40/22500],loss:0.00046543
Epoch[50/22500],loss:0.00094576,val_loss:0.00087917,val_mae:0.02445080,val_R2:0.99981588
Epoch[60/22500],loss:0.00077822
Epoch[70/22500],loss:0.00125163
Epoch[80/22500],loss:0.00052941
Epoch[90/22500],loss:0.00139898
Epoch[100/22500],loss:0.00023284,val_loss:0.00184260,val_mae:0.03427681,val_R2:0.99927759
Epoch[110/22500],loss:0.00026100
Epoch[120/22500],loss:0.00048493
Epoch[130/22500],loss:0.00044712
Epoch[140/22500],loss:0.00141518
Epoch[150/22500],loss:0.00042165,val_loss:0.00080209,val_mae:0.02240311,val_R2:0.99964958
Epoch[160/22500],loss:0.00039338
Epoch[170/22500],loss:0.00017220
Epoch[180/22500],loss:0.00048187
Epoch[190/22500],loss:0.00043696
Epoch[200/22500],loss:0.00066158,val_loss:0.00042352,val_mae:0.01756713,val_R2:0.99983722
Epoch[210/22500],loss:0.00136176


In [30]:
trainer.save_param('/dipole/dipole4.pth')

In [31]:
model.load_state_dict(torch.load('/dipole/dipole4.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/dipole_20000/dipole4.pth'))


<All keys matched successfully>

In [33]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=4e-5,weight_decay=0,optimizer='AdamW')

In [34]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1])

Epoch[0/22500],loss:0.00036650,val_loss:0.00054917,val_mae:0.01765863,val_R2:0.99977326
Epoch[10/22500],loss:0.00031945
Epoch[20/22500],loss:0.00013046
Epoch[30/22500],loss:0.00023408
Epoch[40/22500],loss:0.00075830
Epoch[50/22500],loss:0.00055685,val_loss:0.00139193,val_mae:0.01953620,val_R2:0.99970049
Epoch[60/22500],loss:0.00011359
Epoch[70/22500],loss:0.00031529
Epoch[80/22500],loss:0.00019651
Epoch[90/22500],loss:0.00015736
Epoch[100/22500],loss:0.00028800,val_loss:0.00014351,val_mae:0.00999182,val_R2:0.99996883
Epoch[110/22500],loss:0.00052051
Epoch[120/22500],loss:0.00020176
Epoch[130/22500],loss:0.00040226
Epoch[140/22500],loss:0.00017248
Epoch[150/22500],loss:0.00016672,val_loss:0.00020356,val_mae:0.01129131,val_R2:0.99985456
Epoch[160/22500],loss:0.00009404
Epoch[170/22500],loss:0.00029262
Epoch[180/22500],loss:0.00032313
Epoch[190/22500],loss:0.00014890
Epoch[200/22500],loss:0.00019236,val_loss:0.00452620,val_mae:0.02685896,val_R2:0.99876183
Epoch[210/22500],loss:0.00287776


In [35]:
trainer.save_param('/dipole/dipole5.pth')

In [36]:
model.load_state_dict(torch.load('/dipole/dipole5.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/dipole_20000/dipole5.pth'))


<All keys matched successfully>

In [38]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=2e-5,weight_decay=0,optimizer='AdamW')

In [39]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1])

Epoch[0/22500],loss:0.00012649,val_loss:0.00029647,val_mae:0.01315670,val_R2:0.99984443
Epoch[10/22500],loss:0.00143480
Epoch[20/22500],loss:0.00017048
Epoch[30/22500],loss:0.00033764
Epoch[40/22500],loss:0.00014494
Epoch[50/22500],loss:0.00013119,val_loss:0.00020340,val_mae:0.01174533,val_R2:0.99993914
Epoch[60/22500],loss:0.00009977
Epoch[70/22500],loss:0.00021830
Epoch[80/22500],loss:0.00122221
Epoch[90/22500],loss:0.00063884
Epoch[100/22500],loss:0.00100717,val_loss:0.00165697,val_mae:0.02590880,val_R2:0.99937671
Epoch[110/22500],loss:0.00005915
Epoch[120/22500],loss:0.00023869
Epoch[130/22500],loss:0.00017366
Epoch[140/22500],loss:0.00013925
Epoch[150/22500],loss:0.00022882,val_loss:0.00024047,val_mae:0.01343922,val_R2:0.99990386
Epoch[160/22500],loss:0.00012094
Epoch[170/22500],loss:0.00017518
Epoch[180/22500],loss:0.00007539
Epoch[190/22500],loss:0.00011209
Epoch[200/22500],loss:0.00016209,val_loss:0.00019050,val_mae:0.01102495,val_R2:0.99993253
Epoch[210/22500],loss:0.00010630


In [40]:
trainer.save_param('/dipole/dipole6.pth')

In [41]:
model.load_state_dict(torch.load('/dipole/dipole6.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/dipole_20000/dipole1.pth'))


<All keys matched successfully>

In [43]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-5,weight_decay=0,optimizer='AdamW')

In [44]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1])

Epoch[0/22500],loss:0.01041024,val_loss:0.02179582,val_mae:0.11522888,val_R2:0.99227202
Epoch[10/22500],loss:0.01248213
Epoch[20/22500],loss:0.01658234
Epoch[30/22500],loss:0.00874178
Epoch[40/22500],loss:0.01009312
Epoch[50/22500],loss:0.01089474,val_loss:0.02238088,val_mae:0.11418888,val_R2:0.99573141
Epoch[60/22500],loss:0.00828633
Epoch[70/22500],loss:0.02055488
Epoch[80/22500],loss:0.00699626
Epoch[90/22500],loss:0.01170952
Epoch[100/22500],loss:0.00999121,val_loss:0.00984696,val_mae:0.07984612,val_R2:0.99631757
Epoch[110/22500],loss:0.01077893
Epoch[120/22500],loss:0.01004809
Epoch[130/22500],loss:0.01401439
Epoch[140/22500],loss:0.00943018
Epoch[150/22500],loss:0.01813119,val_loss:0.01512287,val_mae:0.08992314,val_R2:0.99508071
Epoch[160/22500],loss:0.01533284
Epoch[170/22500],loss:0.00875204
Epoch[180/22500],loss:0.00606777
Epoch[190/22500],loss:0.01513286
Epoch[200/22500],loss:0.00934399,val_loss:0.01870086,val_mae:0.10633253,val_R2:0.99395192
Epoch[210/22500],loss:0.00714893


In [45]:
trainer.save_param('/dipole/dipole7.pth')