In [None]:
# Model Training and Fingerprint Generation Workflow: takes the perturbed graph lists generated in 3) to train a GNN encoder 
# then generates a set of final embeddings using the unperturbed graph list and the trained model.
# Naming: saves the weights of the trained model as '{MODEL_NAME}.pth', the final embeddings as 'embeddings_{MODEL_NAME}.pt'
# and the final embedding labels as 'labels_{MODEL_NAME}.pt'

In [None]:
# Global Variables:

MODEL_NAME = 'Test'
SUPERCELL_SIZE = 3 # Size of supercell used in graph list 1 in (NxN) unit cells. Same as SUPERCELL_SIZE_1 in 3)
BATCH_SIZE = 48 # Number of materials ran through the network at a time
NUM_EPOCHS = 1000 # How many times each trainer runs through the training data
LEARNING_RATE = 1e-4 # Adjusts how sensitive the network is when changing the weights
EMBEDDING_DIMENSION = 192 # How many dimensions the final vector (structural fingerprint)has
HIDDEN_DIMENSION = 192 # How many nodes in the hidden layer of the neural network
LAYER_NUMBER = 6
DROPOUT_RATE = 0.1 # set the rate of dropout during training
WEIGHT_DECAY = 1e-5 # sets the level of weight decay during training
TEMPERATURE = 0.07 # Quanitifies how much negative samples are 'pushed' from positive samples

In [None]:
# Graph list loader: loads the graph lists generated in 3)

import torch

# Folder location of graph lists
INPUT_FOLDER = 'Graphs/2DMatpedia Sublattices'

# Load unperturbed graph list 1
graph_list_unperturbed_1 = torch.load(f"{INPUT_FOLDER}\graph_list_unperturbed_1.pt")

# Load the first perturbed graph list
graph_list_set_1 = torch.load(f"{INPUT_FOLDER}\graph_list_set_1.pt")

# Load the second perturbed graph list
graph_list_set_2 = torch.load(f"{INPUT_FOLDER}\graph_list_set_2.pt")

print("Successfully loaded all three graph lists.")

In [None]:
# GNN Encoder: takes a graph and carries out hyperedge message passing using a custom message passing layer 
# and generates an embedding or fingerprint using global attention pooling.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, GlobalAttention
from torch_geometric.utils import add_self_loops
from torch_geometric.nn.norm import LayerNorm

# 1. Class that defines the hyperedge message passing layer
class HyperedgeGNNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim, hyperedge_dim, aggr='add'):
        super(HyperedgeGNNConv, self).__init__(aggr=aggr)

        self.lin_node = nn.Linear(in_channels, out_channels)
        self.lin_edge = nn.Linear(in_channels + edge_dim, out_channels)
        self.lin_hyperedge = nn.Linear(in_channels + hyperedge_dim, out_channels)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index, edge_attr, hyperedge_index, hyperedge_attr):
        # ---- Node & Edge transforms ----
        x = self.lin_node(x)
        edge_messages = self.propagate(edge_index, x=x, edge_attr=edge_attr, mode='edge')
        
        # ---- Build hyperedge edges (i->j, k->j), but REMOVE (j->j) ----
        # hyperedge_index shape = [3, num_hyperedges], representing (i, j, k) for each hyperedge
        num_hyperedges = hyperedge_index.size(1)

        # We only want i->j and k->j:
        sender_indices = torch.cat([hyperedge_index[0], hyperedge_index[2]])  # i, k
        receiver_indices = hyperedge_index[1].repeat(2)  # j repeated 2 times

        # Construct edge_index for hyperedges
        hyperedge_edge_index = torch.stack([sender_indices, receiver_indices], dim=0)

        # Repeat hyperedge attributes for i->j and k->j (2 edges), not 3
        hyperedge_edge_attr = hyperedge_attr.repeat(2, 1)

        # Propagate hyperedge messages
        hyperedge_messages = self.propagate(
            hyperedge_edge_index, x=x, edge_attr=hyperedge_edge_attr, mode='hyperedge'
        )

        # ---- Combine everything ----
        out = x + edge_messages + hyperedge_messages
        return self.relu(out)

    def message(self, x_j, edge_attr, mode):
        if mode == 'edge':
            msg_input = torch.cat([x_j, edge_attr], dim=-1)
            msg = self.lin_edge(msg_input)
        elif mode == 'hyperedge':
            msg_input = torch.cat([x_j, edge_attr], dim=-1)
            msg = self.lin_hyperedge(msg_input)
        else:
            raise ValueError("Invalid mode. Use 'edge' or 'hyperedge'.")
        return msg

