In [1]:
import atomdnn

# 'float32' is used for reading data and train by default, one can set data_type to 'float64' here
atomdnn.data_type = 'float64'

# force and stress are evaluated by default, 
# if one only need to compute potential energy, then set compute_force to false
atomdnn.compute_force = True

# default value is for converting ev/A^3 to GPa
# note that: the predicted positive stress means tension and negative stress means compression
stress_unit_convert = 160.2176 

import numpy as np
import tensorflow as tf
import pickle
from atomdnn import data
from atomdnn import network
from atomdnn.data import Data
from atomdnn.data import *
from atomdnn.network import Network
# import importlib
# importlib.reload(atomdnn.data)
# importlib.reload(atomdnn.network)

# Load data class from saved pickle file

In [5]:
grdata = pickle.load(open("/mnt/machine_learning/data_graphene/grdata.pickle", "rb", -1))

# Shuffel and then split the data into training, validation and test sets

### split(self, train_data_percent=None, val_data_percent=None, test_data_percent=None, data_size=None)

- **train_pct**: percentage of data used for training

- **val_pct**: percentage of data used for validation

- **test_pct**: percentage of data used for test

- **data_size**: if not set, use the whole data


In [6]:
grdata.shuffel()

In [7]:
(x_train,y_train),(x_val,y_val),(x_test,y_test) = grdata.split(0.7,0.2,0.1)

Traning data: 407 images
Validation data: 116 images
Test data: 59 images


# Build Network object from class Network 

__init__(self, elements=None, num_fingerprints=None, arch=None,activation_function=None, data_type=None, import_dir=None)

- **elements:** element list, required

- **num_fingerprints:** number of fingerprints in data, required

- **std**: = [mean, standard_deviation] of fingerprints, if set, standarlize the fingprints

- **norm**: = [min, max] of fingerprints, if set, normalize the fingerprints

- **arch:** number of layers of neural network

- **activation_function:** if not set, default is 'tanh'

- **import_dir:** read from the directory of a saved (imported) network, if used, all other parameters are disabled

In [8]:
model = Network(elements=['C'],num_fingerprints=grdata.num_fingerprints, std = [grdata.mean_fp,grdata.dev_fp],
               arch=[50,50])

activation function is set to tanh by default.


# Train the model

**train(self, train_input_dict, train_output_dict,
              batch_size=None, epochs=None, loss_fn=None, optimizer=None, lr=None, train_force=False, train_stress=False)**

- **train_input_dict**: input dictionary generated from build_dataset() for training
    
- **train_output_dict**: output dictionary generated from build_dataset() for training
    
- **batch_size**: if not set, use 30
    
- **epochs**: if not set, use 1
    
- **opimizer**: if not set, use Adam
    
- **lr**: learning rate, if not set, use 0.01
    
- **train_force**: if force used for training
    
- **train_stress**: if stress used for training

In [10]:
model.train(x_train,y_train, validation_data=[x_val,y_val], batch_size=30, epochs=500,train_force=True, 
            train_stress=True, pe_loss_weight=0.1, force_loss_weight=1, stress_loss_weight=0.0001)

Forces are used for training.
Stresses are used for training.

===> Epoch 1/500 - 2.365s/epoch
     training_loss    - pe: 144.518 - force: 21.407 - stress: 577030.049 - total: 93.562
     validation_loss  - pe: 116.097 - force: 11.010 - stress: 500661.376 - total: 72.686

===> Epoch 2/500 - 2.244s/epoch
     training_loss    - pe: 85.549 - force: 11.376 - stress: 307300.891 - total: 50.661
     validation_loss  - pe: 63.677 - force: 10.494 - stress: 397801.476 - total: 56.642

===> Epoch 3/500 - 2.365s/epoch
     training_loss    - pe: 150.463 - force: 11.176 - stress: 455915.298 - total: 71.814
     validation_loss  - pe: 107.334 - force: 10.221 - stress: 383299.662 - total: 59.284

===> Epoch 4/500 - 2.350s/epoch
     training_loss    - pe: 119.560 - force: 17.163 - stress: 524645.216 - total: 81.583
     validation_loss  - pe: 138.610 - force: 11.795 - stress: 432040.820 - total: 68.861

