In [None]:
# Clustering: takes the final embeddings generated in 4) and clusters them using hdbscan. The clusters are then visualised
# in 2D using UMAP and t-SNE (allows for separation of validation and training datasets)
# Features lots of code cells each helping to analyse the quality of the results.

In [None]:
# Decides which dimensionality reduction algorithm to use for making the plots
DIMENSIONALITY_REDUCTION_ALGORITHM = 't-SNE' #'t-SNE', 'UMAP' or 'both'
# Decides which split of the data to use for making the plots
DATA_SPLIT = 'validation' #'training', 'validation' or 'full'

# The supercell folder used to make the final embeddings. Used to visualise the structures of each material
INPUT_FOLDER = 'Supercells/2DMatpedia Sublattices 3x3'

# The name of your model used throughout the pipeline
MODEL_NAME = 'Test'

# Global parameter for the supercell size for graph visualisations. This should match your input folder and model
SUPERCELL_SIZE = 3

In [None]:
# Embedding Loader: loads the final embeddings and labels generated in 4) and splits them into their training and validation sets

import torch
import numpy as np

torch.set_printoptions(profile="full")

embeddings_path = f'Embeddings/embeddings_{MODEL_NAME}.pt'
labels_path = f'Labels/labels_{MODEL_NAME}.pt'

embeddings = torch.load(embeddings_path, map_location=torch.device('cpu')).numpy()
labels = torch.load(labels_path, map_location=torch.device('cpu'))

# Calculates training and validation set sizes
NUM_TRAINING = int(len(embeddings) * 0.8)
NUM_VALIDATE = len(embeddings) - NUM_TRAINING 

# Set random seed for reproducibility
fixed_seed = 42
np.random.seed(fixed_seed)
indices = np.random.permutation(len(embeddings))

shuffled_embeddings = [embeddings[i] for i in indices]
shuffled_labels = [labels[i] for i in indices]

training_embeddings = shuffled_embeddings[:NUM_TRAINING]
training_labels = shuffled_labels[:NUM_TRAINING]
validation_embeddings = shuffled_embeddings[NUM_TRAINING:]
validation_labels = shuffled_labels[NUM_TRAINING:]

training_embeddings = np.array(training_embeddings)
validation_embeddings = np.array(validation_embeddings)



In [None]:
# Clustering Workflow: takes the loaded embeddings and uses hdbscan to cluster them. Then, displays these clusters
# in a 2D plot using dimensionality reduction algorithms (either t-SNE or UMAP)
# Also provides some stats inlcuding average outlier score and numbers of clustered points
# Naming: saves the clustering labels as 'clustered_labels.npy' and the plots as 't-SNE Clusters.png' and 'UMAP Clusters.png'


import matplotlib.pyplot as plt
import hdbscan
import numpy as np
from sklearn.manifold import TSNE
import umap.umap_ as umap

def hdbscan_clustering(data):
    # Initialize and fit HDBSCAN
    hdb = hdbscan.HDBSCAN(min_samples=3, min_cluster_size=4, prediction_data=True)
    cluster_labels = hdb.fit_predict(data)
    
    # Retrieve outlier scores
    outlier_scores = hdb.outlier_scores_
    
    return hdb, cluster_labels, outlier_scores

def tsne_plot(data, cluster_labels):
    tsne = TSNE(n_components=2, perplexity=20, random_state=42, init='pca')
    data_tsne_2d = tsne.fit_transform(data)

    np.save("data_tsne_2d.npy", data_tsne_2d)

    unique_labels = np.unique(cluster_labels)
    background_points = (cluster_labels == -1)
    
    plt.figure(figsize=(10, 8))
    plt.scatter(data_tsne_2d[background_points, 0], data_tsne_2d[background_points, 1],
                c='lightgray', s=10, alpha=0.5, label='Noise')
    
    plt.scatter(data_tsne_2d[~background_points, 0], data_tsne_2d[~background_points, 1],
                c=cluster_labels[~background_points], cmap='tab20', s=10, alpha=0.7)
    
    # Annotate centroids for each cluster (excluding noise)
    for label in unique_labels:
        if label != -1:
            label_points = data_tsne_2d[cluster_labels == label]
            centroid = np.mean(label_points, axis=0)
            plt.text(centroid[0], centroid[1], str(label), fontsize=8, fontweight='bold', 
                     color='black', ha='center', va='center')
    
    plt.title('t-SNE Visualization of HDBSCAN Clusters in 2D')
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.legend()
    plt.savefig(f'Plots and Visualisations/Clustering Plots/t-SNE Clusters - {MODEL_NAME}.png')
    plt.show()

