In [None]:
#Graph generator: variable supercell size, next nearest neighbor and PBC, 0 node features, xyz edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes as displacement vectors considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)  # Apply minimum image convention
    displacement_vectors = delta_scaled @ cell  # Convert back to Cartesian coordinates

    # Create edge attributes as displacement vectors
    edge_attr = torch.tensor(displacement_vectors, dtype=torch.float)
    return edge_attr

# Function to compute node features
def compute_node_features(atoms):
    num_atoms = len(atoms)
    # Create node features as a tensor of ones
    node_features = torch.ones((num_atoms, 1), dtype=torch.float)
    return node_features

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features
    node_features = compute_node_features(atoms)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms)
        perturbed_graph.x = perturbed_node_features

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta):
    graph_list = []
    atoms_list = []
    filenames = sorted([f for f in os.listdir(folder_path) if f.endswith('.xyz')])
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

    return graph_list, atoms_list

# Paths to the supercell files folders
supercell_folder_3x3 = 'supercells_flatband_3x3'
supercell_folder_4x4 = 'supercells_flatband_4x4'

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_3x3, atoms_list_unperturbed_3x3 = read_graphs_from_folder(supercell_folder_3x3, delta)
print("Unperturbed graph list for 3x3 supercells created with periodic boundary conditions accounted for.")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_4x4, atoms_list_unperturbed_4x4 = read_graphs_from_folder(supercell_folder_4x4, delta)
print("Unperturbed graph list for 4x4 supercells created with periodic boundary conditions accounted for.")

# Perturb the unperturbed 3x3 graphs to create the first perturbed graph list
graph_list_set_1 = perturb_graphs(graph_list_unperturbed_3x3, atoms_list_unperturbed_3x3, perturbation_size, delta)

# Perturb the unperturbed 4x4 graphs to create the second perturbed graph list
graph_list_set_2 = perturb_graphs(graph_list_unperturbed_4x4, atoms_list_unperturbed_4x4, perturbation_size, delta)

print("Two perturbed graph lists created from 3x3 and 4x4 supercells respectively.")


In [None]:
#Graph generator: variable supercell size, next nearest neighbor and PBC, xyz node features, xyz edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes as displacement vectors considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)  # Apply minimum image convention
    displacement_vectors = delta_scaled @ cell  # Convert back to Cartesian coordinates

    # Create edge attributes as displacement vectors
    edge_attr = torch.tensor(displacement_vectors, dtype=torch.float)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    graph.supercell_size = SUPERCELL_SIZE  # Store supercell size in graph

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta, SUPERCELL_SIZE):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, SUPERCELL_SIZE)
        perturbed_graph.x = perturbed_node_features

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta, SUPERCELL_SIZE):
    graph_list = []
    atoms_list = []
    filenames = sorted([f for f in os.listdir(folder_path) if f.endswith('.xyz')])
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

    return graph_list, atoms_list

# Paths to the supercell files folders
supercell_folder_1 = 'supercells_flatband_rotated_shifted_3x3'
supercell_folder_2 = 'supercells_flatband_rotated_shifted_4x4'

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_1, atoms_list_unperturbed_1 = read_graphs_from_folder(
    supercell_folder_1, delta, SUPERCELL_SIZE_1
)
print("Unperturbed graph list for 3x3 supercells created with periodic boundary conditions accounted for.")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_2, atoms_list_unperturbed_2 = read_graphs_from_folder(
    supercell_folder_2, delta, SUPERCELL_SIZE_2
)
print("Unperturbed graph list for 4x4 supercells created with periodic boundary conditions accounted for.")

# Perturb the unperturbed 3x3 graphs to create the first perturbed graph list
graph_list_set_1 = perturb_graphs(
    graph_list_unperturbed_1, atoms_list_unperturbed_1, perturbation_size, delta, SUPERCELL_SIZE_1
)

# Perturb the unperturbed 4x4 graphs to create the second perturbed graph list
graph_list_set_2 = perturb_graphs(
    graph_list_unperturbed_2, atoms_list_unperturbed_2, perturbation_size, delta, SUPERCELL_SIZE_2
)

print("Two perturbed graph lists created from 3x3 and 4x4 supercells respectively.")

In [None]:
#Graph generator: variable supercell size, connnected graphs, next nearest neighbor and PBC, xyz node features, xyz edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx

delta = 0.1 #Defines the tolerance delta for nearest neighbors

perturbation_size = 0.05 #Define the perturbation size

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes as displacement vectors considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)  # Apply minimum image convention
    displacement_vectors = delta_scaled @ cell  # Convert back to Cartesian coordinates

    # Create edge attributes as displacement vectors
    edge_attr = torch.tensor(displacement_vectors, dtype=torch.float)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE, label):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    graph.supercell_size = SUPERCELL_SIZE  # Store supercell size in graph
    graph.label = label  # Store the label in the graph

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta, SUPERCELL_SIZE):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, SUPERCELL_SIZE)
        perturbed_graph.x = perturbed_node_features

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta, SUPERCELL_SIZE):
    graph_list = []
    atoms_list = []
    filenames = sorted([f for f in os.listdir(folder_path) if f.endswith('.xyz')])
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE, label)

        # Append the graph and atoms to their respective lists
        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

    return graph_list, atoms_list

# Function to check connectivity and separate graphs along with their corresponding atoms
def separate_connected_disconnected(graph_list, atoms_list):
    connected_graphs = []
    connected_atoms = []
    disconnected_graphs = []
    disconnected_atoms = []

    for graph, atoms in zip(graph_list, atoms_list):
        # Convert PyG graph to NetworkX graph
        G = to_networkx(graph, to_undirected=True)

        if nx.is_connected(G):
            connected_graphs.append(graph)
            connected_atoms.append(atoms)
        else:
            disconnected_graphs.append(graph)
            disconnected_atoms.append(atoms)

    return connected_graphs, connected_atoms, disconnected_graphs, disconnected_atoms

# Paths to the supercell files folders
supercell_folder_1 = 'Supercells/supercells_flatband_rotated_shifted_aligned_3x3'
supercell_folder_2 = 'Supercells/supercells_flatband_rotated_shifted_aligned_4x4'

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_1, atoms_list_unperturbed_1 = read_graphs_from_folder(
    supercell_folder_1, delta, SUPERCELL_SIZE_1
)
print("Unperturbed graph list for 3x3 supercells created with periodic boundary conditions accounted for.")

# Separate connected and disconnected graphs for 3x3 supercells
connected_graphs_unperturbed_1, connected_atoms_unperturbed_1, disconnected_graphs_unperturbed_1, disconnected_atoms_unperturbed_1 = separate_connected_disconnected(
    graph_list_unperturbed_1, atoms_list_unperturbed_1
)
print(f"Total graphs in 3x3: {len(graph_list_unperturbed_1)}")
print(f"Connected graphs in 3x3: {len(connected_graphs_unperturbed_1)}")
print(f"Disconnected graphs in 3x3: {len(disconnected_graphs_unperturbed_1)}")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_2, atoms_list_unperturbed_2 = read_graphs_from_folder(
    supercell_folder_2, delta, SUPERCELL_SIZE_2
)
print("Unperturbed graph list for 4x4 supercells created with periodic boundary conditions accounted for.")

# Separate connected and disconnected graphs for 4x4 supercells
connected_graphs_unperturbed_2, connected_atoms_unperturbed_2, disconnected_graphs_unperturbed_2, disconnected_atoms_unperturbed_2 = separate_connected_disconnected(
    graph_list_unperturbed_2, atoms_list_unperturbed_2
)
print(f"Total graphs in 4x4: {len(graph_list_unperturbed_2)}")
print(f"Connected graphs in 4x4: {len(connected_graphs_unperturbed_2)}")
print(f"Disconnected graphs in 4x4: {len(disconnected_graphs_unperturbed_2)}")

# Extract labels of connected graphs
labels_connected_1 = set(graph.label for graph in connected_graphs_unperturbed_1)
labels_connected_2 = set(graph.label for graph in connected_graphs_unperturbed_2)

# Find common labels
common_labels = labels_connected_1.intersection(labels_connected_2)
print(f"Number of common connected graphs: {len(common_labels)}")

# Filter connected graphs and atoms to only keep those with labels in common_labels
filtered_connected_graphs_unperturbed_1 = []
filtered_connected_atoms_unperturbed_1 = []
for graph, atoms in zip(connected_graphs_unperturbed_1, connected_atoms_unperturbed_1):
    if graph.label in common_labels:
        filtered_connected_graphs_unperturbed_1.append(graph)
        filtered_connected_atoms_unperturbed_1.append(atoms)

filtered_connected_graphs_unperturbed_2 = []
filtered_connected_atoms_unperturbed_2 = []
for graph, atoms in zip(connected_graphs_unperturbed_2, connected_atoms_unperturbed_2):
    if graph.label in common_labels:
        filtered_connected_graphs_unperturbed_2.append(graph)
        filtered_connected_atoms_unperturbed_2.append(atoms)

print(f"Filtered connected graphs in 3x3: {len(filtered_connected_graphs_unperturbed_1)}")
print(f"Filtered connected graphs in 4x4: {len(filtered_connected_graphs_unperturbed_2)}")

# Perturb the filtered connected 3x3 graphs to create the first perturbed graph list
graph_list_set_1 = perturb_graphs(
    filtered_connected_graphs_unperturbed_1, filtered_connected_atoms_unperturbed_1, perturbation_size, delta, SUPERCELL_SIZE_1
)

# Perturb the filtered connected 4x4 graphs to create the second perturbed graph list
graph_list_set_2 = perturb_graphs(
    filtered_connected_graphs_unperturbed_2, filtered_connected_atoms_unperturbed_2, perturbation_size, delta, SUPERCELL_SIZE_2
)

graph_list_unperturbed = filtered_connected_graphs_unperturbed_1
print("Two perturbed graph lists created from common connected graphs in 3x3 and 4x4 supercells respectively.")

