In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('/detanet-md/')
from detanet_model.detanet_pbc import *
from detanet_model.metrics import *
from torch_geometric.loader import DataLoader
from e3nn import o3
import os
from torch_geometric.data import Data
import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

  impl_abstract(
  impl_abstract(


In [3]:
model=DetaNet(num_features=128,
                 act='swish',
                 maxl=3,
                 num_block=3,
                 radial_type='trainable_bessel',
                 num_radial=128,
                 attention_head=8,
                 cutoff_lower=0.0,
                 cutoff_upper=3.5,
                 max_num_neighbors=150,
                 strategy="brute",
                 check_errors=True,
                 dropout=0.0,
                 use_cutoff=False,
                 max_atomic_number=35,
                 atom_ref=None,
                 scale=1.0,
                 scalar_outsize=1,
                 irreps_out=None,
                 summation=True,
                 norm=False,
                 out_type='scalar',
                 grad_type=None,
                 device=torch.device('cuda'))

# print(model)

<All keys matched successfully>

In [4]:
class Trainer:
    def __init__(self,model,train_loader,val_loader=None,loss_function=l2loss,device=torch.device('cuda'),
                 optimizer='Adam_amsgrad',lr=5e-4,weight_decay=0):
        self.opt_type=optimizer
        self.device=device
        self.model=model
        self.train_data=train_loader
        self.val_data=val_loader
        self.device=device
        self.opts={'AdamW':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'AdamW_amsgrad':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adam':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'Adam_amsgrad':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adadelta':torch.optim.Adadelta(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'RMSprop':torch.optim.RMSprop(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'SGD':torch.optim.SGD(self.model.parameters(),lr=lr,weight_decay=weight_decay)
        }
        self.optimizer=self.opts[self.opt_type]
        self.loss_function=loss_function
        self.step=-1
    def train(self,num_epoch,targ,stop_loss=1e-8,element=None,q9=None,loss_area=[100000,1e-8,1e8],
              val_per_train=10,view_data=False,print_per_epoch=10,data_scale=1):
        self.model.train()
        len_train=len(self.train_data)
        epoch=num_epoch
        for i in range(epoch):
            val_datas=iter(self.val_data)
            for j,batch in enumerate(self.train_data):
                self.step=self.step+1
                torch.cuda.empty_cache()
                self.optimizer.zero_grad()
#                 print(j,self.step)
                try:
                    out = self.model(pos=torch.tensor(batch.pos, device=self.device, dtype=torch.float32),
                                     z=batch.z.to(self.device),
                                     batch=batch.batch.to(self.device), box=batch.box.to(self.device),
                                     )
                    target = batch[targ].to(self.device) * data_scale
                    loss = self.loss_function(out.reshape(target.shape), target)
                    if self.step < loss_area[0]:
                        loss.backward()
                    elif loss_area[1] < loss.item() < loss_area[2]:
                        loss.backward()
                    self.optimizer.step()
                    if (self.step % val_per_train == 0) and (self.val_data is not None):
                        val_batch = next(val_datas)
                        val_target = val_batch[targ].to(self.device).reshape(-1) * data_scale

                        val_out = self.model(pos=torch.tensor(val_batch.pos, device=self.device, dtype=torch.float32),
                                             z=val_batch.z.to(self.device),
                                             batch=val_batch.batch.to(self.device),
                                             box=val_batch.box.to(self.device),
                                            ).reshape(val_target.shape)
                        val_loss = self.loss_function(val_out, val_target).item()
                        val_mae = l1loss(val_out, val_target).item()
                        val_R2 = R2(val_out, val_target).item()
                        if self.step % print_per_epoch == 0:
                            print('Epoch[{}/{}],loss:{:.8f},val_loss:{:.8f},val_mae:{:.8f},val_R2:{:.8f}'
                                  .format(self.step, num_epoch * len_train, loss.item(), val_loss, val_mae, val_R2))

                            if view_data:
                                print('valout:{:.8f},valtarget:{:.8f}'.format(val_out.flatten()[0].item(),
                                                                               val_target.flatten()[0].item()))
                except RuntimeError as e:
                    print(f"Encountered RuntimeError: {e}. Skipping current batch.")
                    continue
                except Exception as e:
                    print(f"Encountered unexpected error: {e}. Skipping current batch.")
                    continue
                
                assert (loss > stop_loss) or (val_loss > stop_loss),'Training and prediction Loss is less' \
                                                                    ' than cut-off Loss, so training stops'
                if (self.step % print_per_epoch == 0) and (self.step%val_per_train!=0):
                    print('Epoch[{}/{}],loss:{:.8f}'.format(self.step,num_epoch*len_train, loss.item()))
                if self.step%500==0:
                    self.save_param(path+str(self.step)+'.pth')
    def load_state_and_optimizer(self,state_path=None,optimizer_path=None):
        if state_path is not None:
            state_dict=torch.load(state_path)
            self.model.load_state_dict(state_dict)
        if optimizer_path is not None:
            self.optimizer=torch.load(optimizer_path)

    def save_param(self,path):
        torch.save(self.model.state_dict(),path)

    def save_model(self,path):
        torch.save(self.model, path)

    def save_opt(self,path):
        torch.save(self.optimizer,path)

    def params(self):
        return self.model.state_dict()


    def load_state_and_optimizer(self,state_path=None,optimizer_path=None):
        if state_path is not None:
            state_dict=torch.load(state_path)
            self.model.load_state_dict(state_dict)
        if optimizer_path is not None:
            self.optimizer=torch.load(optimizer_path)

    def save_param(self,path):
        torch.save(self.model.state_dict(),path)

    def save_model(self,path):
        torch.save(self.model, path)

    def save_opt(self,path):
        torch.save(self.optimizer,path)

    def params(self):
        return self.model.state_dict()

In [6]:
datasets=torch.load('/H9C8NO2.pt')

Data(pos=[180, 3], z=[180], atomization_energy=-860.875, force=[180, 3], dipole=[1, 3], box=[1, 3, 3])
10001


In [7]:
train_datasets=[]
val_datasets1=[]
for i in range(len(datasets)):
    zerr=0
    targ='atomization_energy'
    data=datasets[i]
    data_=Data(pos=data.pos,z=data.z,box=data.box,targ=data[targ])
    if zerr==0:
        if i%(10*1)==0:
            val_datasets1.append(data_)
        else:
            train_datasets.append(data_)
    else:
        print('error',data.smile)
print(len(train_datasets))
print(train_datasets[0])
print(len(val_datasets1))
print(val_datasets1[0])

9000
Data(pos=[180, 3], z=[180], box=[1, 3, 3], targ=-863.40625)
1001
Data(pos=[180, 3], z=[180], box=[1, 3, 3], targ=-860.875)


In [8]:
val_datasets=[]
test_datasets = []
for i in range(len(val_datasets1)):
    d = val_datasets1[i]
    if i % (2*1) == 0:
        val_datasets.append(d)
    else:
        test_datasets.append(d)
print(len(val_datasets))
print(len(test_datasets))

501
500


In [10]:
bathes=4
trainloader=DataLoader(train_datasets,batch_size=bathes,shuffle=True)
valloader=DataLoader(val_datasets,batch_size=bathes,shuffle=True)

In [11]:
device=torch.device('cuda')
dtype=torch.float32
model=model.to(dtype)
model=model.to(device)

In [12]:
path = '/energy/energy'

In [13]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-3,weight_decay=0,optimizer='AdamW')

In [14]:
trainer.train(num_epoch=10,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,1000])

Epoch[0/22500],loss:36510.09375000,val_loss:2711.44165039,val_mae:52.02972412,val_R2:-57904.34765625
Epoch[10/22500],loss:988.55487061
Epoch[20/22500],loss:271.89422607
Epoch[30/22500],loss:11.29458332
Epoch[40/22500],loss:15.66808128
Epoch[50/22500],loss:3.34381437,val_loss:10.93019676,val_mae:3.24391174,val_R2:-8.96485901
Epoch[60/22500],loss:0.66468245
Epoch[70/22500],loss:1.59649670
Epoch[80/22500],loss:2.04603720
Epoch[90/22500],loss:3.63477325
Epoch[100/22500],loss:3.69887304,val_loss:2.92483377,val_mae:1.52284241,val_R2:-7.08690548
Epoch[110/22500],loss:1.08616877
Epoch[120/22500],loss:3.57676768
Epoch[130/22500],loss:1.72857392
Epoch[140/22500],loss:5.31986237
Epoch[150/22500],loss:1.54593611,val_loss:0.56087959,val_mae:0.63282776,val_R2:-3.42598486
Epoch[160/22500],loss:2.05665207
Epoch[170/22500],loss:2.11616564
Epoch[180/22500],loss:5.23220301
Epoch[190/22500],loss:3.57376480
Epoch[200/22500],loss:0.88918978,val_loss:1.29340339,val_mae:1.12524414,val_R2:-15.41132355
Epoch[21

In [15]:
trainer.save_param('/energy/energy1.pth')

In [16]:
model.load_state_dict(torch.load('energy/energy1.pth'))

<All keys matched successfully>

In [18]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=4e-4,weight_decay=0,optimizer='AdamW')

In [19]:
trainer.train(num_epoch=15,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,100])

Epoch[0/33750],loss:2567.87207031,val_loss:60.25006104,val_mae:7.75842285,val_R2:-25.11940002
Epoch[10/33750],loss:42.06017303
Epoch[20/33750],loss:32.77682495
Epoch[30/33750],loss:9.04299927
Epoch[40/33750],loss:3.54441237
Epoch[50/33750],loss:2.73933029,val_loss:0.58964384,val_mae:0.76011658,val_R2:-5.41402531
Epoch[60/33750],loss:0.26888546
Epoch[70/33750],loss:0.42977718
Epoch[80/33750],loss:0.12205605
Epoch[90/33750],loss:0.34078217
Epoch[100/33750],loss:0.28859296,val_loss:0.56452858,val_mae:0.53843689,val_R2:0.88871318
Epoch[110/33750],loss:0.49825174
Epoch[120/33750],loss:0.38463685
Epoch[130/33750],loss:0.33318913
Epoch[140/33750],loss:0.19094245
Epoch[150/33750],loss:0.09422119,val_loss:0.17538261,val_mae:0.30776978,val_R2:0.37752336
Epoch[160/33750],loss:0.18862128
Epoch[170/33750],loss:0.10157289
Epoch[180/33750],loss:0.14331195
Epoch[190/33750],loss:0.17463130
Epoch[200/33750],loss:0.02755473,val_loss:0.32682708,val_mae:0.50048828,val_R2:0.05212301
Epoch[210/33750],loss:0.

In [20]:
trainer.save_param('/energy/energy2.pth')              

In [21]:
model.load_state_dict(torch.load('/energy/energy2.pth'))

<All keys matched successfully>

In [23]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=2e-4,weight_decay=0,optimizer='AdamW')

