In [None]:
#Global Variables

MODEL_NAME = 'flatband_Hyperedge_test'
SUPERCELL_SIZE = 3 # Size of supercell in (NxN) unit cells
SUPERCELL_SIZE_1 = 3
SUPERCELL_SIZE_2 = 4
#CUT_OFF_DISTANCE = 5  # Distance in angstroms below which nodes are connected with edges
#MASKING_PERCENTAGE = 0.1 # Percentage of nodes features and edge attributes that are masked
BATCH_SIZE = 64 # Number of materials ran through the network at a time
NUM_EPOCHS = 1000 # How many times each trainer runs through the training data
LEARNING_RATE = 1e-3 # Adjusts how sensitive the network is when changing the weights
EMBEDDING_DIMENSION = 128 # How many dimensions the final vector (structural fingerprint)has
HIDDEN_DIMENSION = 128 # How many nodes in the hidden layer of the neural network
LAMBDA_PARAM = 5e-3 # lambda parameter for the Barlow Twins loss

In [None]:
#Graph generator: hyperedge, variable supercell size, next nearest neighbor and PBC, xyz node features, xyz edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data

# Define the tolerance delta for nearest neighbors
delta = 0.1 

# Define the perturbation size
perturbation_size = 0.05 

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index_set = set()

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from i to nearest neighbors
        for j in nn_indices:
            edge_index_set.add((i, j))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from i to next-nearest neighbors
            for j in nnn_indices:
                edge_index_set.add((i, j))

    # Convert edge_index_set to a tensor
    if len(edge_index_set) > 0:
        edge_index = torch.tensor(list(edge_index_set), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    # Make edges undirected by adding reverse edges
    edge_index_rev = edge_index.flip(0)
    edge_index = torch.cat([edge_index, edge_index_rev], dim=1)

    return edge_index

# Function to compute edge attributes as displacement vectors considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)  # Apply minimum image convention
    displacement_vectors = delta_scaled @ cell  # Convert back to Cartesian coordinates

    # Create edge attributes as displacement vectors
    edge_attr = torch.tensor(displacement_vectors, dtype=torch.float)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N 

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to compute hyperedges based on existing edges
def compute_hyperedges(edge_index, num_nodes):
    # Build adjacency list
    adj_list = [[] for _ in range(num_nodes)]
    for idx in range(edge_index.size(1)):
        src = edge_index[0, idx].item()
        tgt = edge_index[1, idx].item()
        adj_list[src].append(tgt)

    # Remove duplicates in adjacency lists
    for neighbors in adj_list:
        neighbors[:] = list(set(neighbors))

    # Create hyperedges
    hyperedges = []
    for j in range(num_nodes):
        neighbors = adj_list[j]
        # For all pairs of neighbors of node j
        for idx1 in range(len(neighbors)):
            for idx2 in range(idx1 + 1, len(neighbors)):
                i = neighbors[idx1]
                k = neighbors[idx2]
                hyperedges.append([i, j, k])

    if len(hyperedges) > 0:
        hyperedge_index = torch.tensor(hyperedges, dtype=torch.long).t().contiguous()  # Shape [3, num_hyperedges]
    else:
        hyperedge_index = torch.empty((3, 0), dtype=torch.long)

    return hyperedge_index

# Function to compute hyperedge attributes (angles between edges)
def compute_hyperedge_attr(atoms, hyperedge_index):
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    num_hyperedges = hyperedge_index.size(1)
    hyperedge_attr = []

    for idx in range(num_hyperedges):
        i = hyperedge_index[0, idx].item()
        j = hyperedge_index[1, idx].item()
        k = hyperedge_index[2, idx].item()

        # Scaled positions
        pos_i = scaled_positions[i]
        pos_j = scaled_positions[j]
        pos_k = scaled_positions[k]

        # Displacement vectors considering PBCs
        delta_ji_scaled = pos_i - pos_j
        delta_ji_scaled -= np.round(delta_ji_scaled)
        d_ji = delta_ji_scaled @ cell

        delta_jk_scaled = pos_k - pos_j
        delta_jk_scaled -= np.round(delta_jk_scaled)
        d_jk = delta_jk_scaled @ cell

        # Compute angle between d_ji and d_jk
        cos_theta = np.dot(d_ji, d_jk) / (np.linalg.norm(d_ji) * np.linalg.norm(d_jk) + 1e-8)
        angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))

        hyperedge_attr.append([angle])

    hyperedge_attr = torch.tensor(hyperedge_attr, dtype=torch.float)
    return hyperedge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    num_atoms = len(atoms)

    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Compute hyperedges
    hyperedge_index = compute_hyperedges(edge_index, num_atoms)

    # Compute hyperedge attributes
    hyperedge_attr = compute_hyperedge_attr(atoms, hyperedge_index)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        hyperedge_index=hyperedge_index,
        hyperedge_attr=hyperedge_attr,
    )

    graph.supercell_size = SUPERCELL_SIZE  # Store supercell size in graph

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update atoms object with perturbed positions
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        edge_index = graph.edge_index

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, graph.supercell_size)
        perturbed_graph.x = perturbed_node_features

        # Recompute hyperedge attributes based on perturbed positions
        hyperedge_index = graph.hyperedge_index
        perturbed_hyperedge_attr = compute_hyperedge_attr(perturbed_atoms, hyperedge_index)
        perturbed_graph.hyperedge_attr = perturbed_hyperedge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta, SUPERCELL_SIZE):
    graph_list = []
    atoms_list = []
    filenames = sorted([f for f in os.listdir(folder_path) if f.endswith('.xyz')])
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

    return graph_list, atoms_list

# Paths to the supercell files folders
supercell_folder_1 = 'supercells_flatband_rotated_shifted_aligned_3x3'
supercell_folder_2 = 'supercells_flatband_rotated_shifted_aligned_4x4'

