<a href="https://colab.research.google.com/github/AbhiJeet70/PowerfulGNNs/blob/main/Attack_SAGN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary packages
!pip install torch-geometric
!pip install matplotlib
!pip install scikit-learn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
import networkx as nx
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid, Flickr
from sklearn.decomposition import PCA
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import networkx as nx
from sklearn.neighbors import NearestNeighbors
import pandas as pd
from torch_geometric.utils import to_networkx
from torch_geometric.nn.models.autoencoder import GAE
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import pandas as pd
import gc
from torch_geometric.utils import from_networkx


import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed()

# Check if GPU is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


# Load datasets
def load_dataset(dataset_name):
    if dataset_name in ["Cora", "PubMed", "CiteSeer"]:
        dataset = Planetoid(root=f"./data/{dataset_name}", name=dataset_name)
    elif dataset_name == "Flickr":
        dataset = Flickr(root="./data/Flickr")
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    return dataset


import numpy as np
import torch

def split_dataset(data, test_size=0.2, val_size=0.1, seed=None):
    """
    Splits the dataset into train, validation, and test sets.
    Additionally creates target and clean test masks for attack evaluation.

    Args:
        data: PyTorch Geometric Data object.
        test_size: Fraction of nodes to use for the test set.
        val_size: Fraction of nodes to use for the validation set.
        seed: Random seed for reproducibility.

    Returns:
        data: PyTorch Geometric Data object with updated masks.
    """
    num_nodes = data.num_nodes
    indices = np.arange(num_nodes)

    if seed is not None:
        np.random.seed(seed)

    np.random.shuffle(indices)

    # Calculate split sizes
    num_test = int(test_size * num_nodes)
    num_val = int(val_size * num_nodes)
    num_train = num_nodes - num_test - num_val
    num_target = int(0.1 * num_nodes)  # Half of 20%

    assert num_train > 0, "Training set size is too small!"
    assert num_val > 0, "Validation set size is too small!"
    assert num_test > 0, "Test set size is too small!"
    assert num_target > 0, "Target mask size is too small!"

    device = data.x.device  # Infer device from data

    # Initialize masks
    train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
    target_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
    clean_test_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)

    # Assign indices to masks
    train_mask[indices[:num_train]] = True
    val_mask[indices[num_train:num_train + num_val]] = True
    test_mask[indices[num_train + num_val:]] = True
    target_mask[indices[num_train + num_val:num_train + num_val + num_target]] = True
    clean_test_mask[indices[num_train + num_val + num_target:]] = True

    # Assign masks to data
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    data.target_mask = target_mask
    data.clean_test_mask = clean_test_mask

    # Validation checks
    assert not (train_mask & val_mask).any(), "Train and validation masks overlap!"
    assert not (train_mask & test_mask).any(), "Train and test masks overlap!"
    assert not (val_mask & test_mask).any(), "Validation and test masks overlap!"

    return data


# Select nodes to poison based on high-centrality (degree centrality) for a stronger impact
def select_high_centrality_nodes(data, num_nodes_to_select):
    graph = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    graph.add_edges_from(edge_index.T)
    centrality = nx.degree_centrality(graph)
    sorted_nodes = sorted(centrality, key=centrality.get, reverse=True)
    return torch.tensor(sorted_nodes[:num_nodes_to_select], dtype=torch.long).to(device)


class TriggerGenerator(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TriggerGenerator, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),  # Reduce dimensions
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim)  # Restore to input_dim
        )

    def forward(self, x):
        return self.mlp(x)

class GAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(GAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, x, edge_index):
        return self.encoder(x, edge_index)

    def decode(self, z, edge_index):
        return self.decoder(z, edge_index)