def umap_plot(data, cluster_labels):
    umap_reducer = umap.UMAP(n_neighbors=20, n_components=2, min_dist=0.5, random_state=42, init='pca')
    data_umap_2d = umap_reducer.fit_transform(data)

    np.save("data_umap_2d.npy", data_umap_2d)

    unique_labels = np.unique(cluster_labels)
    background_points = (cluster_labels == -1)
    
    plt.figure(figsize=(10, 8))
    plt.scatter(data_umap_2d[background_points, 0], data_umap_2d[background_points, 1],
                c='lightgray', s=10, alpha=0.5, label='Noise')
    
    plt.scatter(data_umap_2d[~background_points, 0], data_umap_2d[~background_points, 1],
                c=cluster_labels[~background_points], cmap='tab20', s=10, alpha=0.7)
    
    # Annotate centroids for each cluster (excluding noise)
    for label in unique_labels:
        if label != -1:
            label_points = data_umap_2d[cluster_labels == label]
            centroid = np.mean(label_points, axis=0)
            plt.text(centroid[0], centroid[1], str(label), fontsize=8, fontweight='bold', 
                     color='black', ha='center', va='center')
    
    plt.title('UMAP Visualization of HDBSCAN Clusters in 2D')
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.legend()
    plt.savefig(f'Plots and Visualisations/Clustering Plots/UMAP Clusters - {MODEL_NAME}.png')
    plt.show()

def save_clusters_with_labels(cluster_labels, embedding_labels, filename='clustered_labels.npy'):
    combined_array = np.column_stack((embedding_labels, cluster_labels))
    np.save(filename, combined_array)
    print(f"Clusters and labels saved to {filename}")

def full_clustering_workflow(data, embedding_labels, dimensionality_reduction='both'):
    print("Clustering with HDBSCAN...")
    hdb, cluster_labels, outlier_scores = hdbscan_clustering(data)
    
    # Save the combined result of embedding labels + cluster labels
    save_clusters_with_labels(cluster_labels, embedding_labels)

    # Depending on user choice, run t-SNE and/or UMAP
    if dimensionality_reduction in ['t-SNE', 'both']:
        print("Visualizing with t-SNE...")
        tsne_plot(data, cluster_labels)
    if dimensionality_reduction in ['UMAP', 'both']:
        print("Visualizing with UMAP...")
        umap_plot(data, cluster_labels)
    
    # Masks
    clustered_mask = (cluster_labels != -1)
    noise_mask = (cluster_labels == -1)
    
    print("Total number of data points:", len(cluster_labels))
    print("Number of clustered points:", np.sum(clustered_mask))
    print("Number of noise points:", np.sum(noise_mask))
    
    # Filter out NaN values from outlier scores
    valid_indices = ~np.isnan(outlier_scores)
    clustered_mask = clustered_mask & valid_indices
    noise_mask = noise_mask & valid_indices
    
    clustered_outlier_scores = outlier_scores[clustered_mask]
    noise_outlier_scores = outlier_scores[noise_mask]
    
    # Print average outlier scores
    if clustered_outlier_scores.size > 0:
        avg_clustered_outlier_score = np.mean(clustered_outlier_scores)
        print(f"Average Outlier Score for Clustered Points: {avg_clustered_outlier_score:.4f}")
    else:
        print("No valid outlier scores for clustered points.")
    
    if noise_outlier_scores.size > 0:
        avg_noise_outlier_score = np.mean(noise_outlier_scores)
        print(f"Average Outlier Score for Noise Points: {avg_noise_outlier_score:.4f}")
    else:
        print("No valid outlier scores for noise points.")
    
    return hdb, cluster_labels

def load_clusters_with_labels(filename='clustered_labels.npy'):
    loaded_array = np.load(filename)
    embedding_labels = loaded_array[:, 0]  
    cluster_labels = loaded_array[:, 1]    
    print(f"Clusters and labels loaded from {filename}")
    return embedding_labels, cluster_labels


# Decide which split to use based on DATA_SPLIT
if DATA_SPLIT == 'training':
    data_for_clustering = training_embeddings
    labels_for_clustering = training_labels
elif DATA_SPLIT == 'validation':
    data_for_clustering = validation_embeddings
    labels_for_clustering = validation_labels
else:
    # 'full' or any other unexpected value uses the entire dataset
    data_for_clustering = embeddings
    labels_for_clustering = labels

# Now run the clustering workflow using the user's choice of dimensionality reduction
hdb, cluster_labels = full_clustering_workflow(
    data_for_clustering,
    labels_for_clustering,
    dimensionality_reduction=DIMENSIONALITY_REDUCTION_ALGORITHM
)

# Optionally load the saved cluster labels + embedding labels
embedding_labels, loaded_cluster_labels = load_clusters_with_labels('clustered_labels.npy')


