Copyright 2021-2022 @ Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd

This code is a part of Cybertron package.

The Cybertron is open-source software based on the AI-framework:
MindSpore (https://www.mindspore.cn/)

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and
limitations under the License.

Cybertron tutorial 06: Multi-task with multiple readouts (example 2)

In [1]:
import sys
import time
import numpy as np
import mindspore as ms
from mindspore import nn
from mindspore import Tensor
from mindspore import context
from mindspore import dataset as ds
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

from cybertron import Cybertron
from cybertron import MolCT
from cybertron.train import MAE, MLoss
from cybertron.train import WithLabelLossCell, WithLabelEvalCell
from cybertron.train import TrainMonitor
from cybertron.train import TransformerLR

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

In [2]:
data_name = sys.path[0] + '/dataset_qm9_normed_'
train_file = data_name + 'trainset_1024.npz'
valid_file = data_name + 'validset_128.npz'

train_data = np.load(train_file)
valid_data = np.load(valid_file)

# diplole,polarizability,HOMO,LUMO,gap,R2,zpve,capacity
idx = [0, 1, 2, 3, 4, 5, 6, 11]

num_atom = int(train_data['num_atoms'])
scale = Tensor(train_data['scale'][idx], ms.float32)
shift = Tensor(train_data['shift'][idx], ms.float32)
ref = Tensor(train_data['type_ref'][:, idx], ms.float32)

In [3]:
mod = MolCT(
    cutoff=1,
    n_interaction=3,
    dim_feature=128,
    n_heads=8,
    activation='swish',
    max_cycles=1,
    length_unit='nm',
)

In [4]:
net = Cybertron(mod, readout='graph', dim_output=[1, 1, 3, 1, 1, 1],
                num_atoms=num_atom, length_unit='nm')
net.print_info()

Cybertron Engine, Ride-on!
--------------------------------------------------------------------------------
    Length unit: nm
    Input unit scale: 1
--------------------------------------------------------------------------------
    Deep molecular model:  MolCT
--------------------------------------------------------------------------------
       Length unit: nm
       Atom embedding size: 64
       Cutoff distance: 1.0 nm
       Radical basis function (RBF): LogGaussianBasis
          Minimum distance: 0.04 nm
          Maximum distance: 1.0 nm
          Reference distance: 1.0 nm
          Log Gaussian begin: -3.218876
          Log Gaussian end: 0.006724119
          Interval for log Gaussian: 0.0512
          Sigma for log gaussian: 0.3
          Number of basis functions: 64
          Rescale the range of RBF to (-1,1).
       Calculate distance: Yes
       Calculate bond: No
       Feature dimension: 128
-----------------------------------------------------------------------

In [5]:
tot_params = 0
for i, param in enumerate(net.get_parameters()):
    tot_params += param.size
    print(i, param.name, param.shape)
print('Total parameters: ', tot_params)

0 model.atom_embedding.embedding_table (64, 128)
1 model.dis_filter.linear.weight (128, 64)
2 model.dis_filter.linear.bias (128,)
3 model.dis_filter.residual.nonlinear.mlp.0.weight (128, 128)
4 model.dis_filter.residual.nonlinear.mlp.0.bias (128,)
5 model.dis_filter.residual.nonlinear.mlp.1.weight (128, 128)
6 model.dis_filter.residual.nonlinear.mlp.1.bias (128,)
7 model.interactions.0.positional_embedding.norm.gamma (128,)
8 model.interactions.0.positional_embedding.norm.beta (128,)
9 model.interactions.0.positional_embedding.x2q.weight (128, 128)
10 model.interactions.0.positional_embedding.x2k.weight (128, 128)
11 model.interactions.0.positional_embedding.x2v.weight (128, 128)
12 model.interactions.0.multi_head_attention.output.weight (128, 128)
13 model.interactions.1.positional_embedding.norm.gamma (128,)
14 model.interactions.1.positional_embedding.norm.beta (128,)
15 model.interactions.1.positional_embedding.x2q.weight (128, 128)
16 model.interactions.1.positional_embedding.x2k.

In [6]:
n_epoch = 8
repeat_time = 1
batch_size = 32

In [7]:
ds_train = ds.NumpySlicesDataset(
    {'R': train_data['R'], 'Z': train_data['Z'], 'E': train_data['E'][:, idx]}, shuffle=True)
ds_train = ds_train.batch(batch_size, drop_remainder=True)
ds_train = ds_train.repeat(repeat_time)

In [8]:
ds_valid = ds.NumpySlicesDataset(
    {'R': valid_data['R'], 'Z': valid_data['Z'], 'E': valid_data['E'][:, idx]}, shuffle=False)
ds_valid = ds_valid.batch(128)
ds_valid = ds_valid.repeat(1)

In [9]:
loss_network = WithLabelLossCell('RZE', net, nn.MAELoss())
eval_network = WithLabelEvalCell('RZE', net, nn.MAELoss(), scale=scale, shift=shift, type_ref=ref)

WithLabelLossCell with input type: RZE
WithLabelEvalCell with input type: RZE
   with scaleshift for training and evaluate dataset:
   Output.            Scale           Shift        Mode
   0:        1.503183e+00    2.672827e+00       graph
   1:        8.173762e+00    7.528103e+01       graph
   2:        5.767056e+01   -6.306687e+02       graph
   3:        1.229870e+02    3.108809e+01       graph
   4:        1.238940e+02    6.617564e+02       graph
   5:        2.804632e+02    1.189402e+03       graph
   6:        8.699451e+01    3.914438e+02       graph
   7:        6.082039e+00   -2.213512e+01       graph
   with reference value for atom types:
   Type     Label0    Label1    Label2    Label3    Label4    Label5    Label6    Label7
   0:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00
   1:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  2.98e+00
   2:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00

In [10]:
lr = TransformerLR(learning_rate=1., warmup_steps=4000, dimension=128)
optim = nn.Adam(params=net.trainable_params(), learning_rate=lr)

In [11]:
eval_mae = 'EvalMAE'
atom_mae = 'AtomMAE'
eval_loss = 'Evalloss'
model = Model(loss_network, optimizer=optim, eval_network=eval_network,
              metrics={eval_mae: MAE([1, 2], reduce_all_dims=False),
                       atom_mae: MAE([1, 2, 3], reduce_all_dims=False, averaged_by_atoms=True),
                       eval_loss: MLoss(0)},)


In [12]:
outdir = 'Tutorial_C06'
outname = outdir + '_' + net.model_name
record_cb = TrainMonitor(model, outname, per_step=32, avg_steps=32,
                         directory=outdir, eval_dataset=ds_valid, best_ckpt_metrics=eval_loss)

In [13]:
config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=64, append_info=[net.hyper_param])
ckpoint_cb = ModelCheckpoint(prefix=outname, directory=outdir, config=config_ck)