class OODDetector(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(OODDetector, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(input_dim, hidden_dim, latent_dim)
        self.gae = GAE(self.encoder, self.decoder)

    def forward(self, x, edge_index):
        z = self.gae.encode(x, edge_index)
        return z

    def reconstruct(self, z, edge_index):
        return self.decoder(z, edge_index)

    def reconstruction_loss(self, x, edge_index):
        z = self.gae.encode(x, edge_index)
        reconstructed = self.reconstruct(z, edge_index)
        return F.mse_loss(reconstructed, x)

    def detect_ood(self, x, edge_index, threshold):
        loss = self.reconstruction_loss(x, edge_index)
        return loss > threshold

class Encoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, latent_dim)

    def forward(self, x, edge_index):
        z = F.relu(self.conv1(x, edge_index))
        z = self.conv2(z, edge_index)
        return z

class Decoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Decoder, self).__init__()
        self.conv1 = GCNConv(latent_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, input_dim)

    def forward(self, z, edge_index):
        x = F.relu(self.conv1(z, edge_index))
        x = self.conv2(x, edge_index)
        return x



def train_ood_detector(ood_detector, data, optimizer, epochs=50):
    ood_detector.train()
    for epoch in range(epochs):
        optimizer.zero_grad()  # Clear the gradients

        # Forward pass
        z = ood_detector(data.x, data.edge_index)

        # Reconstruct data using the latent embedding
        reconstructed_x = ood_detector.reconstruct(z, data.edge_index)

        # Use only the training mask to compute reconstruction loss
        loss = F.mse_loss(reconstructed_x[data.train_mask], data.x[data.train_mask])
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Reconstruction Loss: {loss.item():.4f}")



def train_with_poisoned_data(model, data, optimizer, poisoned_nodes, trigger_gen, attack, ood_detector=None, alpha=0.7, early_stopping=False):
    # Apply trigger injection
    data_poisoned = inject_trigger(data, poisoned_nodes, attack, trigger_gen, ood_detector, alpha)

    # Training loop
    model.train()
    for epoch in range(100):
        optimizer.zero_grad()

        # Forward pass
        out = model(data_poisoned.x, data_poisoned.edge_index)

        # Calculate loss
        loss = F.cross_entropy(out[data_poisoned.train_mask], data_poisoned.y[data_poisoned.train_mask])

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Optional logging
        if early_stopping and epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    return model, data_poisoned



from sklearn.cluster import KMeans

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