===> Epoch 5/500 - 2.385s/epoch
     training_loss    - pe: 122.708 - force: 11.755 - stress: 5


===> Epoch 40/500 - 2.000s/epoch
     training_loss    - pe: 135.809 - force: 14.301 - stress: 766307.001 - total: 104.513
     validation_loss  - pe: 105.165 - force: 9.264 - stress: 375017.909 - total: 57.283

===> Epoch 41/500 - 2.018s/epoch
     training_loss    - pe: 121.722 - force: 16.816 - stress: 885415.810 - total: 117.530
     validation_loss  - pe: 83.048 - force: 7.660 - stress: 449022.767 - total: 60.867

===> Epoch 42/500 - 2.046s/epoch
     training_loss    - pe: 139.156 - force: 25.887 - stress: 555298.632 - total: 95.333
     validation_loss  - pe: 47.655 - force: 8.568 - stress: 365202.644 - total: 49.854

===> Epoch 43/500 - 2.031s/epoch
     training_loss    - pe: 133.925 - force: 16.914 - stress: 594778.709 - total: 89.785
     validation_loss  - pe: 78.435 - force: 7.466 - stress: 336680.431 - total: 48.978

===> Epoch 44/500 - 2.035s/epoch
     training_loss    - pe: 81.986 - force: 9.251 - stress: 333121.879 - total: 50.762
     validation_loss  - pe: 101.832 


===> Epoch 80/500 - 2.078s/epoch
     training_loss    - pe: 80.519 - force: 9.713 - stress: 466223.797 - total: 64.387
     validation_loss  - pe: 68.835 - force: 6.261 - stress: 221172.517 - total: 35.262

===> Epoch 81/500 - 2.052s/epoch
     training_loss    - pe: 94.501 - force: 8.153 - stress: 350410.763 - total: 52.644
     validation_loss  - pe: 54.437 - force: 5.800 - stress: 440596.152 - total: 55.303

===> Epoch 82/500 - 2.069s/epoch
     training_loss    - pe: 123.724 - force: 11.138 - stress: 606594.287 - total: 84.170
     validation_loss  - pe: 33.256 - force: 5.879 - stress: 277944.534 - total: 36.999

===> Epoch 83/500 - 2.089s/epoch
     training_loss    - pe: 110.873 - force: 13.866 - stress: 556862.720 - total: 80.639
     validation_loss  - pe: 51.901 - force: 5.378 - stress: 326097.800 - total: 43.178

===> Epoch 84/500 - 2.047s/epoch
     training_loss    - pe: 126.770 - force: 4.878 - stress: 465744.810 - total: 64.129
     validation_loss  - pe: 65.879 - force


===> Epoch 120/500 - 1.981s/epoch
     training_loss    - pe: 83.691 - force: 3.465 - stress: 320787.611 - total: 43.913
     validation_loss  - pe: 34.456 - force: 5.454 - stress: 215566.948 - total: 30.457

===> Epoch 121/500 - 1.977s/epoch
     training_loss    - pe: 103.087 - force: 8.217 - stress: 577838.650 - total: 76.310
     validation_loss  - pe: 26.790 - force: 4.898 - stress: 419925.843 - total: 49.569

===> Epoch 122/500 - 1.992s/epoch
     training_loss    - pe: 72.062 - force: 9.449 - stress: 412318.841 - total: 57.887
     validation_loss  - pe: 17.328 - force: 5.183 - stress: 253656.421 - total: 32.282

===> Epoch 123/500 - 2.007s/epoch
     training_loss    - pe: 84.565 - force: 6.906 - stress: 328918.366 - total: 48.254
     validation_loss  - pe: 24.795 - force: 4.692 - stress: 298058.262 - total: 36.977

===> Epoch 124/500 - 1.991s/epoch
     training_loss    - pe: 109.610 - force: 8.791 - stress: 458263.844 - total: 65.578
     validation_loss  - pe: 31.612 - for


===> Epoch 160/500 - 2.035s/epoch
     training_loss    - pe: 86.958 - force: 7.833 - stress: 487673.278 - total: 65.296
     validation_loss  - pe: 9.973 - force: 4.680 - stress: 196012.566 - total: 25.279

===> Epoch 161/500 - 2.043s/epoch
     training_loss    - pe: 85.547 - force: 4.499 - stress: 362008.277 - total: 49.255
     validation_loss  - pe: 8.439 - force: 4.313 - stress: 414420.157 - total: 46.599

