In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

from collections import Counter
import pandas as pd
import os
from pathlib import Path
from operator import itemgetter

import itertools as it
from collections import Counter
import numpy as np

from aiida import load_profile
from aiida.orm import load_node

from pymatgen.io.ase import AseAtomsAdaptor

from pybat import Cathode

from ase.build.tools import sort
from ase.io.espresso import read_espresso_in
from ase.visualize import view

from itables import init_notebook_mode
init_notebook_mode(all_interactive=True)

import itables.options as opt

from project_settings import *

opt.style = "table-layout:auto;width:80%:left"
opt.columnDefs = [{"className": "dt-left", "targets": "_all"}]

pd.set_option('display.max_colwidth', None)
pd.options.mode.chained_assignment = None

In [None]:
load_profile()

***
# Sampling of configurations using `icet`

In [None]:
from config_class import ConfigSet

config_set = ConfigSet()


In [None]:
# 'spinel_LiMnNiO': default_dict,   # Li8Mn12Ni4O32
# 'spinel_LiMnO': default_dict,     # Li8Mn16O32
# 'olivine_LiFePO': default_dict,   # Li4Fe4P4O16
# 'olivine_LiMnFePO': default_dict, # Li4Mn2Fe2P4O16
# 'olivine_LiMnPO4': default_dict,  # Li4Mn4P4O16

# Define dictionary that holds the dataframes for the structures and sizes
olivine_names = ['olivine_LiFePO', 'olivine_LiMnFePO', 'olivine_LiMnPO']
olivine_df_keys = list(it.product(olivine_names, [1,]))
olivine_df_keys
spinel_names = ['spinel_LiMnNiO', 'spinel_LiMnO']
spinel_df_keys = list(it.product(spinel_names, [1,]))
sampling_df_keys = olivine_df_keys + spinel_df_keys

fully_lithiated_df = pd.read_pickle(os.path.join(project_dir, 'data', 'fully_lithiated_df.pkl'))
fully_lithiated_df.insert(0, 'short_name', olivine_names + spinel_names)

print(list(fully_lithiated_df.columns))
fully_lithiated_df.head()

In [None]:
GLOBAL_SYMPREC = 1e-5

def create_sampling_df(input_pmg, short_name, cell_size, **kwargs):
    # ! Lithiation currently works only for one size

    initial_cat = Cathode.from_structure(input_pmg)
    configuration_list = initial_cat.get_cation_configurations(**kwargs)

    pmg_list = [cathode.as_ordered_structure() for cathode in configuration_list]
    ase_list = [AseAtomsAdaptor.get_atoms(pmg_) for pmg_ in pmg_list]
    formula_list = [_.get_chemical_formula() for _ in ase_list]
    cell_vector_list = [np.array2string(_.cell.todict()['array']) for _ in ase_list]
    sg_list = [_.space_group(symprec=GLOBAL_SYMPREC) for _ in configuration_list]
    li_list = [_.get_chemical_symbols().count('Li') for _ in ase_list]

    # ? Still need to apply the multiplicity_dict
    li_multiplicity_dict = dict(Counter(li_list))
    cell_multiplicity_dict = dict(Counter(cell_vector_list))
    sg_multiplicity_dict = dict(Counter(sg_list))

    # Put all the data into pandas df
    return_df = pd.DataFrame()
    
    return_df['short_name'] = [short_name]*len(configuration_list)
    return_df['cell_size'] = [cell_size]*len(configuration_list)
    return_df['configuration'] = configuration_list
    return_df['li_number'] = li_list
    return_df['formula'] = formula_list
    return_df['cell_vectors'] = cell_vector_list
    return_df['space_group'] = sg_list
    return_df['ase_structure'] = ase_list
    return_df['lithiation'] = return_df['li_number'] / (len(kwargs['substitution_sites'])*kwargs['sizes'][0])

    # return_df = return_df.sort_values(by='lithiation', ascending=False)
    return_df['li_mpc']= return_df['li_number'].map(li_multiplicity_dict)
    return_df['cell_mpc']= return_df['cell_vectors'].map(cell_multiplicity_dict)
    return_df['sg_mpc']= return_df['space_group'].map(sg_multiplicity_dict)
    return_df['index'] = [str(_) for _ in return_df.index]
    return_df['specific_name'] = \
        return_df['short_name'] + '-' + \
        return_df['cell_size'].astype(str) + '-' + \
        return_df['lithiation'].astype(str) + '-' + \
        return_df['index'].astype(str)
    
    # return_df[["short_name", "cell_size", "lithiation"]].apply("-".join, axis=1)

    return return_df

conf_df_list = []