In [None]:
# 2D visualisation plot: Average bond length
# Computes the average magnitude of the lattice a and b vectors and plots them as a colobar on the original t-SNE axes
# Note: only works with t-SNE

import os
from ase.io import read
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

lattice_lengths = np.zeros(len(validation_labels))

for i, label in enumerate(validation_labels):
    file_path = os.path.join(INPUT_FOLDER, f"supercell_2dm-{label}.xyz")
    atoms = read(file_path)
    
    a_val, b_val, c_val, alpha, beta, gamma = atoms.cell.cellpar()
    
    a_scaled = a_val / SUPERCELL_SIZE
    b_scaled = b_val / SUPERCELL_SIZE
    c_scaled = c_val / SUPERCELL_SIZE

    avg_lattice_length = 0.5 * (a_scaled + b_scaled)
    
    lattice_lengths[i] = avg_lattice_length


def plot_tsne_by_lattice_length(coords_2d, lattice_lengths,
                               cmap='turbo',
                               title='2D Projection'):
    plt.figure(figsize=(10, 8))
    #norm=mcolors.Normalize(vmin=0, vmax=40)
    #norm = mcolors.PowerNorm(gamma=0.5, vmin=0, vmax=40)
    norm=mcolors.LogNorm(vmin=1, vmax=15)

    plt.scatter(coords_2d[:, 0], coords_2d[:, 1],
                c=lattice_lengths,
                cmap=cmap,
                norm=norm,
                s=10)
    plt.title(title)
    plt.xlabel('Dim 1')
    plt.ylabel('Dim 2')
    plt.savefig(f"Plots and Visualisations/Property Plots/Lattice Length/{MODEL_NAME}_lattice_length.png")
    plt.show()

# Load the saved 2D coords
data_tsne_2d = np.load("data_tsne_2d.npy")

plot_tsne_by_lattice_length(data_tsne_2d, lattice_lengths, 
                      title="t-SNE colored by Average Lattice Vector Length")

In [None]:
# 2D visualisation plot: Number of atoms in unit cell
# Note: only works with t-SNE

import os
from ase.io import read
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

num_atoms_in_unitcell = np.zeros(len(validation_labels))

for i, label in enumerate(validation_labels):
    file_path = os.path.join(INPUT_FOLDER, f"supercell_2dm-{label}.xyz")
    
    atoms = read(file_path)
    
    num_atoms_in_unitcell[i] = len(atoms) // (SUPERCELL_SIZE**2)

def plot_tsne_by_unitcell_atoms(coords_2d, num_atoms_in_unitcell,
                               cmap='turbo',
                               title='t-SNE colored by # of Atoms in Unit Cell'):
    plt.figure(figsize=(10, 8))
    
    norm = mcolors.LogNorm(vmin=1, vmax=num_atoms_in_unitcell.max()) 
    
    plt.scatter(coords_2d[:, 0], coords_2d[:, 1],
                c=num_atoms_in_unitcell,
                cmap=cmap,
                norm=norm,
                s=10)
    plt.title(title)
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.savefig(f"Plots and Visualisations/Property Plots/Atoms in Unit Cell/{MODEL_NAME}_unitcell_atoms.png")
    plt.show()

data_tsne_2d = np.load("data_tsne_2d.npy")

plot_tsne_by_unitcell_atoms(
    data_tsne_2d, 
    num_atoms_in_unitcell,
    title="t-SNE colored by # of Atoms in Single Unit Cell"
)


In [None]:
# 2D visualisation plot: Lattice types (work in progress)
# Note: only works with t-SNE

import os
from ase.io import read
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# 1. Classification Function
def classify_2D_lattice(a, b, angle_deg, tol=1e-1):
    """
    Given the lengths a, b (the first two cell vectors) and the angle between them (in degrees),
    return a string representing the 2D Bravais lattice type.
    
    Tolerances are used to handle floating-point rounding. You can tweak them as needed.
    """
    same_lengths   = abs(a - b) < tol
    angle_diff_90  = abs(angle_deg - 90)
    angle_diff_120 = abs(angle_deg - 120)

    if same_lengths and angle_diff_90 < 1.0:
        return "Square"
    elif not same_lengths and angle_diff_90 < 1.0:
        return "Rectangular"
    elif same_lengths and angle_diff_120 < 2.0:
        return "Hexagonal"
    else:
        return "Oblique"

# 2. Read Structures & Assign Lattice Types
lattice_types = []