===> Epoch 162/500 - 2.068s/epoch
     training_loss    - pe: 76.558 - force: 3.593 - stress: 348855.117 - total: 46.135
     validation_loss  - pe: 7.056 - force: 4.486 - stress: 246734.317 - total: 29.865

===> Epoch 163/500 - 2.052s/epoch
     training_loss    - pe: 93.648 - force: 7.681 - stress: 540508.772 - total: 71.096
     validation_loss  - pe: 6.443 - force: 3.614 - stress: 282623.186 - total: 32.521

===> Epoch 164/500 - 2.870s/epoch
     training_loss    - pe: 66.144 - force: 8.934 - stress: 358080.337 - total: 51.356
     validation_loss  - pe: 8.725 - force: 4.5


===> Epoch 200/500 - 2.000s/epoch
     training_loss    - pe: 75.126 - force: 7.712 - stress: 401354.475 - total: 55.360
     validation_loss  - pe: 8.795 - force: 4.173 - stress: 219462.467 - total: 26.999

===> Epoch 201/500 - 2.019s/epoch
     training_loss    - pe: 76.882 - force: 6.073 - stress: 474011.410 - total: 61.162
     validation_loss  - pe: 8.582 - force: 3.862 - stress: 427053.564 - total: 47.426

===> Epoch 202/500 - 2.009s/epoch
     training_loss    - pe: 76.577 - force: 7.115 - stress: 467910.986 - total: 61.563
     validation_loss  - pe: 7.574 - force: 4.075 - stress: 229555.489 - total: 27.787

===> Epoch 203/500 - 2.031s/epoch
     training_loss    - pe: 76.213 - force: 4.053 - stress: 323937.727 - total: 44.068
     validation_loss  - pe: 6.857 - force: 3.337 - stress: 296335.386 - total: 33.656

===> Epoch 204/500 - 2.118s/epoch
     training_loss    - pe: 68.117 - force: 3.354 - stress: 293648.630 - total: 39.531
     validation_loss  - pe: 8.555 - force: 3.9


===> Epoch 240/500 - 2.013s/epoch
     training_loss    - pe: 51.942 - force: 5.316 - stress: 262019.266 - total: 36.713
     validation_loss  - pe: 7.426 - force: 3.635 - stress: 186372.367 - total: 23.015

===> Epoch 241/500 - 2.017s/epoch
     training_loss    - pe: 80.686 - force: 6.094 - stress: 378058.085 - total: 51.968
     validation_loss  - pe: 6.865 - force: 3.397 - stress: 395833.867 - total: 43.667

===> Epoch 242/500 - 2.047s/epoch
     training_loss    - pe: 66.710 - force: 7.104 - stress: 369796.935 - total: 50.755
     validation_loss  - pe: 6.912 - force: 3.596 - stress: 232447.746 - total: 27.532

===> Epoch 243/500 - 2.031s/epoch
     training_loss    - pe: 65.056 - force: 5.546 - stress: 517983.271 - total: 63.850
     validation_loss  - pe: 6.178 - force: 3.011 - stress: 272340.036 - total: 30.863

===> Epoch 244/500 - 2.025s/epoch
     training_loss    - pe: 65.903 - force: 6.427 - stress: 492704.710 - total: 62.287
     validation_loss  - pe: 8.272 - force: 3.5


===> Epoch 280/500 - 2.239s/epoch
     training_loss    - pe: 70.631 - force: 4.629 - stress: 446787.650 - total: 56.371
     validation_loss  - pe: 7.396 - force: 3.309 - stress: 191258.051 - total: 23.174

===> Epoch 281/500 - 2.066s/epoch
     training_loss    - pe: 68.167 - force: 8.234 - stress: 262300.214 - total: 41.281
     validation_loss  - pe: 7.800 - force: 3.188 - stress: 402483.450 - total: 44.216

===> Epoch 282/500 - 2.013s/epoch
     training_loss    - pe: 45.246 - force: 4.784 - stress: 259051.958 - total: 35.213
     validation_loss  - pe: 7.024 - force: 3.314 - stress: 229147.146 - total: 26.932