def select_diverse_nodes(data, num_nodes_to_select, num_clusters=None):
    if num_clusters is None:
        num_clusters = len(torch.unique(data.y))

    # Use GCN encoder to get node embeddings
    encoder = GCNEncoder(data.num_features, out_channels=16).to(data.x.device)  # Move encoder to the correct device
    encoder.eval()
    with torch.no_grad():
        embeddings = encoder(data.x, data.edge_index).to(data.x.device)  # Ensure embeddings are on the correct device

    # Perform K-means clustering
    kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(embeddings.cpu().numpy())  # Clustering runs on CPU
    labels = kmeans.labels_
    cluster_centers = torch.tensor(kmeans.cluster_centers_, device=data.x.device)  # Move cluster centers to device

    # Select nodes closest to the cluster centers
    selected_nodes = []
    for i in range(num_clusters):
        cluster_indices = torch.where(torch.tensor(labels, device=data.x.device) == i)[0]
        center = cluster_centers[i]
        distances = torch.norm(embeddings[cluster_indices] - center, dim=1)
        closest_node = cluster_indices[torch.argmin(distances)]
        selected_nodes.append(closest_node)

    # Calculate node degrees
    degree = torch.bincount(data.edge_index[0], minlength=data.num_nodes).to(data.x.device)
    high_degree_nodes = torch.topk(degree, len(selected_nodes) // 2).indices

    # Combine diverse nodes and high-degree nodes
    combined_nodes = torch.cat([torch.tensor(selected_nodes, device=data.x.device), high_degree_nodes])
    unique_nodes = torch.unique(combined_nodes)[:num_nodes_to_select]

    return unique_nodes.to(data.x.device)  # Ensure the selected nodes are on the correct device

def inject_trigger(data, poisoned_nodes, attack_type, trigger_gen=None, ood_detector=None, alpha=0.7, trigger_size=5, trigger_density=0.5, input_dim=None):
    # Clone data to avoid overwriting the original graph
    data_poisoned = data.clone()
    device = data_poisoned.x.device

    if len(poisoned_nodes) == 0:
        raise ValueError("No poisoned nodes selected. Ensure 'poisoned_nodes' is populated and non-empty.")

    # Adjust trigger_size if it exceeds the number of poisoned nodes
    trigger_size = min(trigger_size, len(poisoned_nodes))

    if attack_type == 'SBA-Samp':
        # Subgraph-Based Attack - Random Sampling
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node].to(device) for node in poisoned_nodes[:trigger_size]]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ]).to(device)
        natural_features = avg_features + torch.randn_like(avg_features) * 0.02  # Add small randomness

        # Generate subgraph with realistic density
        G = nx.erdos_renyi_graph(trigger_size, trigger_density)
        trigger_edge_index = torch.tensor(list(G.edges), dtype=torch.long, device=device).t()

        # Connect poisoned nodes to the subgraph
        poisoned_edges = torch.stack((
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ))

        # Update graph structure and features
        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index, poisoned_edges], dim=1)
        data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]

    elif attack_type == 'SBA-Gen':
        # Subgraph-Based Attack - Gaussian
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node].to(device) for node in poisoned_nodes[:trigger_size]]
        feature_mean = data.x.mean(dim=0)
        feature_std = data.x.std(dim=0)

        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else feature_mean for nodes in connected_nodes
        ]).to(device)
        natural_features = avg_features + torch.normal(mean=0.0, std=0.03, size=avg_features.shape, device=device)

        trigger_edge_index = []
        for i in range(trigger_size):
            for j in range(i + 1, trigger_size):
                similarity = torch.exp(-torch.norm((natural_features[i] - natural_features[j]) / feature_std)**2)
                if similarity > torch.rand(1).item():
                    trigger_edge_index.append([i, j])

        trigger_edge_index = torch.tensor(trigger_edge_index, dtype=torch.long, device=device).t()
        if trigger_edge_index.numel() > 0:
            trigger_edge_index += poisoned_nodes[:trigger_size].unsqueeze(0)

        poisoned_edges = torch.stack((
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ))

        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index, poisoned_edges], dim=1)
        data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]

    elif attack_type == 'DPGBA':
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node].to(device) for node in poisoned_nodes]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ]).to(device)

        if trigger_gen is None:
            raise ValueError("Trigger generator is required for the DPGBA attack.")

        with torch.no_grad():
            trigger_features = trigger_gen(avg_features)

        if trigger_features.shape[1] != data.x.shape[1]:
            raise ValueError(f"Trigger feature dimension mismatch: {trigger_features.shape[1]} vs {data.x.shape[1]}")

        node_alphas = torch.rand(len(poisoned_nodes), device=device) * 0.3 + 0.5
        distribution_preserved_features = (
            node_alphas.unsqueeze(1) * data.x[poisoned_nodes]
            + (1 - node_alphas.unsqueeze(1)) * trigger_features
        )

        data_poisoned.x[poisoned_nodes] = distribution_preserved_features

    elif attack_type == 'GTA':
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node].to(device) for node in poisoned_nodes]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ]).to(device)
        trigger_features = avg_features + torch.randn_like(avg_features) * 0.05
        data_poisoned.x[poisoned_nodes] = trigger_features

    elif attack_type == 'UGBA':
        num_clusters = len(torch.unique(data_poisoned.y))
        diverse_nodes = select_diverse_nodes(data_poisoned, len(poisoned_nodes)).to(device)
        connected_nodes = [
            data_poisoned.edge_index[0][data_poisoned.edge_index[1] == node].to(device)
            for node in diverse_nodes
        ]

        avg_features = torch.stack([
            data_poisoned.x[nodes].mean(dim=0) if len(nodes) > 0 else data_poisoned.x.mean(dim=0)
            for nodes in connected_nodes
        ]).to(device)
        refined_trigger_features = avg_features + torch.normal(mean=2.0, std=0.5, size=avg_features.shape, device=device)
        data_poisoned.x[diverse_nodes] = refined_trigger_features

        new_edges = []
        for i in range(len(diverse_nodes)):
            node = diverse_nodes[i]
            neighbor = connected_nodes[i][0] if len(connected_nodes[i]) > 0 else diverse_nodes[(i + 1) % len(diverse_nodes)]
            new_edges.append([node, neighbor])

        new_edges = torch.tensor(new_edges, dtype=torch.long, device=device).t()
        data_poisoned.edge_index = torch.cat([data_poisoned.edge_index, new_edges], dim=1)

    return data_poisoned