# 2. Class that defines the shape of the GNN encoder using the hyperedge message passing layer
class GNNWithHyperedges(nn.Module):
    def __init__(
        self,
        num_node_features,
        num_edge_features,
        num_hyperedge_features,
        hidden_dim=128,
        embedding_dim=64,
        num_layers=4,
        dropout_rate=0.1,
    ):
        super(GNNWithHyperedges, self).__init__()

        self.num_layers = num_layers
        self.dropout_rate = dropout_rate

        self.node_embedding = nn.Linear(num_node_features, hidden_dim)
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        for _ in range(num_layers):
            conv = HyperedgeGNNConv(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                edge_dim=num_edge_features,
                hyperedge_dim=num_hyperedge_features,
                aggr='mean',
            )
            self.convs.append(conv)
            self.norms.append(LayerNorm(hidden_dim))
        
        # Global Attention Pooling
        self.attention_pool = GlobalAttention(gate_nn=nn.Linear(hidden_dim, 1))

        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(p=self.dropout_rate),
            nn.Linear(embedding_dim, embedding_dim),
        )

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        hyperedge_index = data.hyperedge_index
        hyperedge_attr = data.hyperedge_attr

        batch = data.batch if hasattr(data, 'batch') else torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        x = self.node_embedding(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout_rate, training=self.training)

        for conv, norm in zip(self.convs, self.norms):
            x_residual = x
            x = conv(x, edge_index, edge_attr, hyperedge_index, hyperedge_attr)
            x = norm(x)
            x = F.relu(x) + x_residual
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Apply Global Attention pooling instead of mean pooling
        x = self.attention_pool(x, batch)

        return self.fc(x)


In [None]:
# Loss Function: uses the InfoNCE loss function

import torch
import torch.nn.functional as F

def info_nce_loss(z1, z2, temperature=0.07):
    """
    Computes the InfoNCE (NT-Xent) loss for two batches of embeddings z1, z2.
    Each batch has size N, so total 2N embeddings.
    
    Args:
        z1: Tensor of shape (N, d), embeddings for batch 1
        z2: Tensor of shape (N, d), embeddings for batch 2
        temperature: Softmax temperature (float)
    Returns:
        A scalar tensor representing the contrastive loss.
    """
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    z = torch.cat([z1, z2], dim=0) 
    
    sim_matrix = torch.matmul(z, z.t())

    sim_matrix = sim_matrix / temperature
    
    N = z1.shape[0]

    pos_indices = torch.arange(N, 2*N)
    neg_indices = torch.arange(0, N)

    pos_index = torch.cat([pos_indices, neg_indices], dim=0).to(z.device)

    labels = pos_index

    mask = torch.eye(2*N, dtype=torch.bool, device=z.device)
    sim_matrix = sim_matrix.masked_fill(mask, float('-inf'))
    
    loss = F.cross_entropy(sim_matrix, labels)
    return loss


In [None]:
# Data Loader: splits the graph lists into training and validation sets with an 80/20 split using a fixed seed for reproducibility

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import DataLoader
import numpy as np

# Calculates training and validation set sizes
NUM_TRAINING = int(len(graph_list_set_1) * 0.8)  # 80% for training
NUM_VALIDATE = len(graph_list_set_1) - NUM_TRAINING  # Remaining for validation

# Set random seed for reproducibility
fixed_seed = 42
np.random.seed(fixed_seed)
indices = np.random.permutation(len(graph_list_set_1))