for label in validation_labels:
    file_path = os.path.join(INPUT_FOLDER, f"supercell_2dm-{label}.xyz")
    atoms = read(file_path)
    
    a_val, b_val, c_val, alpha, beta, gamma = atoms.cell.cellpar()

    a_scaled = a_val / SUPERCELL_SIZE
    b_scaled = b_val / SUPERCELL_SIZE
    c_scaled = c_val / SUPERCELL_SIZE

    lat_type = classify_2D_lattice(a_scaled, b_scaled, gamma)
    lattice_types.append(lat_type)

lattice_types = np.array(lattice_types)


# 3. Plot t-SNE by Lattice Type
def plot_tsne_by_lattice_type(coords_2d, lattice_types, title='t-SNE Colored by 2D Lattice Type'):
    """
    Create a scatter plot, coloring each point based on its 2D lattice type.
    We'll create a discrete legend rather than a continuous colorbar.
    """
    plt.figure(figsize=(10, 8))
    
    unique_types = np.unique(lattice_types)
    
    color_map = plt.get_cmap('tab10')
    
    for i, lat_type in enumerate(unique_types):
        mask = (lattice_types == lat_type)
        
        plt.scatter(coords_2d[mask, 0],
                    coords_2d[mask, 1],
                    c=[color_map(i)], 
                    s=10,
                    label=lat_type)
    
    plt.title(title)
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.legend()
    plt.savefig(f"Plots and Visualisations/Property Plots/Lattice Type/{MODEL_NAME}_lattice_types.png")
    plt.show()

# 4. Load t-SNE coordinates & Plot
data_tsne_2d = np.load("data_tsne_2d.npy")
plot_tsne_by_lattice_type(data_tsne_2d, lattice_types)


In [None]:
# 2D visualisation plot: Aspect ratios
# Note: only works with t-SNE

import os
from ase.io import read
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

aspect_ratios = np.zeros(len(validation_labels))

for i, label in enumerate(validation_labels):
    file_path = os.path.join(INPUT_FOLDER, f"supercell_2dm-{label}.xyz")
    atoms = read(file_path)
    
    a_val, b_val, c_val, alpha, beta, gamma = atoms.cell.cellpar()
    
    a_scaled = a_val / SUPERCELL_SIZE
    b_scaled = b_val / SUPERCELL_SIZE
    
    if b_scaled == 0:
        ratio = 1.0
    else:
        ratio = max(a_scaled, b_scaled) / min(a_scaled, b_scaled)
    
    aspect_ratios[i] = ratio

def plot_tsne_by_aspect_ratio(coords_2d, aspect_ratios,
                              cmap='turbo',
                              title='t-SNE colored by Aspect Ratio'):
    plt.figure(figsize=(10, 8))
    norm = mcolors.LogNorm(vmin=1, vmax=10)

    plt.scatter(coords_2d[:, 0], coords_2d[:, 1],
                c=aspect_ratios,
                cmap=cmap,
                norm=norm,
                s=10)
    plt.title(title)
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.savefig(f"Plots and Visualisations/Property Plots/Aspect Ratio/{MODEL_NAME}_aspect_ratio.png")
    plt.show()

# Load the saved 2D t-SNE coords
data_tsne_2d = np.load("data_tsne_2d.npy")

# Plot with aspect ratio
plot_tsne_by_aspect_ratio(
    data_tsne_2d, 
    aspect_ratios,
    title="t-SNE colored by Lattice Aspect Ratio"
)

In [None]:
# 2D visualisation plot: Buckling Heights
# Note: only works with t-SNE

import os
from ase.io import read
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

buckling_heights = np.zeros(len(validation_labels))

for i, label in enumerate(validation_labels):
    file_path = os.path.join(INPUT_FOLDER, f"supercell_2dm-{label}.xyz")
    atoms = read(file_path)
    
    z_coords = atoms.positions[:, 2]
    
    buckling_heights[i] = z_coords.max() - z_coords.min()

def plot_tsne_by_buckling_height(coords_2d, buckling_heights,
                                 cmap='viridis',
                                 title='t-SNE colored by Buckling Height'):
    plt.figure(figsize=(10, 8))
    
    norm = mcolors.Normalize(vmin=0, vmax=10)

    plt.scatter(coords_2d[:, 0], coords_2d[:, 1],
                c=buckling_heights,
                cmap=cmap,
                norm=norm,
                s=10)
    plt.title(title)
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.savefig(f"Plots and Visualisations/Property Plots/Buckling Height/{MODEL_NAME}_buckling_height.png")
    plt.show()

# Load your saved t-SNE coordinates
data_tsne_2d = np.load("data_tsne_2d.npy")

# Plot using the computed buckling heights
plot_tsne_by_buckling_height(data_tsne_2d, buckling_heights,
                             title="t-SNE colored by Buckling Height (Z-Coords)")

