<a href="https://colab.research.google.com/github/AbhiJeet70/PowerfulGNNs/blob/main/Attack_SUN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


# Install necessary packages
!pip install torch-geometric
!pip install matplotlib
!pip install scikit-learn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
import networkx as nx
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid, Flickr
from sklearn.decomposition import PCA
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import networkx as nx
from sklearn.neighbors import NearestNeighbors
import pandas as pd
from torch_geometric.utils import to_networkx
from torch_geometric.nn.models.autoencoder import GAE

import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed()

# Check if GPU is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


# Load datasets
def load_dataset(dataset_name):
    if dataset_name in ["Cora", "PubMed", "CiteSeer"]:
        dataset = Planetoid(root=f"./data/{dataset_name}", name=dataset_name)
    elif dataset_name == "Flickr":
        dataset = Flickr(root="./data/Flickr")
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    return dataset


# Split dataset into train/validation/test
# Updated to randomly mask out 20% of nodes, use 10% for labeled nodes, and 10% for validation
def split_dataset(data, test_size=0.2, val_size=0.1):
    num_nodes = data.num_nodes
    indices = np.arange(num_nodes)
    np.random.shuffle(indices)

    num_test = int(test_size * num_nodes)
    num_val = int(val_size * num_nodes)
    num_train = num_nodes - num_test - num_val

    train_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)

    train_mask[indices[:num_train]] = True
    val_mask[indices[num_train:num_train + num_val]] = True
    test_mask[indices[num_train + num_val:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    # Mask out 20% nodes for attack performance evaluation (half target, half clean test)
    num_target = int(0.1 * num_nodes)  # Half of 20%
    target_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    clean_test_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    target_mask[indices[num_train + num_val:num_train + num_val + num_target]] = True
    clean_test_mask[indices[num_train + num_val + num_target:]] = True

    data.target_mask = target_mask
    data.clean_test_mask = clean_test_mask

    return data

# Define GNN Model with multiple architectures (GCN, GraphSAGE, GAT)
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, model_type='GCN'):
        super(GNN, self).__init__()
        if model_type == 'GCN':
            self.conv1 = GCNConv(input_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, output_dim)
        elif model_type == 'GraphSage':
            self.conv1 = SAGEConv(input_dim, hidden_dim)
            self.conv2 = SAGEConv(hidden_dim, output_dim)
        elif model_type == 'GAT':
            self.conv1 = GATConv(input_dim, hidden_dim, heads=8, concat=True)
            self.conv2 = GATConv(hidden_dim * 8, output_dim, heads=1, concat=False)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Select nodes to poison based on high-centrality (degree centrality) for a stronger impact
def select_high_centrality_nodes(data, num_nodes_to_select):
    graph = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    graph.add_edges_from(edge_index.T)
    centrality = nx.degree_centrality(graph)
    sorted_nodes = sorted(centrality, key=centrality.get, reverse=True)
    return torch.tensor(sorted_nodes[:num_nodes_to_select], dtype=torch.long).to(device)


class TriggerGenerator(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TriggerGenerator, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),  # Reduce dimensions
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim)  # Restore to input_dim
        )

    def forward(self, x):
        return self.mlp(x)

class GAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(GAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, x, edge_index):
        return self.encoder(x, edge_index)

    def decode(self, z, edge_index):
        return self.decoder(z, edge_index)

class OODDetector(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(OODDetector, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(input_dim, hidden_dim, latent_dim)
        self.gae = GAE(self.encoder, self.decoder)

    def forward(self, x, edge_index):
        z = self.gae.encode(x, edge_index)
        return z

    def reconstruct(self, z, edge_index):
        return self.decoder(z, edge_index)

    def reconstruction_loss(self, x, edge_index):
        z = self.gae.encode(x, edge_index)
        reconstructed = self.reconstruct(z, edge_index)
        return F.mse_loss(reconstructed, x)

    def detect_ood(self, x, edge_index, threshold):
        loss = self.reconstruction_loss(x, edge_index)
        return loss > threshold

class Encoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, latent_dim)

    def forward(self, x, edge_index):
        z = F.relu(self.conv1(x, edge_index))
        z = self.conv2(z, edge_index)
        return z

class Decoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Decoder, self).__init__()
        self.conv1 = GCNConv(latent_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, input_dim)

    def forward(self, z, edge_index):
        x = F.relu(self.conv1(z, edge_index))
        x = self.conv2(x, edge_index)
        return x



