In [None]:
# Graph Generator: takes two folders of supercells (each of a different supercell size) to create 4 graph lists (2 unperturbed and 2 perturbed)
# Representation: next_nearest edges - scalar Cartesian attributes, nearest hyperedges - scalar attributes, 
# relative unit cell position node features - vector Cartesian features, cross supercell boundary connections
# Note: this code takes two different folders generated using 2) containing the same materials but with different supercell sizes
# Naming: the graphs are saved in three lists ('graph_list_unperturbed_1.pt','graph_list_set_1.pt','graph_list_set_2.pt')

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data

# Defines the two supercell sizes used for creating the two perturbations of each material's graph
SUPERCELL_SIZE_1 = 3
SUPERCELL_SIZE_2 = 4

# Paths to the supercell files folders
SUPERCELL_FOLDER_1 = 'Supercells/2DMatpedia Sublattices 3x3'
SUPERCELL_FOLDER_2 = 'Supercells/2DMatpedia Sublattices 4x4'

# Folder that the three lists of graphs will be saved to
OUTPUT_FOLDER = 'Graphs/2DMatpedia Sublattices'

# Define the tolerance delta for nearest neighbors in Angstroms
DELTA = 0.1

# Define the perturbation size in Angtroms
PERTURBATION_SIZE = 0.05

# ensures output folder is made
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# 1. Function to compute separate nearest and next-nearest neighbor edge connections
def compute_nn_nnn_edge_index(atoms, delta):
    """
    Returns two edge_index tensors:
      - edge_index_nn: containing only nearest-neighbor edges
      - edge_index_full: containing nearest- + next-nearest neighbors
    """
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)

    # Sets for edges
    edge_index_nn_set = set()
    edge_index_full_set = set()

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges for nearest neighbors to both sets
        for j in nn_indices:
            edge_index_nn_set.add((i, j))
            edge_index_full_set.add((i, j))

        # Exclude nearest neighbors for next-nearest neighbor determination
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges only to the 'full' set
            for j in nnn_indices:
                edge_index_full_set.add((i, j))

    # Convert sets to tensors
    def set_to_undirected_tensor(edge_set):
        if not edge_set:
            # Handle no edges
            return torch.empty((2, 0), dtype=torch.long)
        edge_index = torch.tensor(list(edge_set), dtype=torch.long).t().contiguous()
        # Make undirected by adding reverse edges
        edge_index_rev = edge_index.flip(0)
        edge_index = torch.cat([edge_index, edge_index_rev], dim=1)
        return edge_index

    edge_index_nn = set_to_undirected_tensor(edge_index_nn_set)
    edge_index_full = set_to_undirected_tensor(edge_index_full_set)

    return edge_index_nn, edge_index_full

# 2. Function to compute edge attributes as scalar distances between atoms (for a given edge_index)
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# 3. Function to compute node features as relative cartesian positions within each unit cell
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError(
            "Number of atoms per unit cell is not consistent with "
            "total atoms and supercell dimensions."
        )

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# 4. Function to compute hyperedges using only nearest-neighbor edge connections
def compute_hyperedges(edge_index_nn, num_nodes):
    """
    Computes hyperedges from a given edge_index tensor (intended for NN edges only).
    Hyperedges are triplets [i, j, k] where j is a common neighbor of i and k.
    """
    row, col = edge_index_nn
    # Build adjacency list from nearest-neighbor edges only
    adj_list = [[] for _ in range(num_nodes)]
    for idx in range(edge_index_nn.size(1)):
        src = row[idx].item()
        tgt = col[idx].item()
        adj_list[src].append(tgt)

    # Remove duplicates in adjacency lists
    for neighbors in adj_list:
        neighbors[:] = list(set(neighbors))

    # Create hyperedges
    hyperedges = []
    for j in range(num_nodes):
        neighbors = adj_list[j]
        # For all pairs of neighbors of node j
        for idx1 in range(len(neighbors)):
            for idx2 in range(idx1 + 1, len(neighbors)):
                i = neighbors[idx1]
                k = neighbors[idx2]
                hyperedges.append([i, j, k])

    if len(hyperedges) > 0:
        hyperedge_index = torch.tensor(hyperedges, dtype=torch.long).t().contiguous()  # [3, num_hyperedges]
    else:
        hyperedge_index = torch.empty((3, 0), dtype=torch.long)

    return hyperedge_index

# 5. Function to compute hyperedge attributes as angle between two edges from a given hyperedge index
def compute_hyperedge_attr(atoms, hyperedge_index):
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    num_hyperedges = hyperedge_index.size(1)
    hyperedge_attr = []

    for idx in range(num_hyperedges):
        i = hyperedge_index[0, idx].item()
        j = hyperedge_index[1, idx].item()
        k = hyperedge_index[2, idx].item()

        # Scaled positions
        pos_i = scaled_positions[i]
        pos_j = scaled_positions[j]
        pos_k = scaled_positions[k]

        # Displacement vectors considering PBCs
        delta_ji_scaled = pos_i - pos_j
        delta_ji_scaled -= np.round(delta_ji_scaled)
        d_ji = delta_ji_scaled @ cell

        delta_jk_scaled = pos_k - pos_j
        delta_jk_scaled -= np.round(delta_jk_scaled)
        d_jk = delta_jk_scaled @ cell

        # Compute angle between d_ji and d_jk
        cos_theta = np.dot(d_ji, d_jk) / (np.linalg.norm(d_ji) * np.linalg.norm(d_jk) + 1e-8)
        angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))  # Angle in radians

        hyperedge_attr.append([angle])

    hyperedge_attr = torch.tensor(hyperedge_attr, dtype=torch.float)  # [num_hyperedges, 1]
    return hyperedge_attr

