<a href="https://colab.research.google.com/github/AdamKimhub/Msproject1/blob/colab/forfinal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from pathlib import Path
import ast
import numpy as np
import pandas as pd
from pymatgen.core import Structure, PeriodicSite, DummySpecie
from pymatgen.analysis.local_env import MinimumDistanceNN


In [None]:
# Get combined file
# To do this you need to handle the dataset size issue...


## The workflow

In [None]:
# The combined csv with all the high and low datasets will be named: combined.csv
defects_combined_path = Path(f'dataset/combined/combined.csv')
defects_df = pd.read_csv(defects_combined_path)


In [None]:
def props_targets(data_set, columns):
    x_props = data_set.drop(columns, axis= 1)
    y_targets = data_set[columns]
    return x_props, y_targets

## Attributes
Which attributes are the targets and which are the features?
1. HOMO value
2. LUMO value
3. Formtion energy*
4. Formation energy per site*

The target values (formation energy, HOMO, LUMO, etc.) are what the model needs to learn to predict.

If included, the model might “cheat” by directly using them instead of learning from the structure and other features.

x = node features (matrix, one row per node)

edge_index = connectivity (which node connects to which node)

edge_attr = edge features (like distance)

y = target (band gap)

u = global features (formation energy, total_mag, etc.)

I think the formula that the model should get is, if you have material x missing(vacancy defect) in the lattice or replaced by material y(the substitutional defect), the energy change in that site has a specific value(formation energy). When this value is combined with other formation energies at different defect sites at distances z, you get a positive or negative change to the band gap of the host materials.
So attributes will be:
1. Defect type(vacncy or substitutional)
2. Atomic number change of the site.(from here you get the formation energy)
3. Get formation energy relation with distance. 

In [None]:
# Seperate the target columns from the attribute columns
target_columns = ["formation_energy", "norm_homo", "norm_lumo"]
defects_df_x, defects_df_y = props_targets(defects_df,target_columns)

In [None]:
# Split the data into train, validation and test sets
from sklearn.model_selection import train_test_split

# Assume df contains structures and properties
train_df_x, test_df_x, train_df_y, test_df_y= train_test_split(defects_df_x, defects_df_y, test_size=0.2,
                                    stratify= defects_df_x["dataset_material"],
                                    random_state=42)

train_df_x, val_df_x, train_df_y, val_df_y = train_test_split(train_df_x, train_df_y, test_size=0.15,
                                    stratify= train_df_x["dataset_material"],
                                    random_state=42)



In [None]:

# Considering the train_df
# Represent the data as a graph
for index, row in train_df_x.iterrows():
    # Get attributes defining the structures
    cif_id = row["_id"]
    dataset_type = row["dataset_material"]
    the_material = row["base"]

    # Get defective structure
    defective_struct_path = Path(f'dataset/{dataset_type}/cifs/{cif_id}.cif')
    defective_structure = Structure.from_file(defective_struct_path)

    # Get reference structure
    ref_file_path = Path(f'dataset/{dataset_type}/{the_material}.cif')
    ref_unit_cell = Structure.from_file(ref_file_path)

    # Get cell size
    cell_size = list(ast.literal_eval(row["cell"]))
    reference_structure = ref_unit_cell.make_supercell(cell_size)

    # get defects structure, and nodes and edges
    defects_structure = get_defects_structure(defective_structure, reference_structure)
    part_nodes, the_edges, the_edge_features = get_nodes_edges(defect_structure)

    # Add global features to the nodes
    global_columns = ["energy", "fermi_level", "total_mag", "base", "vacancy_sites",
                    "substituiton_sites", "formation_energy", "formation_energy_per_site",
                    "energy_per_atom", "E_1"]

    global_vals = []
    for n in global_columns:
        global_vals.append(row[n])

    the_nodes = []
    for sub_list in part_nodes:
        sub_list = sub_list + global_vals
        the_nodes.append(sub_list)
