In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('/detanet-md/')
from detanet_model.detanet_pbc import *
from detanet_model.metrics import *

from torch_geometric.loader import DataLoader
from e3nn import o3
import os
from torch_geometric.data import Data
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

  impl_abstract(
  impl_abstract(


In [2]:
model=DetaNet(num_features=128,
                 act='swish',
                 maxl=3,
                 num_block=3,
                 radial_type='trainable_bessel',
                 num_radial=128,
                 attention_head=8,
                 cutoff_lower=0.0,
                 cutoff_upper=3.5,
                 max_num_neighbors=120,
                 strategy="brute",
                 check_errors=True,
                 box_vecs=None,
                 dropout=0.0,
                 use_cutoff=False,
                 max_atomic_number=35,
                 atom_ref=None,
                 scale=1.0,
                 scalar_outsize=1,
                 irreps_out=None,
                 summation=True,
                 norm=False,
                 out_type='scalar',
                 grad_type='force',
                 device=torch.device('cuda'))


In [4]:
class Trainer:
    def __init__(self,model,train_loader,val_loader=None,loss_function=l2loss,device=torch.device('cuda'),
                 optimizer='Adam_amsgrad',lr=5e-4,weight_decay=0):
        self.opt_type=optimizer
        self.device=device
        self.model=model
        self.train_data=train_loader
        self.val_data=val_loader
        self.device=device
        self.opts={'AdamW':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'AdamW_amsgrad':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adam':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'Adam_amsgrad':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adadelta':torch.optim.Adadelta(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'RMSprop':torch.optim.RMSprop(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'SGD':torch.optim.SGD(self.model.parameters(),lr=lr,weight_decay=weight_decay)
        }
        self.optimizer=self.opts[self.opt_type]
        self.loss_function=loss_function
        self.step=-1
    def train(self,num_epoch,targ,stop_loss=1e-8,element=None,q9=None,loss_area=[100000,1e-8,1e8],
              val_per_train=10,view_data=False,print_per_epoch=10,data_scale=1):
        self.model.train()
        len_train=len(self.train_data)
        epoch=num_epoch
        

        for i in range(epoch):
            val_datas=iter(self.val_data)
            for j,batch in enumerate(self.train_data):
                self.step=self.step+1
                torch.cuda.empty_cache()
                self.optimizer.zero_grad()
#                 print(j,self.step)
                try:
                    out = self.model(pos=torch.tensor(batch.pos, device=self.device, dtype=torch.float32),
                                     z=batch.z.to(self.device),
                                     batch=batch.batch.to(self.device), box=batch.box.to(self.device),
                                     )
                    target = batch[targ].to(self.device) * data_scale
                    loss = self.loss_function(out.reshape(target.shape), target)
                    if self.step < loss_area[0]:
                        loss.backward()
                    elif loss_area[1] < loss.item() < loss_area[2]:
                        loss.backward()
                    self.optimizer.step()
                    if (self.step % val_per_train == 0) and (self.val_data is not None):
                        val_batch = next(val_datas)
                        val_target = val_batch[targ].to(self.device).reshape(-1) * data_scale

                        val_out = self.model(pos=torch.tensor(val_batch.pos, device=self.device, dtype=torch.float32),
                                             z=val_batch.z.to(self.device),
                                             batch=val_batch.batch.to(self.device),
                                             box=val_batch.box.to(self.device),
                                             ).reshape(val_target.shape)
                        val_loss = self.loss_function(val_out, val_target).item()
                        val_mae = l1loss(val_out, val_target).item()
                        val_R2 = R2(val_out, val_target).item()
                        # 写入 log 文件
                        log_file.write(f"{self.step}\t{val_loss:.8f}\t{val_mae:.8f}\t{val_R2:.8f}\n")
                
                        if self.step % print_per_epoch == 0:
                            print('Epoch[{}/{}],loss:{:.8f},val_loss:{:.8f},val_mae:{:.8f},val_R2:{:.8f}'
                                  .format(self.step, num_epoch * len_train, loss.item(), val_loss, val_mae, val_R2))

                            if view_data:
                                print('valout:{:.8f},valtarget:{:.8f}'.format(val_out.flatten()[0].item(),
                                                                               val_target.flatten()[0].item()))
                except RuntimeError as e:
                    print(f"Encountered RuntimeError: {e}. Skipping current batch.")
                    continue
                except Exception as e:
                    print(f"Encountered unexpected error: {e}. Skipping current batch.")
                    continue
                
                assert (loss > stop_loss) or (val_loss > stop_loss),'Training and prediction Loss is less' \
                                                                    ' than cut-off Loss, so training stops'
                if (self.step % print_per_epoch == 0) and (self.step%val_per_train!=0):
                    print('Epoch[{}/{}],loss:{:.8f}'.format(self.step,num_epoch*len_train, loss.item()))
                if self.step%500==0:
                    self.save_param(path+str(self.step)+'.pth')
    def load_state_and_optimizer(self,state_path=None,optimizer_path=None):
        if state_path is not None:
            state_dict=torch.load(state_path)
            self.model.load_state_dict(state_dict)
        if optimizer_path is not None:
            self.optimizer=torch.load(optimizer_path)

    def save_param(self,path):
        torch.save(self.model.state_dict(),path)

    def save_model(self,path):
        torch.save(self.model, path)

    def save_opt(self,path):
        torch.save(self.optimizer,path)

    def params(self):
        return self.model.state_dict()


    def load_state_and_optimizer(self,state_path=None,optimizer_path=None):
        if state_path is not None:
            state_dict=torch.load(state_path)
            self.model.load_state_dict(state_dict)
        if optimizer_path is not None:
            self.optimizer=torch.load(optimizer_path)

    def save_param(self,path):
        torch.save(self.model.state_dict(),path)

    def save_model(self,path):
        torch.save(self.model, path)

    def save_opt(self,path):
        torch.save(self.optimizer,path)

    def params(self):
        return self.model.state_dict()

In [6]:
datasets=torch.load('/H9C8NO2.pt')

  datasets=torch.load('/home/huwei22/jishengjiao/detanet/cp2k_data/paracetamol_dipole/H9C8NO2_force_energy_dipole_35495_warp.pt')


500


In [7]:
train_datasets=[]
val_datasets1=[]
for i in range(len(datasets)):
    zerr=0
    targ='force'
    data=datasets[i]
    data_=Data(pos=data.pos,z=data.z,box=data.box,targ=data[targ])
    if zerr==0:
        if i%(10*1)==0:
            val_datasets1.append(data_)
        else:
            train_datasets.append(data_)
    else:
        print('error',data.smile)
print(len(train_datasets))
print(train_datasets[0])
print(len(val_datasets1))
print(val_datasets1[0])

450
Data(pos=[80, 3], z=[80], box=[1, 3, 3], targ=[80, 3])
50
Data(pos=[80, 3], z=[80], box=[1, 3, 3], targ=[80, 3])


In [8]:
val_datasets=[]
test_datasets = []
for i in range(len(val_datasets1)):
    d = val_datasets1[i]
    if i % (2*1) == 0:
        val_datasets.append(d)
    else:
        test_datasets.append(d)
print(len(val_datasets))
print(len(test_datasets))

25
25


In [9]:
bathes=16
# train_datasets = torch.load('train3000.pt')
# val_datasets = torch.load('val3000.pt')
trainloader=DataLoader(train_datasets,batch_size=bathes,shuffle=True)
valloader=DataLoader(val_datasets,batch_size=bathes,shuffle=True)

In [11]:
device=torch.device('cuda')
dtype=torch.float32
model=model.to(dtype)
model=model.to(device)
#optimizer=torch.optim.Adam(model.parameters(),lr=1e-3)

In [12]:
path = '/force/force'

In [13]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-3,weight_decay=0,optimizer='AdamW')

In [14]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,10000])

Epoch[0/290],loss:0.76018971,val_loss:1.51074564,val_mae:0.94692457,val_R2:-2.08769059
Epoch[10/290],loss:0.21273141
Epoch[20/290],loss:0.18780869
Epoch[30/290],loss:0.16790709
Epoch[40/290],loss:0.16112576
Epoch[50/290],loss:0.13113967,val_loss:0.13036703,val_mae:0.25381136,val_R2:0.71193433
Epoch[60/290],loss:0.13246863
Epoch[70/290],loss:0.13770199
Epoch[80/290],loss:0.11606736
Epoch[90/290],loss:0.12170154
Epoch[100/290],loss:0.10811009,val_loss:0.10781800,val_mae:0.23870926,val_R2:0.77184826
Epoch[110/290],loss:0.11641143
Epoch[120/290],loss:0.09998190
Epoch[130/290],loss:0.10679995
Epoch[140/290],loss:0.10076512
Epoch[150/290],loss:0.09589428,val_loss:0.09241914,val_mae:0.21891621,val_R2:0.80936545
Epoch[160/290],loss:0.10025460
Epoch[170/290],loss:0.09429752
Epoch[180/290],loss:0.09340759
Epoch[190/290],loss:0.09809297
Epoch[200/290],loss:0.08201345,val_loss:0.08098060,val_mae:0.20880041,val_R2:0.84436560
Epoch[210/290],loss:0.08444451
Epoch[220/290],loss:0.07287866
Epoch[230/29

In [15]:
trainer.save_param('/force/force1.pth')

In [16]:
model.load_state_dict(torch.load('/force/force1.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/force_500/force1.pth'))


<All keys matched successfully>

In [17]:
#path = 'train_p/time_save/force/force2/force'

In [18]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=4e-4,weight_decay=0,optimizer='AdamW')

In [19]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,10000])