In [24]:
trainer.train(num_epoch=15,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,100])

Epoch[0/33750],loss:979.46252441,val_loss:104.04863739,val_mae:10.19630432,val_R2:-285.09201050
Epoch[10/33750],loss:70.46316528
Epoch[20/33750],loss:34.23699188
Epoch[30/33750],loss:5.66805172
Epoch[40/33750],loss:0.23979101
Epoch[50/33750],loss:1.61485052,val_loss:2.49138236,val_mae:1.53421021,val_R2:-1.47724533
Epoch[60/33750],loss:0.49838018
Epoch[70/33750],loss:0.44850454
Epoch[80/33750],loss:0.18890834
Epoch[90/33750],loss:0.04814738
Epoch[100/33750],loss:0.18511349,val_loss:0.13351756,val_mae:0.28935242,val_R2:0.33994329
Epoch[110/33750],loss:0.11686207
Epoch[120/33750],loss:0.11363351
Epoch[130/33750],loss:0.00634240
Epoch[140/33750],loss:0.09129632
Epoch[150/33750],loss:0.20458420,val_loss:0.08376733,val_mae:0.20713806,val_R2:0.74880630
Epoch[160/33750],loss:0.16026548
Epoch[170/33750],loss:0.09468378
Epoch[180/33750],loss:0.02205468
Epoch[190/33750],loss:0.02242927
Epoch[200/33750],loss:0.14033741,val_loss:0.08464833,val_mae:0.26203918,val_R2:-1.20598984
Epoch[210/33750],loss

