In [1]:
from sklearn.metrics import mean_squared_error
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from ase.visualize import view
import ase.io
from amp import Amp
from amp.model.neuralnetwork import NeuralNetwork
from amp.model import LossFunction
import operator
import amptorch
import copy
import matplotlib
from skorch import NeuralNetRegressor
from skorch.dataset import CVSplit
from skorch.callbacks import Checkpoint, EpochScoring
from skorch.callbacks.lr_scheduler import LRScheduler
import skorch.callbacks.base
from amptorch.gaussian import SNN_Gaussian
from amptorch.model import BPNN, CustomMSELoss
from amptorch.data_preprocess import AtomsDataset, factorize_data, collate_amp, TestDataset
from amptorch.skorch_model import AMP
from amptorch.skorch_model.utils import target_extractor, energy_score, forces_score
from amptorch.analysis import parity_plot
from torch.utils.data import DataLoader
from torch.nn import init
from skorch.utils import to_numpy
import matplotlib.pyplot as plt

In [2]:
def Select_image(images):
    train_image = []
    test_image = []
    for index,image in enumerate(images[:]):
        if index-index//3*3 == 0:
            test_image.append(image)
        else:
            train_image.append(image)
    print(len(train_image), len(test_image))
    return train_image, test_image


In [3]:
def run_model(assignments, numberi=0):
    class train_end_load_best_valid_loss(skorch.callbacks.base.Callback):
        def on_train_end(self, net, X, y):
            net.load_params('valid_best_params.pt')
    out_file = open('log-sigopt','a')
    # LR_schedule = LRScheduler("CosineAnnealingLR", T_max=assignments["T_max"])
    cp = Checkpoint(monitor='valid_loss_best', fn_prefix='valid_best_')
    load_best_valid_loss = train_end_load_best_valid_loss()
    images = ase.io.read('./traj_taged_adsorptionenergy.traj', index=':')
    Gs = {}
    Gs["G2_etas"] = np.logspace(np.log10(0.05), np.log10(5.0), num=4)
    Gs["G2_rs_s"] = [0] * 4
    Gs["G4_etas"] = [0.005]
    Gs["G4_zetas"] = [1.0]
    Gs["G4_gammas"] = [+1.0, -1]
    Gs["cutoff"] = 6.5
    forcetraining = False
    label = "sigopt-zeolite"
    train_images, test_images = Select_image(images)

    DFT_energies_test = [image.get_potential_energy() for image in test_images]
    DFT_energies_train = [image.get_potential_energy() for image in train_images]
    training_data = AtomsDataset(train_images, SNN_Gaussian, Gs, forcetraining=forcetraining,
            label=label, cores=1, delta_data=None, specific_atoms=True)

    unique_atoms = training_data.elements
    fp_length = training_data.fp_length
    device = "cpu"
    torch.set_num_threads(1)
    optimizer = optim.LBFGS
    batch_size = len(training_data)
    net = NeuralNetRegressor(
        module=BPNN(
            unique_atoms,
            [fp_length, assignments["hidden_layers"], assignments["num_nodes"]],
            device,
            forcetraining=forcetraining,
        ),
        criterion=CustomMSELoss,
        criterion__force_coefficient=0,
        optimizer=optimizer,
        #optimizer=torch.optim.LBFGS,
        lr=assignments["learning_rate"],
        #lr=1e-1,
        batch_size=batch_size,
        max_epochs=assignments["epochs"],
        iterator_train__collate_fn=collate_amp,
        iterator_train__shuffle=False,
        iterator_valid__collate_fn=collate_amp,
        iterator_valid__shuffle=False,
        device=device,
        train_split=CVSplit(cv=0.2),
        callbacks=[
            EpochScoring(
                energy_score,
                on_train=True,
                use_caching=True,
                target_extractor=target_extractor,
            ),
             cp,
             load_best_valid_loss,
            # LR_schedule
        ],
    )
    calc = AMP(training_data, net, label,specific_atoms=True)
    calc.train(overwrite=True)
    for image in test_images:
        image.set_calculator(calc)
    for image in train_images:
        image.set_calculator(calc)
    pred_energies_test = [image.get_potential_energy() for image in test_images]
    pred_energies_train = [image.get_potential_energy() for image in train_images]
    for e_t in pred_energies_test:
        if np.isnan(e_t):
            e_t = 1e10
    for e_tr in pred_energies_train:
        if np.isnan(e_tr):
            e_tr =1e10
    energy_rmse_test = mean_squared_error(pred_energies_test, DFT_energies_test)
    energy_rmse_train = mean_squared_error(pred_energies_train, DFT_energies_train)
    print('***************', file=out_file)
    print('No. ', numberi, file=out_file)
    print('**RMSE_TEST:', energy_rmse_test,file=out_file)
    print('**RMSE_Train:', energy_rmse_train,file=out_file)
    print('Energies:::::::', file=out_file)
    print('train data',file=out_file)
    for i in range(len(train_images)):
        print('(',DFT_energies_train[i],';',pred_energies_train[i],') ', end='',file=out_file)
    print('****',file=out_file)
    print('test data', file=out_file)
    for i in range(len(test_images)):
        print('(',DFT_energies_test[i],';',pred_energies_test[i],') ', end='',file=out_file)
    print('****',file=out_file)
    fig = plt.figure('ML vs DFT')
    ax = fig.add_subplot(111)
    ax.scatter(DFT_energies_train, pred_energies_train, c='b')
    ax.scatter(DFT_energies_test, pred_energies_test, c='r')
    ax.set_xlabel('E_DFT')
    ax.set_ylabel('E_ML')
    figure_name = './DFTvsML'+ str(numberi)+ '.png'
    fig.savefig(figure_name, bbox_inches='tight')
    fig = 0
    return energy_rmse_test, energy_rmse_train