# Shuffle the graph lists
shuffled_graph_list_set_1 = [graph_list_set_1[i] for i in indices]
shuffled_graph_list_set_2 = [graph_list_set_2[i] for i in indices]
shuffled_graph_list_set_3 = [graph_list_unperturbed_1[i] for i in indices]

# Split the datasets into training and validation sets
train_graph_list_set_1 = shuffled_graph_list_set_1[:NUM_TRAINING]
val_graph_list_set_1 = shuffled_graph_list_set_1[NUM_TRAINING:]

train_graph_list_set_2 = shuffled_graph_list_set_2[:NUM_TRAINING]
val_graph_list_set_2 = shuffled_graph_list_set_2[NUM_TRAINING:]

# Create DataLoaders for training and validation sets
train_loader_set_1 = DataLoader(train_graph_list_set_1, batch_size=BATCH_SIZE, shuffle=False)
val_loader_set_1 = DataLoader(val_graph_list_set_1, batch_size=BATCH_SIZE, shuffle=False)

train_loader_set_2 = DataLoader(train_graph_list_set_2, batch_size=BATCH_SIZE, shuffle=False)
val_loader_set_2 = DataLoader(val_graph_list_set_2, batch_size=BATCH_SIZE, shuffle=False)

print("Data loaders created for training and validation sets.")


In [None]:
# Trainer: initialises the encoder with random weights then trains using the training dataset. Prints loss for valdiation and training
# and saves the model with the best validation loss
# Note: you can stop this code cell from running at any point and it will still save the current best model

import os
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Directories
os.makedirs("Models", exist_ok=True)
os.makedirs("Plots and Visualisations/Training Plots", exist_ok=True)

# Print Training Setup
print(f"""
-----------training starting---------------
BATCH_SIZE: {BATCH_SIZE}
NUM_EPOCHS: {NUM_EPOCHS}
Learning Rate: {LEARNING_RATE}
Hidden Dimension: {HIDDEN_DIMENSION}
Embedding Dimension: {EMBEDDING_DIMENSION}
Temperature (InfoNCE): {TEMPERATURE}
""")