Epoch[0/290],loss:0.07127935,val_loss:0.19025539,val_mae:0.31627303,val_R2:0.58523321
Epoch[10/290],loss:0.08295422
Epoch[20/290],loss:0.06956093
Epoch[30/290],loss:0.06932035
Epoch[40/290],loss:0.06396323
Epoch[50/290],loss:0.05834263,val_loss:0.05450312,val_mae:0.16900483,val_R2:0.88303685
Epoch[60/290],loss:0.05612916
Epoch[70/290],loss:0.05940564
Epoch[80/290],loss:0.05159123
Epoch[90/290],loss:0.05966023
Epoch[100/290],loss:0.05658167,val_loss:0.04798416,val_mae:0.16087058,val_R2:0.89873558
Epoch[110/290],loss:0.05666075
Epoch[120/290],loss:0.05446865
Epoch[130/290],loss:0.05126342
Epoch[140/290],loss:0.04911235
Epoch[150/290],loss:0.05166174,val_loss:0.04623088,val_mae:0.15920888,val_R2:0.90181774
Epoch[160/290],loss:0.04245992
Epoch[170/290],loss:0.04630632
Epoch[180/290],loss:0.05037106
Epoch[190/290],loss:0.04386312
Epoch[200/290],loss:0.04339999,val_loss:0.03992602,val_mae:0.14917356,val_R2:0.91550422
Epoch[210/290],loss:0.04126164
Epoch[220/290],loss:0.04092456
Epoch[230/290

In [20]:
trainer.save_param('/force/force2.pth')

In [21]:
model.load_state_dict(torch.load('/force/force2.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/force_500/force2.pth'))


<All keys matched successfully>

In [23]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=2e-4,weight_decay=0,optimizer='AdamW')