SUPERCELL_SIZE_1 = 3  # Adjust as needed
SUPERCELL_SIZE_2 = 4  # Adjust as needed

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_1, atoms_list_unperturbed_1 = read_graphs_from_folder(
    supercell_folder_1, delta, SUPERCELL_SIZE_1
)
print("Unperturbed graph list for 3x3 supercells created with periodic boundary conditions accounted for.")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_2, atoms_list_unperturbed_2 = read_graphs_from_folder(
    supercell_folder_2, delta, SUPERCELL_SIZE_2
)
print("Unperturbed graph list for 4x4 supercells created with periodic boundary conditions accounted for.")

# Perturb the unperturbed 3x3 graphs to create the first perturbed graph list
graph_list_set_1 = perturb_graphs(
    graph_list_unperturbed_1, atoms_list_unperturbed_1, perturbation_size, delta
)

# Perturb the unperturbed 4x4 graphs to create the second perturbed graph list
graph_list_set_2 = perturb_graphs(
    graph_list_unperturbed_2, atoms_list_unperturbed_2, perturbation_size, delta
)

print("Two perturbed graph lists created from 3x3 and 4x4 supercells respectively.")


In [None]:
#Encoder: Hyperedge

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import add_self_loops
from torch_geometric.nn.norm import LayerNorm

class HyperedgeGNNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim, hyperedge_dim, aggr='add'):
        super(HyperedgeGNNConv, self).__init__(aggr=aggr)

        # Linear transformations for node features
        self.lin_node = nn.Linear(in_channels, out_channels)

        # Linear transformations for edge messages
        self.lin_edge = nn.Linear(in_channels + edge_dim, out_channels)

        # Linear transformations for hyperedge messages
        self.lin_hyperedge = nn.Linear(in_channels + hyperedge_dim, out_channels)

        # Activation function
        self.relu = nn.ReLU()

    def forward(self, x, edge_index, edge_attr, hyperedge_index, hyperedge_attr):
        # x: Node features [num_nodes, in_channels]
        # edge_index: Edge indices [2, num_edges]
        # edge_attr: Edge attributes [num_edges, edge_dim]
        # hyperedge_index: Hyperedge indices [3, num_hyperedges]
        # hyperedge_attr: Hyperedge attributes [num_hyperedges, hyperedge_dim]

        # Initial node feature transformation
        x = self.lin_node(x)

        # Message passing for edges
        edge_messages = self.propagate(
            edge_index, x=x, edge_attr=edge_attr, mode='edge'
        )

        # Message passing for hyperedges
        # For hyperedges, we need to reshape the hyperedge_index to simulate pairwise edges
        # We'll process hyperedges by sending messages from hyperedge nodes to central node

        # Expand hyperedges into pairwise edges
        num_hyperedges = hyperedge_index.size(1)
        # For each hyperedge (i, j, k), create edges (i -> j), (k -> j), and (j -> j)
        sender_indices = torch.cat([hyperedge_index[0], hyperedge_index[2], hyperedge_index[1]])
        receiver_indices = hyperedge_index[1].repeat(3)
        hyperedge_edge_index = torch.stack([sender_indices, receiver_indices], dim=0)

        # Repeat hyperedge attributes for each new edge
        hyperedge_edge_attr = hyperedge_attr.repeat(3, 1)

        hyperedge_messages = self.propagate(
            hyperedge_edge_index, x=x, edge_attr=hyperedge_edge_attr, mode='hyperedge'
        )

        # Combine messages
        out = x + edge_messages + hyperedge_messages
        out = self.relu(out)

        return out

    def message(self, x_j, edge_attr, mode):
        if mode == 'edge':
            # Edge message passing remains the same
            msg_input = torch.cat([x_j, edge_attr], dim=-1)
            msg = self.lin_edge(msg_input)
        elif mode == 'hyperedge':
            # For hyperedges, x_j represents features from hyperedge nodes
            # Concatenate x_j with hyperedge attributes
            msg_input = torch.cat([x_j, edge_attr], dim=-1)
            msg = self.lin_hyperedge(msg_input)
        else:
            raise ValueError("Invalid mode. Use 'edge' or 'hyperedge'.")

        return msg

    def message_and_aggregate(self, adj_t):
        # Not used in this implementation
        pass

