In [None]:
""" Code to fit a pair-potential to reproduce a dimer dissociation curve """

In [None]:
import itertools as it
import os
from types import SimpleNamespace

import matplotlib.pyplot as plt

import plato_pylib.plato.mod_plato_inp_files as modInp
import plato_pylib.plato.parse_tbint_files as parseTbint
import plato_pylib.utils.job_running_functs as jobRun

import plato_fit_integrals.core.coeffs_to_tables as coeffToTab
import plato_fit_integrals.core.create_analytical_reprs as analyticFuncts
import plato_fit_integrals.core.opt_runner as optRunner
import plato_fit_integrals.core.obj_funct_calculator as objFunctCalc
import plato_fit_integrals.core.workflow_coordinator as wflowCoord

import plato_fit_integrals.initialise.obj_functs_targ_vals as objCmpFuncts
import plato_fit_integrals.initialise.create_coeff_tables_converters as createCoeffTabs
import plato_fit_integrals.initialise.create_ecurve_workflows as ecurves
import plato_fit_integrals.initialise.fit_analytic_to_initial_tables as fitInit

import plato_fit_integrals.utils.plot_functs as fitPlotFuncts


TEST_SEPS = [x for x in range (2,7)]
ATOM_SYMBOL = "Si"


N_CORES = 6

MODEL_DATAFOLDER = "Test/format_4"
FULL_PATH_MODEL_DATAFOLDER = modInp.getAbsolutePathForPlatoTightBindingDataSet(MODEL_DATAFOLDER)

WORK_FOLDER = "work_folder"
REF_WORK_FOLDER = os.path.join("work_folder","ref_calcs")

OUT_ATTR = "energy_vals"

RUN_REF_JOBS = True

#Paramters for the analytic function we use to represent the pair-potential
RCUT = parseTbint.getBdtRcut( os.path.join(FULL_PATH_MODEL_DATAFOLDER,"{}_{}.bdt".format(ATOM_SYMBOL,ATOM_SYMBOL)) )
REF_R0 = 1
N_POLY = 5
TAIL_DELTA = 0.5


In [None]:
def getModOptsDict():
    outDict = dict()
    outDict["blochstates"] = [1,1,1]
    outDict["dataset"] = MODEL_DATAFOLDER
    return outDict

In [None]:

def createCoeffsToTablesObj():
    integHolder = createCoeffTabs.createIntegHolderFromModelFolderPath(FULL_PATH_MODEL_DATAFOLDER)
    integInfo = coeffToTab.IntegralTableInfo(FULL_PATH_MODEL_DATAFOLDER, "pairpot", ATOM_SYMBOL, ATOM_SYMBOL)
    relIntegTable = integHolder.getIntegTableFromInfoObj(integInfo,inclCorrs=False)
    aFunct = createAnalyticRepFunct(relIntegTable)
    return coeffToTab.CoeffsTablesConverter([aFunct], [integInfo], integHolder)



#Want the node to be movable
def createAnalyticRepFunct(integTable):
    nodePositions = fitInit.findCrossings(integTable.integrals)
    valAtR0 = fitInit.getInterpYValGivenXValandInpData(REF_R0,integTable.integrals)
    outFunct = analyticFuncts.Cawkwell17ModTailRepr(rCut=RCUT,refR0=REF_R0,valAtR0=valAtR0,
                                                                 nPoly=N_POLY, tailDelta=TAIL_DELTA,
                                                                 nodePositions=nodePositions)
    outFunct.promoteNodePositionsToVariables()
    print("nodePositions = {}".format(nodePositions))
    return outFunct



In [None]:
def createWorkFlowCoordAndObjFunctCalc(inpGeoms, targEnergies):
    workFlow = createWorkFlowDissocCurve(inpGeoms)
    objCalculator = createObjFunctCalculator(workFlow, targEnergies)
    workFlowCoord = wflowCoord.WorkFlowCoordinator([workFlow])
    return workFlowCoord, objCalculator
    
def createWorkFlowDissocCurve(inpGeoms):
    platoCodeStr = "tb1"
    modOpts = getModOptsDict()
    workFlow = ecurves.CreateStructEnergiesWorkFlow(inpGeoms, modOpts, WORK_FOLDER, platoCodeStr, outAttr=OUT_ATTR)()
    wflowCoord.decorateWorkFlowWithPrintOutputsEveryNSteps(workFlow)
    return workFlow

def createObjFunctCalculator(inpWorkFlow,targEnergies):
    outCalculator = ecurves.createObjFunctCalculatorFromEcurveWorkflow(inpWorkFlow, targEnergies, "relrootsqrdev",averageMethod="mean",
                                                                       errorRetVal=1e10)
    return outCalculator



In [None]:
#Setup ref calculations
inpGeoms = ecurves.createDimerDissocCurveStructs(TEST_SEPS,ATOM_SYMBOL,ATOM_SYMBOL)
platoRefStr = "dft2"
varyType=None

refWorkFlow = ecurves.CreateStructEnergiesWorkFlow(inpGeoms,getModOptsDict(), REF_WORK_FOLDER,platoRefStr,outAttr=OUT_ATTR, varyType=varyType) ()


In [None]:
#Run ref calculations
if RUN_REF_JOBS:
    refRunComms = refWorkFlow.preRunShellComms
    jobRun.executeRunCommsParralel(refRunComms, N_CORES)

In [None]:
#Get the reference energies 
refWorkFlow.run()
refEnergies = getattr( refWorkFlow.output, refWorkFlow.namespaceAttrs[0] ) #Only a sinlge attr

In [None]:
#Create the actual workflow + other components
coeffsToTables = createCoeffsToTablesObj()
workFlowCoord, objFunctCalculator = createWorkFlowCoordAndObjFunctCalc(inpGeoms,refEnergies)
finalFitObjFunction = optRunner.ObjectiveFunction(coeffsToTables, workFlowCoord, objFunctCalculator)


In [None]:
#Fit coefficients to the initial pair-potential
fitResInitInts = fitInit.fitAnalyticFormToStartIntegrals(coeffsToTables,method='Nelder-Mead')




In [None]:
#Plot initial fit to integrals
figA = fitPlotFuncts.plotFittedIntsVsInitial(coeffsToTables._integInfo[0],coeffsToTables)
figA.get_axes()[0].set_xlim(2,12)
figA.get_axes()[0].set_ylim(-0.4,1.0)

In [None]:
#Get and plot the initial values for the dissociation curve
initEnergies = getattr(finalFitObjFunction.workFlowCoordinator.runAndGetPropertyValues(),OUT_ATTR)
fitPlotFuncts.plotDissocCurvesInitVsFinal(inpGeoms, refEnergies, initEnergies)


In [None]:
#Fit to the distance vs energy curve
fitRes = optRunner.carryOutOptimisationBasicOptions(finalFitObjFunction,method='Nelder-Mead')

In [None]:
fitRes

In [None]:
#Plot fitted vs target dissociation energy curve
finalEnergies = getattr(fitRes.calcVals,OUT_ATTR)
fitPlotFuncts.plotDissocCurvesInitVsFinal(inpGeoms, refEnergies, finalEnergies)

In [None]:
#Plot Initial vs final integrals
figA = fitPlotFuncts.plotFittedIntsVsInitial(coeffsToTables._integInfo[0],coeffsToTables)
# figA.get_axes()[0].set_xlim(2,12)
# figA.get_axes()[0].set_ylim(-0.4,1.0)