# 6. Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    """
    - edge_index: contains nearest + next-nearest neighbors
    - hyperedges: computed only from nearest-neighbor edges
    """
    num_atoms = len(atoms)

    # 1. Compute separate NN and full (NN + NNN) edge indices
    edge_index_nn, edge_index_full = compute_nn_nnn_edge_index(atoms, delta)

    # 2. Compute edge attributes for the full connectivity
    edge_attr_full = compute_edge_attr(atoms, edge_index_full)

    # 3. Compute node features
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # 4. Compute hyperedges using only nearest-neighbor edges
    hyperedge_index = compute_hyperedges(edge_index_nn, num_atoms)

    # 5. Compute hyperedge attributes
    hyperedge_attr = compute_hyperedge_attr(atoms, hyperedge_index)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index_full,    # Full connectivity (NN + NNN)
        edge_attr=edge_attr_full,      # Attributes for full edges
        hyperedge_index=hyperedge_index,
        hyperedge_attr=hyperedge_attr,
    )

    graph.supercell_size = SUPERCELL_SIZE  # Store supercell size in graph

    return graph

# 7. Apply perturbations to a list of graphs as a random movement of each atom in each Cartesian axis
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta):
    """
    Note: Here, we keep the same connectivity (edge_index) as the original
    graphs. If you want to recompute the connectivity after perturbation,
    you'll need to call `create_graph_from_structure` again.
    """
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update atoms object with perturbed positions
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # The connectivity is not recomputed here; we reuse graph.edge_index
        edge_index_full = graph.edge_index
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, edge_index_full)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, graph.supercell_size)
        perturbed_graph.x = perturbed_node_features

        # Recompute hyperedge attributes
        hyperedge_index = graph.hyperedge_index
        perturbed_hyperedge_attr = compute_hyperedge_attr(perturbed_atoms, hyperedge_index)
        perturbed_graph.hyperedge_attr = perturbed_hyperedge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list


# 8. Read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta, SUPERCELL_SIZE):
    graph_list = []
    atoms_list = []
    
    # Function to extract numeric and text parts for sorting
    def natural_sort_key(filename):
        return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', filename)]

    # Apply natural sorting
    filenames = sorted(
        [f for f in os.listdir(folder_path) if f.endswith('.xyz')],
        key=natural_sort_key
    )
    
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Remove file extension for cleaner parsing
        filename_no_ext = os.path.splitext(filename)[0]

        # Use a regex to capture the numeric ID and element
        # This pattern expects something like ..._1234_Ta (or similar) before .xyz
        match = re.match(r'^.*?(\d+)_(\w+)$', filename_no_ext)
        if match:
            id_part = match.group(1)
            element_part = match.group(2)
            label = f"{id_part}_{element_part}"
        else:
            # Fallback if no match (can modify as needed)
            label = filename_no_ext  

        # Create graph with nearest + next-nearest edges
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label  # Store the combined label

        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created with label {label}.")

    return graph_list, atoms_list

# 9. Generates the graphs, perturbs them and stores them in 4 separate lists (2 unperturbed lists and their corresponding perturbed lists)

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_1, atoms_list_unperturbed_1 = read_graphs_from_folder(
    SUPERCELL_FOLDER_1, DELTA, SUPERCELL_SIZE_1
)
print("Unperturbed graph list for 3x3 supercells created.")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_2, atoms_list_unperturbed_2 = read_graphs_from_folder(
    SUPERCELL_FOLDER_2, DELTA, SUPERCELL_SIZE_2
)
print("Unperturbed graph list for 4x4 supercells created.")

# Perturb the unperturbed 3x3 graphs
graph_list_set_1 = perturb_graphs(
    graph_list_unperturbed_1, atoms_list_unperturbed_1, PERTURBATION_SIZE, DELTA
)

# Perturb the unperturbed 4x4 graphs
graph_list_set_2 = perturb_graphs(
    graph_list_unperturbed_2, atoms_list_unperturbed_2, PERTURBATION_SIZE, DELTA
)

print("Two perturbed graph lists created from 3x3 and 4x4 supercells respectively.")

# 1. Saves 3 of the graph lists (only 1 unperturbed list is saved which can be used to generate final embeddings)

# Save unperturbed graph list 1
torch.save(graph_list_unperturbed_1, f"{OUTPUT_FOLDER}\graph_list_unperturbed_1.pt")

# Save the first perturbed graph list (3x3)
torch.save(graph_list_set_1, f"{OUTPUT_FOLDER}\graph_list_set_1.pt")

# Save the second perturbed graph list (4x4)
torch.save(graph_list_set_2, f"{OUTPUT_FOLDER}\graph_list_set_2.pt")