In [1]:
from pathlib import Path
import yaml
import sys
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pymatgen.core.structure import Structure, Element
import torch
sys.path.append("..")
from ai4mat.models.megnet_pytorch.megnet_on_structures import MEGNetOnStructures
from ai4mat.data.data import get_unit_cell, StorageResolver, read_defects_descriptions
from ai4mat.common.sparse_representation import get_sparse_defect, SINGLE_ENENRGY_COLUMN
from ai4mat.common.random_defect_generation import generate_structure_with_random_defects
from ai4mat.common.eos import EOS

In [2]:
training_experiment = "combined_mixed_all_train"
model_names = {
    "formation_energy_per_site": "megnet_pytorch/sparse/05-12-2022_19-50-53/d6b7ce45",
    "homo_lumo_gap_min": "megnet_pytorch/sparse/05-12-2022_19-50-53/831cc496"}

predictors = dict()
for target, trial_name in model_names.items():
    with open(StorageResolver()["trials"] / f"{trial_name}.yaml", "r") as f:
        config = yaml.safe_load(f)
    predictors[target] = MEGNetOnStructures(config['model_params'])
    predictors[target].load(StorageResolver()["checkpoints"] / training_experiment / target / trial_name / "0.pth",
                            map_location='cpu')

experiment_path = StorageResolver()["experiments"].joinpath(training_experiment)
with open(Path(experiment_path, "config.yaml")) as experiment_file:
    experiment_config = yaml.safe_load(experiment_file)
# We don't check whether the data split
# It's the caller's responsibility to ensure that all the datasets are used for training
training_datasets = experiment_config["datasets"]

In [3]:
unit_cells = dict()
defects_list = []
for dataset in tqdm(training_datasets):
    defects = read_defects_descriptions(StorageResolver()["csv_cif"] / dataset)
    materials = defects.base.unique()
    assert len(materials) == 1
    material = materials[0]
    unit_cell = get_unit_cell(StorageResolver()["csv_cif"] / dataset, materials)[material]
    cell = defects.cell.unique()
    assert len(cell) == 1
    cell = cell[0]
    # We have two MoS2/WSe2 unit cells with slightly different height,
    # both are valid and there is no need to include them here
    unit_cells[material] = EOS().get_augmented_struct(unit_cell)
    defects_list.append(defects)
defects_pd = pd.concat(defects_list, axis=0)
if 'pbc' in defects_pd.columns:
    defects_pd = defects_pd.drop(columns=['pbc'])

  0%|          | 0/8 [00:00<?, ?it/s]



In [4]:
from itertools import chain
from collections import namedtuple
SubstitutionDefect = namedtuple('substitution_defect', ['type', 'from_', 'to'])
VacancyDefect = namedtuple('vacancy_defect', ['type', 'element'])
def to_named_tuple(dict_):
    if dict_['type'] == 'substitution':
        return SubstitutionDefect(dict_['type'], dict_['from'], dict_['to'])
    elif dict_['type'] == 'vacancy':
        return VacancyDefect(**dict_)
    else:
        raise ValueError(f"Unknown defect type {dict_['type']}")

In [5]:
available_defects = defects_pd.groupby(['base', 'cell']).apply(lambda x: set(map(to_named_tuple, chain(*x.defects))))

In [6]:
import ipywidgets as widgets

In [7]:
material_labels = list(map(lambda x: f"{x[0]}, {x[1][0]}x{x[1][1]} supercell", available_defects.index))

In [9]:
def get_label(defect_tuple):
    if defect_tuple.type == 'substitution':
        return f"{defect_tuple.from_} -> {defect_tuple.to}"
    elif defect_tuple.type == 'vacancy':
        return f"{defect_tuple.element} vacancy"
    else:
        raise ValueError(f"Unknown defect type {defect_tuple.type}")

In [87]:
from IPython.display import clear_output
base_selection = widgets.RadioButtons(options=zip(material_labels, available_defects.index), description='Base material')
total_structures_selection = widgets.IntSlider(min=0, max=100, step=1, value=1, description="Total structures to generate")
total_defects_selection = widgets.IntSlider(min=0, max=15, step=1, value=1, description="Total defects")
max_defect_counts_selection = dict()
def prepare_defect_sliders(base_material):
    max_defect_counts_selection.clear()
    for defect in available_defects[base_material]:
        max_defect_counts_selection[defect] = widgets.IntSlider(min=0, max=15, step=1, value=1, description=get_label(defect))
    controls.children = [base_selection,
                            total_structures_selection,
                            total_defects_selection,
                            widgets.Label("Max counts for each defect type:"),
                            *max_defect_counts_selection.values()]
def select_defects(change):
    if change['type'] == 'change' and change['name'] == 'value':
        prepare_defect_sliders(change['owner'].value)
base_selection.observe(select_defects)
controls = widgets.VBox([base_selection])
prepare_defect_sliders(base_selection.value)
display(controls)