In [24]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1000])

Epoch[0/290],loss:0.03425516,val_loss:0.10426765,val_mae:0.22947307,val_R2:0.78632641
Epoch[10/290],loss:0.04797443
Epoch[20/290],loss:0.03527195
Epoch[30/290],loss:0.03601695
Epoch[40/290],loss:0.03057639
Epoch[50/290],loss:0.03272638,val_loss:0.02817344,val_mae:0.12517577,val_R2:0.94256431
Epoch[60/290],loss:0.02859977
Epoch[70/290],loss:0.03007060
Epoch[80/290],loss:0.02715768
Epoch[90/290],loss:0.02656628
Epoch[100/290],loss:0.03080537,val_loss:0.02711016,val_mae:0.12335425,val_R2:0.94256574
Epoch[110/290],loss:0.02717419
Epoch[120/290],loss:0.02699777
Epoch[130/290],loss:0.02575174
Epoch[140/290],loss:0.02245147
Epoch[150/290],loss:0.02535047,val_loss:0.02563528,val_mae:0.12045695,val_R2:0.94751728
Epoch[160/290],loss:0.02624718
Epoch[170/290],loss:0.02619005
Epoch[180/290],loss:0.02509644
Epoch[190/290],loss:0.02388595
Epoch[200/290],loss:0.02291097,val_loss:0.02161923,val_mae:0.11062668,val_R2:0.95331955
Epoch[210/290],loss:0.02315632
Epoch[220/290],loss:0.02175741
Epoch[230/290

In [25]:
trainer.save_param('/force/force3.pth')

In [26]:
model.load_state_dict(torch.load('/force/force3.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/force_500/force3.pth'))


<All keys matched successfully>

In [28]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-4,weight_decay=0,optimizer='AdamW')