def train_ood_detector(ood_detector, data, optimizer, epochs=50):
    ood_detector.train()
    for epoch in range(epochs):
        optimizer.zero_grad()  # Clear the gradients

        # Forward pass
        z = ood_detector(data.x, data.edge_index)

        # Reconstruct data using the latent embedding
        reconstructed_x = ood_detector.reconstruct(z, data.edge_index)

        # Use only the training mask to compute reconstruction loss
        loss = F.mse_loss(reconstructed_x[data.train_mask], data.x[data.train_mask])
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Reconstruction Loss: {loss.item():.4f}")



def train_with_poisoned_data(model, data, optimizer, poisoned_nodes, trigger_gen, attack, ood_detector=None, alpha=0.7, early_stopping=False):
    # Apply trigger injection
    data_poisoned = inject_trigger(data, poisoned_nodes, attack, trigger_gen, ood_detector, alpha)

    # Training loop
    model.train()
    for epoch in range(100):
        optimizer.zero_grad()  # Clear the gradients

        # Forward pass
        out = model(data_poisoned.x, data_poisoned.edge_index)

        # Calculate loss
        loss = F.cross_entropy(out[data_poisoned.train_mask], data_poisoned.y[data_poisoned.train_mask])

        # Backward pass
        loss.backward()  # No retain_graph=True unless explicitly required
        optimizer.step()

        # Optional: Print loss during training for insight
        if early_stopping and epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    return model, data_poisoned



from sklearn.cluster import KMeans

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