===> Epoch 283/500 - 2.078s/epoch
     training_loss    - pe: 69.477 - force: 4.987 - stress: 366374.196 - total: 48.573
     validation_loss  - pe: 6.075 - force: 2.731 - stress: 275625.440 - total: 30.901

===> Epoch 284/500 - 2.081s/epoch
     training_loss    - pe: 58.058 - force: 6.307 - stress: 347513.471 - total: 46.864
     validation_loss  - pe: 7.855 - force: 3.2


===> Epoch 320/500 - 2.091s/epoch
     training_loss    - pe: 59.216 - force: 5.916 - stress: 374455.564 - total: 49.283
     validation_loss  - pe: 7.544 - force: 2.868 - stress: 197304.826 - total: 23.353

===> Epoch 321/500 - 2.100s/epoch
     training_loss    - pe: 51.389 - force: 6.570 - stress: 367286.377 - total: 48.437
     validation_loss  - pe: 8.930 - force: 2.959 - stress: 400565.431 - total: 43.909

===> Epoch 322/500 - 2.145s/epoch
     training_loss    - pe: 60.695 - force: 3.503 - stress: 401939.832 - total: 49.767
     validation_loss  - pe: 8.757 - force: 3.105 - stress: 229596.275 - total: 26.940

===> Epoch 323/500 - 2.089s/epoch
     training_loss    - pe: 58.578 - force: 7.654 - stress: 259090.926 - total: 39.421
     validation_loss  - pe: 7.719 - force: 2.550 - stress: 275460.330 - total: 30.868

===> Epoch 324/500 - 2.093s/epoch
     training_loss    - pe: 38.053 - force: 4.772 - stress: 257404.823 - total: 34.318
     validation_loss  - pe: 8.661 - force: 2.9


===> Epoch 360/500 - 2.166s/epoch
     training_loss    - pe: 35.443 - force: 5.262 - stress: 344973.725 - total: 43.303
     validation_loss  - pe: 8.093 - force: 2.702 - stress: 186892.728 - total: 22.201

===> Epoch 361/500 - 2.108s/epoch
     training_loss    - pe: 40.230 - force: 4.730 - stress: 274278.300 - total: 36.181
     validation_loss  - pe: 7.056 - force: 2.787 - stress: 391668.866 - total: 42.659

===> Epoch 362/500 - 2.084s/epoch
     training_loss    - pe: 48.215 - force: 5.555 - stress: 415951.228 - total: 51.972
     validation_loss  - pe: 6.686 - force: 2.811 - stress: 227880.923 - total: 26.267

===> Epoch 363/500 - 2.108s/epoch
     training_loss    - pe: 42.048 - force: 6.324 - stress: 352036.093 - total: 45.732
     validation_loss  - pe: 5.343 - force: 2.393 - stress: 272455.689 - total: 30.172

===> Epoch 364/500 - 2.090s/epoch
     training_loss    - pe: 49.293 - force: 3.307 - stress: 405685.229 - total: 48.805
     validation_loss  - pe: 7.481 - force: 2.7


===> Epoch 400/500 - 2.131s/epoch
     training_loss    - pe: 31.416 - force: 2.783 - stress: 299387.693 - total: 35.863
     validation_loss  - pe: 7.224 - force: 2.563 - stress: 189619.638 - total: 22.247

===> Epoch 401/500 - 2.107s/epoch
     training_loss    - pe: 36.483 - force: 4.658 - stress: 563396.279 - total: 64.646
     validation_loss  - pe: 7.267 - force: 2.592 - stress: 393911.723 - total: 42.709

===> Epoch 402/500 - 2.130s/epoch
     training_loss    - pe: 27.909 - force: 6.073 - stress: 353387.339 - total: 44.202
     validation_loss  - pe: 6.880 - force: 2.688 - stress: 230684.978 - total: 26.444

===> Epoch 403/500 - 2.109s/epoch
     training_loss    - pe: 32.527 - force: 4.645 - stress: 284801.660 - total: 36.377
     validation_loss  - pe: 5.594 - force: 2.321 - stress: 268116.433 - total: 29.692

===> Epoch 404/500 - 2.142s/epoch
     training_loss    - pe: 37.694 - force: 5.558 - stress: 459722.051 - total: 55.299
     validation_loss  - pe: 7.522 - force: 2.4