In [None]:
#Graph generator: hyperedge, variable supercell size, next nearest neighbor and PBC, xyz node features, xyz edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index_set = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from i to nearest neighbors
        for j in nn_indices:
            edge_index_set.add((i, j))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from i to next-nearest neighbors
            for j in nnn_indices:
                edge_index_set.add((i, j))

    # Convert edge_index_set to a tensor
    if len(edge_index_set) > 0:
        edge_index = torch.tensor(list(edge_index_set), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    # Make edges undirected by adding reverse edges
    edge_index_rev = edge_index.flip(0)
    edge_index = torch.cat([edge_index, edge_index_rev], dim=1)

    return edge_index

# Function to compute edge attributes as displacement vectors considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)  # Apply minimum image convention
    displacement_vectors = delta_scaled @ cell  # Convert back to Cartesian coordinates

    # Create edge attributes as displacement vectors
    edge_attr = torch.tensor(displacement_vectors, dtype=torch.float)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to compute hyperedges based on existing edges
def compute_hyperedges(edge_index, num_nodes):
    # Build adjacency list
    adj_list = [[] for _ in range(num_nodes)]
    for idx in range(edge_index.size(1)):
        src = edge_index[0, idx].item()
        tgt = edge_index[1, idx].item()
        adj_list[src].append(tgt)

    # Remove duplicates in adjacency lists
    for neighbors in adj_list:
        neighbors[:] = list(set(neighbors))

    # Create hyperedges
    hyperedges = []
    for j in range(num_nodes):
        neighbors = adj_list[j]
        # For all pairs of neighbors of node j
        for idx1 in range(len(neighbors)):
            for idx2 in range(idx1 + 1, len(neighbors)):
                i = neighbors[idx1]
                k = neighbors[idx2]
                hyperedges.append([i, j, k])

    if len(hyperedges) > 0:
        hyperedge_index = torch.tensor(hyperedges, dtype=torch.long).t().contiguous()  # Shape [3, num_hyperedges]
    else:
        hyperedge_index = torch.empty((3, 0), dtype=torch.long)

    return hyperedge_index

# Function to compute hyperedge attributes (angles between edges)
def compute_hyperedge_attr(atoms, hyperedge_index):
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    num_hyperedges = hyperedge_index.size(1)
    hyperedge_attr = []

    for idx in range(num_hyperedges):
        i = hyperedge_index[0, idx].item()
        j = hyperedge_index[1, idx].item()
        k = hyperedge_index[2, idx].item()

        # Scaled positions
        pos_i = scaled_positions[i]
        pos_j = scaled_positions[j]
        pos_k = scaled_positions[k]

        # Displacement vectors considering PBCs
        delta_ji_scaled = pos_i - pos_j
        delta_ji_scaled -= np.round(delta_ji_scaled)
        d_ji = delta_ji_scaled @ cell

        delta_jk_scaled = pos_k - pos_j
        delta_jk_scaled -= np.round(delta_jk_scaled)
        d_jk = delta_jk_scaled @ cell

        # Compute angle between d_ji and d_jk
        cos_theta = np.dot(d_ji, d_jk) / (np.linalg.norm(d_ji) * np.linalg.norm(d_jk) + 1e-8)
        angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))  # Angle in radians

        hyperedge_attr.append([angle])

    hyperedge_attr = torch.tensor(hyperedge_attr, dtype=torch.float)  # Shape [num_hyperedges, 1]
    return hyperedge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    num_atoms = len(atoms)

    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Compute hyperedges
    hyperedge_index = compute_hyperedges(edge_index, num_atoms)

    # Compute hyperedge attributes
    hyperedge_attr = compute_hyperedge_attr(atoms, hyperedge_index)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        hyperedge_index=hyperedge_index,
        hyperedge_attr=hyperedge_attr,
    )

    graph.supercell_size = SUPERCELL_SIZE  # Store supercell size in graph

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update atoms object with perturbed positions
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge_index (optional if edges remain the same)
        # If you expect the perturbation to change the connectivity, recompute edge_index
        # For this example, we'll keep the same edge_index
        edge_index = graph.edge_index

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, graph.supercell_size)
        perturbed_graph.x = perturbed_node_features

        # Recompute hyperedge attributes based on perturbed positions
        hyperedge_index = graph.hyperedge_index
        perturbed_hyperedge_attr = compute_hyperedge_attr(perturbed_atoms, hyperedge_index)
        perturbed_graph.hyperedge_attr = perturbed_hyperedge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta, SUPERCELL_SIZE):
    graph_list = []
    atoms_list = []
    filenames = sorted([f for f in os.listdir(folder_path) if f.endswith('.xyz')])
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

    return graph_list, atoms_list

# Paths to the supercell files folders
supercell_folder_1 = 'Supercells/supercells_flatband_rotated_shifted_aligned_3x3'
supercell_folder_2 = 'Supercells/supercells_flatband_rotated_shifted_aligned_4x4'

SUPERCELL_SIZE_1 = 3  # Adjust as needed
SUPERCELL_SIZE_2 = 4  # Adjust as needed

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_1, atoms_list_unperturbed_1 = read_graphs_from_folder(
    supercell_folder_1, delta, SUPERCELL_SIZE_1
)
print("Unperturbed graph list for 3x3 supercells created with periodic boundary conditions accounted for.")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_2, atoms_list_unperturbed_2 = read_graphs_from_folder(
    supercell_folder_2, delta, SUPERCELL_SIZE_2
)
print("Unperturbed graph list for 4x4 supercells created with periodic boundary conditions accounted for.")

# Perturb the unperturbed 3x3 graphs to create the first perturbed graph list
graph_list_set_1 = perturb_graphs(
    graph_list_unperturbed_1, atoms_list_unperturbed_1, perturbation_size, delta
)

# Perturb the unperturbed 4x4 graphs to create the second perturbed graph list
graph_list_set_2 = perturb_graphs(
    graph_list_unperturbed_2, atoms_list_unperturbed_2, perturbation_size, delta
)

print("Two perturbed graph lists created from 3x3 and 4x4 supercells respectively.")

In [None]:
#Graph generator: hyperedge, variable supercell size, next nearest neighbor and PBC, xyz node features, x edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index_set = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from i to nearest neighbors
        for j in nn_indices:
            edge_index_set.add((i, j))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next-nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from i to next-nearest neighbors
            for j in nnn_indices:
                edge_index_set.add((i, j))

    # Convert edge_index_set to a tensor
    if len(edge_index_set) > 0:
        edge_index = torch.tensor(list(edge_index_set), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    # Make edges undirected by adding reverse edges
    edge_index_rev = edge_index.flip(0)
    edge_index = torch.cat([edge_index, edge_index_rev], dim=1)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to compute hyperedges based on existing edges
def compute_hyperedges(edge_index, num_nodes):
    # Build adjacency list
    adj_list = [[] for _ in range(num_nodes)]
    for idx in range(edge_index.size(1)):
        src = edge_index[0, idx].item()
        tgt = edge_index[1, idx].item()
        adj_list[src].append(tgt)

    # Remove duplicates in adjacency lists
    for neighbors in adj_list:
        neighbors[:] = list(set(neighbors))

    # Create hyperedges
    hyperedges = []
    for j in range(num_nodes):
        neighbors = adj_list[j]
        # For all pairs of neighbors of node j
        for idx1 in range(len(neighbors)):
            for idx2 in range(idx1 + 1, len(neighbors)):
                i = neighbors[idx1]
                k = neighbors[idx2]
                hyperedges.append([i, j, k])

    if len(hyperedges) > 0:
        hyperedge_index = torch.tensor(hyperedges, dtype=torch.long).t().contiguous()  # Shape [3, num_hyperedges]
    else:
        hyperedge_index = torch.empty((3, 0), dtype=torch.long)

    return hyperedge_index

# Function to compute hyperedge attributes (angles between edges)
def compute_hyperedge_attr(atoms, hyperedge_index):
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    num_hyperedges = hyperedge_index.size(1)
    hyperedge_attr = []

    for idx in range(num_hyperedges):
        i = hyperedge_index[0, idx].item()
        j = hyperedge_index[1, idx].item()
        k = hyperedge_index[2, idx].item()

        # Scaled positions
        pos_i = scaled_positions[i]
        pos_j = scaled_positions[j]
        pos_k = scaled_positions[k]

        # Displacement vectors considering PBCs
        delta_ji_scaled = pos_i - pos_j
        delta_ji_scaled -= np.round(delta_ji_scaled)
        d_ji = delta_ji_scaled @ cell

        delta_jk_scaled = pos_k - pos_j
        delta_jk_scaled -= np.round(delta_jk_scaled)
        d_jk = delta_jk_scaled @ cell

        # Compute angle between d_ji and d_jk
        cos_theta = np.dot(d_ji, d_jk) / (np.linalg.norm(d_ji) * np.linalg.norm(d_jk) + 1e-8)
        angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))  # Angle in radians

        hyperedge_attr.append([angle])

    hyperedge_attr = torch.tensor(hyperedge_attr, dtype=torch.float)  # Shape [num_hyperedges, 1]
    return hyperedge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    num_atoms = len(atoms)

    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Compute hyperedges
    hyperedge_index = compute_hyperedges(edge_index, num_atoms)

    # Compute hyperedge attributes
    hyperedge_attr = compute_hyperedge_attr(atoms, hyperedge_index)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        hyperedge_index=hyperedge_index,
        hyperedge_attr=hyperedge_attr,
    )

    graph.supercell_size = SUPERCELL_SIZE  # Store supercell size in graph

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update atoms object with perturbed positions
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge_index (optional if edges remain the same)
        # If you expect the perturbation to change the connectivity, recompute edge_index
        # For this example, we'll keep the same edge_index
        edge_index = graph.edge_index

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, graph.supercell_size)
        perturbed_graph.x = perturbed_node_features

        # Recompute hyperedge attributes based on perturbed positions
        hyperedge_index = graph.hyperedge_index
        perturbed_hyperedge_attr = compute_hyperedge_attr(perturbed_atoms, hyperedge_index)
        perturbed_graph.hyperedge_attr = perturbed_hyperedge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta, SUPERCELL_SIZE):
    graph_list = []
    atoms_list = []
    filenames = sorted([f for f in os.listdir(folder_path) if f.endswith('.xyz')])
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

    return graph_list, atoms_list