def select_diverse_nodes(data, num_nodes_to_select, num_clusters=None):
    """
    Select nodes using a clustering-based approach to ensure diversity, along with high-degree nodes.

    Parameters:
    - data: PyG data object representing the graph.
    - num_nodes_to_select: Number of nodes to select for poisoning.
    - num_clusters: Number of clusters to form for diversity. Defaults to number of classes if not provided.

    Returns:
    - Tensor containing indices of selected nodes.
    """
    # Set the number of clusters equal to the number of classes in the dataset if not provided
    if num_clusters is None:
        num_clusters = len(torch.unique(data.y))

    # Use GCN encoder to get node embeddings that capture both attribute and structural information
    encoder = GCNEncoder(data.num_features, out_channels=16)  # Assuming out_channels = 16
    encoder.eval()
    with torch.no_grad():
        embeddings = encoder(data.x, data.edge_index).cpu().numpy()

    # Perform K-means clustering to find representative nodes
    kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(embeddings)
    labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_

    # Select nodes closest to the cluster centers
    selected_nodes = []
    for i in range(num_clusters):
        cluster_indices = np.where(labels == i)[0]
        center = cluster_centers[i]
        distances = np.linalg.norm(embeddings[cluster_indices] - center, axis=1)
        closest_node = cluster_indices[np.argmin(distances)]
        selected_nodes.append(closest_node)

    # Calculate node degrees
    degree = torch.bincount(data.edge_index[0])  # Calculate node degrees
    # Select high-degree nodes
    high_degree_nodes = torch.topk(degree, len(selected_nodes) // 2).indices

    # Convert the graph to NetworkX to calculate centrality measures
    G = to_networkx(data, to_undirected=True)
    betweenness_centrality = nx.betweenness_centrality(G)
    central_nodes = sorted(betweenness_centrality, key=betweenness_centrality.get, reverse=True)
    central_nodes_tensor = torch.tensor(central_nodes[:len(selected_nodes) // 2], dtype=torch.long)

    # Combine diverse nodes, high-degree nodes, and central nodes
    combined_nodes = torch.cat([torch.tensor(selected_nodes), high_degree_nodes, central_nodes_tensor])
    # Get unique nodes and limit to num_nodes_to_select
    unique_nodes = torch.unique(combined_nodes)[:num_nodes_to_select]

    return torch.tensor(selected_nodes[:num_nodes_to_select], dtype=torch.long).to(data.x.device)

def inject_trigger(data, poisoned_nodes, attack_type, trigger_gen=None, ood_detector=None, alpha=0.7, trigger_size=5, trigger_density=0.5, input_dim=None):
    # Clone data to avoid overwriting the original graph
    data_poisoned = data.clone()
    device = data_poisoned.x.device

    if len(poisoned_nodes) == 0:
        raise ValueError("No poisoned nodes selected. Ensure 'poisoned_nodes' is populated and non-empty.")

    # Adjust trigger_size if it exceeds the number of poisoned nodes
    trigger_size = min(trigger_size, len(poisoned_nodes))

    if attack_type == 'SBA-Samp':
        # Subgraph-Based Attack - Random Sampling
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes[:trigger_size]]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ])
        natural_features = avg_features + torch.randn_like(avg_features) * 0.02  # Small randomness

        # Generate subgraph with realistic density
        G = nx.erdos_renyi_graph(trigger_size, trigger_density)
        trigger_edge_index = torch.tensor(list(G.edges)).t().contiguous()

        # Connect poisoned nodes to the subgraph
        poisoned_edges = torch.stack([
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ])

        # Update graph structure and features
        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index.to(device), poisoned_edges.to(device)], dim=1)
        data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]

    elif attack_type == 'SBA-Gen':
        # Subgraph-Based Attack - Gaussian
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes[:trigger_size]]

        # Calculate mean and standard deviation of features
        feature_mean = data.x.mean(dim=0)
        feature_std = data.x.std(dim=0)

        # Generate natural features with Gaussian noise
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else feature_mean for nodes in connected_nodes
        ])
        natural_features = avg_features + torch.normal(mean=0.0, std=0.03, size=avg_features.shape).to(data.x.device)

        # Generate subgraph edges based on Gaussian similarity
        trigger_edge_index = []
        for i in range(trigger_size):
            for j in range(i + 1, trigger_size):
                similarity = torch.exp(-torch.norm((natural_features[i] - natural_features[j]) / feature_std)**2)
                if similarity > torch.rand(1).item():  # Random threshold
                    trigger_edge_index.append([i, j])

        # Convert edges to PyTorch tensor
        trigger_edge_index = torch.tensor(trigger_edge_index, dtype=torch.long).t().contiguous()

        # Check if trigger_edge_index is non-empty
        if trigger_edge_index.numel() > 0:  # Ensure there are edges before modifying
            trigger_edge_index += poisoned_nodes[:trigger_size].unsqueeze(0)

        # Connect generated subgraph to existing nodes
        poisoned_edges = torch.stack([
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ])

        # Update graph structure and features
        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index.to(device), poisoned_edges.to(device)], dim=1)
        data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]


    elif attack_type == 'DPGBA':
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ]).to(device)

        # Generate trigger features using the trigger generator
        if trigger_gen is None:
            raise ValueError("Trigger generator is required for the DPGBA attack.")

        with torch.no_grad():  # Disable gradient tracking for feature generation
            trigger_features = trigger_gen(avg_features)

        # Validate dimensions
        if trigger_features.shape[1] != data.x.shape[1]:
            raise ValueError(f"Trigger feature dimension mismatch: {trigger_features.shape[1]} vs {data.x.shape[1]}")

        # Blend features
        node_alphas = torch.rand(len(poisoned_nodes)).to(device) * 0.3 + 0.5
        distribution_preserved_features = (
            node_alphas.unsqueeze(1) * data.x[poisoned_nodes]
            + (1 - node_alphas.unsqueeze(1)) * trigger_features
        )

        # Update poisoned nodes
        data_poisoned.x[poisoned_nodes] = distribution_preserved_features


    elif attack_type == 'GTA':
        # Graph Trojan Attack
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ])
        trigger_features = avg_features + torch.randn_like(avg_features) * 0.05
        data_poisoned.x[poisoned_nodes] = trigger_features

    elif attack_type == 'UGBA':
        # Unnoticeable Graph Backdoor Attack
        num_clusters = len(torch.unique(data_poisoned.y))
        diverse_nodes = select_diverse_nodes(data_poisoned, len(poisoned_nodes))
        connected_nodes = [data_poisoned.edge_index[0][data_poisoned.edge_index[1] == node] for node in diverse_nodes]

        # Generate refined trigger features
        avg_features = torch.stack([
            data_poisoned.x[nodes].mean(dim=0) if len(nodes) > 0 else data_poisoned.x.mean(dim=0) for nodes in connected_nodes
        ])
        refined_trigger_features = avg_features + torch.normal(mean=2.0, std=0.5, size=avg_features.shape).to(data_poisoned.x.device)
        data_poisoned.x[diverse_nodes] = refined_trigger_features

        # Add edges between diverse nodes for structural blending
        new_edges = []
        for i in range(len(diverse_nodes)):
            node = diverse_nodes[i]
            neighbor = connected_nodes[i][0] if len(connected_nodes[i]) > 0 else diverse_nodes[(i + 1) % len(diverse_nodes)]
            new_edges.append([node, neighbor])

        # Update graph structure
        new_edges = torch.tensor(new_edges, dtype=torch.long).t().contiguous().to(data_poisoned.edge_index.device)
        data_poisoned.edge_index = torch.cat([data_poisoned.edge_index, new_edges], dim=1)

    return data_poisoned