===> Epoch 440/500 - 2.163s/epoch
     training_loss    - pe: 24.497 - force: 4.698 - stress: 448319.132 - total: 51.979
     validation_loss  - pe: 6.969 - force: 2.345 - stress: 188669.004 - total: 21.909

===> Epoch 441/500 - 2.160s/epoch
     training_loss    - pe: 26.423 - force: 2.920 - stress: 311680.709 - total: 36.730
     validation_loss  - pe: 8.913 - force: 2.453 - stress: 390147.136 - total: 42.360

===> Epoch 442/500 - 2.173s/epoch
     training_loss    - pe: 22.692 - force: 2.876 - stress: 295487.821 - total: 34.694
     validation_loss  - pe: 9.062 - force: 2.566 - stress: 226049.071 - total: 26.078

===> Epoch 443/500 - 2.170s/epoch
     training_loss    - pe: 26.110 - force: 4.601 - stress: 478879.356 - total: 55.100
     validation_loss  - pe: 7.851 - force: 2.243 - stress: 264089.348 - total: 29.437

===> Epoch 444/500 - 2.089s/epoch
     training_loss    - pe: 23.952 - force: 4.967 - stress: 288484.696 - total: 36.211
     validation_loss  - pe: 8.659 - force: 2.3


===> Epoch 480/500 - 1.983s/epoch
     training_loss    - pe: 18.173 - force: 4.863 - stress: 340286.124 - total: 40.709
     validation_loss  - pe: 7.145 - force: 2.269 - stress: 184725.737 - total: 21.456

===> Epoch 481/500 - 1.994s/epoch
     training_loss    - pe: 16.968 - force: 3.951 - stress: 439683.025 - total: 49.616
     validation_loss  - pe: 7.206 - force: 2.358 - stress: 391254.316 - total: 42.204

===> Epoch 482/500 - 2.025s/epoch
     training_loss    - pe: 15.776 - force: 4.182 - stress: 408018.892 - total: 46.561
     validation_loss  - pe: 6.963 - force: 2.467 - stress: 225362.421 - total: 25.699

===> Epoch 483/500 - 2.008s/epoch
     training_loss    - pe: 18.886 - force: 2.715 - stress: 308195.591 - total: 35.424
     validation_loss  - pe: 5.508 - force: 2.182 - stress: 266691.760 - total: 29.402

===> Epoch 484/500 - 1.990s/epoch
     training_loss    - pe: 15.001 - force: 2.613 - stress: 297121.687 - total: 33.826
     validation_loss  - pe: 7.423 - force: 2.2

In [None]:
model.evaluate(x_test,y_test)

In [None]:
image = data.slice(grdata.input_dict,0,60)

In [None]:
model.predict(image)

In [None]:
data.slice(grdata.output_dict,0,60)

# Prediction: compute potential energy, force and stress

**predict (self, input_dict, training=False,compute_force=True)**

- **input_dict**: input dictionary generated from build_dataset function
    
- **training**: set to False
    
- **compute_force**: if compute force, derivative data are needed

In [None]:
stress_predict = tf.convert_to_tensor(stress_predict)
mask = [True,True,True,False,True,True,False,False,True]
tf.reshape(tf.boolean_mask(stress_predict, mask,axis=1),[-1,6])

In [None]:
loss_fn = tf.keras.losses.get('mae')

pe_predict = model.predict(x_test)['pe']
force_predict = model.predict(x_test)['force']
stress_predict = model.predict(x_test)['stress']

print(loss_fn(pe_predict,y_test['pe']))
print(tf.reduce_mean(loss_fn(force_predict,y_test['force'])))
print(tf.reduce_mean(loss_fn(stress_predict,y_test['stress'])))

In [None]:
model.__call__(x_test)

# Save trained model

**save(obj, model_dir, descriptor=None)**

- **obj**: Network object

- **model_dir**: directory for saving the trained model

- **descriptor**: descriptor parameters used to generate fingerprints, if set, a parameters file is generated for LAMMPS simulation