In [4]:
assignments = {
        'epochs': 995,
        'learning_rate': 0.02315667835660472,
        'hidden_layers': 7,
        'num_nodes': 36
        }

In [5]:
run_model(assignments)

720 361
Calculating fingerprints...
Fingerprints Calculated!
  epoch    energy_score    train_loss    valid_loss    cp     dur
-------  --------------  ------------  ------------  ----  ------
      1          [36m0.6139[0m      [32m343.9013[0m       [35m13.0896[0m     +  1.9110
      2          [36m0.5810[0m      [32m308.0504[0m       19.9706        1.8890
      3          [36m0.5512[0m      [32m277.2128[0m       25.9631        1.9574
      4          [36m0.5105[0m      [32m237.8072[0m       19.2430        1.8756
      5          [36m0.4880[0m      [32m217.2871[0m       16.2464        1.9177
      6          [36m0.4775[0m      [32m208.0200[0m       14.3731        1.9798
      7          [36m0.4494[0m      [32m184.2796[0m       15.3259        1.9570
      8          [36m0.3982[0m      [32m144.6816[0m       16.3194        1.9870
      9          [36m0.3592[0m      [32m117.7025[0m       14.9146        1.8919
     10          [36m0.3314[0m      [32

     96          [36m0.0507[0m        [32m2.3475[0m       11.1698        1.9710
     97          [36m0.0505[0m        [32m2.3313[0m       10.9648        1.9589
     98          [36m0.0502[0m        [32m2.3018[0m       10.7024        1.8830
     99          [36m0.0497[0m        [32m2.2579[0m       10.3863        1.9563
    100          [36m0.0492[0m        [32m2.2097[0m       10.1128        1.9727
    101          [36m0.0488[0m        [32m2.1728[0m        9.9165        1.9368
    102          [36m0.0485[0m        [32m2.1446[0m        9.6813        1.9899
    103          [36m0.0480[0m        [32m2.1061[0m        9.4619        1.9196
    104          [36m0.0475[0m        [32m2.0596[0m        9.3107        1.9783
    105          [36m0.0471[0m        [32m2.0256[0m        9.2312        2.0197
    106          [36m0.0469[0m        [32m2.0038[0m        9.1780        1.9709
    107          [36m0.0467[0m        [32m1.9888[0m        9.1445       

    193          [36m0.0302[0m        [32m0.8336[0m        8.5986        1.9483
    194          [36m0.0300[0m        [32m0.8224[0m        8.6299        1.9223
    195          [36m0.0299[0m        [32m0.8141[0m        8.6348        1.9834
    196          [36m0.0297[0m        [32m0.8058[0m        8.6115        1.9498
    197          [36m0.0295[0m        [32m0.7947[0m        8.5656        1.9738
    198          [36m0.0293[0m        [32m0.7852[0m        8.4883        1.9122
    199          [36m0.0292[0m        [32m0.7757[0m        8.4205        1.9403
    200          [36m0.0290[0m        [32m0.7674[0m        8.3558        1.9138
    201          [36m0.0289[0m        [32m0.7605[0m        8.3159        1.9829
    202          [36m0.0288[0m        [32m0.7545[0m        8.2995        1.9984
    203          [36m0.0286[0m        [32m0.7487[0m        8.2791        1.8923
    204          [36m0.0285[0m        [32m0.7421[0m        8.2102       

    291          [36m0.0172[0m        [32m0.2694[0m       13.0548        1.9304
    292          [36m0.0171[0m        [32m0.2670[0m       13.1150        1.8729
    293          [36m0.0170[0m        [32m0.2634[0m       13.1469        1.9728
    294          [36m0.0169[0m        [32m0.2615[0m       13.1755        1.9742
    295          [36m0.0169[0m        [32m0.2604[0m       13.2031        1.9793
    296          [36m0.0169[0m        [32m0.2596[0m       13.2407        1.9194
    297          [36m0.0168[0m        [32m0.2586[0m       13.3272        1.9422
    298          [36m0.0168[0m        [32m0.2568[0m       13.4459        1.9612
    299          [36m0.0167[0m        [32m0.2547[0m       13.5782        1.9524
    300          [36m0.0166[0m        [32m0.2521[0m       13.7435        1.9712
    301          [36m0.0165[0m        [32m0.2491[0m       13.9301        1.9005
    302          [36m0.0164[0m        [32m0.2462[0m       14.1412       

    389          [36m0.0118[0m        [32m0.1281[0m       28.3987        1.9822
    390          [36m0.0118[0m        [32m0.1275[0m       28.5668        1.8889
    391          [36m0.0118[0m        [32m0.1266[0m       28.7455        1.9633
    392          [36m0.0117[0m        [32m0.1255[0m       28.8817        1.9368
    393          [36m0.0117[0m        [32m0.1247[0m       28.9989        1.9343
    394          [36m0.0117[0m        [32m0.1240[0m       29.1049        1.9570
    395          [36m0.0116[0m        [32m0.1234[0m       29.2284        1.9627
    396          [36m0.0116[0m        [32m0.1227[0m       29.3724        1.9020
    397          [36m0.0115[0m        [32m0.1216[0m       29.5056        1.9240
    398          [36m0.0115[0m        [32m0.1204[0m       29.5869        1.9231
    399          [36m0.0114[0m        [32m0.1192[0m       29.6364        1.9630
    400          [36m0.0114[0m        [32m0.1184[0m       29.6440       

    487          [36m0.0086[0m        [32m0.0682[0m       29.1939        1.9835
    488          [36m0.0086[0m        [32m0.0679[0m       29.1297        1.9344
    489          [36m0.0086[0m        [32m0.0677[0m       29.0348        1.9387
    490          [36m0.0086[0m        [32m0.0673[0m       28.9327        1.8899
    491          [36m0.0086[0m        [32m0.0670[0m       28.8517        1.9319
    492          [36m0.0085[0m        [32m0.0666[0m       28.7520        1.9504
    493          [36m0.0085[0m        [32m0.0661[0m       28.6561        1.9394
    494          [36m0.0085[0m        [32m0.0655[0m       28.5752        1.9441
    495          [36m0.0084[0m        [32m0.0650[0m       28.5064        1.9610
    496          [36m0.0084[0m        [32m0.0646[0m       28.4365        1.8947
    497          [36m0.0084[0m        [32m0.0642[0m       28.3647        1.9165
    498          [36m0.0084[0m        [32m0.0637[0m       28.3127       

    585          [36m0.0063[0m        [32m0.0368[0m       29.9697        1.9075
    586          [36m0.0063[0m        [32m0.0366[0m       29.9801        1.8857
    587          [36m0.0063[0m        [32m0.0363[0m       30.0089        1.9458
    588          [36m0.0063[0m        [32m0.0361[0m       30.0479        1.9565
    589          [36m0.0063[0m        [32m0.0359[0m       30.0980        1.9576
    590          [36m0.0063[0m        [32m0.0357[0m       30.1588        1.8858
    591          [36m0.0062[0m        [32m0.0355[0m       30.2221        1.8639
    592          [36m0.0062[0m        [32m0.0353[0m       30.2822        2.0066
    593          [36m0.0062[0m        [32m0.0352[0m       30.3406        1.9541
    594          [36m0.0062[0m        [32m0.0350[0m       30.3835        1.9900
    595          [36m0.0062[0m        [32m0.0349[0m       30.4221        1.8899
    596          [36m0.0062[0m        [32m0.0348[0m       30.4583       

    683          [36m0.0049[0m        [32m0.0221[0m       31.9973        1.9607
    684          [36m0.0049[0m        [32m0.0221[0m       32.0304        1.9717
    685          [36m0.0049[0m        [32m0.0220[0m       32.0740        1.8963
    686          [36m0.0049[0m        [32m0.0219[0m       32.1242        1.9220
    687          [36m0.0049[0m        [32m0.0218[0m       32.1790        1.9619
    688          [36m0.0049[0m        [32m0.0217[0m       32.2286        1.9662
    689          [36m0.0049[0m        [32m0.0215[0m       32.2679        1.9701
    690          [36m0.0048[0m        [32m0.0214[0m       32.2845        1.8971
    691          [36m0.0048[0m        [32m0.0212[0m       32.2848        1.8700
    692          [36m0.0048[0m        [32m0.0211[0m       32.2776        1.9569
    693          [36m0.0048[0m        [32m0.0210[0m       32.2652        1.9559
    694          [36m0.0048[0m        [32m0.0209[0m       32.2550       

    781          [36m0.0039[0m        [32m0.0140[0m       33.4304        1.9628
    782          [36m0.0039[0m        [32m0.0139[0m       33.4323        1.8771
    783          [36m0.0039[0m        [32m0.0138[0m       33.4394        1.9445
    784          [36m0.0039[0m        [32m0.0138[0m       33.4468        1.9323
    785          [36m0.0039[0m        [32m0.0137[0m       33.4516        1.9644
    786          [36m0.0039[0m        [32m0.0137[0m       33.4503        1.8665
    787          [36m0.0039[0m        [32m0.0136[0m       33.4425        1.8604
    788          [36m0.0039[0m        [32m0.0135[0m       33.4351        1.9271
    789          [36m0.0038[0m        [32m0.0135[0m       33.4302        1.9682
    790          [36m0.0038[0m        [32m0.0134[0m       33.4308        1.9798
    791          [36m0.0038[0m        [32m0.0134[0m       33.4312        1.9060
    792          [36m0.0038[0m        [32m0.0133[0m       33.4336       

    879          [36m0.0031[0m        [32m0.0087[0m       32.3938        1.9166
    880          [36m0.0031[0m        [32m0.0087[0m       32.3966        1.9301
    881          [36m0.0031[0m        [32m0.0087[0m       32.4001        1.9800
    882          [36m0.0031[0m        [32m0.0086[0m       32.4018        1.9851
    883          [36m0.0031[0m        [32m0.0086[0m       32.3982        2.0110
    884          [36m0.0031[0m        [32m0.0085[0m       32.3918        1.8669
    885          [36m0.0030[0m        [32m0.0085[0m       32.3852        1.9340
    886          [36m0.0030[0m        [32m0.0085[0m       32.3790        1.9855
    887          [36m0.0030[0m        [32m0.0084[0m       32.3718        1.9827
    888          [36m0.0030[0m        [32m0.0084[0m       32.3652        1.9354
    889          [36m0.0030[0m        [32m0.0084[0m       32.3602        1.9148
    890          [36m0.0030[0m        [32m0.0083[0m       32.3601       

    977          [36m0.0025[0m        [32m0.0058[0m       32.8847        1.9691
    978          [36m0.0025[0m        [32m0.0058[0m       32.9011        1.9825
    979          [36m0.0025[0m        [32m0.0058[0m       32.9150        1.9332
    980          [36m0.0025[0m        [32m0.0058[0m       32.9263        1.9078
    981          [36m0.0025[0m        [32m0.0057[0m       32.9376        1.9545
    982          [36m0.0025[0m        [32m0.0057[0m       32.9491        1.9819
    983          [36m0.0025[0m        [32m0.0057[0m       32.9637        1.9806
    984          [36m0.0025[0m        [32m0.0057[0m       32.9787        1.9761
    985          [36m0.0025[0m        [32m0.0057[0m       32.9957        2.0031
    986          [36m0.0025[0m        [32m0.0057[0m       33.0161        1.9361
    987          [36m0.0025[0m        [32m0.0057[0m       33.0399        1.9568
    988          [36m0.0025[0m        [32m0.0056[0m       33.0660       

(3.1543664925188364, 1.7374764715340063)