import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import pandas as pd
import gc


def dominant_set_clustering(data, threshold=0.7, use_pca=True, pca_components=10):
    """
    Applies a simplified outlier detection framework using a combination of K-Means clustering and distance-based heuristics.

    Parameters:
    - data: PyG data object representing the graph.
    - threshold: Quantile threshold for identifying outliers based on cluster distances.
    - use_pca: Whether to use PCA for dimensionality reduction.
    - pca_components: Number of PCA components to use if PCA is applied.

    Returns:
    - pruned_nodes: Set of nodes identified as outliers.
    - data: Updated PyG data object with modified features and labels for outliers.
    """
    # Step 1: Determine the number of clusters based on the number of classes
    n_clusters = len(data.y.unique())  # Number of unique classes in the dataset

    # Step 2: Dimensionality reduction using PCA (optional)
    node_features = data.x.detach().cpu().numpy()
    if use_pca and node_features.shape[1] > pca_components:
        pca = PCA(n_components=pca_components)
        node_features = pca.fit_transform(node_features)

    # Step 3: K-Means Clustering to identify clusters and potential outliers
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(node_features)
    cluster_labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_

    # Calculate distances to cluster centers
    distances = np.linalg.norm(node_features - cluster_centers[cluster_labels], axis=1)

    # Identify outlier candidates based on distance threshold
    distance_threshold = np.percentile(distances, 100 * threshold)
    outlier_candidates = np.where(distances > distance_threshold)[0]

    # Step 4: Update data to reflect removal of outlier influence
    pruned_nodes = set(outlier_candidates)
    if len(pruned_nodes) > 0:
        outliers = torch.tensor(list(pruned_nodes), dtype=torch.long, device=data.x.device)

        # Assign an invalid label (-1) to outlier nodes to discard them during training
        data.y[outliers] = -1

        # Replace the features of outliers with the average feature value to reduce their impact
        data.x[outliers] = data.x.mean(dim=0).to(data.x.device)

    return pruned_nodes, data

def defense_prune_edges(data, quantile_threshold=0.9):
    """
    Prunes edges based on adaptive cosine similarity between node features.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile to determine pruning threshold (e.g., 0.9 means pruning edges in the top 10% dissimilar).

    Returns:
    - data: Updated PyG data object with pruned edges.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Adaptive threshold based on quantile of similarity distribution
    similarity_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Keep edges with cosine similarity above the threshold
    pruned_mask = cosine_similarities >= similarity_threshold
    pruned_edges = edge_index[:, pruned_mask]

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    return data



def defense_prune_and_discard_labels(data, quantile_threshold=0.2):
    """
    Prunes edges based on adaptive cosine similarity and discards labels of nodes connected by pruned edges selectively.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile threshold for cosine similarity pruning (e.g., 0.2 means pruning edges in the bottom 20%).

    Returns:
    - data: Updated PyG data object with pruned edges and selectively discarded labels.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features using PyTorch
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Use quantile to determine adaptive threshold for pruning
    adaptive_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Mask edges with similarity below the adaptive threshold
    pruned_mask = cosine_similarities < adaptive_threshold
    pruned_edges = edge_index[:, ~pruned_mask]  # Retain edges that are above the threshold

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    # Selectively discard labels of nodes connected by many pruned edges
    pruned_src, pruned_dst = edge_index[:, pruned_mask]
    pruned_nodes_count = torch.bincount(torch.cat([pruned_src, pruned_dst]), minlength=data.num_nodes)

    # Only discard labels if the node has a high count of pruned edges
    threshold_count = int(torch.median(pruned_nodes_count).item())  # Use median count as a threshold
    nodes_to_discard = torch.where(pruned_nodes_count > threshold_count)[0]

    data.y[nodes_to_discard] = -1  # Use -1 to represent discarded labels

    return data