def dominant_set_clustering(data, threshold=0.7, use_pca=True, pca_components=10):
    """
    Applies a simplified outlier detection framework using a combination of K-Means clustering and distance-based heuristics.

    Parameters:
    - data: PyG data object representing the graph.
    - threshold: Quantile threshold for identifying outliers based on cluster distances.
    - use_pca: Whether to use PCA for dimensionality reduction.
    - pca_components: Number of PCA components to use if PCA is applied.

    Returns:
    - pruned_nodes: Set of nodes identified as outliers.
    - data: Updated PyG data object with modified features and labels for outliers.
    """
    # Step 1: Determine the number of clusters based on the number of classes
    n_clusters = len(data.y.unique())  # Number of unique classes in the dataset

    # Step 2: Dimensionality reduction using PCA (optional)
    node_features = data.x.detach().cpu().numpy()
    if use_pca and node_features.shape[1] > pca_components:
        pca = PCA(n_components=pca_components)
        node_features = pca.fit_transform(node_features)

    # Step 3: K-Means Clustering to identify clusters and potential outliers
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(node_features)
    cluster_labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_

    # Calculate distances to cluster centers
    distances = np.linalg.norm(node_features - cluster_centers[cluster_labels], axis=1)

    # Identify outlier candidates based on distance threshold
    distance_threshold = np.percentile(distances, 100 * threshold)
    outlier_candidates = np.where(distances > distance_threshold)[0]

    # Step 4: Update data to reflect removal of outlier influence
    pruned_nodes = set(outlier_candidates)
    if len(pruned_nodes) > 0:
        outliers = torch.tensor(list(pruned_nodes), dtype=torch.long, device=data.x.device)

        # Assign an invalid label (-1) to outlier nodes to discard them during training
        data.y[outliers] = -1

        # Replace the features of outliers with the average feature value to reduce their impact
        data.x[outliers] = data.x.mean(dim=0).to(data.x.device)

    return pruned_nodes, data

def defense_prune_edges(data, quantile_threshold=0.9):
    """
    Prunes edges based on adaptive cosine similarity between node features.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile to determine pruning threshold (e.g., 0.9 means pruning edges in the top 10% dissimilar).

    Returns:
    - data: Updated PyG data object with pruned edges.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Adaptive threshold based on quantile of similarity distribution
    similarity_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Keep edges with cosine similarity above the threshold
    pruned_mask = cosine_similarities >= similarity_threshold
    pruned_edges = edge_index[:, pruned_mask]

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    return data



def defense_prune_and_discard_labels(data, quantile_threshold=0.2):
    """
    Prunes edges based on adaptive cosine similarity and discards labels of nodes connected by pruned edges selectively.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile threshold for cosine similarity pruning (e.g., 0.2 means pruning edges in the bottom 20%).

    Returns:
    - data: Updated PyG data object with pruned edges and selectively discarded labels.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features using PyTorch
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Use quantile to determine adaptive threshold for pruning
    adaptive_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Mask edges with similarity below the adaptive threshold
    pruned_mask = cosine_similarities < adaptive_threshold
    pruned_edges = edge_index[:, ~pruned_mask]  # Retain edges that are above the threshold

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    # Selectively discard labels of nodes connected by many pruned edges
    pruned_src, pruned_dst = edge_index[:, pruned_mask]
    pruned_nodes_count = torch.bincount(torch.cat([pruned_src, pruned_dst]), minlength=data.num_nodes)

    # Only discard labels if the node has a high count of pruned edges
    threshold_count = int(torch.median(pruned_nodes_count).item())  # Use median count as a threshold
    nodes_to_discard = torch.where(pruned_nodes_count > threshold_count)[0]

    data.y[nodes_to_discard] = -1  # Use -1 to represent discarded labels

    return data