# Paths to the supercell files folders
supercell_folder_1 = 'Supercells/supercells_flatband_rotated_shifted_aligned_3x3'
supercell_folder_2 = 'Supercells/supercells_flatband_rotated_shifted_aligned_4x4'

SUPERCELL_SIZE_1 = 3  # Adjust as needed
SUPERCELL_SIZE_2 = 4  # Adjust as needed

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_1, atoms_list_unperturbed_1 = read_graphs_from_folder(
    supercell_folder_1, delta, SUPERCELL_SIZE_1
)
print("Unperturbed graph list for 3x3 supercells created with periodic boundary conditions accounted for.")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_2, atoms_list_unperturbed_2 = read_graphs_from_folder(
    supercell_folder_2, delta, SUPERCELL_SIZE_2
)
print("Unperturbed graph list for 4x4 supercells created with periodic boundary conditions accounted for.")

# Perturb the unperturbed 3x3 graphs to create the first perturbed graph list
graph_list_set_1 = perturb_graphs(
    graph_list_unperturbed_1, atoms_list_unperturbed_1, perturbation_size, delta
)

# Perturb the unperturbed 4x4 graphs to create the second perturbed graph list
graph_list_set_2 = perturb_graphs(
    graph_list_unperturbed_2, atoms_list_unperturbed_2, perturbation_size, delta
)

print("Two perturbed graph lists created from 3x3 and 4x4 supercells respectively.")

In [None]:
# Graph generator: hyperedge, variable supercell size, lattice vector and intra unit cell, xyz node features, xyz edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data
from scipy.spatial import cKDTree  # Added import for cKDTree

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.005  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta, N):
    scaled_positions = atoms.get_scaled_positions()
    num_atoms = len(scaled_positions)
    cell = atoms.get_cell()

    # Build a KDTree of the scaled positions
    tree = cKDTree(scaled_positions)

    # Define shifts in scaled coordinates (±a and ±b)
    shifts_scaled = np.array([
        [1 / N, 0, 0],
        [-1 / N, 0, 0],
        [0, 1 / N, 0],
        [0, -1 / N, 0],
    ])

    edge_index = set()

    # Existing lattice vector connections
    for i in range(num_atoms):
        s_i = scaled_positions[i]
        for shift_scaled in shifts_scaled:
            s_j_candidate = s_i + shift_scaled
            s_j_candidate = np.mod(s_j_candidate, 1.0)  # Apply PBCs
            # Query KDTree to find atoms near s_j_candidate
            idxs = tree.query_ball_point(s_j_candidate, r=delta)
            for idx in idxs:
                if idx != i:  # Exclude self-loops
                    edge_index.add((i, idx))
                    edge_index.add((idx, i))  # Add reverse edge

    # Determine the number of atoms per unit cell
    total_unit_cells = N * N  # For a 2D lattice
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        # Add edges within each unit cell
        for idx1 in group:
            for idx2 in group:
                if idx1 != idx2:
                    edge_index.add((idx1, idx2))
                    edge_index.add((idx2, idx1))  # Ensure symmetry

    # Convert edge_index to tensor
    if edge_index:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to compute hyperedges based on existing edges
def compute_hyperedges(edge_index, num_nodes):
    # Build adjacency list
    adj_list = [[] for _ in range(num_nodes)]
    for idx in range(edge_index.size(1)):
        src = edge_index[0, idx].item()
        tgt = edge_index[1, idx].item()
        adj_list[src].append(tgt)

    # Remove duplicates in adjacency lists
    for neighbors in adj_list:
        neighbors[:] = list(set(neighbors))

    # Create hyperedges
    hyperedges = []
    for j in range(num_nodes):
        neighbors = adj_list[j]
        # For all pairs of neighbors of node j
        for idx1 in range(len(neighbors)):
            for idx2 in range(idx1 + 1, len(neighbors)):
                i = neighbors[idx1]
                k = neighbors[idx2]
                hyperedges.append([i, j, k])

    if len(hyperedges) > 0:
        hyperedge_index = torch.tensor(hyperedges, dtype=torch.long).t().contiguous()  # Shape [3, num_hyperedges]
    else:
        hyperedge_index = torch.empty((3, 0), dtype=torch.long)

    return hyperedge_index

# Function to compute hyperedge attributes (angles between edges)
def compute_hyperedge_attr(atoms, hyperedge_index):
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    num_hyperedges = hyperedge_index.size(1)
    hyperedge_attr = []

    for idx in range(num_hyperedges):
        i = hyperedge_index[0, idx].item()
        j = hyperedge_index[1, idx].item()
        k = hyperedge_index[2, idx].item()

        # Scaled positions
        pos_i = scaled_positions[i]
        pos_j = scaled_positions[j]
        pos_k = scaled_positions[k]

        # Displacement vectors considering PBCs
        delta_ji_scaled = pos_i - pos_j
        delta_ji_scaled -= np.round(delta_ji_scaled)
        d_ji = delta_ji_scaled @ cell

        delta_jk_scaled = pos_k - pos_j
        delta_jk_scaled -= np.round(delta_jk_scaled)
        d_jk = delta_jk_scaled @ cell

        # Compute angle between d_ji and d_jk
        cos_theta = np.dot(d_ji, d_jk) / (np.linalg.norm(d_ji) * np.linalg.norm(d_jk) + 1e-8)
        angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))  # Angle in radians

        hyperedge_attr.append([angle])

    hyperedge_attr = torch.tensor(hyperedge_attr, dtype=torch.float)  # Shape [num_hyperedges, 1]
    return hyperedge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    num_atoms = len(atoms)

    # Compute edge_index using the new function
    edge_index = compute_edge_index(atoms, delta, SUPERCELL_SIZE)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Compute hyperedges
    hyperedge_index = compute_hyperedges(edge_index, num_atoms)

    # Compute hyperedge attributes
    hyperedge_attr = compute_hyperedge_attr(atoms, hyperedge_index)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        hyperedge_index=hyperedge_index,
        hyperedge_attr=hyperedge_attr,
    )

    graph.supercell_size = SUPERCELL_SIZE  # Store supercell size in graph

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update atoms object with perturbed positions
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, perturbed_graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, graph.supercell_size)
        perturbed_graph.x = perturbed_node_features

        # Recompute hyperedge attributes based on perturbed positions
        perturbed_hyperedge_attr = compute_hyperedge_attr(perturbed_atoms, perturbed_graph.hyperedge_index)
        perturbed_graph.hyperedge_attr = perturbed_hyperedge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to read graphs and atoms from a folder
def read_graphs_from_folder(folder_path, delta, SUPERCELL_SIZE):
    graph_list = []
    atoms_list = []
    filenames = sorted([f for f in os.listdir(folder_path) if f.endswith('.xyz')])
    for filename in filenames:
        # Read the structure file
        filepath = os.path.join(folder_path, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list.append(graph)
        atoms_list.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

    return graph_list, atoms_list

# Paths to the supercell files folders
supercell_folder_1 = 'Supercells/supercells_flatband_rotated_shifted_aligned_3x3'
supercell_folder_2 = 'Supercells/supercells_flatband_rotated_shifted_aligned_4x4'

SUPERCELL_SIZE_1 = 3  # Adjust as needed
SUPERCELL_SIZE_2 = 4  # Adjust as needed

# Read graphs and atoms from the 3x3 folder
graph_list_unperturbed_1, atoms_list_unperturbed_1 = read_graphs_from_folder(
    supercell_folder_1, delta, SUPERCELL_SIZE_1
)
print("Unperturbed graph list for 3x3 supercells created with periodic boundary conditions accounted for.")

# Read graphs and atoms from the 4x4 folder
graph_list_unperturbed_2, atoms_list_unperturbed_2 = read_graphs_from_folder(
    supercell_folder_2, delta, SUPERCELL_SIZE_2
)
print("Unperturbed graph list for 4x4 supercells created with periodic boundary conditions accounted for.")

# Perturb the unperturbed 3x3 graphs to create the first perturbed graph list
graph_list_set_1 = perturb_graphs(
    graph_list_unperturbed_1, atoms_list_unperturbed_1, perturbation_size, delta
)

# Perturb the unperturbed 4x4 graphs to create the second perturbed graph list
graph_list_set_2 = perturb_graphs(
    graph_list_unperturbed_2, atoms_list_unperturbed_2, perturbation_size, delta
)

print("Two perturbed graph lists created from 3x3 and 4x4 supercells respectively.")

In [None]:
#Graph generator: next nearest neighbour and PBC, xyz node features and edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes as displacement vectors considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)  # Apply minimum image convention
    displacement_vectors = delta_scaled @ cell  # Convert back to Cartesian coordinates

    # Create edge attributes as displacement vectors
    edge_attr = torch.tensor(displacement_vectors, dtype=torch.float)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features (relative positions within unit cells)
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, SUPERCELL_SIZE):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update positions in the atoms object
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge attributes using the same edge_index
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, perturbed_graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, SUPERCELL_SIZE)
        perturbed_graph.x = perturbed_node_features

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Path to the supercell files folder
supercell_folder = 'supercells_flatband_rotated_shifted_3x3'

# Lists to store the graph objects and the corresponding atoms
graph_list_unperturbed = []
atoms_list_unperturbed = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list_unperturbed.append(graph)
        atoms_list_unperturbed.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Unperturbed graph list created with periodic boundary conditions accounted for.")

# First perturbed graph list
graph_list_set_1 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size, SUPERCELL_SIZE)

# Second perturbed graph list
graph_list_set_2 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size, SUPERCELL_SIZE)

print("Three graph lists created: unperturbed, perturbed set 1, and perturbed set 2 with scaling and masking applied.")