In [None]:
# 2D visualisation plot: Average coordination number
# Note: only works with t-SNE

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from ase.io import read

# 1) Load the graph list from PyTorch Geometric
graph_list_path = os.path.join('Graphs/sublattice_test', "graph_list_unperturbed_1.pt")
graph_list = torch.load(graph_list_path)

# 2) Build a dict: { graph_label : average_coordination_number }
label_to_avgcoord = {}
for data in graph_list:
    # data.x.shape[0] = number of nodes (atoms)
    num_nodes = data.x.shape[0]
    # data.edge_index.shape[1] = number of edges (pairs)
    num_edges = data.edge_index.shape[1]
    
    # Average coordination number = E / N
    avg_cnum = num_edges / (num_nodes*2.0)
    
    # data.label is the label string, e.g. '1_F'
    label_str = data.label
    label_to_avgcoord[label_str] = avg_cnum

# 3) For each structure in validation_labels, retrieve the average CN
coordination_values = np.zeros(len(validation_labels))
for i, lbl in enumerate(validation_labels):
    coordination_values[i] = label_to_avgcoord[lbl]

# 4) Plot the t-SNE embedding, colored by average coordination number
def plot_tsne_by_coordination(coords_2d, coordination_vals,
                              cmap='viridis',
                              title='t-SNE colored by Avg Coordination Number'):
    plt.figure(figsize=(10, 8))
    

    norm = mcolors.LogNorm(vmin=1, vmax=max(coordination_values)) 
    
    plt.scatter(coords_2d[:, 0], coords_2d[:, 1],
                c=coordination_vals,
                cmap=cmap,
                norm=norm,
                s=10)

    plt.title(title)
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    

    plt.savefig(f"Plots and Visualisations/Property Plots/Coordination Number/{MODEL_NAME}_coordination_number.png")
    plt.show()

data_tsne_2d = np.load("data_tsne_2d.npy")

plot_tsne_by_coordination(
    data_tsne_2d,
    coordination_values,
    title="t-SNE colored by Avg Coordination Number"
)


In [None]:
# Cluster Finder: returns the cluster number of a given material and then views the material in an interactive window

from ase.io import read
from ase.visualize import view

# Material you want to find
MATERIAL_LABEL = '1_Ru'

# Find the indices where embedding_labels match the MATERIAL_LABEL
indices = np.where(embedding_labels == MATERIAL_LABEL)[0]

if len(indices) == 0:
    print(f"Material with label {MATERIAL_LABEL} is not found in the dataset.")
else:
    index = indices[0]  # Assuming each material label is unique
    cluster = cluster_labels[index]
    if cluster == -1:
        print(f"Material with label {MATERIAL_LABEL} is considered noise (not assigned to any cluster).")
    else:
        print(f"Material with label {MATERIAL_LABEL} is in cluster {int(cluster)}.")




# Read the structure from the XYZ file
atoms = read(f'{INPUT_FOLDER}\supercell_2dm-{MATERIAL_LABEL}.xyz')
# View the structure (opens an interactive window)
view(atoms)


In [None]:
# Defines the cluster you want to analyse and returns the labels of the materials within it

CLUSTER_NUM = 32
cluster = embedding_labels[cluster_labels==CLUSTER_NUM]
cluster

In [None]:
# Cluster Analyser: Provides the properties of and visualizes every node in the specified cluster

import os
import shutil
import numpy as np
import torch
import pandas as pd
from ase.io import read
from torch_geometric.data import Data
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from IPython.display import HTML

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the folder to save plots
plot_folder = 'Plots and Visualisations\Temp\Cluster Node Plots'

# Ensure the plot folder exists (and clear it)
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)
else:
    for filename in os.listdir(plot_folder):
        file_path = os.path.join(plot_folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)  # Remove the file
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # Remove the directory
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")

# 1. Compute edges (nearest and next-nearest neighbors) with PBC
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()

    for i in range(num_atoms):
        dist_matrix[i, i] = np.inf

        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest-neighbor cutoff
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]
        for j in nn_indices:
            edge_index.add((j, i))

        # Next-nearest neighbor
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]
        if len(remaining_distances) > 0:
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# 2. Create a PyTorch Geometric Data object (no edge_attr used)
def create_graph_from_structure(atoms, delta):
    edge_index = compute_edge_index(atoms, delta)
    
    # No node features for now
    num_nodes = len(atoms)
    node_features = torch.empty((num_nodes, 0), dtype=torch.float)

    # Build the Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index
        # Note: no edge_attr here
    )

    return graph

