In [None]:
# Load in the relevant packages
from ase.io import read,write
from ase.build import bulk
import numpy as np
import ase.db as db
from ase.visualize import view
from ase.calculators.lj import LennardJones
import matplotlib.pyplot as plt
import json
import sympy
import matplotlib.pyplot as plt
import os
import sys

# Get the absolute path of the parent directory
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))

# Add the parent directory to the sys.path list
sys.path.insert(0, parent_dir)

# Import the module from the parent directory
from helper import fit_task

# try:
#     os.remove('G2B.db')
# except:
#     pass

#SYSTEM
damping=0.1 # Empirical damping parameter used in the rattleling, hard materials need a small value... 
            # Natural range problably around 0.25-5. 
Tm= 1975    # Melting point
a0=3.238   # Cell paramter
c0=5.176
u0=0.3814

crystal=bulk('ZnO','wurtzite',a=a0,c=c0,u=u0)

#FORCE-FIELD
# https://pubs.acs.org/doi/10.1021/jp411308z
# 		
r_cut=12.0
CCS_res=0.1
pairs=["Zn-Zn","Zn-O","O-O"]
charge_dict= {"Zn":1.14,"O":-1.14}
V_ij="D_ij*( (1-exp(-a_ij*(r_ij-r_0)))**2 - 1) + C_ij*exp(-r_ij/rho)"

FF={"Zn-Zn": {"D_ij":0.0,     "a_ij":0.0, "r_0":0.0 ,"C_ij": 78.91    ,"rho":0.5177},
    "Zn-O" : {"D_ij":0.0,     "a_ij":0.0, "r_0":0.0 ,"C_ij":257600.0  ,"rho":0.1396},
    "O-O"  : {"D_ij": 0.1567, "a_ij":3.405 ,"r_0":1.164 ,"C_ij": 0.0 ,"rho":1.0}}


FT=fit_task(pairs=pairs,
            crystal=crystal,
            Tm=Tm,
            r_cut=r_cut,
            V_ij=V_ij,
            charge_dict=charge_dict,
            damping=damping,
            CCS_res=CCS_res
            )

FT.assign_params(FF)

# BUILD INITIAL TRAINING-SET
We first deterimine the bounds for the volume and then try to fill up data uniformly accros the assible volume range.

In [None]:
FT.init_training()

# BUILD SCARMBLED DATA SET

In [None]:
FT.scramble(size=1000,DB="G2B.db",damping=0.25)

# Check sampling quality

In [None]:
FT.check_sampling()

# Fit and analyse

In [None]:
import pandas as pd
df=pd.DataFrame()

In [None]:
from ccs_fit.scripts.ccs_fetch import ccs_fetch
from ccs_fit import ccs_fit
import pandas as pd
import seaborn as sns

for Ns in [1,2,4,8,16,32]:
    i=0
    while i < 10:
        ccs_fetch(mode="CCS+Q",DFT_DB="G2B.db",include_forces=True,charge_dict=charge_dict,R_c=r_cut,Ns=Ns)
        ccs_fit("CCS_input.json")
        Overlap = float( FT.calculate_overlap_rmse()['Total'] )
        q_err=FT.compare_q()
        tmp=pd.DataFrame([{"Method":"CCS", "No_samples":Ns,"Overlap":Overlap,"q_error":q_err }])
        df=pd.concat([df,tmp],ignore_index=True)
        Overlap = float(FT.calculate_overlap_rmse(UNC=True)['Total'])
        q_err=FT.compare_q(UNC=True)
        tmp=pd.DataFrame([{"Method":"UNC", "No_samples":Ns,"Overlap":Overlap,"q_error":q_err }])
        df=pd.concat([df,tmp],ignore_index=True)
        i += 1    

         
    sns.lineplot(x='No_samples', y='Overlap', hue='Method', data=df)
    plt.show()
    sns.lineplot(x='No_samples', y='q_error', hue='Method', data=df)
    plt.show()



In [None]:
import pandas as pd
df_NoF=pd.DataFrame()

In [None]:
from ccs_fit.scripts.ccs_fetch import ccs_fetch
from ccs_fit import ccs_fit
import pandas as pd
import seaborn as sns

for Ns in [8,16,32]:
    i=0
    while i < 10:
        ccs_fetch(mode="CCS+Q",DFT_DB="G2B.db",include_forces=False,charge_dict=charge_dict,R_c=r_cut,Ns=Ns)
        ccs_fit("CCS_input.json")
        Overlap = float( FT.calculate_overlap_rmse()['Total'] )
        q_err=FT.compare_q()
        tmp=pd.DataFrame([{"Method":"CCS", "No_samples":Ns,"Overlap":Overlap,"q_error":q_err }])
        df_NoF=pd.concat([df_NoF,tmp],ignore_index=True)
        Overlap = float(FT.calculate_overlap_rmse(UNC=True)['Total'])
        q_err=FT.compare_q(UNC=True)
        tmp=pd.DataFrame([{"Method":"UNC", "No_samples":Ns,"Overlap":Overlap,"q_error":q_err }])
        df_NoF=pd.concat([df_NoF,tmp],ignore_index=True)
        i += 1    

         
    sns.lineplot(x='No_samples', y='Overlap', hue='Method', data=df_NoF)
    plt.show()
    sns.lineplot(x='No_samples', y='q_error', hue='Method', data=df_NoF)
    plt.show()