# 1. Initialise the Model
model = GNNWithHyperedges(
    num_node_features=graph_list_set_1[0].x.size(1),  # Adjust based on node features
    num_edge_features=graph_list_set_1[0].edge_attr.size(1),  # Adjust based on edge features
    num_hyperedge_features=graph_list_set_1[0].hyperedge_attr.size(1),  # Adjust based on hyperedge features
    hidden_dim=HIDDEN_DIMENSION,
    embedding_dim=EMBEDDING_DIMENSION,
    num_layers=LAYER_NUMBER,  # Customize as needed
    dropout_rate=DROPOUT_RATE,
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# 2. InfoNCE (NT-Xent) Loss Function
def info_nce_loss(z1, z2, temperature=0.07):
    """
    Computes the InfoNCE (NT-Xent) loss for two batches of embeddings z1, z2.
    Each batch has size N, so we get 2N embeddings total.
    """
    # Normalize embeddings
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    N = z1.size(0)
    # Concatenate along batch dimension: 2N x d
    z = torch.cat([z1, z2], dim=0)

    # Pairwise similarity (2N x 2N)
    sim_matrix = torch.matmul(z, z.t()) / temperature

    # Labels: each sample in [0..N-1] has positive at index i+N
    # in [N..2N-1], positive at index i-N
    pos_indices = torch.arange(N, 2*N)
    neg_indices = torch.arange(0, N)
    pos_index = torch.cat([pos_indices, neg_indices], dim=0).to(device)
    labels = pos_index

    # Mask out self-similarity
    mask = torch.eye(2*N, dtype=torch.bool, device=device)
    sim_matrix = sim_matrix.masked_fill(mask, float('-inf'))

    # Cross-entropy over rows
    loss = F.cross_entropy(sim_matrix, labels)
    return loss

# 3. Compute Loss Wrapper
def compute_loss(model, data1, data2):
    data1, data2 = data1.to(device), data2.to(device)
    embedding1 = model(data1)
    embedding2 = model(data2)
    return info_nce_loss(embedding1, embedding2, temperature=TEMPERATURE)

# 4. Training for One Epoch
def train_one_epoch(model, loader_set_1, loader_set_2, optimizer):
    model.train()
    total_loss = 0.0

    for data1, data2 in zip(loader_set_1, loader_set_2):
        optimizer.zero_grad()
        loss = compute_loss(model, data1, data2)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(loader_set_1)
    return avg_train_loss

# 5. Validation 
def validate_model(model, loader_set_1, loader_set_2):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for data1, data2 in zip(loader_set_1, loader_set_2):
            loss = compute_loss(model, data1, data2)
            total_loss += loss.item()

    avg_val_loss = total_loss / len(loader_set_1)
    return avg_val_loss

# 6. Full Training Loop
def train_model(model, train_loader_1, train_loader_2, val_loader_1, val_loader_2, 
                optimizer, num_epochs):
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')  # Initialize best validation loss

    try:
        for epoch in range(num_epochs):
            # Training
            avg_train_loss = train_one_epoch(model, train_loader_1, train_loader_2, optimizer)
            train_losses.append(avg_train_loss)

            # Validation
            avg_val_loss = validate_model(model, val_loader_1, val_loader_2)
            val_losses.append(avg_val_loss)

            # Check best validation loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), f'Models/{MODEL_NAME}.pth')
                print(f"New best model saved with validation loss: {best_val_loss:.4f}")

            print(f"Epoch [{epoch + 1}/{num_epochs}], "
                  f"Training Loss: {avg_train_loss:.4f}, "
                  f"Validation Loss: {avg_val_loss:.4f}")

    except KeyboardInterrupt:
        print("Training interrupted by user. Best model saved so far.")

    return train_losses, val_losses

# 7. Plotting 
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"Plots and Visualisations/Training Plots/{MODEL_NAME}.png")
    plt.show()

# 8. Run Training
train_losses, val_losses = train_model(
    model, 
    train_loader_set_1, 
    train_loader_set_2, 
    val_loader_set_1, 
    val_loader_set_2, 
    optimizer, 
    NUM_EPOCHS
)

plot_losses(train_losses, val_losses)


In [None]:
# Final Embedding generator: takes the trained model and generates a list of final embeddings using the unperturbed graph list
# Saves the embeddings along with their corresponding labels

os.makedirs('Embeddings', exist_ok=True)
os.makedirs('Labels', exist_ok=True)

# 1. Function that runs each graph through the model extracting the final embedding and corresponding graph label, returns them both
def test_model(model, loader):
    model.eval()
    embeddings_list = []
    labels_list = []
    
    with torch.no_grad():
        for batch_idx, data in enumerate(loader):
            data = data.to(device)
            embeddings = model(data)
            embeddings_list.append(embeddings.cpu())
            labels_list.extend(data.label)

    # Concatenate embeddings into a single tensor
    all_embeddings = torch.cat(embeddings_list, dim=0)
    
    # Labels remain a list of strings
    return all_embeddings, labels_list

# 2. loads the needed data, generates the graph with labels and saves them

test_loader = DataLoader(graph_list_unperturbed_1, batch_size=BATCH_SIZE, shuffle=False)

# Call the test function to generate embeddings and labels for the entire test set
output_graph_embeddings, output_graph_labels = test_model(model, test_loader)

# Print the shapes of the final embeddings and labels
print(f"Generated embeddings shape: {output_graph_embeddings.shape}")
print(f"Sample labels: {output_graph_labels[:5]}")

embeddings_path = f'Embeddings/embeddings_{MODEL_NAME}.pt'
labels_path = f'Labels/labels_{MODEL_NAME}.pt'

torch.save(output_graph_embeddings, embeddings_path)

# Save the validation graph embeddings
torch.save(output_graph_labels, labels_path)

# Print confirmation message
print(f"embeddings saved to {embeddings_path}")
print(f"labels saved to {labels_path}")