In [None]:
descriptor = {'name': 'acsf', 
              'cutoff': 6.5001,
              'etaG2':[0.01,0.025,0.05,0.075,0.1,0.15,0.2,0.3,0.4,0.5,0.6,0.8,1,1.5,2,3,5,10], 
              'etaG4': [0.01], 
              'zeta': [0.08,0.1,0.15,0.2,0.3,0.35,0.5,0.6,0.8,1.,1.5,2.,3.0,4.,5.5,7.0,10.0,25.0,50.0,100.0],
              'lambda': [1.0, -1.0]}

save_dir = 'graphene_24atoms.tfdnn'
network.save(model, save_dir,descriptor=descriptor)

# Load the saved model for continuous training and prediction

**load(model_dir)**

- **model_dir**: saved model directory

In [None]:
save_dir = 'graphene_24atoms.tfdnn'
model = network.load(save_dir)

In [None]:
# print signature

network.print_signature(save_dir)

In [None]:
onedata = data.read_inputdata_from_lmp(batch_mode=False, fp_filename='data_graphene_96atoms/dump_fingerprints.200',der_filename='data_graphene_96atoms/dump_fingerprints_der.200')

In [None]:
# peratom potential energy

new_model.__call__(onedata.input_dict)

In [None]:
# compute peratom stress 

centerid = onedata.input_dict['center_atom_id']

center_one_hot = tf.one_hot(centerid,depth=onedata.num_blocks,axis=1,dtype=onedata.data_type)

stress_block = new_model.__call__(onedata.input_dict)['stress']

stress_peratom = tf.matmul(center_one_hot,stress_block)

evA2bar = 1602176

for i in range(0,onedata.num_atoms):
    print(i+1,"          ",stress_peratom[0][i].numpy()*ev2bar)
    print('\n')

In [None]:
atom_pe = new_model.__call__(onedata.input_dict)['atom_pe'][0]

for i in range(0,onedata.num_atoms):
    print("%d:   %.6g %.6 %.6 %.6 %.6 %.6 %.6 %.6" % (i+1, atom_pe[i].numpy()))

In [None]:
stress = new_model.__call__(onedata.input_dict)['stress'][0]

for i in range(0,onedata.num_atoms):
    print("%d:   %.6g" % (i+1, atom_pe[i].numpy()))

In [None]:
new_model.predict(onedata.input_dict)

In [None]:
force = new_model.predict(onedata.input_dict)['force'][0]

print ("%s %5s %15s %15s"%("atom_id","f_x","f_y","f_z"))
for i in range(0,onedata.num_atoms):
    print("%d %15.6f %15.6f %15.6f" % (i+1,force[i][0].numpy(), force[i][1].numpy(), force[i][2].numpy()))

In [None]:
# continue the training 

new_model.train(grdata.train_input_dict, grdata.train_output_dict, batch_size=30, epochs=100)

In [None]:
new_model.predict(grdata.test_input_dict)

# Check C_inference

In [None]:
# chose the second image to test
image = data.slice_dict (grdata.input_dict,0,1)
model.predict(image)

In [None]:
print(image['neighbor_atom_id'][0][23])
print(image['neighbor_atom_coord'][0][23])

In [None]:
image

In [None]:
model.__call__(image)

In [None]:
!../c_inference/inference_energy "../example/graphene_energy.tfdnn/" "../example/data_graphene_96atoms/dump_fingerprints.200" 96  

# Save dataset

In [None]:
tfdataset = tf.data.Dataset.from_tensor_slices((grdata.input_dict,grdata.output_dict))

In [None]:
tf.data.experimental.save(tfdataset, 'graphene_tfdataset')

In [None]:
newdataset = tf.data.experimental.load('graphene_tfdataset',element_spec=tfdataset.element_spec)

# debuging

In [None]:
pe_model = network.load('graphene_energy.tfdnn')

In [None]:
pe_model.predict(onedata.input_dict,compute_force=False)

In [None]:
onedata = data.read_inputdata_from_lmp(batch_mode=False, fp_filename='data_graphene_96atoms/dump_fingerprints.200',der_filename='data_graphene_96atoms/dump_fingerprints_der.200')

In [None]:
int(float('2'))

In [None]:
model.predict(onedata.input_dict)

In [None]:
for i in range(96):
    print(i+1, model.predict(onedata.input_dict)['force'][0][i].numpy())

In [None]:
onedata.input_dict['fingerprints'][0][1]

In [None]:
print('%.8e'%onedata.input_dict['fingerprints'][0][1][0].numpy())