class GNNWithHyperedges(nn.Module):
    def __init__(
        self,
        num_node_features,
        num_edge_features,
        num_hyperedge_features,
        hidden_dim=128,
        embedding_dim=64,
        num_layers=4,
        dropout_rate=0.1,
    ):
        super(GNNWithHyperedges, self).__init__()

        self.num_layers = num_layers
        self.dropout_rate = dropout_rate

        # Initial linear transformation for node features
        self.node_embedding = nn.Linear(num_node_features, hidden_dim)

        # Lists to hold convolutional and normalization layers
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        for _ in range(num_layers):
            conv = HyperedgeGNNConv(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                edge_dim=num_edge_features,
                hyperedge_dim=num_hyperedge_features,
                aggr='mean',
            )
            self.convs.append(conv)
            self.norms.append(LayerNorm(hidden_dim))

        # Global mean pooling
        self.pool = global_mean_pool

        # Final MLP for graph-level embedding
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(p=self.dropout_rate),
            nn.Linear(embedding_dim, embedding_dim),
        )

    def forward(self, data):
        x = data.x  # Node features
        edge_index = data.edge_index  # Edge indices
        edge_attr = data.edge_attr  # Edge attributes
        hyperedge_index = data.hyperedge_index  # Hyperedge indices
        hyperedge_attr = data.hyperedge_attr  # Hyperedge attributes

        # Ensure batch attribute is present
        if hasattr(data, 'batch'):
            batch = data.batch
        else:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        # Initial node embedding
        x = self.node_embedding(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Apply convolutional layers with residual connections
        for conv, norm in zip(self.convs, self.norms):
            x_residual = x  # Save residual
            x = conv(x, edge_index, edge_attr, hyperedge_index, hyperedge_attr)
            x = norm(x)
            x = F.relu(x)
            x = x + x_residual  # Residual connection
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Global mean pooling to get graph-level representation
        x = self.pool(x, batch)

        # Final MLP layers to get embedding
        x = self.fc(x)

        return x


In [None]:
#Loss function: barlow twins

def off_diagonal(x):
    # Returns the off-diagonal elements of a square matrix x
    n, m = x.shape
    assert n == m
    return x.flatten()[1:].view(n - 1, n + 1)[:, :-1].flatten()

def barlow_twins_loss(z_a, z_b, lambd=LAMBDA_PARAM):
    """
    Computes the Barlow Twins loss between two sets of embeddings.
    Args:
        z_a: Embeddings from the first set (BATCH_SIZE x embedding_dim)
        z_b: Embeddings from the second set (BATCH_SIZE x embedding_dim)
        lambd: Balancing parameter for off-diagonal terms
    Returns:
        loss: Scalar tensor representing the loss
    """
    # Normalize the embeddings along the batch dimension
    N, D = z_a.size()
    
    # Subtract mean and divide by standard deviation
    z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0))
    z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0))
    
    # Compute the cross-correlation matrix
    c = torch.mm(z_a_norm.T, z_b_norm) / N  # D x D matrix
    
    # Loss terms
    on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
    off_diag = off_diagonal(c).pow_(2).sum()
    loss = on_diag + lambd * off_diag
    return loss

In [None]:
# Data loader code with necessary adjustments

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import DataLoader

# Define constants
NUM_TRAINING = int(len(graph_list_set_1) * 0.8)  # 80% for training
NUM_VALIDATE = len(graph_list_set_1) - NUM_TRAINING  # Remaining for validation

# Set random seed for reproducibility
fixed_seed = 42
np.random.seed(fixed_seed)
indices = np.random.permutation(len(graph_list_set_1))

# Shuffle the graph lists
shuffled_graph_list_set_1 = [graph_list_set_1[i] for i in indices]
shuffled_graph_list_set_2 = [graph_list_set_2[i] for i in indices]

# Split the datasets into training and validation sets
train_graph_list_set_1 = shuffled_graph_list_set_1[:NUM_TRAINING]
val_graph_list_set_1 = shuffled_graph_list_set_1[NUM_TRAINING:]

train_graph_list_set_2 = shuffled_graph_list_set_2[:NUM_TRAINING]
val_graph_list_set_2 = shuffled_graph_list_set_2[NUM_TRAINING:]

# Create DataLoaders for training and validation sets
train_loader_set_1 = DataLoader(train_graph_list_set_1, batch_size=BATCH_SIZE, shuffle=False)
val_loader_set_1 = DataLoader(val_graph_list_set_1, batch_size=BATCH_SIZE, shuffle=False)

train_loader_set_2 = DataLoader(train_graph_list_set_2, batch_size=BATCH_SIZE, shuffle=False)
val_loader_set_2 = DataLoader(val_graph_list_set_2, batch_size=BATCH_SIZE, shuffle=False)

print("Data loaders created for training and validation sets.")


In [None]:
# Trainer adjusted for hyperedge GNN model

import os
import matplotlib.pyplot as plt
import torch
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create necessary directories
os.makedirs("Models", exist_ok=True)
os.makedirs("Training Plots", exist_ok=True)

print(f"""
-----------training starting---------------
BATCH_SIZE: {BATCH_SIZE}
NUM_EPOCHS: {NUM_EPOCHS}
Learning Rate: {LEARNING_RATE}
Hidden Dimension: {HIDDEN_DIMENSION}
Embedding Dimension: {EMBEDDING_DIMENSION}
Lambda (Barlow Twins Loss): {LAMBDA_PARAM}
""")

# Initialize the model with required arguments
model = GNNWithHyperedges(
    num_node_features=graph_list_set_1[0].x.size(1),  # Adjust based on node features
    num_edge_features=graph_list_set_1[0].edge_attr.size(1),  # Adjust based on edge features
    num_hyperedge_features=graph_list_set_1[0].hyperedge_attr.size(1),  # Adjust based on hyperedge features
    hidden_dim=HIDDEN_DIMENSION,
    embedding_dim=EMBEDDING_DIMENSION,
    num_layers=2,  # Customize as needed
    dropout_rate=0.1,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

def compute_loss(model, data1, data2, lambd):
    data1, data2 = data1.to(device), data2.to(device)
    embedding1 = model(data1)
    embedding2 = model(data2)
    loss = barlow_twins_loss(embedding1, embedding2, lambd)
    return loss

def train_one_epoch(model, loader_set_1, loader_set_2, optimizer, lambd):
    model.train()
    total_loss = 0.0

    for data1, data2 in zip(loader_set_1, loader_set_2):
        optimizer.zero_grad()
        loss = compute_loss(model, data1, data2, lambd)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(loader_set_1)
    return avg_train_loss

def validate_model(model, loader_set_1, loader_set_2, lambd):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for data1, data2 in zip(loader_set_1, loader_set_2):
            loss = compute_loss(model, data1, data2, lambd)
            total_loss += loss.item()

    avg_val_loss = total_loss / len(loader_set_1)
    return avg_val_loss

def train_model(model, train_loader_1, train_loader_2, val_loader_1, val_loader_2, optimizer, num_epochs, lambd):
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')  # Initialize best validation loss

    try:
        for epoch in range(num_epochs):
            # Training phase
            avg_train_loss = train_one_epoch(model, train_loader_1, train_loader_2, optimizer, lambd)
            train_losses.append(avg_train_loss)

            # Validation phase
            avg_val_loss = validate_model(model, val_loader_1, val_loader_2, lambd)
            val_losses.append(avg_val_loss)

            # Check if this is the best validation loss so far
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                # Save the model
                torch.save(model.state_dict(), f'Models/{MODEL_NAME}.pth')
                print(f"New best model saved with validation loss: {best_val_loss:.4f}")

            # Print epoch information
            print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

    except KeyboardInterrupt:
        print("Training interrupted by user. Best model saved so far.")

    return train_losses, val_losses

def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"Training Plots/{MODEL_NAME}.png")
    plt.show()

