In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from detanet_model import *
from torch_geometric.loader import DataLoader

In [2]:
'''First, the pytorch library can be used to load the dataset, which consists of 130K molecules, after importing the pytorch library.'''
dataset=torch.load('qm9s.pt')

len(dataset)

129817

In [3]:
'''where each molecule is a gemetric.data.Data. containing coordinates, atomic numbers, edge indices, and various properties.'''
dataset[0]

Data(edge_index=[2, 20], pos=[5, 3], number=1, smile='C', z=[5], quadrupole=[1, 3, 3], octapole=[1, 3, 3, 3], npacharge=[5], dipole=[1, 3], polar=[1, 3, 3], hyperpolar=[1, 3, 3, 3], energy=[1, 1], Hij=[20, 3, 3], Hi=[5, 3, 3], dedipole=[5, 3, 3], depolar=[5, 3, 6], tran_dipole=[1, 10, 3], tran_energy=[1, 10])

In [4]:
'''We divided the dataset evenly and used 5% of the data for testing and other for training:'''
train_datasets=[]
val_datasets=[]
for i in range(len(dataset)):
    if i%20==0:
        val_datasets.append(dataset[i])
    else:
        train_datasets.append(dataset[i])
        
len(train_datasets),len(val_datasets)

(123326, 6491)

In [5]:
'''Using torch_Geometric.dataloader.DataLoader Converts a dataset into a batch of 64 molecules of training data.'''
bathes=64
trainloader=DataLoader(train_datasets,batch_size=bathes,shuffle=True)
valloader=DataLoader(val_datasets,batch_size=bathes,shuffle=True)

In [6]:
'''After loading the dataset, we train a model using NPA charge as an example.
 	Firstly, construct an untrained model:'''
model=DetaNet(num_features=128,
                 act='swish',
                 maxl=3,
                 num_block=3,
                 radial_type='trainable_bessel',
                 num_radial=32,
                 attention_head=8,
                 rc=5.0,
                 dropout=0.0,
                 use_cutoff=False,
                 max_atomic_number=9,
                 atom_ref=None,
                 scale=None,
                 scalar_outsize=1,
                 irreps_out=None,
                 summation=False,
                 norm=False,
                 out_type='scalar',
                 grad_type=None ,
                 device=torch.device('cuda'))

model.train()