In [None]:
#Graph generator: next nearest neighbour and PBC, xyz node features, x edges

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to compute node features (relative positions within unit cells)
def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    N = SUPERCELL_SIZE
    total_unit_cells = N * N  # For a 2D lattice

    # Compute the number of atoms per unit cell
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    unit_cell_indices = []
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        unit_cell_indices.append(group)

    # Compute relative positions
    positions = atoms.get_positions()
    relative_positions = np.zeros_like(positions)

    for group in unit_cell_indices:
        indices = group
        positions_in_cell = positions[indices]
        reference_position = positions_in_cell[0]  # First atom in the unit cell
        relative_positions_in_cell = positions_in_cell - reference_position
        relative_positions[indices] = relative_positions_in_cell

    # Create node features as torch tensor
    node_features = torch.tensor(relative_positions, dtype=torch.float)
    return node_features

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features (relative positions within unit cells)
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta, SUPERCELL_SIZE):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, SUPERCELL_SIZE)
        perturbed_graph.x = perturbed_node_features

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Path to the supercell files folder
supercell_folder = 'supercells_flatband'

# Lists to store the graph objects and the corresponding atoms
graph_list_unperturbed = []
atoms_list_unperturbed = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list_unperturbed.append(graph)
        atoms_list_unperturbed.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Unperturbed graph list created with periodic boundary conditions accounted for.")

# First perturbed graph list
graph_list_set_1 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size, delta, SUPERCELL_SIZE)

# Second perturbed graph list
graph_list_set_2 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size, delta, SUPERCELL_SIZE)

print("Three graph lists created: unperturbed, perturbed set 1, and perturbed set 2 with scaling and masking applied.")

In [None]:
#Graph generator: next nearest neighbour and PBC, 0 node features, xyz edge attributes

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Define the perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes as displacement vectors considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)  # Apply minimum image convention
    displacement_vectors = delta_scaled @ cell  # Convert back to Cartesian coordinates

    # Create edge attributes as displacement vectors
    edge_attr = torch.tensor(displacement_vectors, dtype=torch.float)
    return edge_attr

def compute_node_features(atoms, SUPERCELL_SIZE):
    num_atoms = len(atoms)
    # Create node features as a tensor of ones
    node_features = torch.ones((num_atoms, 1), dtype=torch.float)
    return node_features


# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, SUPERCELL_SIZE):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Compute node features (relative positions within unit cells)
    node_features = compute_node_features(atoms, SUPERCELL_SIZE)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size, delta, SUPERCELL_SIZE):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Recompute edge attributes
        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        # Recompute node features based on perturbed positions
        perturbed_node_features = compute_node_features(perturbed_atoms, SUPERCELL_SIZE)
        perturbed_graph.x = perturbed_node_features

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Path to the supercell files folder
supercell_folder = 'supercells_flatband'

# Lists to store the graph objects and the corresponding atoms
graph_list_unperturbed = []
atoms_list_unperturbed = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list_unperturbed.append(graph)
        atoms_list_unperturbed.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Unperturbed graph list created with periodic boundary conditions accounted for.")

# First perturbed graph list
graph_list_set_1 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size, delta, SUPERCELL_SIZE)

# Second perturbed graph list
graph_list_set_2 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size, delta, SUPERCELL_SIZE)

print("Three graph lists created: unperturbed, perturbed set 1, and perturbed set 2 with scaling and masking applied.")

In [None]:
#Graph generator: next nearest neighbour and PBC, no atomic number ,perturbation after edge connection, masking and scaling applied

import os
import re
import torch
import numpy as np
from ase.io import read
from ase.geometry import find_mic
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Create two perturbed graph lists by applying perturbations to the unperturbed graph list
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)
    cell = atoms.get_cell()
    pbc = atoms.get_pbc()

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Create dummy node features (ones)
    num_nodes = len(atoms)
    node_features = torch.ones((num_nodes, 1), dtype=torch.float)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to scale node features and edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all node features and edge attributes from the graphs
    all_node_features = []
    all_edge_attributes = []

    for graph in graph_list:
        all_node_features.append(graph.x)
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)

    # Stack all node features
    all_node_features_stacked = torch.cat(all_node_features, dim=0).numpy()

    # Initialize StandardScalers
    node_scaler = StandardScaler()
    node_scaler.fit(all_node_features_stacked)

    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale

    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale node features
        graph_node_features = graph.x.numpy()
        graph_node_features_scaled = node_scaler.transform(graph_node_features)
        graph.x = torch.tensor(graph_node_features_scaled, dtype=torch.float)

        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask a percentage of node features and edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=MASKING_PERCENTAGE):
    for graph in graph_list:
        # Mask a percentage of node features
        num_nodes = graph.x.shape[0]
        num_node_mask = max(1, int(np.ceil(masking_percentage * num_nodes))) if num_nodes > 0 else 0
        if num_node_mask > 0:
            node_mask_indices = np.random.choice(num_nodes, num_node_mask, replace=False)
            graph.x[node_mask_indices] = 0  # Zero out the node features for the selected nodes

        # Mask a percentage of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges

# Path to the supercell files folder
supercell_folder = 'supercells_flatband'

# Lists to store the graph objects and the corresponding atoms
graph_list_unperturbed = []
atoms_list_unperturbed = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the last number from the filename
        match = re.findall(r'\d+', filename)
        if match:
            label = int(match[-1])  # Get the last number
        else:
            label = None  # Handle cases where no number is found


        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list_unperturbed.append(graph)
        atoms_list_unperturbed.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Unperturbed graph list created with periodic boundary conditions accounted for.")

# First perturbed graph list
graph_list_set_1 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size)

# Second perturbed graph list
graph_list_set_2 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size)

# Apply scaling to each graph list
scale_graphs(graph_list_unperturbed)
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)

# Apply masking to the perturbed graph lists
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Three graph lists created: unperturbed, perturbed set 1, and perturbed set 2 with scaling and masking applied.")

In [None]:
#Graph generator: next nearest neighbour and PBC, perturbation after edge connection, masking and scaling applied

import os
import re
import torch
import numpy as np
from ase.io import read
from ase.geometry import find_mic
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Create two perturbed graph lists by applying perturbations to the unperturbed graph list
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta):
    positions = atoms.get_positions()
    num_atoms = len(positions)
    cell = atoms.get_cell()
    pbc = atoms.get_pbc()

    # Compute distance matrix considering PBCs
    dist_matrix = atoms.get_all_distances(mic=True)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Create node features: atomic number
    node_features = torch.tensor(
        [[atom.number] for atom in atoms], dtype=torch.float
    )

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = True  # Ensure PBCs are enabled

        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to scale node features and edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all node features and edge attributes from the graphs
    all_node_features = []
    all_edge_attributes = []

    for graph in graph_list:
        all_node_features.append(graph.x)
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)

    # Stack all node features
    all_node_features_stacked = torch.cat(all_node_features, dim=0).numpy()

    # Initialize StandardScalers
    node_scaler = StandardScaler()
    node_scaler.fit(all_node_features_stacked)

    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale

    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale node features
        graph_node_features = graph.x.numpy()
        graph_node_features_scaled = node_scaler.transform(graph_node_features)
        graph.x = torch.tensor(graph_node_features_scaled, dtype=torch.float)

        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask a percentage of node features and edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=0.15):
    for graph in graph_list:
        # Mask a percentage of node features
        num_nodes = graph.x.shape[0]
        num_node_mask = max(1, int(np.ceil(masking_percentage * num_nodes))) if num_nodes > 0 else 0
        if num_node_mask > 0:
            node_mask_indices = np.random.choice(num_nodes, num_node_mask, replace=False)
            graph.x[node_mask_indices] = 0  # Zero out the node features for the selected nodes

        # Mask a percentage of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges

# Path to the supercell files folder
supercell_folder = 'supercells'

# Lists to store the graph objects and the corresponding atoms
graph_list_unperturbed = []
atoms_list_unperturbed = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        atoms.pbc = True  # Ensure PBCs are enabled

        # Extract the number from the filename
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list_unperturbed.append(graph)
        atoms_list_unperturbed.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Unperturbed graph list created with periodic boundary conditions accounted for.")

# First perturbed graph list
graph_list_set_1 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size)

# Second perturbed graph list
graph_list_set_2 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size)

# Apply scaling to each graph list
scale_graphs(graph_list_unperturbed)
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)

# Apply masking to the perturbed graph lists
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Three graph lists created: unperturbed, perturbed set 1, and perturbed set 2 with scaling and masking applied.")

In [None]:
#Graph generator:  next nearest neighbour, perturbation after edge connection, masking and scaling applies

import re
import os
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Function to compute edge_index based on positions and delta
def compute_edge_index(positions, delta):
    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)
    num_atoms = len(positions)
    edge_index = set()  # Use a set to avoid duplicate edges

    # For each node, find its nearest neighbors and next-nearest neighbors
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf

        # Get distances and sort them along with indices
        distances = dist_matrix[i]
        sorted_indices = np.argsort(distances)
        sorted_distances = distances[sorted_indices]

        # Nearest neighbor distance
        d1 = sorted_distances[0]
        nn_cutoff = d1 + delta

        # Indices of nearest neighbors within d1 + delta
        nn_indices = sorted_indices[sorted_distances <= nn_cutoff]

        # Add edges from nearest neighbors to i
        for j in nn_indices:
            edge_index.add((j, i))

        # Exclude nearest neighbors from consideration for next-nearest neighbors
        remaining_indices = sorted_indices[sorted_distances > nn_cutoff]
        remaining_distances = sorted_distances[sorted_distances > nn_cutoff]

        if len(remaining_distances) > 0:
            # Next nearest neighbor distance
            d2 = remaining_distances[0]
            nnn_cutoff = d2 + delta

            # Indices of next-nearest neighbors within d2 + delta
            nnn_indices = remaining_indices[remaining_distances <= nnn_cutoff]

            # Add edges from next-nearest neighbors to i
            for j in nnn_indices:
                edge_index.add((j, i))

    # Convert edge_index to a tensor
    if len(edge_index) > 0:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index