# Run the training and validation process
train_losses, val_losses = train_model(
    model, 
    train_loader_set_1, train_loader_set_2, 
    val_loader_set_1, val_loader_set_2, 
    optimizer, 
    NUM_EPOCHS, 
    LAMBDA_PARAM
)

plot_losses(train_losses, val_losses)


In [None]:
# Example: Inspecting weights
print("Lin Edge Weights:", model.convs[0].lin_edge.weight)
print("Lin Hyperedge Weights:", model.convs[0].lin_hyperedge.weight)
print("Lin Node Weights:", model.convs[0].lin_node.weight)


In [None]:
#Generates and saves embeddings

def test_model(model, loader):
    model.eval()  # Set the model to evaluation mode
    embeddings_list = []
    labels_list = []
    
    with torch.no_grad():  # Disable gradient calculations
        for batch_idx, data in enumerate(loader):
            data = data.to(device)  # Move the batch to the appropriate device (CPU/GPU)

            # Get embeddings for the batch
            embeddings = model(data)  # Shape: (batch_size, embedding_dim)

            # Store the embeddings for each graph in the batch
            embeddings_list.append(embeddings)

            # Extract labels from the batch and store them
            labels = data.label  # Assuming 'label' is the attribute name
            labels_list.append(labels)

    # Concatenate all the embeddings and labels for the entire dataset
    all_embeddings = torch.cat(embeddings_list, dim=0)  # Shape: (num_graphs, embedding_dim)
    all_labels = torch.cat(labels_list, dim=0)          # Shape: (num_graphs,)
    
    return all_embeddings, all_labels

# Assuming 'graph_list_original' contains graphs with 'label' attributes
test_loader = DataLoader(graph_list_unperturbed_1, batch_size=BATCH_SIZE, shuffle=False)

# Call the test function to generate embeddings and labels for the entire test set
output_graph_embeddings, output_graph_labels = test_model(model, test_loader)

# Print the shapes of the final embeddings and labels
print(f"Generated embeddings shape: {output_graph_embeddings.shape}")
print(f"Corresponding labels shape: {output_graph_labels.shape}")

embeddings_path = f'Embeddings/embeddings_{MODEL_NAME}.pt'
labels_path = f'Labels/labels_{MODEL_NAME}.pt'

torch.save(output_graph_embeddings, embeddings_path)

# Save the validation graph embeddings
torch.save(output_graph_labels, labels_path)

# Print confirmation message
print(f"embeddings saved to {embeddings_path}")
print(f"labels saved to {labels_path}")

In [None]:
#Clustering workflow
#uses hdbscan to cluster embeddings and visualises the clusters use t-SNE and UMAP
#also prints out addtional outlier score plots and other information

import matplotlib.pyplot as plt
import hdbscan
import numpy as np
from sklearn.manifold import TSNE
import umap.umap_ as umap

embeddings = torch.load(embeddings_path, map_location=torch.device('cpu')).numpy()
labels = torch.load(labels_path, map_location=torch.device('cpu')).numpy()

def hdbscan_clustering(data):
    # Initialize and fit HDBSCAN
    hdb = hdbscan.HDBSCAN(min_samples=5, min_cluster_size=4, prediction_data=True)
    cluster_labels = hdb.fit_predict(data)
    
    # Retrieve cluster stability scores and outlier scores
    stability_scores = hdb.cluster_persistence_
    outlier_scores = hdb.outlier_scores_
    
    # Print and return the scores
    print("Cluster Stability Scores:", stability_scores)
    print("Outlier Scores (first 10):", outlier_scores[:10])  # Printing first 10 for brevity
    
    return hdb, cluster_labels, stability_scores, outlier_scores

def tsne_plot(data, cluster_labels):
    tsne = TSNE(n_components=2, perplexity=20, random_state=42, init='pca')
    data_tsne_2d = tsne.fit_transform(data)

    unique_labels = np.unique(cluster_labels)
    background_points = (cluster_labels == -1)
    
    plt.figure(figsize=(10, 8))
    plt.scatter(data_tsne_2d[background_points, 0], data_tsne_2d[background_points, 1],
                c='lightgray', s=10, alpha=0.5, label='Noise')
    
    scatter = plt.scatter(data_tsne_2d[~background_points, 0], data_tsne_2d[~background_points, 1],
                          c=cluster_labels[~background_points], cmap='tab20', s=10, alpha=0.7)
    
    for label in unique_labels:
        if label != -1:
            label_points = data_tsne_2d[cluster_labels == label]
            centroid = np.mean(label_points, axis=0)
            plt.text(centroid[0], centroid[1], str(label), fontsize=8, fontweight='bold', 
                     color='black', ha='center', va='center')
    
    plt.title('t-SNE Visualization of HDBSCAN Clusters in 2D')
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.legend()
    plt.savefig(f'Clustering Plots/t-SNE_clusters_{MODEL_NAME}.png')
    plt.show()