In [29]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1000])

Epoch[0/290],loss:0.02218992,val_loss:0.04106823,val_mae:0.14933933,val_R2:0.91096538
Epoch[10/290],loss:0.02222051
Epoch[20/290],loss:0.01985919
Epoch[30/290],loss:0.02023347
Epoch[40/290],loss:0.02028009
Epoch[50/290],loss:0.01720220,val_loss:0.01936360,val_mae:0.10445476,val_R2:0.96155125
Epoch[60/290],loss:0.01712959
Epoch[70/290],loss:0.01801037
Epoch[80/290],loss:0.01791938
Epoch[90/290],loss:0.02028107
Epoch[100/290],loss:0.01790426,val_loss:0.01832422,val_mae:0.10246880,val_R2:0.96387547
Epoch[110/290],loss:0.01590887
Epoch[120/290],loss:0.01776044
Epoch[130/290],loss:0.01951667
Epoch[140/290],loss:0.01707743
Epoch[150/290],loss:0.01858935,val_loss:0.01762163,val_mae:0.10071072,val_R2:0.96421665
Epoch[160/290],loss:0.01628329
Epoch[170/290],loss:0.01783650
Epoch[180/290],loss:0.01874342
Epoch[190/290],loss:0.01769406
Epoch[200/290],loss:0.01757780,val_loss:0.01692088,val_mae:0.09882487,val_R2:0.96807152
Epoch[210/290],loss:0.01823177
Epoch[220/290],loss:0.01691562
Epoch[230/290

In [30]:
trainer.save_param('/force/force4.pth')

In [31]:
model.load_state_dict(torch.load('/force/force4.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/force_500/force4.pth'))


<All keys matched successfully>

In [32]:
#path = 'train_p/time_save/force/force5/force'

In [33]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=4e-5,weight_decay=0,optimizer='AdamW')

In [34]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1000])

Epoch[0/290],loss:0.01445059,val_loss:0.01621512,val_mae:0.09749524,val_R2:0.96533018
Epoch[10/290],loss:0.01644113
Epoch[20/290],loss:0.01492754
Epoch[30/290],loss:0.01555077
Epoch[40/290],loss:0.01517836
Epoch[50/290],loss:0.01522112,val_loss:0.01434114,val_mae:0.09079600,val_R2:0.96871352
Epoch[60/290],loss:0.01447202
Epoch[70/290],loss:0.01471879
Epoch[80/290],loss:0.01358399
Epoch[90/290],loss:0.01299797
Epoch[100/290],loss:0.01443298,val_loss:0.01412449,val_mae:0.08954899,val_R2:0.96995491
Epoch[110/290],loss:0.01352593
Epoch[120/290],loss:0.01347463
Epoch[130/290],loss:0.01348770
Epoch[140/290],loss:0.01283206
Epoch[150/290],loss:0.01378454,val_loss:0.01415422,val_mae:0.09091183,val_R2:0.97075903
Epoch[160/290],loss:0.01322082
Epoch[170/290],loss:0.01412587
Epoch[180/290],loss:0.01301906
Epoch[190/290],loss:0.01258375
Epoch[200/290],loss:0.01374014,val_loss:0.01311411,val_mae:0.08677372,val_R2:0.97211665
Epoch[210/290],loss:0.01388073
Epoch[220/290],loss:0.01221910
Epoch[230/290

In [35]:
 trainer.save_param('/force/force5.pth')

In [36]:
model.load_state_dict(torch.load('/force/force5.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/force_500/force5.pth'))


<All keys matched successfully>

In [38]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=2e-5,weight_decay=0,optimizer='AdamW')

