In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from transformers import GraphormerForGraphClassification, GraphormerConfig
from torch.optim import AdamW
from tqdm import tqdm
import networkx as nx
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
import pandas as pd
from collections import Counter
import random

# Function to set random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # For deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class GraphDataset(Dataset):
    def __init__(self, hamiltonian_dir, non_hamiltonian_dir):
        self.graphs = []
        self.labels = []

        # Load Hamiltonian graphs (label = 1)
        for filename in os.listdir(hamiltonian_dir):
            if filename.endswith('.npy'):
                adj_matrix = np.load(os.path.join(hamiltonian_dir, filename))
                self.graphs.append(adj_matrix)
                self.labels.append(1)

        # Load Non-Hamiltonian graphs (label = 0)
        for filename in os.listdir(non_hamiltonian_dir):
            if filename.endswith('.npy'):
                adj_matrix = np.load(os.path.join(non_hamiltonian_dir, filename))
                self.graphs.append(adj_matrix)
                self.labels.append(0)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        adj_matrix = self.graphs[idx]
        label = self.labels[idx]
        return {
            'adj_matrix': torch.tensor(adj_matrix, dtype=torch.float),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Function to compute metrics like accuracy and F1-score
def compute_metrics(preds, labels):
    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='macro')  # Use 'macro' for multi-class problems
    return accuracy, f1