def compute_edge_attr(positions, edge_index):
    # Compute distances between connected nodes
    row, col = edge_index
    pos_row = positions[row]
    pos_col = positions[col]
    edge_distances = np.linalg.norm(pos_row - pos_col, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure with perturbation
def create_graph_from_structure(atoms, perturbation_size, delta, edge_index=None):
    # Get the positions of the atoms
    positions = atoms.get_positions()

    # If edge_index is not provided, compute it based on original positions
    if edge_index is None:
        edge_index = compute_edge_index(positions, delta)

    # Apply random perturbation to the positions
    if perturbation_size > 0:
        perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
        positions += perturbation  # Perturb the atomic positions

    # Compute edge attributes (distances) based on the perturbed positions
    edge_attr = compute_edge_attr(positions, edge_index)

    # Create node features: atomic mass and atomic number
    node_features = torch.tensor(
        [[atom.number] for atom in atoms], dtype=torch.float
    )  # Shape (num_atoms, 2)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to scale node features and edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all node features and edge attributes from the graphs
    all_node_features = []
    all_edge_attributes = []

    for graph in graph_list:
        all_node_features.append(graph.x)
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)

    # Stack all node features
    all_node_features_stacked = torch.cat(all_node_features, dim=0).numpy()

    # Initialize StandardScalers
    node_scaler = StandardScaler()
    node_scaler.fit(all_node_features_stacked)

    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale

    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale node features
        graph_node_features = graph.x.numpy()
        graph_node_features_scaled = node_scaler.transform(graph_node_features)
        graph.x = torch.tensor(graph_node_features_scaled, dtype=torch.float)

        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask a percentage of node features and edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=MASKING_PERCENTAGE):
    for graph in graph_list:
        # Mask a percentage of node features
        num_nodes = graph.x.shape[0]
        num_node_mask = max(1, int(np.ceil(masking_percentage * num_nodes))) if num_nodes > 0 else 0
        if num_node_mask > 0:
            node_mask_indices = np.random.choice(num_nodes, num_node_mask, replace=False)
            graph.x[node_mask_indices] = 0  # Zero out the node features for the selected nodes

        # Mask a percentage of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges

# Path to the supercell files folder
supercell_folder = 'supercells'

# Lists to store the graph objects
graph_list_set_1 = []
graph_list_set_2 = []
graph_list_original = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)

        # Extract the number from the filename
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found

        # Get positions (unperturbed)
        positions = atoms.get_positions()

        # Compute edge_index based on original positions
        edge_index = compute_edge_index(positions, delta)

        # Create graph_original with no perturbation
        graph_original = create_graph_from_structure(atoms, 0, delta, edge_index=edge_index)
        graph_original.label = label

        # Create two graphs with different perturbations
        graph_1 = create_graph_from_structure(atoms, 0.05, delta, edge_index=edge_index)
        graph_2 = create_graph_from_structure(atoms, 0.05, delta, edge_index=edge_index)
        graph_1.label = label
        graph_2.label = label

        # Append the graphs to the respective lists
        graph_list_set_1.append(graph_1)
        graph_list_set_2.append(graph_2)
        graph_list_original.append(graph_original)

        print(f"Graphs for {filename} created and added to the lists with label {label}.")

# Apply scaling to each graph list
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)
scale_graphs(graph_list_original)

# Apply masking to each graph list
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Two graph sets created with different random perturbations, scaling, and masking applied.")

In [None]:
#Graph generator: lattice vector and intra unit cell and PBC, perturbation, scaling and masking applied

import os
import re
import torch
import numpy as np
from ase.io import read
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler
from scipy.spatial import cKDTree

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Perturbation size
perturbation_size = 0.05  # Adjust as needed

# Function to compute edge_index based on positions and delta, considering PBCs
def compute_edge_index(atoms, delta, N):
    scaled_positions = atoms.get_scaled_positions()
    num_atoms = len(scaled_positions)
    cell = atoms.get_cell()

    # Build a KDTree of the scaled positions
    tree = cKDTree(scaled_positions)

    # Define shifts in scaled coordinates (±a and ±b)
    shifts_scaled = np.array([
        [1/N, 0, 0],
        [-1/N, 0, 0],
        [0, 1/N, 0],
        [0, -1/N, 0],
    ])

    edge_index = set()

    # Existing lattice vector connections
    for i in range(num_atoms):
        s_i = scaled_positions[i]
        for shift_scaled in shifts_scaled:
            s_j_candidate = s_i + shift_scaled
            s_j_candidate = np.mod(s_j_candidate, 1.0)  # Apply PBCs
            # Query KDTree to find atoms near s_j_candidate
            idxs = tree.query_ball_point(s_j_candidate, r=delta)
            for idx in idxs:
                if idx != i:  # Exclude self-loops
                    edge_index.add((i, idx))
                    edge_index.add((idx, i))  # Add reverse edge

    # Determine the number of atoms per unit cell
    total_unit_cells = N * N  # For a 2D lattice
    atoms_per_unit_cell = num_atoms // total_unit_cells

    # Check for consistency
    if atoms_per_unit_cell * total_unit_cells != num_atoms:
        raise ValueError("Number of atoms per unit cell is not consistent with total atoms and supercell dimensions.")

    # Group atoms by unit cell based on their order in the .xyz file
    for i in range(0, num_atoms, atoms_per_unit_cell):
        group = list(range(i, i + atoms_per_unit_cell))
        # Add edges within each unit cell
        for idx1 in group:
            for idx2 in group:
                if idx1 != idx2:
                    edge_index.add((idx1, idx2))
                    edge_index.add((idx2, idx1))  # Ensure symmetry

    # Convert edge_index to tensor
    if edge_index:
        edge_index = torch.tensor(list(edge_index), dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    return edge_index

# Function to compute edge attributes based on positions and edge_index, considering PBCs
def compute_edge_attr(atoms, edge_index):
    row, col = edge_index
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()

    # Compute displacement vectors considering PBCs
    delta_scaled = scaled_positions[row.numpy()] - scaled_positions[col.numpy()]
    delta_scaled -= np.round(delta_scaled)
    displacement_vectors = delta_scaled @ cell
    edge_distances = np.linalg.norm(displacement_vectors, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta, N):
    # Compute edge_index
    edge_index = compute_edge_index(atoms, delta, N)

    # Compute edge attributes
    edge_attr = compute_edge_attr(atoms, edge_index)

    # Create dummy node features (ones)
    num_nodes = len(atoms)
    node_features = torch.ones((num_nodes, 1), dtype=torch.float)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to apply perturbations to a graph list
def perturb_graphs(graph_list, atoms_list, perturbation_size):
    perturbed_graph_list = []
    for graph, atoms in zip(graph_list, atoms_list):
        # Clone the graph to avoid modifying the original
        perturbed_graph = graph.clone()

        # Get the positions of the atoms from the atoms object
        positions = atoms.get_positions().copy()  # Original positions

        # Apply random perturbation to the positions
        if perturbation_size > 0:
            perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
            positions += perturbation  # Perturb the atomic positions

        # Update edge attributes based on perturbed positions
        # Since edge_index remains the same, we recompute edge_attr
        perturbed_atoms = atoms.copy()
        perturbed_atoms.set_positions(positions)
        perturbed_atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        perturbed_edge_attr = compute_edge_attr(perturbed_atoms, graph.edge_index)
        perturbed_graph.edge_attr = perturbed_edge_attr

        perturbed_graph_list.append(perturbed_graph)

    return perturbed_graph_list

# Function to scale edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all edge attributes from the graphs
    all_edge_attributes = []

    for graph in graph_list:
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)

    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale

    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask a percentage of edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=MASKING_PERCENTAGE):
    for graph in graph_list:
        # Mask a percentage of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges

# Path to the supercell files folder
supercell_folder = 'supercells_flatband'

# Lists to store the graph objects and the corresponding atoms
graph_list_unperturbed = []
atoms_list_unperturbed = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        atoms.pbc = [True, True, False]  # Ensure PBCs are enabled

        # Extract the number from the filename
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found

        # Create graph with PBCs
        graph = create_graph_from_structure(atoms, delta, N=SUPERCELL_SIZE)
        graph.label = label

        # Append the graph and atoms to their respective lists
        graph_list_unperturbed.append(graph)
        atoms_list_unperturbed.append(atoms)

        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Unperturbed graph list created with periodic boundary conditions accounted for.")

# First perturbed graph list
graph_list_set_1 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size)

# Second perturbed graph list
graph_list_set_2 = perturb_graphs(graph_list_unperturbed, atoms_list_unperturbed, perturbation_size)

# Apply scaling to each graph list
scale_graphs(graph_list_unperturbed)
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)

# Apply masking to the perturbed graph lists
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Three graph lists created: unperturbed, perturbed set 1, and perturbed set 2 with scaling and masking applied.")


In [None]:
#Graph generator: nearest neighbour and boundary overlap, perturbation after edge connection, masking and scaling applied
# Global variables should be defined before they are used in function defaults

import os
import re
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix, cKDTree
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Function to compute edge_index and edge_type based on positions, delta, N, and supercell boundary edges
def compute_edge_index(positions, delta, atoms, N=SUPERCELL_SIZE):
    num_atoms = len(positions)
    edge_index = []
    edge_types = []  # 0 for internal edges, 1 for supercell boundary edges

    # Get the lattice vectors a and b
    if atoms.cell is None or len(atoms.cell) < 2:
        raise ValueError("Lattice vectors a and b are not defined in the atoms object.")
    a = atoms.cell[0] / N
    b = atoms.cell[1] / N

    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)

    # For each node, find its nearest neighbor(s) within a tolerance
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf
        # Find the minimum distance for node i
        min_dist = np.min(dist_matrix[i])
        # Define a cutoff distance as min_dist + delta
        cutoff_dist = min_dist + delta
        # Find indices (j) where distance is less than or equal to cutoff_dist
        nearest_neighbors = np.where(dist_matrix[i] <= cutoff_dist)[0]
        for j in nearest_neighbors:
            edge_index.append([j, i])
            edge_types.append(0)  # Internal edge

    # Now connect edge atoms across the supercell boundaries
    # Define the shifts
    shifts = [(N - 1) * a, -(N - 1) * a, (N - 1) * b, -(N - 1) * b]
    tolerance = 0.1  # 0.2 angstroms tolerance for matching positions

    # Build a KDTree for efficient nearest-neighbor search
    tree = cKDTree(positions)

    for shift in shifts:
        # Shift all positions by the current lattice vector
        shifted_positions = positions + shift
        # Find atoms in the original positions that are within the tolerance of the shifted positions
        indices_list = tree.query_ball_point(shifted_positions, tolerance)
        for i, indices in enumerate(indices_list):
            for j in indices:
                if i != j:
                    edge_index.append([j, i])
                    edge_types.append(1)  # Supercell boundary edge

    # Convert edge_index and edge_types to tensors
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_types = torch.tensor(edge_types, dtype=torch.long)
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_types = torch.empty((0,), dtype=torch.long)
    return edge_index, edge_types