In [39]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1000])

Epoch[0/290],loss:0.01343032,val_loss:0.01324860,val_mae:0.08845792,val_R2:0.97224039
Epoch[10/290],loss:0.01379818
Epoch[20/290],loss:0.01273184
Epoch[30/290],loss:0.01274890
Epoch[40/290],loss:0.01214328
Epoch[50/290],loss:0.01241106,val_loss:0.01279065,val_mae:0.08684883,val_R2:0.97394383
Epoch[60/290],loss:0.01261852
Epoch[70/290],loss:0.01308259
Epoch[80/290],loss:0.01175358
Epoch[90/290],loss:0.01250714
Epoch[100/290],loss:0.01214932,val_loss:0.01315524,val_mae:0.08729521,val_R2:0.97226989
Epoch[110/290],loss:0.01331374
Epoch[120/290],loss:0.01352489
Epoch[130/290],loss:0.01088258
Epoch[140/290],loss:0.01279699
Epoch[150/290],loss:0.01286738,val_loss:0.01273606,val_mae:0.08639078,val_R2:0.97344536
Epoch[160/290],loss:0.01321439
Epoch[170/290],loss:0.01178176
Epoch[180/290],loss:0.01169926
Epoch[190/290],loss:0.01224225
Epoch[200/290],loss:0.01272799,val_loss:0.01297366,val_mae:0.08640943,val_R2:0.97373813
Epoch[210/290],loss:0.01184660
Epoch[220/290],loss:0.01162216
Epoch[230/290

In [40]:
trainer.save_param('/force/force6.pth')

In [41]:
model.load_state_dict(torch.load('/force/force6.pth'))

  model.load_state_dict(torch.load('/home/huwei22/jishengjiao/detanet/trained_param_paracetamol/force_500/force6.pth'))


<All keys matched successfully>

In [43]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-5,weight_decay=0,optimizer='AdamW')

In [44]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1000])

Epoch[0/290],loss:0.01282605,val_loss:0.01206865,val_mae:0.08387186,val_R2:0.97381079
Epoch[10/290],loss:0.01265533
Epoch[20/290],loss:0.01164786
Epoch[30/290],loss:0.01336333
Epoch[40/290],loss:0.01278815
Epoch[50/290],loss:0.01124822,val_loss:0.01263928,val_mae:0.08572491,val_R2:0.97501832
Epoch[60/290],loss:0.01221141
Epoch[70/290],loss:0.01151027
Epoch[80/290],loss:0.01168785
Epoch[90/290],loss:0.01309227
Epoch[100/290],loss:0.01258653,val_loss:0.01240768,val_mae:0.08466619,val_R2:0.97543782
Epoch[110/290],loss:0.01105412
Epoch[120/290],loss:0.01193233
Epoch[130/290],loss:0.01152368
Epoch[140/290],loss:0.01116051
Epoch[150/290],loss:0.01067755,val_loss:0.01160874,val_mae:0.08232661,val_R2:0.97555006
Epoch[160/290],loss:0.01162575
Epoch[170/290],loss:0.01218070
Epoch[180/290],loss:0.01177443
Epoch[190/290],loss:0.01133983
Epoch[200/290],loss:0.01185577,val_loss:0.01204291,val_mae:0.08335063,val_R2:0.97530103
Epoch[210/290],loss:0.01053692
Epoch[220/290],loss:0.01248921
Epoch[230/290

In [45]:
trainer.save_param('/force/force7.pth')