In [None]:
import sys

if "google.colab" in sys.modules:
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    original_data = '/content/drive/My Drive/original_dataset'
    final_data = '/content/drive/My Drive/Final_Dataset'

    # Install required packages
    !pip install pymatgen torch_geometric mp_api
    import torch
    from torch_geometric.data import Data

else:
    original_data = "original_dataset"
    final_data = "Final_Dataset"

In [None]:
import numpy as np
import pandas as pd
from pymatgen.core import Structure, PeriodicSite, DummySpecie, Composition
from pymatgen.core.periodic_table import Element
from pymatgen.analysis.local_env import MinimumDistanceNN
from mp_api.client import MPRester
import json
import config


API_KEY = config.API_KEY # Replace with your Materials Project API key

## Create a defects cloud

In [None]:
# For testing

# Get defective structure
defective_file_path = f"{original_data}/high_GaSe/cifs/GaSe_Ga72Se69S1_90b1b09f-acf0-46e6-8173-a2e71c884054.cif"
defective_structure = Structure.from_file(defective_file_path)

# Get reference structure
ref_file_path = f"{final_data}/ref_cifs/high_GaSe.cif"
reference_structure = Structure.from_file(ref_file_path)

In [None]:
def struct_to_dict(structure):
    rounded_coords = np.round(structure.frac_coords, 3)
    return {tuple(coord): site for coord, site in zip(rounded_coords, structure.sites)}

# Functions to get formation energy from Materials Project
def get_formation(element, API_KEY):
    with MPRester(API_KEY) as mpr:
        results = mpr.materials.summary.search(
            elements=[element],
            num_elements=1,
            fields= ["energy_per_atom"]
        )
        forms_list = [result.energy_per_atom for result in results]
        avg_formation_energy = np.mean(forms_list)
        
    return avg_formation_energy
    

def get_from_json(element, API_KEY):
    with open("./test.json", "r") as f:
        try:
            the_dict = json.load(f)
            if element in the_dict:
                to_return = the_dict[element]

            else:
                to_return = get_formation(element, API_KEY)
                the_dict[element] = to_return
                with open("./test.json", "w") as f:
                    json.dump(the_dict, f)

        except:
            the_dict = {}
            with open("./test.json", "a") as f:
                to_return = get_formation(element, API_KEY)
                the_dict[element] = to_return
                json.dump(the_dict, f)

        return to_return

def fe_site(original, new):
    if new == 0: # For vcancy
        fe_defect = get_from_json(original, API_KEY) * -1

    else: # For substitution
        form_original = get_from_json(original, API_KEY)
        form_new = get_from_json(new, API_KEY)
        fe_defect = (form_original * -1) + form_new
        
    return fe_defect

def get_defects_structure(defective_struct, reference_struct):
    mindnn = MinimumDistanceNN()
    # struct to dict
    defective_dict = struct_to_dict(defective_struct)
    reference_dict = struct_to_dict(reference_struct)

    # Get lattice of defective structure
    structure_lattice = defective_struct.lattice

    # List to add all defect sites
    defects_list = []

    # Dictionary to hold properties of each defect site
    defects_properties = {} 

    for ref_coord, ref_site in reference_dict.items():
        # Use the reference coordinates to get the defective site
        def_site = defective_dict.get(ref_coord)

        if def_site:  # The site is found in both the reference structure and the defective structure
            # But are the species the same?
            if ref_site.specie != def_site.specie:  # Substitution
                # Add site to defects list
                defects_list.append(def_site)

                # Get atomic number change and defect type
                add_property = {"original_an":ref_site.specie.Z,
                                "new_an": def_site.specie.Z,
                                "an_change": def_site.specie.Z - ref_site.specie.Z,
                                "vacancy_defect": 0.0,
                                "substitution_defect": 1.0,
                                "bonds_broken": 0.0,
                                "site_fe": fe_site(ref_site.species_string,def_site.species_string)}
                defects_properties[def_site] = add_property

        else: # the site from ref_structure is not found in defective structure
            # This means that the site is a vacancy site
            # Add site to defective structure
            vacant_site = PeriodicSite(
                species= DummySpecie(),
                coords= ref_coord,
                coords_are_cartesian= False, 
                lattice= structure_lattice
                )
            
            # Add site to defects list
            defects_list.append(vacant_site)

            # Get atomic number change and defect type
            add_property={"original_an":ref_site.specie.Z,
                          "new_an": 0,
                          "an_change": 0 - ref_site.specie.Z,
                          "vacancy_defect": 1.0,
                          "substitution_defect": 0.0,
                          "bonds_broken": mindnn.get_cn(reference_struct, reference_struct.sites.index(ref_site)),
                          "site_fe": fe_site(ref_site.species_string,0)}
            defects_properties[vacant_site] = add_property

    # create a defects structure
    defects_struct = Structure.from_sites(defects_list)

    # Add properties to defects structure
    for a_site in defects_struct.sites:
        if a_site in defects_properties.keys():
            a_site.properties.update(defects_properties[a_site])
        else:
            pass

    return defects_struct