def umap_plot(data, cluster_labels):
    umap_reducer = umap.UMAP(n_neighbors=20, n_components=2, min_dist=0.5 ,random_state=42, init='pca')
    data_umap_2d = umap_reducer.fit_transform(data)

    unique_labels = np.unique(cluster_labels)
    background_points = (cluster_labels == -1)
    
    plt.figure(figsize=(10, 8))
    plt.scatter(data_umap_2d[background_points, 0], data_umap_2d[background_points, 1],
                c='lightgray', s=10, alpha=0.5, label='Noise')
    
    scatter = plt.scatter(data_umap_2d[~background_points, 0], data_umap_2d[~background_points, 1],
                          c=cluster_labels[~background_points], cmap='tab20', s=10, alpha=0.7)
    
    for label in unique_labels:
        if label != -1:
            label_points = data_umap_2d[cluster_labels == label]
            centroid = np.mean(label_points, axis=0)
            plt.text(centroid[0], centroid[1], str(label), fontsize=8, fontweight='bold', 
                     color='black', ha='center', va='center')
    
    plt.title('UMAP Visualization of HDBSCAN Clusters in 2D')
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.legend()
    plt.savefig(f'Clustering Plots/UMAP_clusters_{MODEL_NAME}')
    plt.show()

def save_clusters_with_labels(cluster_labels, embedding_labels, filename='clustered_labels.npy'):
    combined_array = np.column_stack((embedding_labels, cluster_labels))
    np.save(filename, combined_array)
    print(f"Clusters and labels saved to {filename}")

def full_clustering_workflow(data, embedding_labels):
    print("Clustering with HDBSCAN...")
    hdb, cluster_labels, stability_scores, outlier_scores = hdbscan_clustering(data)
    save_clusters_with_labels(cluster_labels, embedding_labels)
    
    # Separate calls to visualization functions
    print("Visualizing with t-SNE...")
    tsne_plot(data, cluster_labels)
    
    print("Visualizing with UMAP...")
    umap_plot(data, cluster_labels)
    
    # Calculate and print average cluster stability score
    avg_stability_score = np.mean(stability_scores)
    print(f"Average Cluster Stability Score: {avg_stability_score:.4f}")
    
    # Masks
    clustered_mask = (cluster_labels != -1)
    noise_mask = (cluster_labels == -1)
    
    # Verify the number of clustered and noise points
    num_clustered_points = np.sum(clustered_mask)
    num_noise_points = np.sum(noise_mask)
    
    print("Total number of data points:", len(cluster_labels))
    print("Number of clustered points:", num_clustered_points)
    print("Number of noise points:", num_noise_points)
    
    # Filter out NaN values from outlier scores
    valid_indices = ~np.isnan(outlier_scores)
    clustered_mask = clustered_mask & valid_indices
    noise_mask = noise_mask & valid_indices
    
    clustered_outlier_scores = outlier_scores[clustered_mask]
    noise_outlier_scores = outlier_scores[noise_mask]
    
    # Calculate and print average outlier scores
    if clustered_outlier_scores.size > 0:
        avg_clustered_outlier_score = np.mean(clustered_outlier_scores)
        print(f"Average Outlier Score for Clustered Points: {avg_clustered_outlier_score:.4f}")
    else:
        print("No valid outlier scores for clustered points.")
    
    if noise_outlier_scores.size > 0:
        avg_noise_outlier_score = np.mean(noise_outlier_scores)
        print(f"Average Outlier Score for Noise Points: {avg_noise_outlier_score:.4f}")
    else:
        print("No valid outlier scores for noise points.")
    
    # Plot histograms
    # Clustered points
    if clustered_outlier_scores.size > 0:
        plt.figure(figsize=(8, 5))
        plt.hist(clustered_outlier_scores, bins=20, color='skyblue', edgecolor='black')
        plt.title('Outlier Scores Distribution for Clustered Points')
        plt.xlabel('Outlier Score')
        plt.ylabel('Frequency')
        plt.show()
    else:
        print("No valid outlier scores for clustered points to plot.")
    
    # Noise points
    if noise_outlier_scores.size > 0:
        plt.figure(figsize=(8, 5))
        plt.hist(noise_outlier_scores, bins=20, color='salmon', edgecolor='black')
        plt.title('Outlier Scores Distribution for Noise Points')
        plt.xlabel('Outlier Score')
        plt.ylabel('Frequency')
        plt.show()
    
    return hdb, cluster_labels

def load_clusters_with_labels(filename='clustered_labels.npy'):
    # Load the saved numpy array
    loaded_array = np.load(filename)
    
    # Split the loaded array into embedding labels and cluster labels
    embedding_labels = loaded_array[:, 0]  # First column is embedding labels
    cluster_labels = loaded_array[:, 1]    # Second column is cluster labels
    
    print(f"Clusters and labels loaded from {filename}")
    return embedding_labels, cluster_labels

# Assuming 'embeddings' and 'labels' are already defined in your workspace
hdb, cluster_labels = full_clustering_workflow(embeddings[:4000], labels[:4000])

# Load embedding_labels from the saved file
embedding_labels, _ = load_clusters_with_labels('clustered_labels.npy')

In [None]:
#Provides the properties of and visualizes every node in the specified cluster, including the full lattice formula

CLUSTER_NUM = 48
cluster = embedding_labels[cluster_labels==CLUSTER_NUM]
cluster

import os
import shutil
import re
import numpy as np
import torch
import pandas as pd
from ase.io import read
from torch_geometric.data import Data
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from IPython.display import HTML

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Global parameter for the supercell size
SUPERCELL_SIZE = 3

# Define the folder to save plots
plot_folder = 'cluster_plots'

# Ensure the plot folder exists
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)
else:
    # Delete all files in the directory
    for filename in os.listdir(plot_folder):
        file_path = os.path.join(plot_folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)  # Remove the file
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # Remove the directory
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")