# Compute ASR and Clean Accuracy (using .detach() to avoid retaining computation graph)
def compute_metrics(model, data, poisoned_nodes):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index).detach()
        _, pred = out.max(dim=1)
        asr = (pred[poisoned_nodes] == data.y[poisoned_nodes]).sum().item() / len(poisoned_nodes) * 100
        clean_acc = accuracy_score(data.y[data.test_mask].cpu(), pred[data.test_mask].cpu()) * 100
    return asr, clean_acc



# Visualization Function
# Visualize PCA for Attacks
# Added function to visualize PCA projections of node embeddings for different attacks
def visualize_pca_for_attacks(attack_embeddings_dict):
    pca = PCA(n_components=2)
    plt.figure(figsize=(20, 10))

    for i, (attack, attack_data) in enumerate(attack_embeddings_dict.items(), 1):
        embeddings = attack_data['data'].detach().cpu().numpy()
        poisoned_nodes = attack_data['poisoned_nodes'].detach().cpu().numpy()

        # Apply PCA to the node embeddings
        pca_result = pca.fit_transform(embeddings)

        # Create masks for clean and poisoned nodes
        clean_mask = np.ones(embeddings.shape[0], dtype=bool)
        clean_mask[poisoned_nodes] = False

        # Extract clean and poisoned node embeddings after PCA
        clean_embeddings = pca_result[clean_mask]
        poisoned_embeddings = pca_result[~clean_mask]

        # Plotting clean and poisoned nodes
        plt.subplot(2, 3, i)
        plt.scatter(clean_embeddings[:, 0], clean_embeddings[:, 1], s=10, alpha=0.5, label='Clean Nodes', c='b')
        plt.scatter(poisoned_embeddings[:, 0], poisoned_embeddings[:, 1], s=10, alpha=0.8, label='Poisoned Nodes', c='r')
        plt.title(f'PCA Visualization for {attack}')
        plt.xlabel('PCA Component 1')
        plt.ylabel('PCA Component 2')
        plt.legend()

    plt.tight_layout()
    plt.show()

import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_dense_adj, k_hop_subgraph
import torch.nn.functional as F

class SUNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SUNLayer, self).__init__()
        # Separate transformations for root and non-root nodes
        self.root_mlp = nn.Linear(in_channels, out_channels)
        self.non_root_mlp = nn.Linear(in_channels, out_channels)
        self.global_mlp = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, subgraph_masks):
        # Convert edge_index to dense adjacency matrix
        adjacency_matrix = to_dense_adj(edge_index, max_num_nodes=x.size(0))[0]

        # Local message passing within subgraphs
        local_features = torch.matmul(adjacency_matrix, x)

        # Global aggregation across subgraphs
        global_features = self.global_mlp(torch.mean(x, dim=0, keepdim=True))
        global_features = global_features.expand(x.size(0), global_features.size(1))  # Broadcast to match x's shape

        # Initialize root and non-root features with correct shape
        root_features = torch.zeros((x.size(0), global_features.size(1)), device=x.device)
        non_root_features = torch.zeros((x.size(0), global_features.size(1)), device=x.device)

        # Apply transformations to root nodes
        root_features[subgraph_masks] = self.root_mlp(x[subgraph_masks])

        # Apply transformations to non-root nodes
        non_root_features = self.non_root_mlp(local_features)

        # Combine root, non-root, and global updates
        updated_features = root_features + non_root_features + global_features

        return updated_features