# Compute ASR and Clean Accuracy (using .detach() to avoid retaining computation graph)
def compute_metrics(model, data, poisoned_nodes):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index).detach()
        _, pred = out.max(dim=1)
        asr = (pred[poisoned_nodes] == data.y[poisoned_nodes]).sum().item() / len(poisoned_nodes) * 100
        clean_acc = accuracy_score(data.y[data.test_mask].cpu(), pred[data.test_mask].cpu()) * 100
    return asr, clean_acc



# Visualization Function
# Visualize PCA for Attacks
# Added function to visualize PCA projections of node embeddings for different attacks
def visualize_pca_for_attacks(attack_embeddings_dict):
    pca = PCA(n_components=2)
    plt.figure(figsize=(20, 10))

    for i, (attack, attack_data) in enumerate(attack_embeddings_dict.items(), 1):
        embeddings = attack_data['data'].detach().cpu().numpy()
        poisoned_nodes = attack_data['poisoned_nodes'].detach().cpu().numpy()

        # Apply PCA to the node embeddings
        pca_result = pca.fit_transform(embeddings)

        # Create masks for clean and poisoned nodes
        clean_mask = np.ones(embeddings.shape[0], dtype=bool)
        clean_mask[poisoned_nodes] = False

        # Extract clean and poisoned node embeddings after PCA
        clean_embeddings = pca_result[clean_mask]
        poisoned_embeddings = pca_result[~clean_mask]

        # Plotting clean and poisoned nodes
        plt.subplot(2, 3, i)
        plt.scatter(clean_embeddings[:, 0], clean_embeddings[:, 1], s=10, alpha=0.5, label='Clean Nodes', c='b')
        plt.scatter(poisoned_embeddings[:, 0], poisoned_embeddings[:, 1], s=10, alpha=0.8, label='Poisoned Nodes', c='r')
        plt.title(f'PCA Visualization for {attack}')
        plt.xlabel('PCA Component 1')
        plt.ylabel('PCA Component 2')
        plt.legend()

    plt.tight_layout()
    plt.show()



import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid


class SubstructureAwareGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SubstructureAwareGNN, self).__init__()
        self.ego_gnn = MessagePassingLayer(in_channels, hidden_channels)
        self.cut_gnn = MessagePassingLayer(in_channels, hidden_channels)
        self.global_encoder = nn.Linear(in_channels, hidden_channels)
        self.final_fc = nn.Linear(3 * hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # Extract subgraphs
        ego_features = self.extract_ego_subgraph(x, edge_index)
        cut_features = self.extract_cut_subgraph(x, edge_index)

        # Apply GNN layers
        ego_encoded = self.ego_gnn(ego_features, edge_index)
        cut_encoded = self.cut_gnn(cut_features, edge_index)
        global_encoded = self.global_encoder(x)

        # Concatenate and pass through the final layer
        combined_features = torch.cat([ego_encoded, cut_encoded, global_encoded], dim=-1)
        output = self.final_fc(combined_features)
        return F.log_softmax(output, dim=1)

    def extract_ego_subgraph(self, x, edge_index):
        k = 2  # Number of hops
        num_nodes = x.size(0)
        ego_features = torch.zeros_like(x, device=x.device)  # Initialize features

        for node_idx in range(num_nodes):
            # Extract k-hop subgraph
            subset, _, _, _ = k_hop_subgraph(node_idx, k, edge_index, relabel_nodes=False)

            # Compute mean of neighbor features
            if subset.numel() > 0:
                ego_features[node_idx] = x[subset].mean(dim=0)
            else:
                ego_features[node_idx] = x[node_idx]  # Fallback to node's own features

        return ego_features

    def extract_cut_subgraph(self, x, edge_index):
        # Calculate edge betweenness centrality approximation
        edge_weights = torch.rand(edge_index.size(1), device=edge_index.device)  # Replace with actual edge weights
        num_edges_to_remove = edge_weights.size(0) // 2

        # Sort edges by weights and mask the top ones
        _, sorted_indices = edge_weights.sort(descending=True)
        mask = torch.ones(edge_index.size(1), dtype=torch.bool, device=edge_index.device)
        mask[sorted_indices[:num_edges_to_remove]] = False
        new_edge_index = edge_index[:, mask]

        # Aggregate features for the remaining subgraph
        num_nodes = x.size(0)
        cut_features = torch.zeros_like(x, device=x.device)

        for node_idx in range(num_nodes):
            # Find neighbors in the new edge_index
            neighbors = new_edge_index[1][new_edge_index[0] == node_idx]

            # Compute mean of neighbor features
            if neighbors.numel() > 0:
                cut_features[node_idx] = x[neighbors].mean(dim=0)
            else:
                cut_features[node_idx] = x[node_idx]  # Fallback to node's own features

        return cut_features


class MessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MessagePassingLayer, self).__init__(aggr="add")
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=self.linear(x))

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return F.relu(aggr_out)