# Function to parse the reduced_sublattice_structure.txt file
def parse_reduced_sublattice_structure(file_path):
    material_formula_mapping = {}
    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()

        material_label = None
        for i, line in enumerate(lines):
            line = line.strip()
            if line.startswith('2dm-'):
                # Extract the material label number
                material_label = int(line.split('-')[1])
                # The next line contains the full lattice formula
                if i + 1 < len(lines):
                    next_line = lines[i + 1].strip()
                    # Extract the full lattice formula (e.g., 'TaI3')
                    full_lattice_formula = next_line.split()[0]
                    material_formula_mapping[material_label] = full_lattice_formula
    except FileNotFoundError:
        print(f"File {file_path} not found.")
    except Exception as e:
        print(f"An error occurred while parsing the file: {e}")
    return material_formula_mapping

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # No node features
    num_nodes = len(atoms)
    node_features = torch.empty((num_nodes, 0), dtype=torch.float)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to plot the graph with consistent y-size
def plot_graph(atoms, graph, margin=0.1, tolerance=0.1, y_size=6, filename=None):
    """
    Plots the graph of an atomic structure with consistent y-direction size.

    Parameters:
    - atoms: ASE Atoms object representing the structure.
    - graph: PyTorch Geometric Data object for the graph.
    - margin: Margin added to the plot boundaries as a fraction of plot size.
    - tolerance: Tolerance for grouping atoms into layers by z-coordinate.
    - y_size: Desired size of the plot in the y-direction (in inches).
    - filename: If provided, saves the plot to this file instead of showing it.
    """
    positions = atoms.get_positions()
    cell = atoms.get_cell()

    # Extract x, y, and z positions
    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]

    # Group z positions into layers using the tolerance
    z_grouped = np.round(z / tolerance) * tolerance

    # Get unique z positions and assign an index to each layer
    z_unique, indices = np.unique(z_grouped, return_inverse=True)
    N_layers = len(z_unique)

    # Define a colormap with a number of colors equal to the number of layers
    cmap = plt.cm.get_cmap('viridis', N_layers)

    # Create figure and axis
    fig, ax = plt.subplots()

    # Shift supercell boundary by 0.5c
    c_shift = 0.5 * cell[2]  # Half the c vector
    corner_positions = np.array([
        [0, 0, 0],
        cell[0],
        cell[0] + cell[1],
        cell[1],
        [0, 0, 0]
    ]) + c_shift  # Add the c-shift vector to each vertex

    # Recalculate plot limits based on the shifted boundary
    x_min, x_max = corner_positions[:, 0].min(), corner_positions[:, 0].max()
    y_min, y_max = corner_positions[:, 1].min(), corner_positions[:, 1].max()

    # Add margin to the limits
    x_margin = (x_max - x_min) * margin
    y_margin = (y_max - y_min) * margin

    x_min -= x_margin
    x_max += x_margin
    y_min -= y_margin
    y_max += y_margin

    # Dynamically adjust the figure size to maintain consistent y-direction size
    y_range = y_max - y_min
    x_range = x_max - x_min
    aspect_ratio = x_range / y_range

    fig.set_size_inches(y_size * aspect_ratio, y_size)  # Scale x-size proportionally

    # Plot the supercell boundary
    ax.plot(corner_positions[:, 0], corner_positions[:, 1], 'k--', linewidth=1)

    # Plot the nodes with colors based on z position
    ax.scatter(x, y, c=indices, s=50, cmap=cmap, zorder=2)

    # Prepare edge lines
    lines = []
    for idx in range(graph.edge_index.shape[1]):
        i = graph.edge_index[0, idx].item()
        j = graph.edge_index[1, idx].item()

        pos_i = positions[i]
        pos_j = positions[j]

        # Compute delta_scaled considering PBCs
        delta_scaled = atoms.get_scaled_positions()[j] - atoms.get_scaled_positions()[i]
        delta_scaled -= np.round(delta_scaled)

        # Adjusted position of j for plotting
        delta = delta_scaled @ cell
        pos_j_plot = pos_i + delta

        # Add the line segment
        lines.append([pos_i[:2], pos_j_plot[:2]])

    # Create a LineCollection from the lines
    lc = LineCollection(lines, colors='gray', linewidths=1, zorder=1)
    ax.add_collection(lc)

    # Set limits with new boundaries
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

    # Set aspect ratio to equal
    ax.set_aspect('equal')

    # Remove axes and adjust layout
    ax.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    # Save the plot to a file or show it
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

# Function to process every node in the specified cluster and collect data
def process_nodes_in_cluster(CLUSTER_NUM, hdb, cluster_labels, embedding_labels, formula_mapping):
    # Find indices of points in the given cluster
    cluster_indices = np.where(cluster_labels == CLUSTER_NUM)[0]

    if len(cluster_indices) == 0:
        print(f"No points found in cluster {CLUSTER_NUM}.")
        return None

    data_list = []

    for idx in cluster_indices:
        material_label = int(embedding_labels[idx])

        # Build the file path
        file_path = f'supercells_flatband_rotated_shifted_3x3_test/supercell_2dm-{material_label}.xyz'

        # Get the full lattice formula from the mapping
        full_lattice_formula = formula_mapping.get(material_label, 'Unknown')

        # Process the structure
        try:
            atoms = read(file_path)
            atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

            # Create graph with PBCs
            graph = create_graph_from_structure(atoms, delta)

            # Save the plot to a file in the plot_folder
            plot_filename = f'graph_cluster_{CLUSTER_NUM}_material_{material_label}_idx_{idx}.png'
            plot_filepath = os.path.join(plot_folder, plot_filename)
            plot_graph(atoms, graph, margin=0.2, tolerance=0.1, y_size=6, filename=plot_filepath)
            print(f"Graph for material label {material_label} from cluster {CLUSTER_NUM} saved as {plot_filepath}.")

            # Extract additional properties
            num_atoms = len(atoms)  # Number of atoms
            cell_lengths = atoms.cell.diagonal()  # Get lengths of the cell
            cell_angles = atoms.cell.cellpar()[3:6]  # Get angles in degrees

            # Scale the first two cell lengths by dividing by SUPERCELL_SIZE
            scaled_cell_lengths = [
                f'{cell_lengths[0] / SUPERCELL_SIZE:.3f}',
                f'{cell_lengths[1] / SUPERCELL_SIZE:.3f}',
                f'{cell_lengths[2]:.3f}'
            ]

            # Store lengths and angles as lists for correct display
            formatted_cell_lengths = [f'{length}' for length in scaled_cell_lengths]
            formatted_cell_angles = [f'{angle:.3f}' for angle in cell_angles]

            # Calculate the number of atoms in the unit cell
            num_atoms_in_unit_cell = num_atoms // (SUPERCELL_SIZE * SUPERCELL_SIZE)

            # Create an Atoms object for the unit cell
            unit_cell_atoms = atoms[:num_atoms_in_unit_cell]

            # Get the chemical formula of the unit cell
            formula = unit_cell_atoms.get_chemical_formula()

            # Append data to data_list
            data_list.append({
                'Cluster': CLUSTER_NUM,
                'Material': material_label,
                'Full Lattice Formula': full_lattice_formula,
                'Chemical Formula': formula,
                'No. atoms per cell': num_atoms_in_unit_cell,
                'Unitcell Lengths': formatted_cell_lengths,
                'Cell Angles': formatted_cell_angles,
                'Plot': plot_filepath  # Ensure that 'Plot' is the last column
            })

        except FileNotFoundError:
            print(f"Structure file for material label {material_label} not found at {file_path}.")
            continue
        except Exception as e:
            print(f"An error occurred while processing material {material_label}: {e}")
            continue

    return data_list