In [25]:
trainer.save_param('/energy/energy3.pth')

In [26]:
model.load_state_dict(torch.load('/energy/energy3.pth'))

<All keys matched successfully>

In [28]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-4,weight_decay=0,optimizer='AdamW')

In [29]:
trainer.train(num_epoch=15,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,100])

Epoch[0/33750],loss:0.03988346,val_loss:10.26032829,val_mae:3.20315552,val_R2:-74.63151550
Epoch[10/33750],loss:1.65648270
Epoch[20/33750],loss:0.79550040
Epoch[30/33750],loss:0.42775118
Epoch[40/33750],loss:0.09317501
Epoch[50/33750],loss:0.04180795,val_loss:0.02429127,val_mae:0.13305664,val_R2:0.43240845
Epoch[60/33750],loss:0.03668416
Epoch[70/33750],loss:0.01901150
Epoch[80/33750],loss:0.00372264
Epoch[90/33750],loss:0.01115309
Epoch[100/33750],loss:0.00261906,val_loss:0.01188937,val_mae:0.10093689,val_R2:0.88609701
Epoch[110/33750],loss:0.01381767
Epoch[120/33750],loss:0.01431766
Epoch[130/33750],loss:0.02923242
Epoch[140/33750],loss:0.00444002
Epoch[150/33750],loss:0.02752128,val_loss:0.03040086,val_mae:0.17297363,val_R2:0.97594619
Epoch[160/33750],loss:0.00086333
Epoch[170/33750],loss:0.00256232
Epoch[180/33750],loss:0.00730331
Epoch[190/33750],loss:0.00585898
Epoch[200/33750],loss:0.00295504,val_loss:0.00824652,val_mae:0.07774353,val_R2:0.98449230
Epoch[210/33750],loss:0.008442

