# Model analysis and tests

Input:DFT_DB, DFTB_DB.db, CCS_DB.db and PiNN_DB.db - ASE databasse with coordinates, energies and forces for all levels.

Output: Correlation plots showing DFT vs model data. 

In [None]:
PARAUTOMATIK_PATH='WRITE-Your-PATH-PARAUTOMATIK-Here'   # Add your installation path here

In [3]:
import os 
import sys
from ase import Atoms
from ase import io
from ase.io import read, write
import ase.db as db
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tnrange, tqdm
base_dir=os.getcwd()
os.chdir(base_dir)
if not os.path.isdir(base_dir+"/ModelAnalysis"):
    os.mkdir(base_dir+"/ModelAnalysis/")

In [None]:
#Read Data
dbname_DFT='DFT_DB.db'
dbname_DFTB='DFTB_DB.db'
dbname_CCS='CCS_DB.db'
dbname_PiNN='PiNN_DB.db'
db_dftb=db.connect(dbname_DFTB)
db_ccs=db.connect(dbname_CCS)
db_dft=db.connect(dbname_DFT)
db_pinn=db.connect(dbname_PiNN)
epred=[];edft=[];edftb=[];epinn=[];eccs=[];ecpinn=[]
for row in tqdm(db_dftb.select()):
    structure=row.toatoms()
    natoms=len(structure.get_atomic_numbers())
    id=str(row.id)
    epinn.append(db_pinn.get('id='+id).energy/natoms)
    edftb.append(db_dftb.get('id='+id).energy/natoms)
    eccs.append(db_ccs.get('id='+id).energy/natoms)
    edft.append(db_dft.get('id='+id).energy/natoms)       

In [None]:
edftb_ccs=[];
for i in range(0,len(edftb)):
    edftb_ccs.append(edftb[i]+eccs[i])

edftb_ccs_pinn=[];
for i in range(0,len(epinn)):
    edftb_ccs_pinn.append(edftb[i]+eccs[i]+epinn[i])
os.chdir(base_dir)   

In [None]:
#plot Energies 
plt.figure(figsize = (10,10))
sns.scatterplot(edft, edftb_ccs, s=50, alpha=0.8, color='r')
sns.scatterplot(edft, edftb_ccs_pinn, s=50, alpha=0.3,color='g')

plt.plot([np.min(edft),np.max(edft)], [np.min(edft),np.max(edft)], 'k-', lw=2.5)
plt.legend(labels=["DFTB+CCS","DFTB+CCS+PiNN"], fontsize=30 )
plt.xlabel('E$_{DFT}$/atom (eV)',fontsize=30)
plt.ylabel('E$_{model}$',fontsize=30)
plt.xlim(np.min(edft),np.max(edft))
plt.ylim(np.min(edft),np.max(edft))
plt.savefig('DFTvsModels.png')
plt.show()