# Load the formula mapping from the reduced_sublattice_structure.txt file
formula_mapping = parse_reduced_sublattice_structure('reduced_sublattice_structure.txt')

# List to store data for the table
data_list = []

# Process the specified cluster
cluster_data = process_nodes_in_cluster(CLUSTER_NUM, hdb, cluster_labels, embedding_labels, formula_mapping)
if cluster_data is not None:
    data_list.extend(cluster_data)
else:
    print(f"No data collected for cluster {CLUSTER_NUM}.")

# Define the column order, ensuring 'Plot' is the last column
column_order = [
    'Cluster',
    'Material',
    'Full Lattice Formula', 
    'Chemical Formula',
    'No. atoms per cell',
    'Unitcell Lengths',
    'Cell Angles',
    'Plot' 
]

# Create a DataFrame from the collected data with the specified column order
df = pd.DataFrame(data_list, columns=column_order)

# Function to convert image paths to HTML tags for display
def path_to_image_html(path):
    return '<img src="{}" height="200">'.format(path)  # Set consistent height

# Display the DataFrame with images in the 'Plot' column, suppressing the index column
HTML(df.to_html(index=False, escape=False, formatters={'Plot': path_to_image_html}))

In [None]:
#Provides the properties of and visualises the core node of every cluster

import os
import shutil
import re
import numpy as np
import torch
import pandas as pd
from ase.io import read
from torch_geometric.data import Data
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from IPython.display import HTML

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Global parameter for the supercell size
SUPERCELL_SIZE = 3

# Define the folder to save plots
plot_folder = 'Core Node Plots'

# Ensure the plot folder exists
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)
else:
    # Delete all files in the directory
    for filename in os.listdir(plot_folder):
        file_path = os.path.join(plot_folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)  # Remove the file
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # Remove the directory
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # No node features
    num_nodes = len(atoms)
    node_features = torch.empty((num_nodes, 0), dtype=torch.float)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to plot the graph with consistent y-size
def plot_graph(atoms, graph, margin=0.1, tolerance=0.1, y_size=6, filename=None):
    """
    Plots the graph of an atomic structure with consistent y-direction size.

    Parameters:
    - atoms: ASE Atoms object representing the structure.
    - graph: PyTorch Geometric Data object for the graph.
    - margin: Margin added to the plot boundaries as a fraction of plot size.
    - tolerance: Tolerance for grouping atoms into layers by z-coordinate.
    - y_size: Desired size of the plot in the y-direction (in inches).
    - filename: If provided, saves the plot to this file instead of showing it.
    """
    positions = atoms.get_positions()
    cell = atoms.get_cell()

    # Extract x, y, and z positions
    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]

    # Group z positions into layers using the tolerance
    z_grouped = np.round(z / tolerance) * tolerance

    # Get unique z positions and assign an index to each layer
    z_unique, indices = np.unique(z_grouped, return_inverse=True)
    N_layers = len(z_unique)

    # Define a colormap with a number of colors equal to the number of layers
    cmap = plt.cm.get_cmap('viridis', N_layers)

    # Create figure and axis
    fig, ax = plt.subplots()

    # Shift supercell boundary by 0.5c
    c_shift = 0.5 * cell[2]  # Half the c vector
    corner_positions = np.array([
        [0, 0, 0],
        cell[0],
        cell[0] + cell[1],
        cell[1],
        [0, 0, 0]
    ]) + c_shift  # Add the c-shift vector to each vertex

    # Recalculate plot limits based on the shifted boundary
    x_min, x_max = corner_positions[:, 0].min(), corner_positions[:, 0].max()
    y_min, y_max = corner_positions[:, 1].min(), corner_positions[:, 1].max()

    # Add margin to the limits
    x_margin = (x_max - x_min) * margin
    y_margin = (y_max - y_min) * margin

    x_min -= x_margin
    x_max += x_margin
    y_min -= y_margin
    y_max += y_margin

    # Dynamically adjust the figure size to maintain consistent y-direction size
    y_range = y_max - y_min
    x_range = x_max - x_min
    aspect_ratio = x_range / y_range

    fig.set_size_inches(y_size * aspect_ratio, y_size)  # Scale x-size proportionally

    # Plot the supercell boundary
    ax.plot(corner_positions[:, 0], corner_positions[:, 1], 'k--', linewidth=1)

    # Plot the nodes with colors based on z position
    ax.scatter(x, y, c=indices, s=50, cmap=cmap, zorder=2)

    # Prepare edge lines
    lines = []
    for idx in range(graph.edge_index.shape[1]):
        i = graph.edge_index[0, idx].item()
        j = graph.edge_index[1, idx].item()

        pos_i = positions[i]
        pos_j = positions[j]

        # Compute delta_scaled considering PBCs
        delta_scaled = atoms.get_scaled_positions()[j] - atoms.get_scaled_positions()[i]
        delta_scaled -= np.round(delta_scaled)

        # Adjusted position of j for plotting
        delta = delta_scaled @ cell
        pos_j_plot = pos_i + delta

        # Add the line segment
        lines.append([pos_i[:2], pos_j_plot[:2]])

    # Create a LineCollection from the lines
    lc = LineCollection(lines, colors='gray', linewidths=1, zorder=1)
    ax.add_collection(lc)

    # Set limits with new boundaries
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

    # Set aspect ratio to equal
    ax.set_aspect('equal')

    # Remove axes and adjust layout
    ax.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    # Save the plot to a file or show it
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