# Function to compute edge attributes based on positions, edge_index, and edge_types
def compute_edge_attr(positions, edge_index, edge_types):
    # Compute distances between connected nodes
    row, col = edge_index
    pos_row = positions[row]
    pos_col = positions[col]
    edge_distances = np.linalg.norm(pos_row - pos_col, axis=1)

    # Set edge distances to zero for supercell boundary edges
    edge_distances[edge_types == 1] = 0.0

    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure with perturbation and fixed edge_index
def create_graph_from_structure(atoms, perturbation_size, delta, N=SUPERCELL_SIZE, edge_index=None, edge_types=None):
    # Get the positions of the atoms
    positions = atoms.get_positions()

    # If edge_index is not provided, compute it based on original positions
    if edge_index is None or edge_types is None:
        edge_index, edge_types = compute_edge_index(positions, delta, atoms, N)

    # Apply random perturbation to the positions
    if perturbation_size > 0:
        perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
        positions += perturbation  # Perturb the atomic positions

    # Compute edge attributes (distances) based on the perturbed positions
    edge_attr = compute_edge_attr(positions, edge_index, edge_types)

    # Create node features: only atomic number
    node_features = torch.tensor(
        [[atom.number] for atom in atoms], dtype=torch.float
    )  # Shape (num_atoms, 1)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to scale node features and edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all node features and edge attributes from the graphs
    all_node_features = []
    all_edge_attributes = []

    for graph in graph_list:
        all_node_features.append(graph.x)
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)

    # Stack all node features
    all_node_features_stacked = torch.cat(all_node_features, dim=0).numpy()

    # Initialize StandardScalers
    node_scaler = StandardScaler()
    node_scaler.fit(all_node_features_stacked)

    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale

    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale node features
        graph_node_features = graph.x.numpy()
        graph_node_features_scaled = node_scaler.transform(graph_node_features)
        graph.x = torch.tensor(graph_node_features_scaled, dtype=torch.float)

        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask a percentage of node features and edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=MASKING_PERCENTAGE):
    for graph in graph_list:
        # Mask a percentage of node features
        num_nodes = graph.x.shape[0]
        num_node_mask = max(1, int(np.ceil(masking_percentage * num_nodes))) if num_nodes > 0 else 0
        if num_node_mask > 0:
            node_mask_indices = np.random.choice(num_nodes, num_node_mask, replace=False)
            graph.x[node_mask_indices] = 0  # Zero out the node features for the selected nodes

        # Mask a percentage of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges

# Path to the supercell files folder
supercell_folder = 'supercells'

# Lists to store the graph objects
graph_list_set_1 = []
graph_list_set_2 = []
graph_list_original = []

# Tolerance delta for including nearest neighbors
delta = 0.1  # Adjust this value as needed

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)

        # Extract the number from the filename
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found

        # Get positions (unperturbed)
        positions = atoms.get_positions()

        # Compute edge_index and edge_types based on original positions
        edge_index, edge_types = compute_edge_index(positions, delta, atoms, N=SUPERCELL_SIZE)

        # Create graph_original with no perturbation
        graph_original = create_graph_from_structure(atoms, 0, delta, N=SUPERCELL_SIZE, edge_index=edge_index, edge_types=edge_types)
        graph_original.label = label

        # Create two graphs with different perturbations and fixed edge_index and edge_types
        graph_1 = create_graph_from_structure(atoms, 0.05, delta, N=SUPERCELL_SIZE, edge_index=edge_index, edge_types=edge_types)
        graph_2 = create_graph_from_structure(atoms, 0.05, delta, N=SUPERCELL_SIZE, edge_index=edge_index, edge_types=edge_types)
        graph_1.label = label
        graph_2.label = label

        # Append the graphs to the respective lists
        graph_list_set_1.append(graph_1)
        graph_list_set_2.append(graph_2)
        graph_list_original.append(graph_original)

        print(f"Graphs for {filename} created and added to the lists with label {label}.")

# Apply scaling to each graph list
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)
scale_graphs(graph_list_original)

# Apply masking to each graph list (except the original graphs)
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Two graph sets created with different random perturbations, supercell boundary edges, scaling, and masking applied.")


In [None]:
#Graph generator: nearest neighbour and boundary overlap, no masking, perturbation or scaling

import os
import re
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix, cKDTree
from torch_geometric.data import Data

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta=0.1, N=4):
    # Get the positions of the atoms
    positions = atoms.get_positions()
    num_atoms = len(positions)
    edge_index = []
    edge_attr = []
    
    # Get the lattice vectors a and b
    if atoms.cell is None or len(atoms.cell) < 2:
        raise ValueError("Lattice vectors a and b are not defined in the atoms object.")
    a = (atoms.cell[0])/N
    b = (atoms.cell[1])/N
    
    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)
    
    # For each node, find its nearest neighbor(s) within a tolerance
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf
        # Find the minimum distance for node i
        min_dist = np.min(dist_matrix[i])
        # Define a cutoff distance as min_dist + delta
        cutoff_dist = min_dist + delta
        # Find indices (j) where distance is less than or equal to cutoff_dist
        nearest_neighbors = np.where(dist_matrix[i] <= cutoff_dist)[0]
        for j in nearest_neighbors:
            # Add edges in both directions since it's an undirected graph
            edge_index.append([i, j])
            edge_index.append([j, i])
            # Add the corresponding distance as the edge attribute
            edge_attr.append(dist_matrix[i, j])
            edge_attr.append(dist_matrix[i, j])  # Same distance in undirected graph
    
    # Now connect edge atoms across the supercell boundaries
    # Define the shifts
    shifts = [(N - 1) * a, -(N - 1) * a, (N - 1) * b, -(N - 1) * b]
    tolerance = 0.1  # 0.1 angstroms tolerance for matching positions
    
    # Build a KDTree for efficient nearest-neighbor search
    tree = cKDTree(positions)
    
    for shift in shifts:
        # Shift all positions by the current lattice vector
        shifted_positions = positions + shift
        # Find atoms in the original positions that are within the tolerance of the shifted positions
        indices_list = tree.query_ball_point(shifted_positions, tolerance)
        for i, indices in enumerate(indices_list):
            for j in indices:
                if i != j:
                    # Add edges in both directions
                    edge_index.append([i, j])
                    edge_index.append([j, i])
                    # Edge attribute is zero since they are connected across the boundary
                    edge_attr.append(0.0)
                    edge_attr.append(0.0)
    
    # Convert edge_index and edge_attr to tensors
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float)
    
    # Create node features: atomic mass and atomic number
    node_features = torch.tensor(
        [[atom.mass, atom.number] for atom in atoms], dtype=torch.float
    )
    
    # Create PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )
    
    return graph

# Path to the supercell files folder
supercell_folder = 'supercells'

# List to store the graph objects
graph_list = []

# Tolerance delta for including nearest neighbors
delta = 0.1  # Adjust this value as needed

# Supercell dimension N
N = 4  # Adjust this value as needed

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        
        # Extract the number from the filename to use as a label
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found
        
        # Create the graph with the specified delta and N
        graph = create_graph_from_structure(atoms, delta=delta, N=N)
        
        # Add the label to the graph
        graph.label = label
        
        # Append the graph to the list
        graph_list.append(graph)
        
        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Graph generation complete. Each node is connected to its nearest neighbor(s) within the specified tolerance, including periodic connections across supercell boundaries.")


In [None]:
#Graph generator: nearest neighbour, perturbation after edge connection, masking and scaling applies

import re
import os
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Function to compute edge_index based on positions and delta
def compute_edge_index(positions, delta):
    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)
    num_atoms = len(positions)
    edge_index = []

    # For each node, find its nearest neighbor(s) within a tolerance
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf
        # Find the minimum distance for node i
        min_dist = np.min(dist_matrix[i])
        # Define a cutoff distance as min_dist + delta
        cutoff_dist = min_dist + delta
        # Find indices (j) where distance is less than or equal to cutoff_dist
        nearest_neighbors = np.where(dist_matrix[i] <= cutoff_dist)[0]
        for j in nearest_neighbors:
            # Add edges in both directions since it's an undirected graph
            #edge_index.append([i, j])
            edge_index.append([j, i])

    # Convert edge_index to a tensor if there are edges
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
    return edge_index

# Function to compute edge attributes based on positions and edge_index
def compute_edge_attr(positions, edge_index):
    # Compute distances between connected nodes
    row, col = edge_index
    pos_row = positions[row]
    pos_col = positions[col]
    edge_distances = np.linalg.norm(pos_row - pos_col, axis=1)
    edge_attr = torch.tensor(edge_distances, dtype=torch.float).unsqueeze(1)
    return edge_attr