DetaNet(
  (Embedding): Embedding(
    (act): Swish()
    (elec_emb): Linear(in_features=16, out_features=128, bias=False)
    (nuclare_emb): Embedding(10, 128)
    (ls): Linear(in_features=128, out_features=128, bias=True)
  )
  (Radial): Radial_Basis(
    (radial): Bessel_Function()
  )
  (blocks): Sequential(
    (0): Interaction_Block(
      (message): Message(
        (Attention): Edge_Attention(
          (actq): Swish()
          (actk): Swish()
          (actv): Swish()
          (acta): Swish()
          (softmax): Softmax(dim=-1)
          (lq): Linear(in_features=128, out_features=128, bias=True)
          (lk): Linear(in_features=128, out_features=128, bias=True)
          (lv): Linear(in_features=128, out_features=256, bias=True)
          (la): Linear(in_features=256, out_features=256, bias=True)
          (lrbf): Linear(in_features=32, out_features=128, bias=False)
          (lkrbf): Linear(in_features=128, out_features=128, bias=False)
          (lvrbf): Linear(in_featu

In [7]:
'''Next, define the trainer and the parameters used for training.'''
class Trainer:
    def __init__(self,model,train_loader,val_loader=None,loss_function=l2loss,device=torch.device('cuda'),
                 optimizer='Adam_amsgrad',lr=5e-4,weight_decay=0):
        self.opt_type=optimizer
        self.device=device
        self.model=model
        self.train_data=train_loader
        self.val_data=val_loader
        self.device=device
        self.opts={'AdamW':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'AdamW_amsgrad':torch.optim.AdamW(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adam':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=False,weight_decay=weight_decay),
              'Adam_amsgrad':torch.optim.Adam(self.model.parameters(),lr=lr,amsgrad=True,weight_decay=weight_decay),
              'Adadelta':torch.optim.Adadelta(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'RMSprop':torch.optim.RMSprop(self.model.parameters(),lr=lr,weight_decay=weight_decay),
              'SGD':torch.optim.SGD(self.model.parameters(),lr=lr,weight_decay=weight_decay)
        }
        self.optimizer=self.opts[self.opt_type]
        self.loss_function=loss_function
        self.step=-1
    def train(self,num_train,targ,stop_loss=1e-8, val_per_train=50, print_per_epoch=10):
        self.model.train()
        len_train=len(self.train_data)
        for i in range(num_train):
            val_datas=iter(self.val_data)
            for j,batch in enumerate(self.train_data):
                self.step=self.step+1
                torch.cuda.empty_cache()
                self.optimizer.zero_grad()
                out = self.model(pos=batch.pos.to(self.device), z=batch.z.to(self.device),
                                     batch=batch.batch.to(self.device))
                target = batch[targ].to(self.device)
                loss = self.loss_function(out.reshape(target.shape),target)
                loss.backward()
                self.optimizer.step()
                if (self.step%val_per_train==0) and (self.val_data is not None):
                    val_batch = next(val_datas)
                    val_target=val_batch[targ].to(self.device).reshape(-1)

                    val_out = self.model(pos=val_batch.pos.to(self.device), z=val_batch.z.to(self.device),
                                             batch=val_batch.batch.to(self.device)).reshape(val_target.shape)
                    val_loss = self.loss_function(val_out, val_target).item()
                    val_mae=l1loss(val_out, val_target).item()
                    val_R2=R2(val_out,val_target).item()
                    if self.step % print_per_epoch==0:
                        print('Epoch[{}/{}],loss:{:.8f},val_loss:{:.8f},val_mae:{:.8f},val_R2:{:.8f}'
                              .format(self.step,num_train*len_train,loss.item(),val_loss,val_mae,val_R2))

                    assert (loss > stop_loss) or (val_loss > stop_loss),'Training and prediction Loss is less' \
                                                                        ' than cut-off Loss, so training stops'
                elif (self.step % print_per_epoch == 0) and (self.step%val_per_train!=0):
                    print('Epoch[{}/{}],loss:{:.8f}'.format(self.step,num_train*len_train, loss.item()))
                    
    def load_state_and_optimizer(self,state_path=None,optimizer_path=None):
        if state_path is not None:
            state_dict=torch.load(state_path)
            self.model.load_state_dict(state_dict)
        if optimizer_path is not None:
            self.optimizer=torch.load(optimizer_path)

    def save_param(self,path):
        torch.save(self.model.state_dict(),path)

    def save_model(self,path):
        torch.save(self.model, path)

    def save_opt(self,path):
        torch.save(self.optimizer,path)

    def params(self):
        return self.model.state_dict()
    

In [8]:
'''Then, modify the data type and device type'''
device=torch.device('cuda')
dtype=torch.float32
model=model.to(dtype)
model=model.to(device)

In [9]:
'''Finally, using the trainer, training 20 times from a 5e-4 learning rate'''
trainer=Trainer(model,train_loader=trainloader,val_loader=valloader,loss_function=l2loss,lr=1e-5,weight_decay=0,optimizer='AdamW')

In [10]:
trainer.train(num_train=20,targ='npacharge')

Epoch[0/1927],loss:0.06453520,val_loss:0.09647646,val_mae:0.26150602,val_R2:0.05411941
Epoch[10/1927],loss:0.03383980
Epoch[20/1927],loss:0.02285156
Epoch[30/1927],loss:0.01284886
Epoch[40/1927],loss:0.01130290
Epoch[50/1927],loss:0.00883452,val_loss:0.01035036,val_mae:0.06626014,val_R2:0.90387326
Epoch[60/1927],loss:0.00643805
Epoch[70/1927],loss:0.00432105
Epoch[80/1927],loss:0.00278638
Epoch[90/1927],loss:0.00219296
Epoch[100/1927],loss:0.00170563,val_loss:0.00196205,val_mae:0.02748346,val_R2:0.98102534
Epoch[110/1927],loss:0.00193614
Epoch[120/1927],loss:0.00124189
Epoch[130/1927],loss:0.00122081
Epoch[140/1927],loss:0.00109711
Epoch[150/1927],loss:0.00093220,val_loss:0.00098367,val_mae:0.02152115,val_R2:0.99056077
Epoch[160/1927],loss:0.00097455
Epoch[170/1927],loss:0.00074644
Epoch[180/1927],loss:0.00097347
Epoch[190/1927],loss:0.00087645
Epoch[200/1927],loss:0.00071156,val_loss:0.00069464,val_mae:0.01792950,val_R2:0.99369150
Epoch[210/1927],loss:0.00066975
Epoch[220/1927],loss:0

Epoch[1860/1927],loss:0.00012250
Epoch[1870/1927],loss:0.00011756
Epoch[1880/1927],loss:0.00034985
Epoch[1890/1927],loss:0.00022203
Epoch[1900/1927],loss:0.00012808,val_loss:0.00011116,val_mae:0.00736343,val_R2:0.99896818
Epoch[1910/1927],loss:0.00011418
Epoch[1920/1927],loss:0.00008459


In [11]:
'''After the training is completed, take out the learnable parameters and save them as a .pth file.'''
state_dict=model.state_dict

state_dict()

OrderedDict([('Embedding.act.alpha',
              tensor([0.9962, 0.9973, 0.9957, 0.9963, 0.9970, 0.9955, 0.9992, 0.9982, 0.9989,
                      0.9986, 0.9998, 0.9981, 0.9985, 1.0003, 0.9984, 0.9977, 0.9996, 0.9986,
                      0.9977, 0.9981, 0.9999, 0.9982, 1.0006, 1.0010, 0.9995, 0.9943, 0.9993,
                      0.9981, 1.0002, 0.9980, 0.9988, 0.9987, 0.9986, 0.9991, 0.9979, 1.0001,
                      0.9977, 0.9989, 0.9933, 1.0012, 0.9985, 0.9986, 1.0006, 1.0005, 0.9991,
                      0.9958, 0.9984, 0.9982, 0.9971, 0.9978, 0.9963, 0.9995, 0.9972, 0.9947,
                      0.9982, 0.9947, 0.9979, 1.0004, 0.9943, 0.9986, 0.9994, 0.9996, 0.9954,
                      1.0031, 0.9999, 0.9978, 0.9997, 0.9979, 1.0004, 0.9998, 0.9987, 0.9994,
                      0.9980, 0.9966, 0.9999, 0.9983, 0.9955, 0.9986, 0.9960, 0.9936, 1.0008,
                      0.9938, 0.9968, 0.9962, 0.9958, 0.9984, 0.9957, 1.0001, 0.9967, 0.9959,
                      0

In [12]:
torch.save(model.state_dict(),'npacharge_param.pth')

In [13]:
'''When needed model , simply load parameters on the untrained model. The parameters we have trained are saved in trained_Param/.'''
state_dict=torch.load('npacharge_param.pth')
model.load_state_dict(state_dict)

<All keys matched successfully>