# Function to process the core node of a cluster and collect data
def process_core_node_of_cluster(CLUSTER_NUM, hdb, cluster_labels, embedding_labels):
    # Find indices of points in the given cluster
    cluster_indices = np.where(cluster_labels == CLUSTER_NUM)[0]

    if len(cluster_indices) == 0:
        print(f"No points found in cluster {CLUSTER_NUM}.")
        return None

    # Get the probabilities for these points
    cluster_probabilities = hdb.probabilities_[cluster_indices]

    # Find the index with the highest probability (most representative core point)
    max_prob_index_in_cluster = cluster_indices[np.argmax(cluster_probabilities)]

    # Get the material label
    material_label = int(embedding_labels[max_prob_index_in_cluster])

    # Build the file path
    file_path = f'supercells_flatband_rotated_shifted_aligned_3x3/supercell_2dm-{material_label}.xyz'

    # Process the structure
    try:
        atoms = read(file_path)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta)

        # Save the plot to a file in the plot_folder
        plot_filename = f'graph_cluster_{CLUSTER_NUM}_material_{material_label}.png'
        plot_filepath = os.path.join(plot_folder, plot_filename)
        plot_graph(atoms, graph, margin=0.2, tolerance=0.1, y_size=6, filename=plot_filepath)
        print(f"Graph for material label {material_label} from cluster {CLUSTER_NUM} saved as {plot_filepath}.")

        # Extract additional properties
        num_atoms = len(atoms)  # Number of atoms
        cell_lengths = atoms.cell.diagonal()  # Get lengths of the cell
        cell_angles = atoms.cell.cellpar()[3:6]  # Get angles in degrees

        # Scale the first two cell lengths by dividing by SUPERCELL_SIZE
        scaled_cell_lengths = [
            f'{cell_lengths[0] / SUPERCELL_SIZE:.3f}',
            f'{cell_lengths[1] / SUPERCELL_SIZE:.3f}',
            f'{cell_lengths[2]:.3f}'
        ]

        # Store lengths and angles as lists for correct display
        formatted_cell_lengths = [f'{length}' for length in scaled_cell_lengths]
        formatted_cell_angles = [f'{angle:.3f}' for angle in cell_angles]

        # Calculate the number of atoms in the unit cell
        num_atoms_in_unit_cell = num_atoms // (SUPERCELL_SIZE * SUPERCELL_SIZE)

        # Create an Atoms object for the unit cell
        unit_cell_atoms = atoms[:num_atoms_in_unit_cell]

        # Get the chemical formula of the unit cell
        formula = unit_cell_atoms.get_chemical_formula()

        # Return the data for the table, including additional properties
        return {
            'Cluster': CLUSTER_NUM,
            'Material': material_label,
            'Chemical Formula': formula,
            'No. atoms per cell': num_atoms_in_unit_cell,
            'Unitcell Lengths': formatted_cell_lengths,
            'Cell Angles': formatted_cell_angles,
            'Plot': plot_filepath  # Ensure that 'Plot' is the last column
        }

    except FileNotFoundError:
        print(f"Structure file for material label {material_label} not found at {file_path}.")
        return None
    except Exception as e:
        print(f"An error occurred while processing the structure: {e}")
        return None

# Ensure that embedding_labels, cluster_labels, and hdb are available
# If not already loaded, load them using your existing function
# embedding_labels, cluster_labels = load_clusters_with_labels('clustered_labels.npy')

# Get all unique cluster numbers (excluding noise points labeled as -1)
cluster_numbers = np.unique(cluster_labels)
cluster_numbers = cluster_numbers[cluster_numbers != -1]  # Exclude noise

# List to store data for the table
data_list = []

# Loop over each cluster and collect data
for CLUSTER_NUM in cluster_numbers:
    result = process_core_node_of_cluster(CLUSTER_NUM, hdb, cluster_labels, embedding_labels)
    if result is not None:
        data_list.append(result)

# Define the column order, ensuring 'Plot' is the last column and rearranged as per your request
column_order = [
    'Cluster',
    'Material',
    'Chemical Formula',
    'No. atoms per cell',  # Updated column name
    'Unitcell Lengths',
    'Cell Angles',
    'Plot'  # Ensure 'Plot' is the last column
]

# Create a DataFrame from the collected data with the specified column order
df = pd.DataFrame(data_list, columns=column_order)

# Function to convert image paths to HTML tags for display
def path_to_image_html(path):
    return '<img src="{}" height="200">'.format(path)  # Set consistent height

# Display the DataFrame with images in the 'Plot' column, suppressing the index column
HTML(df.to_html(index=False, escape=False, formatters={'Plot': path_to_image_html}))