def run_sagn_attacks():
    datasets = ["Cora", "PubMed", "CiteSeer"]
    results_summary = []
    dataset_budgets = {'Cora': 10, 'PubMed': 30, 'CiteSeer': 20}  # Poisoning budgets

    for dataset_name in datasets:
        try:
            dataset = load_dataset(dataset_name)
            data = dataset[0].to(device)
            input_dim = data.num_features
            output_dim = dataset.num_classes if isinstance(dataset.num_classes, int) else dataset.num_classes[0]
            data = split_dataset(data)
            poisoned_node_budget = dataset_budgets.get(dataset_name, 10)

            print(f"Training SAGN model for dataset {dataset_name}.")
            for attack in [None, 'SBA-Samp', 'SBA-Gen', 'GTA', 'UGBA', 'DPGBA']:
                try:
                    # Initialize SAGN model and optimizer
                    model = SubstructureAwareGNN(in_channels=input_dim, hidden_channels=64, out_channels=output_dim).to(device)
                    optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

                    if attack is None:
                        # Train baseline model
                        model.train()
                        for epoch in range(100):
                            optimizer.zero_grad()
                            out = model(data.x, data.edge_index)
                            loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
                            loss.backward()
                            optimizer.step()
                            print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

                        # Evaluate baseline accuracy
                        model.eval()
                        with torch.no_grad():
                            out = model(data.x, data.edge_index)
                            predictions = out.argmax(dim=1)
                            baseline_acc = (predictions[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
                        print(f"Dataset: {dataset_name}, Model: SAGN, Baseline Accuracy: {baseline_acc * 100:.2f}%")

                        results_summary.append(f"Dataset: {dataset_name}, Model: SAGN, Attack: None, Defense: None - ASR: N/A, Clean Accuracy: {baseline_acc * 100:.2f}%")
                    else:
                        # Perform attack evaluation
                        trigger_gen = TriggerGenerator(input_dim=data.num_features, hidden_dim=64).to(device)
                        ood_detector = OODDetector(input_dim=input_dim, hidden_dim=64, latent_dim=16).to(device)
                        ood_optimizer = torch.optim.Adam(ood_detector.parameters(), lr=0.001)
                        train_ood_detector(ood_detector, data, ood_optimizer, epochs=10)

                        poisoned_nodes = select_high_centrality_nodes(data, poisoned_node_budget)

                        trained_model, data_poisoned = train_with_poisoned_data(
                            model=model,
                            data=data,
                            optimizer=optimizer,
                            poisoned_nodes=poisoned_nodes,
                            trigger_gen=trigger_gen,
                            attack=attack,
                            ood_detector=ood_detector,
                            alpha=0.7,
                            early_stopping=True
                        )

                        asr, clean_acc = compute_metrics(trained_model, data_poisoned, poisoned_nodes)
                        results_summary.append(f"Dataset: {dataset_name}, Model: SAGN, Attack: {attack}, Defense: None - ASR: {asr:.2f}%, Clean Accuracy: {clean_acc:.2f}%")
                        print(f"Dataset: {dataset_name}, Model: SAGN, Attack: {attack}, Defense: None - ASR: {asr:.2f}%, Clean Accuracy: {clean_acc:.2f}%")
                except Exception as e:
                    print(f"Error during attack {attack} on {dataset_name} with SAGN: {e}")

        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {e}")

    # Save results
    results_df = pd.DataFrame(results_summary, columns=["Summary"])
    print("\nSummary of Results:")
    for result in results_summary:
        print(result)
    results_df.to_csv("sagn_backdoor_attack_results_summary.csv", index=False)

run_sagn_attacks()


Using device: cuda
Training SAGN model for dataset Cora.
Epoch 1, Loss: 1.9323
Epoch 2, Loss: 1.8058
Epoch 3, Loss: 1.6905
Epoch 4, Loss: 1.5731
Epoch 5, Loss: 1.4539
Epoch 6, Loss: 1.3397
Epoch 7, Loss: 1.2272
Epoch 8, Loss: 1.1232
Epoch 9, Loss: 1.0292
Epoch 10, Loss: 0.9395
Epoch 11, Loss: 0.8584
Epoch 12, Loss: 0.7816
Epoch 13, Loss: 0.7182
Epoch 14, Loss: 0.6564
Epoch 15, Loss: 0.6051
Epoch 16, Loss: 0.5578
Epoch 17, Loss: 0.5111
Epoch 18, Loss: 0.4768
Epoch 19, Loss: 0.4383
Epoch 20, Loss: 0.4071
Epoch 21, Loss: 0.3866
Epoch 22, Loss: 0.3547
Epoch 23, Loss: 0.3292
Epoch 24, Loss: 0.3117
Epoch 25, Loss: 0.2967
Epoch 26, Loss: 0.2766
Epoch 27, Loss: 0.2608
Epoch 28, Loss: 0.2384
Epoch 29, Loss: 0.2294
Epoch 30, Loss: 0.2166
Epoch 31, Loss: 0.2059
Epoch 32, Loss: 0.1887
Epoch 33, Loss: 0.1787
Epoch 34, Loss: 0.1739
Epoch 35, Loss: 0.1596
Epoch 36, Loss: 0.1464
Epoch 37, Loss: 0.1405
Epoch 38, Loss: 0.1322
Epoch 39, Loss: 0.1271
Epoch 40, Loss: 0.1223
Epoch 41, Loss: 0.1109
Epoch 42,

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!


Training SAGN model for dataset CiteSeer.
Epoch 1, Loss: 1.7962
Epoch 2, Loss: 1.6370
Epoch 3, Loss: 1.4876
Epoch 4, Loss: 1.3490
Epoch 5, Loss: 1.2271
Epoch 6, Loss: 1.1205
Epoch 7, Loss: 1.0242
Epoch 8, Loss: 0.9387
Epoch 9, Loss: 0.8607
Epoch 10, Loss: 0.7921
Epoch 11, Loss: 0.7326
Epoch 12, Loss: 0.6752
Epoch 13, Loss: 0.6256
Epoch 14, Loss: 0.5829
Epoch 15, Loss: 0.5422
Epoch 16, Loss: 0.5081
Epoch 17, Loss: 0.4709
Epoch 18, Loss: 0.4406
Epoch 19, Loss: 0.4105
Epoch 20, Loss: 0.3853
Epoch 21, Loss: 0.3565
Epoch 22, Loss: 0.3322
Epoch 23, Loss: 0.3055
Epoch 24, Loss: 0.2826
Epoch 25, Loss: 0.2607
Epoch 26, Loss: 0.2394
Epoch 27, Loss: 0.2196
Epoch 28, Loss: 0.2030
Epoch 29, Loss: 0.1859
Epoch 30, Loss: 0.1694
Epoch 31, Loss: 0.1546
Epoch 32, Loss: 0.1437
Epoch 33, Loss: 0.1293
Epoch 34, Loss: 0.1205
Epoch 35, Loss: 0.1097
Epoch 36, Loss: 0.0989
Epoch 37, Loss: 0.0920
Epoch 38, Loss: 0.0831
Epoch 39, Loss: 0.0767
Epoch 40, Loss: 0.0699
Epoch 41, Loss: 0.0659
Epoch 42, Loss: 0.0608
E