def collate_fn(batch):
    adj_matrices = [item['adj_matrix'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)

    max_size = max(adj.shape[0] for adj in adj_matrices)
    num_heads = config.num_attention_heads
    max_dist = 20

    adj_padded, node_features, in_degree_list, out_degree_list, spatial_pos_list = [], [], [], [], []
    attn_bias_list, attn_edge_type_list, input_edges_list = [], [], []

    for adj in adj_matrices:
        size = adj.shape[0]
        pad_size = max_size - size

        adj_pad = torch.nn.functional.pad(adj, (0, pad_size, 0, pad_size), "constant", 0)
        adj_padded.append(adj_pad)

        G = nx.from_numpy_array(adj.numpy())

        degrees = np.array([degree for node, degree in G.degree()])
        node_feat = torch.tensor(degrees, dtype=torch.long).unsqueeze(1)
        node_feat_pad = torch.nn.functional.pad(node_feat, (0, 0, 0, pad_size), "constant", 0)
        node_features.append(node_feat_pad)

        in_degree_pad = np.pad(degrees, (0, pad_size), 'constant')
        out_degree_pad = np.pad(degrees, (0, pad_size), 'constant')
        in_degree_list.append(torch.tensor(in_degree_pad, dtype=torch.long))
        out_degree_list.append(torch.tensor(out_degree_pad, dtype=torch.long))

        spatial_pos = np.zeros((size, size), dtype=np.int64)
        lengths = dict(nx.all_pairs_shortest_path_length(G, cutoff=max_dist))
        for i in range(size):
            for j in range(size):
                spatial_pos[i, j] = lengths[i][j] if j in lengths[i] else max_dist
        spatial_pos_pad = np.pad(spatial_pos, ((0, pad_size), (0, pad_size)), 'constant', constant_values=max_dist)
        spatial_pos_list.append(torch.tensor(spatial_pos_pad, dtype=torch.long))

        edge_type = adj.numpy().astype(np.int64)
        edge_type_pad = np.pad(edge_type, ((0, pad_size), (0, pad_size)), 'constant', constant_values=0)
        attn_edge_type_list.append(torch.tensor(edge_type_pad, dtype=torch.long))

        # Construct input_edges based on shortest paths
        input_edges = np.zeros((size, size, max_dist), dtype=np.int64)
        for i in range(size):
            for j in range(size):
                if j in lengths[i]:
                    dist = lengths[i][j]
                    if dist < max_dist:
                        input_edges[i, j, dist] = 1
        input_edges = np.repeat(input_edges[:, :, :, np.newaxis], num_heads, axis=3)
        input_edges_pad = np.pad(input_edges, ((0, pad_size), (0, pad_size), (0, 0), (0, 0)), 'constant', constant_values=0)
        input_edges_list.append(torch.tensor(input_edges_pad, dtype=torch.long))

        attn_bias = np.zeros((max_size + 1, max_size + 1), dtype=np.float32)
        attn_bias_list.append(torch.tensor(attn_bias))

    adj_padded = torch.stack(adj_padded)
    node_features = torch.stack(node_features)
    in_degree = torch.stack(in_degree_list)
    out_degree = torch.stack(out_degree_list)
    spatial_pos = torch.stack(spatial_pos_list)
    attn_bias = torch.stack(attn_bias_list)
    attn_edge_type = torch.stack(attn_edge_type_list)
    input_edges = torch.stack(input_edges_list)

    return {
        'input_nodes': node_features,
        'attn_bias': attn_bias,
        'in_degree': in_degree,
        'out_degree': out_degree,
        'spatial_pos': spatial_pos,
        'attn_edge_type': attn_edge_type,
        'input_edges': input_edges,
        'labels': labels
    }


# Define the evaluation function
def evaluate(model, dataloader, device, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_nodes = batch['input_nodes'].to(device)
            attn_bias = batch['attn_bias'].to(device)
            in_degree = batch['in_degree'].to(device)
            out_degree = batch['out_degree'].to(device)
            spatial_pos = batch['spatial_pos'].to(device)
            attn_edge_type = batch['attn_edge_type'].to(device)
            input_edges = batch['input_edges'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_nodes=input_nodes,
                input_edges=input_edges,
                attn_bias=attn_bias,
                in_degree=in_degree,
                out_degree=out_degree,
                spatial_pos=spatial_pos,
                attn_edge_type=attn_edge_type
            )

            logits = outputs.logits
            loss = criterion(logits, labels)
            total_loss += loss.item()
            _, preds = torch.max(logits, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds)
            all_labels.extend(labels)

    accuracy = correct / total
    f1 = f1_score(torch.tensor(all_labels), torch.tensor(all_preds), average='macro')

    return total_loss / len(dataloader), accuracy, f1


# Define the training function
def train(model, train_loader, val_loader, test_loader, device, optimizer, criterion, num_epochs, patience):
    best_val_loss = float('inf')
    best_accuracy = 0  # Initialize best accuracy
    early_stop_counter = 0
    best_model_path = './best_graphormer_model.pth'
    columns = ['epoch', 'train_loss', 'train_acc', 'train_f1', 'val_loss', 'val_acc', 'val_f1', 'test_loss', 'test_acc', 'test_f1']
    results = []

    # Initial evaluation before training
    initial_train_loss, initial_train_acc, initial_train_f1 = evaluate(model, train_loader, device, criterion)
    initial_val_loss, initial_val_acc, initial_val_f1 = evaluate(model, val_loader, device, criterion)
    initial_test_loss, initial_test_acc, initial_test_f1 = evaluate(model, test_loader, device, criterion)

    # Log initial metrics before the first epoch
    results.append([0, initial_train_loss, initial_train_acc, initial_train_f1, initial_val_loss, initial_val_acc, initial_val_f1, initial_test_loss, initial_test_acc, initial_test_f1])

    print(f"Initial Metrics - Train Loss: {initial_train_loss:.4f}, Val Loss: {initial_val_loss:.4f}, Test Loss: {initial_test_loss:.4f}")

    # Save initial results to CSV before training
    df = pd.DataFrame(results, columns=columns)
    df.to_csv('graphormer_training_results.csv', index=False)

    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss, correct, total = 0, 0, 0
        all_preds, all_labels = [], []

        for batch in tqdm(train_loader, desc=f'Epoch {epoch}/{num_epochs}'):
            input_nodes = batch['input_nodes'].to(device)
            attn_bias = batch['attn_bias'].to(device)
            in_degree = batch['in_degree'].to(device)
            out_degree = batch['out_degree'].to(device)
            spatial_pos = batch['spatial_pos'].to(device)
            attn_edge_type = batch['attn_edge_type'].to(device)
            input_edges = batch['input_edges'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            outputs = model(
                input_nodes=input_nodes,
                input_edges=input_edges,
                attn_bias=attn_bias,
                in_degree=in_degree,
                out_degree=out_degree,
                spatial_pos=spatial_pos,
                attn_edge_type=attn_edge_type
            )

            logits = outputs.logits
            loss = criterion(logits, labels)
            loss.backward()
            
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(logits, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds)
            all_labels.extend(labels)

        train_acc, train_f1 = compute_metrics(torch.tensor(all_preds), torch.tensor(all_labels))
        train_loss /= len(train_loader)

        # Evaluate on validation set
        val_loss, val_acc, val_f1 = evaluate(model, val_loader, device, criterion)

        # Early stopping
        if val_acc > best_accuracy:
            best_accuracy = val_acc
            early_stop_counter = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved with validation accuracy: {val_acc:.2f}%")
        else:
            early_stop_counter += 1
            print(f"Epochs without improvement: {early_stop_counter}")
        
        if early_stop_counter >= patience:
            print("Early stopping triggered.")
            break

        # Test evaluation
        test_loss, test_acc, test_f1 = evaluate(model, test_loader, device, criterion)

        # Log results
        results.append([epoch, train_loss, train_acc, train_f1, val_loss, val_acc, val_f1, test_loss, test_acc, test_f1])

        # Save results to CSV after each epoch
        df = pd.DataFrame(results, columns=columns)
        df.to_csv('graphormer_training_results.csv', index=False)

        print(f"Epoch {epoch}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Test Loss: {test_loss:.4f}")

    print("Training complete. Best model saved as best_graphormer_model.pth")


# Set the random seed
set_seed(42)

# Directories containing the .npy files
# hamiltonian_dir = './hamiltonian_small_mat'
# non_hamiltonian_dir = './non_hamiltonian_small_mat'
hamiltonian_dir = './hamiltonian_small_mat'
non_hamiltonian_dir = './non_hamiltonian_small_mat'

# Initialize dataset
dataset = GraphDataset(hamiltonian_dir, non_hamiltonian_dir)
indices = list(range(len(dataset)))

# Specify the train_val size and test size
train_val_size = 100  # Example value, modify as per your need
test_size = 500

# Split dataset into train_val and test sets using train_test_split
train_val_indices, test_indices = train_test_split(indices, test_size=test_size, stratify=[dataset[i]['label'] for i in indices], random_state=41)

# Now, split train_val_indices into train and validation sets (80% train, 20% validation)
train_indices, val_indices = train_test_split(train_val_indices[:train_val_size], test_size=0.2, stratify=[dataset[i]['label'] for i in train_val_indices[:train_val_size]], random_state=41)

# Create subsets for train, validation, and test datasets
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

# Calculate the distribution of training, validation, and test sets
train_counter = Counter([dataset[i]['label'].item() for i in train_indices])
val_counter = Counter([dataset[i]['label'].item() for i in val_indices])
test_counter = Counter([dataset[i]['label'].item() for i in test_indices])

# Print the class distribution
print(f"Training Dataset: {train_counter[1]} True, {train_counter[0]} False")
print(f"Validation Dataset: {val_counter[1]} True, {val_counter[0]} False")
print(f"Test Dataset: {test_counter[1]} True, {test_counter[0]} False")

# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, collate_fn=collate_fn)

# Define the model configuration
max_nodes_in_dataset = max(adj.shape[0] for adj in dataset.graphs)
# Define a smaller model configuration
config = GraphormerConfig(
    num_classes=2,
    num_node_types=1,
    num_node_features=1,
    hidden_size=768,
    num_attention_heads=12,
    num_hidden_layers=4,
    max_nodes=max_nodes_in_dataset,
    multi_hop_max_dist=20,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
)

# Initialize the model
model = GraphormerForGraphClassification(config)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Optimizer and loss function
optimizer = AdamW(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

# Train the model
train(model, train_loader, val_loader, test_loader, device, optimizer, criterion, num_epochs=100, patience=10)