# 3. Plotting function
def plot_graph(atoms, graph, margin=0.1, tolerance=0.1, y_size=6, filename=None):
    """
    Plots the graph of an atomic structure with consistent y-direction size.
    """
    positions = atoms.get_positions()
    cell = atoms.get_cell()

    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]

    # Group z positions into layers using the tolerance
    z_grouped = np.round(z / tolerance) * tolerance
    z_unique, indices = np.unique(z_grouped, return_inverse=True)
    N_layers = len(z_unique)

    # Define colormap
    cmap = plt.cm.get_cmap('viridis', N_layers)

    fig, ax = plt.subplots()

    # Shift supercell boundary by 0.5*c
    c_shift = 0.5 * cell[2]
    corner_positions = np.array([
        [0, 0, 0],
        cell[0],
        cell[0] + cell[1],
        cell[1],
        [0, 0, 0]
    ]) + c_shift

    # Determine plot boundaries
    x_min, x_max = corner_positions[:, 0].min(), corner_positions[:, 0].max()
    y_min, y_max = corner_positions[:, 1].min(), corner_positions[:, 1].max()

    x_margin = (x_max - x_min) * margin
    y_margin = (y_max - y_min) * margin

    x_min -= x_margin
    x_max += x_margin
    y_min -= y_margin
    y_max += y_margin

    # Keep y-axis size consistent
    y_range = y_max - y_min
    x_range = x_max - x_min
    aspect_ratio = x_range / y_range
    fig.set_size_inches(y_size * aspect_ratio, y_size)

    # Plot boundary
    ax.plot(corner_positions[:, 0], corner_positions[:, 1], 'k--', linewidth=1)
    # Plot atoms
    ax.scatter(x, y, c=indices, s=50, cmap=cmap, zorder=2)

    # Build line segments for edges
    lines = []
    for idx in range(graph.edge_index.shape[1]):
        i = graph.edge_index[0, idx].item()
        j = graph.edge_index[1, idx].item()

        pos_i = positions[i]
        pos_j = positions[j]

        # Consider minimal image convention for j
        delta_scaled = atoms.get_scaled_positions()[j] - atoms.get_scaled_positions()[i]
        delta_scaled -= np.round(delta_scaled)

        # Adjust pos_j for drawing
        delta = delta_scaled @ cell
        pos_j_plot = pos_i + delta

        lines.append([pos_i[:2], pos_j_plot[:2]])

    lc = LineCollection(lines, colors='gray', linewidths=1, zorder=1)
    ax.add_collection(lc)

    # Final plot settings
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_aspect('equal')
    ax.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

# 4. Process each material in the cluster
def process_nodes_in_cluster(CLUSTER_NUM, cluster_labels, embedding_labels, input_folder):
    """
    Processes structures for each material in a given cluster.

    :param CLUSTER_NUM: int
    :param cluster_labels: np.array of integers (one cluster label per material)
    :param embedding_labels: np.array of strings or ints (identifiers for materials)
    :param input_folder: Path to folder containing supercells, e.g. 'Supercells/supercells_sublattices_3x3'
    """
    cluster_indices = np.where(cluster_labels == CLUSTER_NUM)[0]
    if len(cluster_indices) == 0:
        print(f"No points found in cluster {CLUSTER_NUM}.")
        return None

    data_list = []

    for idx in cluster_indices:
        # Use the label as a string (in case it's something like '1000_Eu')
        material_label_str = str(embedding_labels[idx])

        # Build the file path for that label
        file_path = os.path.join(input_folder, f"supercell_2dm-{material_label_str}.xyz")

        try:
            atoms = read(file_path)
            atoms.pbc = [True, True, False]  # 2D system

            # Create graph
            graph = create_graph_from_structure(atoms, delta)

            # Save plot
            plot_filename = f"graph_cluster_{CLUSTER_NUM}_material_{material_label_str}.png"
            plot_filepath = os.path.join(plot_folder, plot_filename)
            plot_graph(atoms, graph, margin=0.2, tolerance=0.1, y_size=6, filename=plot_filepath)
            print(f"Graph for material label {material_label_str} from cluster {CLUSTER_NUM} saved as {plot_filepath}.")

            # Additional properties
            num_atoms = len(atoms)

            # cellpar returns [a, b, c, alpha, beta, gamma]
            a_val, b_val, c_val, alpha, beta, gamma = atoms.cell.cellpar()

            # Scale the first two cell lengths by SUPERCELL_SIZE,
            # keeping the third length unscaled
            scaled_cell_lengths = [
                f"{a_val / SUPERCELL_SIZE:.3f}",
                f"{b_val / SUPERCELL_SIZE:.3f}",
                f"{c_val:.3f}"  # unchanged
            ]

            # Extract just the angles (alpha, beta, gamma)
            cell_angles = [alpha, beta, gamma]

            # Format the lengths and angles
            formatted_cell_lengths = [f"{length}" for length in scaled_cell_lengths]
            formatted_cell_angles  = [f"{angle:.3f}" for angle in cell_angles]

            # Number of atoms in the single unit cell
            num_atoms_in_unit_cell = num_atoms // (SUPERCELL_SIZE ** 2)

            # Simple approach: slice out the first 'num_atoms_in_unit_cell'
            unit_cell_atoms = atoms[:num_atoms_in_unit_cell]
            formula = unit_cell_atoms.get_chemical_formula()

            data_list.append({
                "Cluster": CLUSTER_NUM,
                "Material": material_label_str,
                "Chemical Formula": formula,
                "No. atoms per cell": num_atoms_in_unit_cell,
                "Unitcell Lengths": formatted_cell_lengths,
                "Cell Angles": formatted_cell_angles,
                "Plot": plot_filepath
            })

        except FileNotFoundError:
            print(f"Structure file for material label {material_label_str} not found at {file_path}.")
            continue
        except Exception as e:
            print(f"Error processing material {material_label_str}: {e}")
            continue

    return data_list