defects_structure = get_defects_structure(defective_structure, reference_structure)
print(defects_structure)

In [None]:
# Turn the defects structure into crystal graph
def get_c_graph(structure):
    sites_list = structure.sites

    # The nodes: These are the sites features
    nodes = []
    for i, site in enumerate(sites_list):
        node_features = [
            i, 
            site.properties["bonds_broken"], 
            site.properties["original_an"], 
            site.properties["new_an"], 
            site.properties["an_change"], 
            site.properties["vacancy_defect"], 
            site.properties["substitution_defect"], 
            site.properties["site_fe"]
        ]

        # Node features syntax
        nodes.append(node_features)
         

    # The edges
    edges = [] # The sites in relation
    edge_features = [] # The distance between each site

    from_e = []
    to_e = []
    
    for i, site_i in enumerate(sites_list):
        for j, site_j  in enumerate(sites_list):
            # Edges 
            from_e.append(i)
            to_e.append(j)

            # Get distance between sites
            dist = site_i.distance(site_j)

            # Are the defects the same or different
            same_diff = int(site_i.properties["an_change"] == site_j.properties["an_change"])

            # What is the site_fe difference
            site_fe_diff = np.abs(site_i.properties["site_fe"] - site_j.properties["site_fe"])

            edge_features.append([dist,same_diff,site_fe_diff])
            
    edges.append(from_e)
    edges.append(to_e)

    # The global features
    the_ids = []
    the_ratios = []
    total_sites = len(sites_list)

    the_formula = structure.formula
    composition = Composition(the_formula)
    element_dict = composition.get_el_amt_dict()

    for symb, numb in element_dict.items():
        try:
            ids = Element(symb).Z - 1
        except ValueError:
            ids = 0
        the_ids.append(ids)
        ration = numb/total_sites
        the_ratios.append(ration)

    return nodes, edges, edge_features, the_ids, the_ratios

sample_graph = get_c_graph(defects_structure)

for i in sample_graph:
    print(i)

I utilised the **cloud representation of defects** in this project. 

This mechanism formulated by Nikita Kazeev is done by mapping the locations of every defect and creating a graphical representation of the *defects-only structure*.

To do this, you need to have:

1. The **defective structure**. This is the structure **with** defects
2. The **pristine structure**. This is pure form of the defective structure.

With these crystal structures, one can map the locations and get some attributes of the defects sites using `get_defects_structure(defective_structure, pristine_structue)`to create a **defects-only structure**.

Once you have the defects-only structure, you can get its graphical representation using `get_c_graph(defects_only_structure)`

The crystal graph consists of:
1. **Nodes:** These represent the atoms/defect points.
2. **Edges:** These are the bonds between atoms.
3. **Edge attributes:** These are the characteristics of the bonds.
4. **Global features:** These are features for the whole defects-only structure.

## Impliment cloud representation in the whole data

In [None]:
# Split the data
from sklearn.model_selection import train_test_split

comb_df = pd.read_csv(f"{final_data}/combined_dataset.csv")

train_set, test_set = train_test_split(comb_df, test_size=0.35, stratify=comb_df['strata'], random_state=42)
test_set, val_set = train_test_split(test_set, test_size=0.5, random_state=42)

In [None]:
# Create graph representation of the structures
def graphy(row):
    defective_structure = Structure.from_file(f"{original_data}/{row["dataset_material"]}/cifs/{row["_id"]}.cif")
    reference_structure = Structure.from_file(f"{final_data}/ref_cifs/{row["dataset_material"]}.cif")

    defects_only_structure = get_defects_structure(defective_structure, reference_structure)

    nodes, edges, edge_features, ids, ratios = get_c_graph(defects_only_structure)

    # target = train_set["band_gap_value"]
    target = row["band_gap_value"]

    the_data = Data(
        x=torch.tensor(nodes, dtype=torch.float),
        edge_index=torch.tensor(edges, dtype=torch.long),
        edge_attr=torch.tensor(edge_features, dtype=torch.float),
        the_ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0),
        the_ratios = torch.tensor(ratios, dtype=torch.float).unsqueeze(0),
        y=torch.tensor(target, dtype=torch.float).unsqueeze(0)
    )
    return the_data


In [None]:

# Turn each dataset into graph data and save it
training = train_set.apply(lambda row: graphy(row), axis = 1).tolist()
torch.save(training, f"{final_data}/combined/training.pt")

validating = val_set.apply(lambda row: graphy(row), axis = 1).tolist()
torch.save(validating, f"{final_data}/combined/validating.pt")

testing = test_set.apply(lambda row: graphy(row), axis = 1).tolist()
torch.save(testing, f"{final_data}/combined/testing.pt")