for structure_index, sampling_df_key in enumerate(sampling_df_keys):
    
    loop_pmg = fully_lithiated_df.loc[fully_lithiated_df['short_name'] == sampling_df_key[0]]['pmg_in'].values[0]
    loop_ase = AseAtomsAdaptor.get_atoms(loop_pmg)
    loop_ase.center()

    short_name = sampling_df_key[0]
    cell_size = sampling_df_key[1]

    kwargs_dict = dict(
        substitution_sites=range(0, loop_ase.get_chemical_symbols().count('Li')),
        cation_list=["Li", "Vac"],
        sizes=[cell_size],
        concentration_restrictions={"Li": (0, 1)},
        max_configurations=None,
        symprec=GLOBAL_SYMPREC,
    )
    
    conf_df = create_sampling_df(input_pmg=loop_pmg, short_name=short_name, cell_size=cell_size, **kwargs_dict)
    conf_df_list.append(conf_df)

sampling_df_dict = dict(zip(sampling_df_keys, conf_df_list))

In [None]:
sampling_df_dict.keys()

In [None]:
show_columns_sampling = ['index', 'short_name', 'specific_name', 'li_number', 'formula', 'space_group', 'lithiation', 'li_mpc', 'cell_mpc', 'sg_mpc']

show = False

for sampling_df_key in itemgetter(0,2)(sampling_df_keys):
    current_df = sampling_df_dict[sampling_df_key]
    print("Structure: {}, Size: {}, # Configs: {}".format(*sampling_df_key, current_df.shape[0]))
    current_df.shape
    current_df[show_columns_sampling]
    # break
    if show is True:
        view(current_df['ase_structure'].values)
        input("Press Enter to continue...")
    # view(sampling_df_dict[sampling_df_key]['ase_structure'].values)
    # break

In [None]:
devel_df = sampling_df_dict[('olivine_LiMnPO', 1)]
devel_df[show_columns_sampling]

In [None]:
from aiida.orm import StructureData, load_code
from aiida_quantumespresso_hp.workflows.hubbard import SelfConsistentHubbardWorkChain
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
from aiida.engine import submit
from ase.atoms import Atoms
from aiida.plugins import DataFactory 

sorting_dict = {
    'Mn': 0,
    'O': 1,
    'P': 2,
    'Li': 3,
}

devel_df.shape

# hubbard_data_pks = []
# submit_pks = []
hubbard_data_pks = [3075, 3101, 3127, 3153, 3179, 3205, 3231]
submit_pks = [3100, 3126, 3152, 3178, 3204, 3230, 3256]

pw_code = load_code('qe-dev-pw@lumi-small')
hp_code = load_code('qe-dev-hp@lumi-small')

for ituple, tuple in enumerate(devel_df.itertuples()):
    
    ase_structure = tuple.ase_structure
    ase_ordered = Atoms(sorted(ase_structure, key=lambda x: sorting_dict[x.symbol]))
    ase_ordered.set_cell(ase_structure.get_cell())
    print(ase_ordered.get_chemical_symbols())
    
    # aiida_structure = StructureData(ase=ase_ordered)
    # hubbard_data = HubbardStructureData(structure=aiida_structure)

    # hubbard_data.initialize_onsites_hubbard('Mn', '3d', 4.5618)
    # hubbard_data.initialize_intersites_hubbard('Mn', '3d', 'O', '2p', 0.0001, number_of_neighbours=7) 

    # hubbard_data.store()
    # hubbard_data_pks.append(hubbard_data.pk)
    hubbard_data = load_node(hubbard_data_pks[ituple])
    print(hubbard_data.get_quantum_espresso_hubbard_card())

    builder = SelfConsistentHubbardWorkChain.get_builder_from_protocol(
        pw_code=pw_code,
        hp_code=hp_code,
        hubbard_structure=hubbard_data,
        protocol='fast',
        overrides=Path('hubbard_overrides.yaml')
        )
    builder.skip_first_relax = True
    KpointsData = DataFactory('array.kpoints')
    kpoints = KpointsData()
    kpoints.set_kpoints_mesh([2, 2, 2])
    builder.kpoints = kpoints

    print(type(builder))
    print(builder)
    
    # submit_return = submit(builder)
    # print(submit_return)
    # submit_pks.append(submit_return.pk)
    break

In [None]:
last_submit_pk = 3390
last_submit_uuid = 'a9daad90-aa2b-43c3-b9c2-9419fa4f7a1a'
!verdi process report {last_submit_pk}

In [None]:
print(hubbard_data_pks)
print(submit_pks)

devel_df.columns
devel_df.insert(6, 'hubbard_data_pk', hubbard_data_pks)
devel_df.insert(7, 'submit_pk', submit_pks)

In [None]:
devel_df[show_columns_sampling + ['hubbard_data_pk', 'submit_pk']]

In [None]:

for submit_pk in submit_pks[1:2]:
    !verdi process status {submit_pk}
    !verdi process report {submit_pk}
    !verdi process show {submit_pk}
    print('='*100)
    break