# 5. Main script usage
data_list = []
cluster_data = process_nodes_in_cluster(
    CLUSTER_NUM, 
    cluster_labels, 
    embedding_labels, 
    input_folder = INPUT_FOLDER
)

if cluster_data is not None:
    data_list.extend(cluster_data)
else:
    print(f"No data collected for cluster {CLUSTER_NUM}.")

# Define the column order
column_order = [
    "Cluster",
    "Material",
    "Chemical Formula",
    "No. atoms per cell",
    "Unitcell Lengths",
    "Cell Angles",
    "Plot"
]

df = pd.DataFrame(data_list, columns=column_order)

def path_to_image_html(path):
    return f'<img src="{path}" height="200">'

HTML(df.to_html(index=False, escape=False, formatters={"Plot": path_to_image_html}))


In [None]:
# Core Node Visualiser: visualises the core node of every cluster

import os
import shutil
import numpy as np
import torch
import pandas as pd
from ase.io import read
from torch_geometric.data import Data
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from IPython.display import HTML

plot_folder = 'Plots and Visualisations\Temp\Core Node Plots'
delta = 0.1

if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)
else:
    for filename in os.listdir(plot_folder):
        file_path = os.path.join(plot_folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")

# Helper functions to build graph and plot
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()

    for i in range(num_atoms):
        dist_matrix[i, i] = np.inf 
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest-neighbor cutoff
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]
        for j in nn_indices:
            edge_index.add((j, i))

        # Next-nearest-neighbor cutoff
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]
        if len(remaining_distances) > 0:
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]
            for j in nnn_indices:
                edge_index.add((j, i))

    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index


def create_graph_from_structure(atoms, delta):
    edge_index = compute_edge_index(atoms, delta)
    num_nodes = len(atoms)

    node_features = torch.empty((num_nodes, 0), dtype=torch.float)

    graph = Data(
        x=node_features,
        edge_index=edge_index
    )
    return graph


def plot_graph(atoms, graph, margin=0.1, tolerance=0.1, y_size=6, filename=None):
    """
    Plots the graph of an atomic structure with consistent y-direction size.
    """
    positions = atoms.get_positions()
    cell = atoms.get_cell()

    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]

    z_grouped = np.round(z / tolerance) * tolerance
    z_unique, indices = np.unique(z_grouped, return_inverse=True)
    N_layers = len(z_unique)

    cmap = plt.cm.get_cmap('viridis', N_layers)
    fig, ax = plt.subplots()

    c_shift = 0.5 * cell[2]
    corner_positions = np.array([
        [0, 0, 0],
        cell[0],
        cell[0] + cell[1],
        cell[1],
        [0, 0, 0]
    ]) + c_shift

    x_min, x_max = corner_positions[:, 0].min(), corner_positions[:, 0].max()
    y_min, y_max = corner_positions[:, 1].min(), corner_positions[:, 1].max()

    x_margin = (x_max - x_min) * margin
    y_margin = (y_max - y_min) * margin

    x_min -= x_margin
    x_max += x_margin
    y_min -= y_margin
    y_max += y_margin

    y_range = y_max - y_min
    x_range = x_max - x_min
    aspect_ratio = x_range / y_range
    min_aspect_ratio = 1 / 3
    if aspect_ratio < min_aspect_ratio:
        aspect_ratio = min_aspect_ratio

    x_size = y_size * aspect_ratio
    fig.set_size_inches(x_size, y_size)

    # Plot the supercell boundary
    ax.plot(corner_positions[:, 0], corner_positions[:, 1], 'k--', linewidth=1)
    ax.scatter(x, y, c=indices, s=50, cmap=cmap, zorder=2)

    lines = []
    for idx_edge in range(graph.edge_index.shape[1]):
        i = graph.edge_index[0, idx_edge].item()
        j = graph.edge_index[1, idx_edge].item()

        pos_i = positions[i]
        pos_j = positions[j]

        # Compute delta with PBCs
        delta_scaled = atoms.get_scaled_positions()[j] - atoms.get_scaled_positions()[i]
        delta_scaled -= np.round(delta_scaled)
        pos_j_plot = pos_i + (delta_scaled @ cell)

        lines.append([pos_i[:2], pos_j_plot[:2]])

    lc = LineCollection(lines, colors='gray', linewidths=1, zorder=1)
    ax.add_collection(lc)

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_aspect('equal')
    ax.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