class SUN(nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels):
        super(SUN, self).__init__()
        self.layer1 = SUNLayer(num_features, hidden_channels)
        self.layer2 = SUNLayer(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        num_nodes = x.size(0)
        subgraph_masks = torch.zeros(num_nodes, dtype=torch.bool, device=x.device)

        # Example subgraph extraction: mark every node as root for simplicity
        for i in range(num_nodes):
            _, _, _, node_mask = k_hop_subgraph(i, 1, edge_index, relabel_nodes=False, num_nodes=num_nodes)
            subgraph_masks[node_mask[:num_nodes]] = True

        # Pass through SUN layers
        x = self.layer1(x, edge_index, subgraph_masks)
        x = F.relu(x)
        x = self.layer2(x, edge_index, subgraph_masks)
        return F.log_softmax(x, dim=1)


def run_all_attacks_sun():
    datasets = ["Cora", "PubMed", "CiteSeer"]
    results_summary = []

    for dataset_name in datasets:
        try:
            print(f"Starting process for dataset: {dataset_name}")

            # Load the dataset
            dataset = load_dataset(dataset_name)
            print(f"Loaded dataset: {dataset_name}")
            data = dataset[0].to(device)
            print(f"Dataset {dataset_name} has {data.num_nodes} nodes and {data.num_edges} edges.")

            input_dim = data.num_features
            output_dim = dataset.num_classes if isinstance(dataset.num_classes, int) else dataset.num_classes[0]
            data = split_dataset(data)
            print(f"Dataset {dataset_name} split into train/val/test.")

            # Dataset-specific poisoning budgets
            dataset_budgets = {'Cora': 10, 'PubMed': 40, 'CiteSeer': 30}
            poisoned_node_budget = dataset_budgets.get(dataset_name, 10)

            # Initialize SUN model and optimizer
            print(f"Initializing SUN for dataset: {dataset_name}")
            model = SUN(num_features=input_dim, num_classes=output_dim, hidden_channels=64).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

            # Train baseline model
            print(f"Training baseline model for dataset: {dataset_name}")
            model.train()
            for epoch in range(200):
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)  # Pass `x` and `edge_index` separately
                loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                if epoch % 10 == 0:
                    print(f"Dataset: {dataset_name}, Epoch {epoch}, Loss: {loss.item():.4f}")

            # Evaluate baseline accuracy
            model.eval()
            with torch.no_grad():
                out = model(data.x, data.edge_index)
                predictions = out.argmax(dim=1)
                baseline_acc = (predictions[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
                print(f"Dataset: {dataset_name}, Baseline Accuracy: {baseline_acc * 100:.2f}%")

                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": "SUN",
                    "Attack": "None",
                    "Defense": "None",
                    "ASR": "N/A",
                    "Clean Accuracy": baseline_acc * 100
                })

            # Initialize Trigger Generator and OOD Detector
            print(f"Initializing Trigger Generator and OOD Detector for dataset: {dataset_name}")
            trigger_gen = TriggerGenerator(input_dim=input_dim, hidden_dim=64).to(device)
            ood_detector = OODDetector(input_dim=input_dim, hidden_dim=64, latent_dim=16).to(device)
            ood_optimizer = torch.optim.Adam(ood_detector.parameters(), lr=0.001)
            train_ood_detector(ood_detector, data, ood_optimizer, epochs=10)

            # Select nodes to poison
            print(f"Selecting poisoned nodes for dataset: {dataset_name}")
            poisoned_nodes = select_high_centrality_nodes(data, poisoned_node_budget).to(device)
            print(f"Selected {len(poisoned_nodes)} poisoned nodes.")

            # Define attack methods
            attack_methods = ['SBA-Samp', 'SBA-Gen', 'GTA', 'UGBA', 'DPGBA']

            for attack in attack_methods:
                try:
                    print(f"Starting attack {attack} on dataset: {dataset_name}")

                    # Use trigger_gen only for DPGBA
                    if attack == 'DPGBA' and trigger_gen is None:
                        print(f"Skipping {attack} as trigger_gen is required and not defined.")
                        continue

                    # Inject the trigger
                    data_poisoned = inject_trigger(
                        data=data,
                        poisoned_nodes=poisoned_nodes,
                        attack_type=attack,
                        trigger_gen=trigger_gen if attack == "DPGBA" else None,
                        alpha=0.7,
                        trigger_size=10,
                        trigger_density=0.5
                    ).to(device)

                    # Train the model on poisoned data
                    model.train()
                    for epoch in range(200):
                        optimizer.zero_grad()
                        out = model(data_poisoned.x, data_poisoned.edge_index)
                        loss = F.cross_entropy(out[data_poisoned.train_mask], data_poisoned.y[data_poisoned.train_mask])
                        loss.backward()
                        optimizer.step()
                        if epoch % 10 == 0:
                            print(f"Dataset: {dataset_name}, Attack: {attack}, Epoch {epoch}, Loss: {loss.item():.4f}")

                    # Evaluate ASR and clean accuracy
                    model.eval()
                    with torch.no_grad():
                        out = model(data_poisoned.x, data_poisoned.edge_index)
                        predictions = out.argmax(dim=1)
                        asr = (predictions[poisoned_nodes] == data.y[poisoned_nodes]).sum().item() / len(poisoned_nodes) * 100
                        clean_acc = (predictions[data_poisoned.test_mask] == data_poisoned.y[data_poisoned.test_mask]).sum().item() / data_poisoned.test_mask.sum().item() * 100

                    results_summary.append({
                        "Dataset": dataset_name,
                        "Model": "SUN",
                        "Attack": attack,
                        "Defense": "None",
                        "ASR": asr,
                        "Clean Accuracy": clean_acc
                    })
                    print(f"Dataset: {dataset_name}, Attack: {attack} - ASR: {asr:.2f}%, Clean Accuracy: {clean_acc:.2f}%")

                except Exception as e:
                    print(f"Error during attack {attack} on dataset {dataset_name}: {e}")

        except Exception as e:
            print(f"Error with dataset {dataset_name}: {e}")

    # Save and display results
    results_df = pd.DataFrame(results_summary)
    print("\nSummary of Results:")
    print(results_df)
    results_df.to_csv("sun_backdoor_attack_results_summary.csv", index=False)