In [14]:
np.set_printoptions(linewidth=200)

print("Start training ...")
beg_time = time.time()
model.train(n_epoch, ds_train, callbacks=[
    record_cb, ckpoint_cb], dataset_sink_mode=False)
end_time = time.time()
used_time = end_time - beg_time
m, s = divmod(used_time, 60)
h, m = divmod(m, 60)
print("Training Fininshed!")
print("Training Time: %02d:%02d:%02d" % (h, m, s))



Start training ...
Epoch: 1, Step: 32, Learning_rate: 1.0830951e-05, Last_Loss: 1.0130372, Avg_loss: 1.1669597309082747, EvalMAE: [  1.3362747   7.703095   62.00464   211.84177   127.35531   165.66887    89.63798     5.1856503], AtomMAE: [ 0.08019408  0.43752763  3.7588108  11.509322    7.733257    9.761714    4.9098487   0.32709664], Evalloss: 1.0163307189941406
Epoch: 2, Step: 64, Learning_rate: 2.2011289e-05, Last_Loss: 0.74949694, Avg_loss: 0.8905524872243404, EvalMAE: [  1.1407021   6.509083   49.054626  114.357376  111.911865  163.76294    63.18687     4.500634 ], AtomMAE: [0.06583923 0.40110385 2.9585197  6.637281   6.680557   9.658892   3.8237557  0.27761894], Evalloss: 0.7861423492431641
Epoch: 3, Step: 96, Learning_rate: 3.3191627e-05, Last_Loss: 0.8207214, Avg_loss: 0.74864068813622, EvalMAE: [  1.0361898   5.9250317  47.802467  101.428925   95.26276   159.91887    53.29        3.5662723], AtomMAE: [0.05967522 0.36806786 2.8738587  6.0552225  5.4488387  9.483607   3.2927709 