# Core-node processing
def process_core_node_in_cluster(cluster_num, cluster_labels, embedding_labels, hdb, input_folder):
    """
    Identifies the single most "core" node (highest probability) within a cluster,
    reads its structure, plots it, and returns a dict with plot info.
    """
    # Get indices belonging to this cluster
    cluster_indices = np.where(cluster_labels == cluster_num)[0]
    if len(cluster_indices) == 0:
        return None

    # Extract HDBSCAN probabilities for these points and find the best
    cluster_probabilities = hdb.probabilities_[cluster_indices]
    max_prob_index_in_cluster = cluster_indices[np.argmax(cluster_probabilities)]

    # Convert material label to a string (if needed to build a filename)
    material_label_str = str(embedding_labels[max_prob_index_in_cluster])

    # Try to read the relevant file from your input folder
    file_path = os.path.join(input_folder, f"supercell_2dm-{material_label_str}.xyz")
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        return None

    try:
        atoms = read(file_path)
        num_atoms = len(atoms)
        # For 2D materials, assume periodic in x and y only
        atoms.pbc = [True, True, False]

        # Create the graph
        graph = create_graph_from_structure(atoms, delta)

        # Prepare output filename for the plot
        plot_filename = f"cluster_{cluster_num}_core_{material_label_str}.png"
        plot_filepath = os.path.join(plot_folder, plot_filename)

        # Plot and save
        plot_graph(
            atoms,
            graph,
            margin=0.1,
            tolerance=0.1,
            y_size=6,
            filename=plot_filepath
        )

        num_atoms_in_unit_cell = num_atoms // (SUPERCELL_SIZE ** 2)

        # slice out the first 'num_atoms_in_unit_cell'
        unit_cell_atoms = atoms[:num_atoms_in_unit_cell]
        formula = unit_cell_atoms.get_chemical_formula()

        # Prepare a dictionary with info for HTML grid
        return {
            "Cluster": cluster_num,
            "Chemical Formula": formula,
            "Plot": plot_filepath
        }

    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None


def process_all_core_nodes(cluster_labels, embedding_labels, hdb, input_folder, skip_noise=True):
    """
    Iterates over each cluster, finds the single highest-probability core node,
    processes it, and returns one row per cluster in a list of dicts.
    """
    data_list = []
    unique_clusters = np.unique(cluster_labels)

    for cluster_num in unique_clusters:
        if skip_noise and cluster_num == -1:
            continue

        core_data = process_core_node_in_cluster(
            cluster_num,
            cluster_labels,
            embedding_labels,
            hdb,
            input_folder
        )
        if core_data is not None:
            data_list.append(core_data)

    return data_list


def generate_aligned_html_grid(data_list, images_per_row=3):
    """
    Generates an HTML grid with aligned tops and captions below each image.
    Expects data_list to be a list of dictionaries where each dictionary
    has at least 'Plot', 'Cluster', and 'Chemical Formula'.
    """
    html = '<div style="display: flex; flex-wrap: wrap; align-items: flex-start;">'
    for idx, item in enumerate(data_list):
        width_percent = 100 // images_per_row
        html += f'''
        <div style="flex: 1 0 {width_percent}%; box-sizing: border-box; padding: 10px; text-align: center;">
            <div style="position: relative; width: 100%; padding-top: 100%; overflow: hidden;">
                <img src="{item['Plot']}" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; object-fit: contain;">
            </div>
            <div style="margin-top: 5px; font-size: 14px;">
                Cluster {item.get('Cluster','?')}: {item.get('Chemical Formula','Unknown')}
            </div>
        </div>
        '''
    html += '</div>'
    return html


data_list = process_all_core_nodes(
    cluster_labels=cluster_labels,
    embedding_labels=embedding_labels,
    hdb=hdb,
    input_folder=INPUT_FOLDER,
    skip_noise=True
)
html_code = generate_aligned_html_grid(data_list, images_per_row=5)
display(HTML(html_code))