run_all_attacks_sun()


Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Using device: cpu
Starting process for dataset: Cora


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Loaded dataset: Cora
Dataset Cora has 2708 nodes and 10556 edges.
Dataset Cora split into train/val/test.
Initializing SUN for dataset: Cora
Training baseline model for dataset: Cora
Dataset: Cora, Epoch 0, Loss: 2.0277
Dataset: Cora, Epoch 10, Loss: 0.5944
Dataset: Cora, Epoch 20, Loss: 0.3298
Dataset: Cora, Epoch 30, Loss: 0.1908
Dataset: Cora, Epoch 40, Loss: 0.1105
Dataset: Cora, Epoch 50, Loss: 0.0634
Dataset: Cora, Epoch 60, Loss: 0.0378
Dataset: Cora, Epoch 70, Loss: 0.0241
Dataset: Cora, Epoch 80, Loss: 0.0165
Dataset: Cora, Epoch 90, Loss: 0.0121
Dataset: Cora, Epoch 100, Loss: 0.0093
Dataset: Cora, Epoch 110, Loss: 0.0075
Dataset: Cora, Epoch 120, Loss: 0.0061
Dataset: Cora, Epoch 130, Loss: 0.0052
Dataset: Cora, Epoch 140, Loss: 0.0044
Dataset: Cora, Epoch 150, Loss: 0.0038
Dataset: Cora, Epoch 160, Loss: 0.0034
Dataset: Cora, Epoch 170, Loss: 0.0030
Dataset: Cora, Epoch 180, Loss: 0.0027
Dataset: Cora, Epoch 190, Loss: 0.0024
Dataset: Cora, Baseline Accuracy: 87.62%
Initial

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!


Loaded dataset: PubMed
Dataset PubMed has 19717 nodes and 88648 edges.
Dataset PubMed split into train/val/test.
Initializing SUN for dataset: PubMed
Training baseline model for dataset: PubMed
Dataset: PubMed, Epoch 0, Loss: 1.1282
Dataset: PubMed, Epoch 10, Loss: 0.7094
Dataset: PubMed, Epoch 20, Loss: 0.6071
Dataset: PubMed, Epoch 30, Loss: 0.5580
Dataset: PubMed, Epoch 40, Loss: 0.5174
Dataset: PubMed, Epoch 50, Loss: 0.4818
Dataset: PubMed, Epoch 60, Loss: 0.4502
Dataset: PubMed, Epoch 70, Loss: 0.4212
Dataset: PubMed, Epoch 80, Loss: 0.3943
Dataset: PubMed, Epoch 90, Loss: 0.3696
Dataset: PubMed, Epoch 100, Loss: 0.3468
Dataset: PubMed, Epoch 110, Loss: 0.3263
Dataset: PubMed, Epoch 120, Loss: 0.3077
Dataset: PubMed, Epoch 130, Loss: 0.2910
Dataset: PubMed, Epoch 140, Loss: 0.2760
Dataset: PubMed, Epoch 150, Loss: 0.2648
Dataset: PubMed, Epoch 160, Loss: 0.2533
Dataset: PubMed, Epoch 170, Loss: 0.2427
Dataset: PubMed, Epoch 180, Loss: 0.2341
Dataset: PubMed, Epoch 190, Loss: 0.22