In [30]:
trainer.save_param('/energy/energy4.pth')

In [31]:
model.load_state_dict(torch.load('/energy/energy4.pth'))

<All keys matched successfully>

In [33]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=4e-5,weight_decay=0,optimizer='AdamW')

In [34]:
trainer.train(num_epoch=15,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,100])

Epoch[0/33750],loss:0.07067419,val_loss:1.09375417,val_mae:1.04560852,val_R2:-1.00706363
Epoch[10/33750],loss:0.17376061
Epoch[20/33750],loss:0.10420706
Epoch[30/33750],loss:0.04392670
Epoch[40/33750],loss:0.02308108
Epoch[50/33750],loss:0.00552895,val_loss:0.00364730,val_mae:0.05270386,val_R2:0.94640601
Epoch[60/33750],loss:0.00270468
Epoch[70/33750],loss:0.00195190
Epoch[80/33750],loss:0.00053941
Epoch[90/33750],loss:0.00390923
Epoch[100/33750],loss:0.00069384,val_loss:0.01039459,val_mae:0.10075378,val_R2:0.99131525
Epoch[110/33750],loss:0.00203034
Epoch[120/33750],loss:0.00003229
Epoch[130/33750],loss:0.00109703
Epoch[140/33750],loss:0.00126139
Epoch[150/33750],loss:0.01089100,val_loss:0.00215080,val_mae:0.03724670,val_R2:0.99011397
Epoch[160/33750],loss:0.01034731
Epoch[170/33750],loss:0.00322720
Epoch[180/33750],loss:0.00242996
Epoch[190/33750],loss:0.01729254
Epoch[200/33750],loss:0.00978586,val_loss:0.00258761,val_mae:0.04718018,val_R2:0.99114954
Epoch[210/33750],loss:0.00517080