# Function to create a PyTorch Geometric Data object from an atomic structure with perturbation
def create_graph_from_structure(atoms, perturbation_size, delta, edge_index=None):
    # Get the positions of the atoms
    positions = atoms.get_positions()

    # If edge_index is not provided, compute it based on original positions
    if edge_index is None:
        edge_index = compute_edge_index(positions, delta)

    # Apply random perturbation to the positions
    if perturbation_size > 0:
        perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
        positions += perturbation  # Perturb the atomic positions

    # Compute edge attributes (distances) based on the perturbed positions
    edge_attr = compute_edge_attr(positions, edge_index)

    # Create node features: atomic mass and atomic number
    node_features = torch.tensor(
        [[atom.number] for atom in atoms], dtype=torch.float
    )  # Shape (num_atoms, 2)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to scale node features and edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all node features and edge attributes from the graphs
    all_node_features = []
    all_edge_attributes = []

    for graph in graph_list:
        all_node_features.append(graph.x)
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)

    # Stack all node features
    all_node_features_stacked = torch.cat(all_node_features, dim=0).numpy()

    # Initialize StandardScalers
    node_scaler = StandardScaler()
    node_scaler.fit(all_node_features_stacked)

    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale

    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale node features
        graph_node_features = graph.x.numpy()
        graph_node_features_scaled = node_scaler.transform(graph_node_features)
        graph.x = torch.tensor(graph_node_features_scaled, dtype=torch.float)

        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask a percentage of node features and edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=MASKING_PERCENTAGE):
    for graph in graph_list:
        # Mask a percentage of node features
        num_nodes = graph.x.shape[0]
        num_node_mask = max(1, int(np.ceil(masking_percentage * num_nodes))) if num_nodes > 0 else 0
        if num_node_mask > 0:
            node_mask_indices = np.random.choice(num_nodes, num_node_mask, replace=False)
            graph.x[node_mask_indices] = 0  # Zero out the node features for the selected nodes

        # Mask a percentage of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges

# Path to the supercell files folder
supercell_folder = 'supercells'

# Lists to store the graph objects
graph_list_set_1 = []
graph_list_set_2 = []
graph_list_original = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)

        # Extract the number from the filename
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found

        # Get positions (unperturbed)
        positions = atoms.get_positions()

        # Compute edge_index based on original positions
        edge_index = compute_edge_index(positions, delta)

        # Create graph_original with no perturbation
        graph_original = create_graph_from_structure(atoms, 0, delta, edge_index=edge_index)
        graph_original.label = label

        # Create two graphs with different perturbations
        graph_1 = create_graph_from_structure(atoms, 0.05, delta, edge_index=edge_index)
        graph_2 = create_graph_from_structure(atoms, 0.05, delta, edge_index=edge_index)
        graph_1.label = label
        graph_2.label = label

        # Append the graphs to the respective lists
        graph_list_set_1.append(graph_1)
        graph_list_set_2.append(graph_2)
        graph_list_original.append(graph_original)

        print(f"Graphs for {filename} created and added to the lists with label {label}.")

# Apply scaling to each graph list
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)
scale_graphs(graph_list_original)

# Apply masking to each graph list
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Two graph sets created with different random perturbations, scaling, and masking applied.")


In [None]:
#Graph generator: nearest neighbour, perturbation, masking and scaling applied
# Graph set generator: Converts the supercells into undirected graphs connecting them based on nearest neighbors within a tolerance.
# Then, creates 2 versions of each graph both with independently random perturbations and masked nodes/features applied to the graphs.
# The graphs are scaled then stored in two lists.

import re
import os
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Define the masking percentage
MASKING_PERCENTAGE = 0.1  # Mask 10% of nodes and edges

# Define the tolerance delta for nearest neighbors
delta = 0.1  # Adjust as needed

# Function to create a PyTorch Geometric Data object from an atomic structure with perturbation
def create_graph_from_structure(atoms, perturbation_size, delta):
    # Get the positions of the atoms
    positions = atoms.get_positions()
    
    # Apply random perturbation to the positions
    perturbation = np.random.uniform(-perturbation_size, perturbation_size, positions.shape)
    positions += perturbation  # Perturb the atomic positions
    
    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)
    num_atoms = len(positions)
    edge_index = []
    edge_attr = []
    
    # For each node, find its nearest neighbor(s) within a tolerance
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf
        # Find the minimum distance for node i
        min_dist = np.min(dist_matrix[i])
        # Define a cutoff distance as min_dist + delta
        cutoff_dist = min_dist + delta
        # Find indices (j) where distance is less than or equal to cutoff_dist
        nearest_neighbors = np.where(dist_matrix[i] <= cutoff_dist)[0]
        for j in nearest_neighbors:
            # Add edges in both directions since it's an undirected graph
            edge_index.append([i, j])
            edge_index.append([j, i])
            # Add the corresponding distance as the edge attribute
            edge_attr.append(dist_matrix[i, j])
            edge_attr.append(dist_matrix[i, j])  # Same distance in undirected graph

    # Convert edge_index to a tensor if there are edges
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float)
    
    # Create node features: atomic mass and atomic number
    node_features = torch.tensor(
        [[atom.mass, atom.number] for atom in atoms], dtype=torch.float
    )  # Shape (num_atoms, 2)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to scale node features and edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all node features and edge attributes from the graphs
    all_node_features = []
    all_edge_attributes = []
    
    for graph in graph_list:
        all_node_features.append(graph.x)
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)
    
    # Stack all node features
    all_node_features_stacked = torch.cat(all_node_features, dim=0).numpy()
    
    # Initialize StandardScalers
    node_scaler = StandardScaler()
    node_scaler.fit(all_node_features_stacked)
    
    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale
    
    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale node features
        graph_node_features = graph.x.numpy()
        graph_node_features_scaled = node_scaler.transform(graph_node_features)
        graph.x = torch.tensor(graph_node_features_scaled, dtype=torch.float)
        
        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask a percentage of node features and edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=MASKING_PERCENTAGE):
    for graph in graph_list:
        # Mask a percentage of node features
        num_nodes = graph.x.shape[0]
        num_node_mask = max(1, int(np.ceil(masking_percentage * num_nodes))) if num_nodes > 0 else 0
        if num_node_mask > 0:
            node_mask_indices = np.random.choice(num_nodes, num_node_mask, replace=False)
            graph.x[node_mask_indices] = 0  # Zero out the node features for the selected nodes

        # Mask a percentage of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges

# Path to the supercell files folder
supercell_folder = 'supercells'

# Lists to store the graph objects
graph_list_set_1 = []
graph_list_set_2 = []
graph_list_original = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        
        # Extract the number from the filename
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found
        
        # Create two graphs with different perturbations
        graph_1 = create_graph_from_structure(atoms, 0.05, delta)
        graph_2 = create_graph_from_structure(atoms, 0.05, delta)
        graph_original = create_graph_from_structure(atoms, 0, delta)
        
        # Add the label to each graph
        graph_1.label = label
        graph_2.label = label
        graph_original.label = label
        
        # Append the graphs to the respective lists
        graph_list_set_1.append(graph_1)
        graph_list_set_2.append(graph_2)
        graph_list_original.append(graph_original)
        
        print(f"Graphs for {filename} created and added to the lists with label {label}.")

# Apply scaling to each graph list
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)
scale_graphs(graph_list_original)

# Apply masking to each graph list
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Two graph sets created with different random perturbations, scaling, and masking applied.")

In [None]:
#Graph generator: nearest neighbour, no masking, perturbation or scaling
# Graph generator: Converts supercell structures into undirected graphs by connecting each node to its nearest neighbor(s) within a tolerance.
import os
import re
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix
from torch_geometric.data import Data

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, delta=0.1):
    # Get the positions of the atoms
    positions = atoms.get_positions()
    
    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)
    num_atoms = len(positions)
    edge_index = []
    edge_attr = []
    
    # For each node, find its nearest neighbor(s) within a tolerance
    for i in range(num_atoms):
        # Exclude self-distance by setting diagonal to infinity
        dist_matrix[i, i] = np.inf
        # Find the minimum distance for node i
        min_dist = np.min(dist_matrix[i])
        # Define a cutoff distance as min_dist + delta
        cutoff_dist = min_dist + delta
        # Find indices (j) where distance is less than or equal to cutoff_dist
        nearest_neighbors = np.where(dist_matrix[i] <= cutoff_dist)[0]
        for j in nearest_neighbors:
            # Add edges in both directions since it's an undirected graph
            edge_index.append([i, j])
            edge_index.append([j, i])
            # Add the corresponding distance as the edge attribute
            edge_attr.append(dist_matrix[i, j])
            edge_attr.append(dist_matrix[i, j])  # Same distance in undirected graph
    
    # Convert edge_index and edge_attr to tensors
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float)
    
    # Create node features: atomic mass and atomic number
    node_features = torch.tensor(
        [[atom.mass, atom.number] for atom in atoms], dtype=torch.float
    )
    
    # Create PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )
    
    return graph

# Path to the supercell files folder
supercell_folder = 'supercells'

# List to store the graph objects
graph_list = []

# Tolerance delta for including nearest neighbors
delta = 0.1  # Adjust this value as needed

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        
        # Extract the number from the filename to use as a label
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found
        
        # Create the graph with the specified delta
        graph = create_graph_from_structure(atoms, delta=delta)
        
        # Add the label to the graph
        graph.label = label
        
        # Append the graph to the list
        graph_list.append(graph)
        
        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Graph generation complete. Each node is connected to its nearest neighbor(s) within the specified tolerance.")



In [None]:
#Graph generator: hard cutoff, perturbation, masking and scaling applied
# Graph set generator: takes the supercells and turns them into undirected graphs connecting them based on a fixed cutoff.
# Then, creates 2 versions of each graph both with independently random perturbations and masked nodes/features applied to the graphs
# The graphs are scaled then Stored in two lists
import re
import os
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler

# Function to create a PyTorch Geometric Data object from an atomic structure with perturbation
def create_graph_from_structure(atoms, CUT_OFF_DISTANCE, pertubation_size):
    # Get the positions of the atoms
    positions = atoms.get_positions()
    
    # Apply random perturbation to the positions
    perturbation = np.random.uniform(-pertubation_size, pertubation_size, positions.shape)
    positions += perturbation  # Perturb the atomic positions
    
    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)
    
    # Collect the edges and the edge distances based on the cutoff distance
    edge_index = []
    edge_attr = []
    num_atoms = len(positions)
    for i in range(num_atoms):
        for j in range(i+1, num_atoms):  # Only check each pair once
            if dist_matrix[i, j] <= CUT_OFF_DISTANCE:
                # Add edges in both directions since it's an undirected graph
                edge_index.append([i, j])
                edge_index.append([j, i])
                # Add the corresponding distance as the edge attribute
                edge_attr.append(dist_matrix[i, j])
                edge_attr.append(dist_matrix[i, j])

    # Convert edge_index to a tensor if there are edges
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  # Transpose and make contiguous
        edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float)
    
    # Create node features: atomic mass and atomic number
    node_features = torch.tensor(
        [[atom.mass, atom.number] for atom in atoms], dtype=torch.float
    )  # Shape (num_atoms, 2)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Function to scale node features and edge attributes in each graph list
def scale_graphs(graph_list):
    # Collect all node features and edge attributes from the graphs
    all_node_features = []
    all_edge_attributes = []
    
    for graph in graph_list:
        all_node_features.append(graph.x)
        # Only append edge attributes if they are not empty
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0:
            all_edge_attributes.append(graph.edge_attr)
    
    # Stack all node features
    all_node_features_stacked = torch.cat(all_node_features, dim=0).numpy()
    
    # Initialize StandardScalers
    node_scaler = StandardScaler()
    node_scaler.fit(all_node_features_stacked)
    
    # Fit edge scaler only if there are edge attributes
    if len(all_edge_attributes) > 0:
        all_edge_attributes_stacked = torch.cat(all_edge_attributes, dim=0).numpy()
        edge_scaler = StandardScaler()
        edge_scaler.fit(all_edge_attributes_stacked)
    else:
        edge_scaler = None  # No edges to scale
    
    # Apply the scaling to each graph
    for graph in graph_list:
        # Scale node features
        graph_node_features = graph.x.numpy()
        graph_node_features_scaled = node_scaler.transform(graph_node_features)
        graph.x = torch.tensor(graph_node_features_scaled, dtype=torch.float)
        
        # Scale edge attributes if they exist
        if graph.edge_attr is not None and graph.edge_attr.shape[0] > 0 and edge_scaler is not None:
            graph_edge_attributes = graph.edge_attr.numpy()
            graph_edge_attributes_scaled = edge_scaler.transform(graph_edge_attributes)
            graph.edge_attr = torch.tensor(graph_edge_attributes_scaled, dtype=torch.float)
        else:
            # If there are no edges, ensure edge_attr is a tensor of appropriate shape
            graph.edge_attr = torch.tensor([], dtype=torch.float).reshape(0, 1)

# Function to mask 10% of node features and edge attributes in each graph
def mask_graphs(graph_list, masking_percentage=MASKING_PERCENTAGE):
    for graph in graph_list:
        # Mask 10% of node features
        num_nodes = graph.x.shape[0]
        num_node_mask = max(1, int(np.ceil(masking_percentage * num_nodes))) if num_nodes > 0 else 0
        if num_node_mask > 0:
            node_mask_indices = np.random.choice(num_nodes, num_node_mask, replace=False)
            graph.x[node_mask_indices] = 0  # Zero out the node features for the selected nodes

        # Mask 10% of edge attributes
        num_edges = graph.edge_attr.shape[0]
        num_edge_mask = max(1, int(np.ceil(masking_percentage * num_edges))) if num_edges > 0 else 0
        if num_edge_mask > 0:
            edge_mask_indices = np.random.choice(num_edges, num_edge_mask, replace=False)
            graph.edge_attr[edge_mask_indices] = 0  # Zero out the edge attributes for the selected edges


# Path to the supercell files folder
supercell_folder = 'supercells'

# Lists to store the graph objects
graph_list_set_1 = []
graph_list_set_2 = []
graph_list_original = []

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        
        # Extract the number from the filename
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found
        
        # Create two graphs with different perturbations
        graph_1 = create_graph_from_structure(atoms, CUT_OFF_DISTANCE, 0.05)
        graph_2 = create_graph_from_structure(atoms, CUT_OFF_DISTANCE, 0.05)
        graph_original = create_graph_from_structure(atoms, CUT_OFF_DISTANCE, 0)
        
        # Add the label to each graph
        graph_1.label = label
        graph_2.label = label
        graph_original.label = label
        
        # Append the graphs to the respective lists
        graph_list_set_1.append(graph_1)
        graph_list_set_2.append(graph_2)
        graph_list_original.append(graph_original)
        
        print(f"Graphs for {filename} created and added to the lists with label {label}.")

# Apply scaling to each graph list
scale_graphs(graph_list_set_1)
scale_graphs(graph_list_set_2)
scale_graphs(graph_list_original)

# Apply masking to each graph list
mask_graphs(graph_list_set_1)
mask_graphs(graph_list_set_2)

print("Two graph sets created with different random perturbations, scaling, and masking applied.")

In [None]:
#Graph Generator: hard cutoff, no masking, perturbation or scaling
# Graph generator: Converts supercell structures into undirected graphs based on a fixed cutoff distance.
import os
import re
import torch
import numpy as np
from ase.io import read
from scipy.spatial import distance_matrix
from torch_geometric.data import Data

# Function to create a PyTorch Geometric Data object from an atomic structure
def create_graph_from_structure(atoms, CUT_OFF_DISTANCE):
    # Get the positions of the atoms
    positions = atoms.get_positions()
    
    # Calculate the distance matrix between all atoms
    dist_matrix = distance_matrix(positions, positions)
    
    # Collect the edges and the edge distances based on the cutoff distance
    edge_index = []
    edge_attr = []
    num_atoms = len(positions)
    for i in range(num_atoms):
        for j in range(i + 1, num_atoms):  # Only check each pair once
            if dist_matrix[i, j] <= CUT_OFF_DISTANCE:
                # Add edges in both directions since it's an undirected graph
                edge_index.append([i, j])
                edge_index.append([j, i])
                # Add the corresponding distance as the edge attribute
                edge_attr.append(dist_matrix[i, j])
                edge_attr.append(dist_matrix[i, j])

    # Convert edge_index to a tensor if there are edges
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  # Transpose and make contiguous
        edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
    else:
        # Handle graphs with no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float)
    
    # Create node features: atomic mass and atomic number
    node_features = torch.tensor(
        [[atom.mass, atom.number] for atom in atoms], dtype=torch.float
    )  # Shape (num_atoms, 2)

    # Create a PyTorch Geometric Data object
    graph = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

    return graph

# Path to the supercell files folder
supercell_folder = 'supercells'

# List to store the graph objects
graph_list = []

# Define the cutoff distance (adjust as needed)
CUT_OFF_DISTANCE = 5.0  # Example value; set according to your requirements

# Iterate over all XYZ files in the supercells folder
for filename in os.listdir(supercell_folder):
    if filename.endswith('.xyz'):
        # Read the structure file
        filepath = os.path.join(supercell_folder, filename)
        atoms = read(filepath)
        
        # Extract the number from the filename to use as a label
        match = re.search(r'\d+', filename)
        if match:
            label = int(match.group())
        else:
            label = None  # Handle cases where no number is found
        
        # Create the graph without any perturbations
        graph = create_graph_from_structure(atoms, CUT_OFF_DISTANCE)
        
        # Add the label to the graph
        graph.label = label
        
        # Append the graph to the list
        graph_list.append(graph)
        
        print(f"Graph for {filename} created and added to the list with label {label}.")

print("Graph generation complete. All graphs are connected using the hard cutoff distance without perturbations, masking, or scaling.")


In [None]:
# Function to calculate the average number of nodes and edges in a list of graphs
def calculate_average_nodes_edges(graph_list):
    total_nodes = 0
    total_edges = 0
    num_graphs = len(graph_list)
    
    # Loop through each graph in the list
    for graph in graph_list:
        num_nodes = graph.x.shape[0]  # Number of nodes (atoms)
        num_edges = graph.edge_index.shape[1] // 2  # Number of edges (since it's undirected, divide by 2)
        
        total_nodes += num_nodes
        total_edges += num_edges
    
    # Calculate averages
    average_nodes = total_nodes / num_graphs if num_graphs > 0 else 0
    average_edges = total_edges / num_graphs if num_graphs > 0 else 0
    
    return average_nodes, average_edges

# Function to check if every node is connected to at least one edge in each graph
def check_node_connections(graph_list):
    graphs_with_isolated_nodes = []
    
    for graph_idx, graph in enumerate(graph_list):
        num_nodes = graph.x.shape[0]  # Number of nodes (atoms)
        
        # Extract the edge index tensor (shape [2, num_edges])
        edge_index = graph.edge_index
        
        # Create a set to track nodes that have at least one connection
        connected_nodes = set(edge_index[0].tolist()) | set(edge_index[1].tolist())  # Combine both directions
        
        # Check if all nodes are connected
        isolated_nodes = [node for node in range(num_nodes) if node not in connected_nodes]
        
        if isolated_nodes:
            graphs_with_isolated_nodes.append((graph_idx, isolated_nodes))
    
    return graphs_with_isolated_nodes

# Calculate the average number of nodes and edges
average_nodes, average_edges = calculate_average_nodes_edges(graph_list)

# Output the average results
print(f"Average number of nodes per graph: {average_nodes}")
print(f"Average number of edges per graph: {average_edges}")

# Check for graphs with isolated nodes
graphs_with_isolated_nodes = check_node_connections(graph_list)

if graphs_with_isolated_nodes:
    print(f"Graphs with isolated nodes (graph index, isolated node indices): {graphs_with_isolated_nodes}")
    print(f"Number of graphs with isolated nodes: {len(graphs_with_isolated_nodes)}")
else:
    print("All graphs have every node connected with at least one edge.")


In [None]:
graph_list_set_1


In [None]:
# Loop through graph_list_set_1 and print the edge attributes for each graph
for i, graph in enumerate(graph_list_set_1):
    print(f"Graph {i + 1}:")
    
    # Check if there are edge attributes
    edge_attributes = graph.edge_attr.numpy()  # Convert to numpy array for easier printing
    print(edge_attributes)
    
    print("-" * 40)  # Separator for readability