# Create an interface with ZSL and relax it with ALIGNN

Use Zur and McGill superlattices matching [algorithm](https://doi.org/10.1063/1.3330840) to create interfaces between two materials using the Pymatgen [implementation](https://pymatgen.org/pymatgen.analysis.interfaces.html#pymatgen.analysis.interfaces.zsl).

<h2 style="color:green">Usage</h2>

1. Drop the materials files into the "uploads" folder in the JupyterLab file browser
1. Set Input Parameters (e.g. `MILLER_INDICES`, `THICKNESS`, `MAX_AREA`) below or use the default values
1. Click "Run" > "Run All" to run all cells
1. Wait for the run to complete (depending on the area, it can take 1-2 min or more). Scroll down to view cell results.
1. Review the strain plot and modify its parameters as needed

## Methodology

The following happens in the script below:

1. Create slabs for each input material. The materials data is passed in from and back to the web application according to this description (TBA).
   We assume that two input materials are either in bulk form (e.g. Ni crystal) or layered (e.g. graphene). 
   
   We construct the interface along the Z-axis. The material corresponding to the bottom of the interface is referred to as the "**substrate**", and the top - as the "**layer**". 

2. Perform strain matching on the slabs to extract the supercell dimensions. The algorithm has a set of parameters, such as the maximum area considered, that can be configured by editing the cells below.

3. When the strain matching is finished, the interface with the lowest strain (and the smallest number of atoms) is selected. We create the corresponding supercells and place them at a specified distance from each other (note no shift is performed currently).


## 1. Set Input Parameters

### 1.1. Select Substrate and Layer from Input Materials

In [None]:
MATERIALS_FILES = [
    "Ni.json",
    "Gr.json",
]

SUBSTRATE_PARAMETERS = {
    "MATERIAL_INDEX": 0,  # the index of the material in the materials_in list
    "MILLER_INDICES": (1, 1, 1),  # the miller indices of the interfacial plane
    "THICKNESS": 6,  # in layers
}

LAYER_PARAMETERS = {
    "MATERIAL_INDEX": 1, # the index of the material in the materials_in list
    "MILLER_INDICES": (0, 0, 1),  # the miller indices of the interfacial plane
    "THICKNESS": 1,  # in layers
}

### 1.2. Set Interface Parameters

The distance between layer and substrate and maximum area to consider when matching.


In [None]:
INTERFACE_PARAMETERS = {
    "DISTANCE_Z": 3.0, # in Angstroms
    "MAX_AREA": 50, # in Angstroms^2
}

### 1.3. Set Algorithm Parameters

In [None]:
ZSL_PARAMETERS = {
    "MAX_AREA": INTERFACE_PARAMETERS["MAX_AREA"],  # The area to consider in Angstrom^2
    "MAX_AREA_TOL": 0.09,  # The area within this tolerance is considered equal
    "MAX_LENGTH_TOL": 0.03,  # supercell lattice vectors lengths within this tolerance are considered equal
    "MAX_ANGLE_TOL": 0.01,  # supercell lattice angles within this tolerance are considered equal
    "STRAIN_TOL": 10e-6,  # strains within this tolerance are considered equal
}
RELAXATION_PARAMETERS = {
    "FMAX": 0.05,
}

## 2. Install Packages

In [None]:
!python --version
!pip list
# !pip install pymatgen==2024.2.8 ase==3.22.1 nbformat==5.9.2 ipykernel==6.29.2 matgl==0.9.2 nglview==3.1.1 alignn==2024.2.4 PyQt5==5.15.10

## 3. Create interfaces

### 3.1. Extract Interfaces and Terminations

Extract all possible layer/substrate supercell combinations within the maximum area including different terminations.

In [None]:
import json
from src.pymatgen_coherent_interface_builder import CoherentInterfaceBuilder, ZSLGenerator
from src.utils import to_pymatgen

materials_in = []
FOLDER = "uploads"
for file in MATERIALS_FILES:
    with open(f"{FOLDER}/{file}", "r") as f:
        data = f.read()
        materials_in.append(json.loads(data))
        
if "materials_in" in globals():
    pymatgen_materials = [to_pymatgen(item) for item in materials_in]
for material in pymatgen_materials:
    print(material, "\n")


def create_interfaces(settings):
    print("Creating interfaces...")
    zsl = ZSLGenerator(
        max_area_ratio_tol=settings["ZSL_PARAMETERS"]["MAX_AREA_TOL"],
        max_area=settings["ZSL_PARAMETERS"]["MAX_AREA"],
        max_length_tol=settings["ZSL_PARAMETERS"]["MAX_LENGTH_TOL"],
        max_angle_tol=settings["ZSL_PARAMETERS"]["MAX_ANGLE_TOL"],
    )

    cib = CoherentInterfaceBuilder(
        substrate_structure=pymatgen_materials[settings["SUBSTRATE_PARAMETERS"]["MATERIAL_INDEX"]],
        film_structure=pymatgen_materials[settings["LAYER_PARAMETERS"]["MATERIAL_INDEX"]],
        substrate_miller=settings["SUBSTRATE_PARAMETERS"]["MILLER_INDICES"],
        film_miller=settings["LAYER_PARAMETERS"]["MILLER_INDICES"],
        zslgen=zsl,
        strain_tol=settings["ZSL_PARAMETERS"]["STRAIN_TOL"],
    )

    # Find terminations
    cib._find_terminations()
    terminations = cib.terminations

    # Create interfaces for each termination
    interfaces = {}
    for termination in terminations:
        interfaces[termination] = []
        for interface in cib.get_interfaces(
            termination,
            gap=settings["INTERFACE_PARAMETERS"]["DISTANCE_Z"],
            film_thickness=settings["LAYER_PARAMETERS"]["THICKNESS"],
            substrate_thickness=settings["SUBSTRATE_PARAMETERS"]["THICKNESS"],
            in_layers=True,
        ):
            # Wrap atoms to unit cell
            interface["interface"].make_supercell((1,1,1), to_unit_cell=True)
            interfaces[termination].append(interface)
    return interfaces, terminations


interfaces, terminations = create_interfaces(
    settings={
        "SUBSTRATE_PARAMETERS": SUBSTRATE_PARAMETERS,
        "LAYER_PARAMETERS": LAYER_PARAMETERS,
        "ZSL_PARAMETERS": ZSL_PARAMETERS,
        "INTERFACE_PARAMETERS": INTERFACE_PARAMETERS,
    }
)

### 3.2. Print out the interfaces and terminations

In [None]:
print(f'Found {len(terminations)} terminations')
for termination in terminations:
    print(f"Found {len(interfaces[termination])} interfaces for", termination, "termination")

## 4. Sort interfaces by strain

### 4.1. Sort all interfaces

In [None]:
# Could be "strain", "von_mises_strain", "mean_abs_strain"
strain_mode = "mean_abs_strain"

# Sort interfaces by the specified strain mode and number of sites
def sort_interfaces(interfaces, terminations):
    sorted_interfaces = {}
    for termination in terminations:
        sorted_interfaces[termination] = sorted(
            interfaces[termination], key=lambda x: (x[strain_mode], x["interface"].num_sites)
        )
    return sorted_interfaces


sorted_interfaces = sort_interfaces(interfaces, terminations)

### 4.2. Print out interfaces with lowest strain for each termination

In [None]:
for termination in terminations:
    print(f"Interface with lowest strain for termination {termination} (index 0):")
    first_interface = interfaces[termination][0]
    print("    strain:", first_interface[strain_mode] * 100, "%")
    print("    number of atoms:", first_interface["interface"].num_sites)

## 5. Plot the results

Plot the number of atoms vs strain. Adjust the parameters as needed.


In [None]:
import plotly.graph_objs as go
from collections import defaultdict

PLOT_SETTINGS = {
    "HEIGHT": 600,
    "X_SCALE": "log",  # or linear
    "Y_SCALE": "log",  # or linear
}


def plot_strain_vs_atoms(sorted_interfaces, terminations, settings):
    # Create a mapping from termination to its index
    termination_to_index = {termination: i for i, termination in enumerate(terminations)}

    grouped_interfaces = defaultdict(list)
    for termination, interfaces in sorted_interfaces.items():
        for index, interface_data in enumerate(interfaces):
            strain_percentage = interface_data["mean_abs_strain"] * 100
            num_sites = interface_data["interface"].num_sites
            key = (strain_percentage, num_sites)
            grouped_interfaces[key].append((index, termination))

    data = []
    for (strain, num_sites), indices_and_terminations in grouped_interfaces.items():
        termination_indices = defaultdict(list)
        for index, termination in indices_and_terminations:
            termination_indices[termination].append(index)
        all_indices = [index for indices in termination_indices.values() for index in indices]
        index_range = f"{min(all_indices)}-{max(all_indices)}" if len(all_indices) > 1 else str(min(all_indices))

        hover_text = "<br>-----<br>".join(
             f"Termination: {termination}<br>Termination index: {termination_to_index[termination]}<br>Interfaces Index Range: {index_range}<br>Strain: {strain:.2f}%<br>Atoms: {num_sites}"
            for termination, indices in termination_indices.items()
        )
        trace = go.Scatter(
            x=[strain],
            y=[num_sites],
            text=[hover_text],
            mode="markers",
            hoverinfo="text",
            name=f"Indices: {index_range}",
        )
        data.append(trace)

    layout = go.Layout(
        xaxis=dict(title="Strain (%)", type=settings["X_SCALE"]),
        yaxis=dict(title="Number of atoms", type=settings["Y_SCALE"]),
        hovermode="closest",
        height=settings["HEIGHT"],
        legend_title_text="Interfaces Index Range",
    )
    fig = go.Figure(data=data, layout=layout)
    fig.show()



plot_strain_vs_atoms(sorted_interfaces, terminations, PLOT_SETTINGS)

for i, termination in enumerate(terminations):
    print(f"Termination {i}:", termination)

## 6. Select the interface to relax

### 6.1. Select the interface with the desired termination and strain

The data in `sorted_interfaces` now contains an object with the following structure:

```json
{
    "('C_P6/mmm_2', 'Si_R-3m_1')": [
        { ...interface for ('C_P6/mmm_2', 'Si_R-3m_1') at index 0...},
        { ...interface for ('C_P6/mmm_2', 'Si_R-3m_1') at index 1...},
        ...
    ],
    "<termination at index 1>": [
        { ...interface for 'termination at index 1' at index 0...},
        { ...interface for 'termination at index 1' at index 1...},
        ...
    ]
}
```

Select the index for termination first, and for it - the index in the list of corresponding interfaces sorted by strain (index 0 has minimum strain).

In [None]:
termination_index = 0
interface_index = 0

termination = terminations[termination_index]

interface = sorted_interfaces[termination][interface_index]["interface"]

## 7. Apply relaxation
### 7.1. Apply relaxation to the selected interface

In [None]:
import plotly.graph_objs as go
from IPython.display import display
from plotly.subplots import make_subplots
from src.utils import ase_to_poscar, pymatgen_to_ase
from ase.optimize import BFGS
from alignn.ff.ff import AlignnAtomwiseCalculator as Calculator, default_path, wt1_path

# Set up the calculator
model_path = default_path()
calculator = Calculator(path=model_path)

# Set up the interface for relaxation
ase_interface = pymatgen_to_ase(interface)
ase_interface.set_calculator(calculator)
dyn = BFGS(ase_interface)

# Initialize empty lists to store steps and energies
steps = []
energies = []

# Create a plotly figure widget
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter"}]])
scatter = go.Scatter(x=[], y=[], mode='lines+markers', name='Energy')
fig.add_trace(scatter)
fig.update_layout(title_text='Real-time Optimization Progress', xaxis_title='Step', yaxis_title='Energy (eV)')

# Display figure widget
f = go.FigureWidget(fig)
display(f)

# Define a callback function to update the plot at each step
def plotly_callback():
    step = dyn.nsteps
    energy = ase_interface.get_total_energy()

    # Add the new step and energy to the lists
    steps.append(step)
    energies.append(energy)

    print(f"Step: {step}, Energy: {energy:.4f} eV")

    # Update the figure with the new data
    with f.batch_update():
        f.data[0].x = steps
        f.data[0].y = energies


# Run the relaxation
dyn.attach(plotly_callback, interval=1)
dyn.run(fmax=RELAXATION_PARAMETERS["FMAX"])

# Extract results
ase_original_interface = pymatgen_to_ase(interface)
ase_original_interface.set_calculator(calculator)
ase_final_interface = ase_interface

original_energy = ase_original_interface.get_total_energy() 
relaxed_energy = ase_interface.get_total_energy()

# Print out the final relaxed structure and energy
print('Original structure:\n', ase_to_poscar(ase_original_interface))
print('\nRelaxed structure:\n', ase_to_poscar(ase_final_interface))
print(f"The final energy is {float(relaxed_energy):.3f} eV.")

### 7.2. View structure before and after relaxation

In [None]:
import base64
from ase.io import write
from ase.build import make_supercell
from IPython.display import HTML
import io

def visualize_material_base64(material, title: str, rotation: str = '0x', number_of_repetitions: int = 3):
    """
    Returns an HTML string with a Base64-encoded image for visualization,
    including the name of the file, positioned horizontally.
    """
    # Set the number of unit cell repetition for the structure
    n = number_of_repetitions
    material_repeat = make_supercell(material, [[n,0,0],[0,n,0],[0,0,1]])
    text = f"{material.symbols} - {title}"

    # Write image to a buffer to display in HTML
    buf = io.BytesIO()
    write(buf, material_repeat, format='png', rotation=rotation)
    buf.seek(0)
    img_str = base64.b64encode(buf.read()).decode('utf-8')
    html_str = f'''
    <div style="display: inline-block; margin: 10px; vertical-align: top;">
        <p>{text}</p>
        <img src="data:image/png;base64,{img_str}" alt="{title}" />
    </div>
    '''
    return html_str

html_original = visualize_material_base64(ase_original_interface, "original", "-90x")
html_relaxed = visualize_material_base64(ase_final_interface, "relaxed", "-90x")

# Display the interfaces before and after relaxation
html_content = f'<div style="display: flex;">{html_original}{html_relaxed}</div>'
display(HTML(html_content))


### 7.3. Calculate energy energy using matgl ALIGNN

In [None]:
def filter_atoms_by_tag(atoms, material_index):
    """Filter atoms by their tag, corresponding to the material index."""
    return atoms[atoms.get_tags() == material_index]


def calculate_energy(atoms, calculator):
    """Set calculator for atoms and return their total energy."""
    atoms.set_calculator(calculator)
    return atoms.get_total_energy()


def calculate_delta_energy(total_energy, *component_energies):
    """Calculate the delta energy by subtracting component energies from the total energy."""
    return total_energy - sum(component_energies)


# Filter atoms for original and relaxed interfaces
substrate_original_interface = filter_atoms_by_tag(ase_original_interface, SUBSTRATE_PARAMETERS["MATERIAL_INDEX"])
layer_original_interface = filter_atoms_by_tag(ase_original_interface, LAYER_PARAMETERS["MATERIAL_INDEX"])
substrate_relaxed_interface = filter_atoms_by_tag(ase_final_interface, SUBSTRATE_PARAMETERS["MATERIAL_INDEX"])
layer_relaxed_interface = filter_atoms_by_tag(ase_final_interface, LAYER_PARAMETERS["MATERIAL_INDEX"])

# Calculate energies
original_substrate_energy = calculate_energy(substrate_original_interface, calculator)
original_layer_energy = calculate_energy(layer_original_interface, calculator)
relaxed_substrate_energy = calculate_energy(substrate_relaxed_interface, calculator)
relaxed_layer_energy = calculate_energy(layer_relaxed_interface, calculator)

# Calculate delta energies
delta_original = calculate_delta_energy(original_energy, original_substrate_energy, original_layer_energy)
delta_relaxed = calculate_delta_energy(relaxed_energy, relaxed_substrate_energy, relaxed_layer_energy)

# Calculate area and effective delta per area
area = ase_original_interface.get_volume() / ase_original_interface.cell[2, 2]
number_of_interface_atoms = ase_final_interface.get_global_number_of_atoms()
number_of_substrate_atoms = substrate_relaxed_interface.get_global_number_of_atoms()
number_of_layer_atoms = layer_relaxed_interface.get_global_number_of_atoms()
effective_delta_relaxed = (relaxed_energy/number_of_interface_atoms - (relaxed_substrate_energy/number_of_substrate_atoms + relaxed_layer_energy/number_of_layer_atoms)) / (2 * area)

# Print out the metrics
print(f"Original Substrate energy: {original_substrate_energy:.4f} eV")
print(f"Relaxed Substrate energy: {relaxed_substrate_energy:.4f} eV")
print(f"Original Layer energy: {original_layer_energy:.4f} eV")
print(f"Relaxed Layer energy: {relaxed_layer_energy:.4f} eV")
print("\nDelta between interface energy and sum of component energies")
print(f"Original Delta: {delta_original:.4f} eV")
print(f"Relaxed Delta: {delta_relaxed:.4f} eV")
print(f"Original Delta per area: {delta_original / area:.4f} eV/Ang^2")
print(f"Relaxed Delta per area: {delta_relaxed / area:.4f} eV/Ang^2")
print(f"Relaxed interface energy: {relaxed_energy:.4f} eV")
print(f"Effective relaxed Delta per area: {effective_delta_relaxed:.4f} eV/Ang^2 ({effective_delta_relaxed / 0.16:.4f} J/m^2)\n")

# Print out the POSCARs
print("Relaxed interface:\n", ase_to_poscar(ase_final_interface))
print("Relaxed substrate:\n", ase_to_poscar(substrate_relaxed_interface))
print("Relaxed layer:\n", ase_to_poscar(layer_relaxed_interface))

## References

[1] ASE alignn calculator relaxation example: https://github.com/usnistgov/alignn?tab=readme-ov-file#alignnff 
[2] Plotly 3D molecular visualization: https://dash.plotly.com/dash-bio/molecule3dviewer