VBox(children=(RadioButtons(description='Base material', options=(('BN, 8x8 supercell', ('BN', (8, 8, 1))), ('…

In [81]:
from collections import defaultdict
max_defect_counts = defaultdict(dict)
for defect, count_widget in max_defect_counts_selection.items():
    if defect.type == 'substitution':
        max_defect_counts[defect.from_][defect.to] = count_widget.value
    else:
        max_defect_counts[defect.element]["Vacancy"] = count_widget.value

In [82]:
from tqdm.auto import trange
rng = np.random.default_rng(42)
reference_supercell = unit_cells[base_selection.value[0]].copy()
reference_supercell.make_supercell(base_selection.value[1])
structures = []
for i in trange(total_structures_selection.value):
    structures.append(generate_structure_with_random_defects(total_defects_selection.value, max_defect_counts, reference_supercell, rng, False))

  0%|          | 0/45 [00:00<?, ?it/s]

In [83]:
elements = set(map(lambda x: x.specie, reference_supercell))
for defect in available_defects[base_selection.value]:
    if defect.type == 'substitution':
        elements.add(Element(defect.from_))
        elements.add(Element(defect.to))
    else:
        elements.add(Element(defect.element))
single_atom_energies_dummy = pd.DataFrame(data=np.zeros((len(elements),1)), columns=[SINGLE_ENENRGY_COLUMN],
                                         index=np.fromiter(elements, dtype=Element, count=len(elements)))
sparse_structures = []
for structure in tqdm(structures):
    sparse_structures.append(get_sparse_defect(structure, unit_cells[base_selection.value[0]], base_selection.value[1], single_atom_energies_dummy)[0])

  0%|          | 0/45 [00:00<?, ?it/s]

In [84]:
sparse_structures[0]

Structure Summary
Lattice
    abc : 19.799508 27.605556 20.0
 angles : 90.0 90.0 90.0
 volume : 10931.528537328959
      A : 19.799508 0.0 1.2123702048446206e-15
      B : -1.6903527897041506e-15 27.605556 1.6903527897041506e-15
      C : 0.0 0.0 20.0
    pbc : True True True
PeriodicSite: X0+ (-0.0000, 16.5081, 4.1800) [0.0000, 0.5980, 0.2090]
PeriodicSite: N (9.8998, 11.9098, 4.1825) [0.5000, 0.4314, 0.2091]
PeriodicSite: N (9.8998, 11.0949, 6.2887) [0.5000, 0.4019, 0.3144]
PeriodicSite: X0+ (4.9499, 0.4141, 6.2800) [0.2500, 0.0150, 0.3140]
PeriodicSite: N (14.8496, 9.6093, 6.2887) [0.7500, 0.3481, 0.3144]
PeriodicSite: N (14.8496, 23.4121, 6.2887) [0.7500, 0.8481, 0.3144]
PeriodicSite: X0+ (1.6434, 4.1960, 4.1800) [0.0830, 0.1520, 0.2090]
PeriodicSite: N (11.5497, 8.7944, 4.1825) [0.5833, 0.3186, 0.2091]

In [85]:
for target, predictor in predictors.items():
    print(target)
    print(predictor.predict_structures(sparse_structures))

formation_energy_per_site


  0%|          | 0/45 [00:00<?, ?it/s]



[[ -9.45735  ]
 [ -7.1647673]
 [ -5.304395 ]
 [ -9.379557 ]
 [ -9.531681 ]
 [ -5.1471143]
 [ -7.1961794]
 [ -9.436419 ]
 [ -9.612769 ]
 [ -7.259413 ]
 [ -5.28455  ]
 [-11.56819  ]
 [ -3.2931852]
 [ -0.7973844]
 [ -7.283188 ]
 [ -9.352597 ]
 [ -7.3912344]
 [ -3.1820154]
 [ -9.339229 ]
 [-11.493329 ]
 [ -7.2290707]
 [ -9.43199  ]
 [ -9.460617 ]
 [ -3.045741 ]
 [-11.735934 ]
 [ -7.3334246]
 [ -7.20026  ]
 [ -9.43051  ]
 [ -9.422464 ]
 [ -7.4376698]
 [-11.467733 ]
 [ -9.395529 ]
 [-11.54327  ]
 [ -7.2522736]
 [ -9.501735 ]
 [ -9.353995 ]
 [-11.560653 ]
 [ -9.432962 ]
 [ -9.386898 ]
 [ -7.4142904]
 [ -7.4720917]
 [ -7.4508877]
 [ -5.0502334]
 [ -7.223666 ]
 [ -5.2595005]]
homo_lumo_gap_min


  0%|          | 0/45 [00:00<?, ?it/s]

[[0.6124114 ]
 [0.4156559 ]
 [0.31034902]
 [0.71984357]
 [0.78188115]
 [0.36402264]
 [0.8198599 ]
 [0.6532231 ]
 [0.6980879 ]
 [0.56682277]
 [0.6357219 ]
 [0.8310488 ]
 [0.3923952 ]
 [0.3038568 ]
 [0.6742657 ]
 [0.36102802]
 [0.6294773 ]
 [0.4698363 ]
 [0.40958408]
 [0.5824387 ]
 [0.6034634 ]
 [0.79876274]
 [0.42780495]
 [0.48113972]
 [0.76717037]
 [0.6436002 ]
 [0.66482496]
 [0.4774354 ]
 [0.519976  ]
 [0.69543743]
 [0.79816467]
 [0.6611786 ]
 [0.7785849 ]
 [0.71011573]
 [0.67526966]
 [0.3014007 ]
 [0.73278165]
 [0.47251245]
 [0.4178466 ]
 [0.7560054 ]
 [0.63701427]
 [0.5684327 ]
 [0.6332985 ]
 [0.44464138]
 [0.351126  ]]