In [35]:
trainer.save_param('/energy/energy5.pth')

In [36]:
model.load_state_dict(torch.load('/energy/energy5.pth'))

<All keys matched successfully>

In [38]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=4e-5,weight_decay=0,optimizer='AdamW')

In [39]:
trainer.train(num_epoch=15,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,100])

Epoch[0/33750],loss:0.00080467,val_loss:1.58069456,val_mae:1.25708008,val_R2:-64.64791107
Epoch[10/33750],loss:0.41582447
Epoch[20/33750],loss:0.11797006
Epoch[30/33750],loss:0.00752119
Epoch[40/33750],loss:0.00280455
Epoch[50/33750],loss:0.00224025,val_loss:0.00036074,val_mae:0.01679993,val_R2:0.99912292
Epoch[60/33750],loss:0.00102071
Epoch[70/33750],loss:0.00071286
Epoch[80/33750],loss:0.00053117
Epoch[90/33750],loss:0.00051311
Epoch[100/33750],loss:0.00266162,val_loss:0.00099031,val_mae:0.02938843,val_R2:0.99772024
Epoch[110/33750],loss:0.00032140
Epoch[120/33750],loss:0.00012444
Epoch[130/33750],loss:0.00028591
Epoch[140/33750],loss:0.00032149
Epoch[150/33750],loss:0.00036647,val_loss:0.00061039,val_mae:0.01625061,val_R2:0.99781251
Epoch[160/33750],loss:0.00154340
Epoch[170/33750],loss:0.00006441
Epoch[180/33750],loss:0.00030971
Epoch[190/33750],loss:0.00016096
Epoch[200/33750],loss:0.00012642,val_loss:0.00055697,val_mae:0.01324463,val_R2:0.99647081
Epoch[210/33750],loss:0.0007548

In [40]:
trainer.save_param('/energy/energy6.pth')

In [41]:
model.load_state_dict(torch.load('/energy/energy6.pth'))

<All keys matched successfully>

In [43]:
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-5,weight_decay=0,optimizer='AdamW')

In [44]:
trainer.train(num_epoch=15,stop_loss=0,targ='targ',val_per_train=50,view_data=False,print_per_epoch=10,loss_area=[100,1e-8,100])

Epoch[0/33750],loss:0.00876392,val_loss:0.04245096,val_mae:0.19979858,val_R2:0.91498649
Epoch[10/33750],loss:0.00371255
Epoch[20/33750],loss:0.00046726
Epoch[30/33750],loss:0.00060651
Epoch[40/33750],loss:0.00062029
Epoch[50/33750],loss:0.00107560,val_loss:0.00039044,val_mae:0.01393127,val_R2:0.99754173
Epoch[60/33750],loss:0.00057786
Epoch[70/33750],loss:0.00034107
Epoch[80/33750],loss:0.00066259
Epoch[90/33750],loss:0.00063315
Epoch[100/33750],loss:0.00017221,val_loss:0.00014859,val_mae:0.00946045,val_R2:0.99971795
Epoch[110/33750],loss:0.00015199
Epoch[120/33750],loss:0.00028202
Epoch[130/33750],loss:0.00064308
Epoch[140/33750],loss:0.00162142
Epoch[150/33750],loss:0.00080147,val_loss:0.00136823,val_mae:0.03523254,val_R2:0.99551904
Epoch[160/33750],loss:0.00047931
Epoch[170/33750],loss:0.00046714
Epoch[180/33750],loss:0.00168755
Epoch[190/33750],loss:0.00065503
Epoch[200/33750],loss:0.00051242,val_loss:0.00285415,val_mae:0.05085754,val_R2:0.98236382
Epoch[210/33750],loss:0.00035381


In [45]:
trainer.save_param('/energy/energy7.pth')