<a href="https://colab.research.google.com/github/AbhiJeet70/GraphPoisoningCodes/blob/main/SBA_GTA_UGBA_DPGBA2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Graph Neural Network Backdoor Attack and Defense Experiment

This experiment explores the vulnerability of Graph Neural Networks (GNNs) to various adversarial attacks and evaluates the effectiveness of an adaptive outlier detection defense. We tested five attack methodologies and applied a clustering-based outlier detection to mitigate the attacks, using key datasets like **Cora & PubMed**. Below is a summary of our methodology, findings, and visual analysis.

## Methodology Overview

### Attack Models
Five attack methods were evaluated:

1. **SBA-Samp** and **SBA-Gen**: Subgraph-based Backdoor Attack using sampled or generated patterns.
2. **GTA**: Graph Trojaning Attack, targeting central nodes.
3. **UGBA**: Unnoticeable Graph Backdoor Attack.
4. **DPGBA**: Distribution Preserved Graph Backdoor Attack.

The attacks aimed to insert triggers into high-centrality nodes to affect the node classification task.

### Defense Mechanism
To mitigate these attacks, an **Adaptive Outlier Detection Mechanism (Dominant Set)** was used, using an **Autoencoder** for **Out-of-Distribution (OOD) detection**. The adaptive pruning function filtered nodes based on a cohesiveness score, with a pre-filter using a high reconstruction loss threshold followed by a k-nearest neighbor (KNN) clustering-based approach.

### Metrics
- **ASR (Attack Success Rate)**: Proportion of attacked nodes misclassified as intended by the adversary.
- **Clean Accuracy**: Classification accuracy of non-attacked nodes.

| Dataset | Model      | Attack   | Defense            | ASR (%) | Clean Accuracy (%) |
|---------|------------|----------|--------------------|---------|--------------------|
| Cora    | GCN        | SBA-Samp | None               | 90.00   | 88.54              |
| Cora    | GCN        | SBA-Samp | Outlier Detection  | 90.00   | 78.93              |
| Cora    | GCN        | SBA-Samp | Prune              | 90.00   | 83.92              |
| Cora    | GCN        | SBA-Samp | Prune + LD         | 0.00    | 46.40              |
| Cora    | GCN        | SBA-Gen  | None               | 100.00  | 87.43              |
| Cora    | GCN        | SBA-Gen  | Outlier Detection  | 100.00  | 77.63              |
| Cora    | GCN        | SBA-Gen  | Prune              | 90.00   | 81.70              |
| Cora    | GCN        | SBA-Gen  | Prune + LD         | 0.00    | 44.92              |
| Cora    | GCN        | GTA      | None               | 90.00   | 87.43              |
| Cora    | GCN        | GTA      | Outlier Detection  | 90.00   | 78.00              |
| Cora    | GCN        | GTA      | Prune              | 100.00  | 79.30              |
| Cora    | GCN        | GTA      | Prune + LD         | 0.00    | 43.99              |
| Cora    | GCN        | UGBA     | None               | 100.00  | 86.88              |
| Cora    | GCN        | UGBA     | Outlier Detection  | 100.00  | 77.45              |
| Cora    | GCN        | UGBA     | Prune              | 80.00   | 77.82              |
| Cora    | GCN        | UGBA     | Prune + LD         | 0.00    | 43.81              |
| Cora    | GCN        | DPGBA    | None               | 100.00  | 86.69              |
| Cora    | GCN        | DPGBA    | Outlier Detection  | 100.00  | 77.63              |
| Cora    | GCN        | DPGBA    | Prune              | 70.00   | 76.34              |
| Cora    | GCN        | DPGBA    | Prune + LD         | 0.00    | 42.88              |
| Cora    | GraphSage  | SBA-Samp | None               | 100.00  | 89.09              |
| Cora    | GraphSage  | SBA-Samp | Outlier Detection  | 100.00  | 79.11              |
| Cora    | GraphSage  | SBA-Samp | Prune              | 90.00   | 78.00              |
| Cora    | GraphSage  | SBA-Samp | Prune + LD         | 0.00    | 43.07              |
| PubMed  | GCN        | SBA-Samp | None               | 90.00   | 84.30              |
| PubMed  | GCN        | SBA-Samp | Outlier Detection  | 82.50   | 76.69              |
| PubMed  | GCN        | SBA-Samp | Prune              | 90.00   | 82.17              |
| PubMed  | GCN        | SBA-Samp | Prune + LD         | 0.00    | 40.65              |
| PubMed  | GCN        | SBA-Gen  | None               | 92.50   | 86.58              |
| PubMed  | GCN        | SBA-Gen  | Outlier Detection  | 72.50   | 78.54              |
| PubMed  | GCN        | SBA-Gen  | Prune              | 92.50   | 83.90              |
| PubMed  | GCN        | SBA-Gen  | Prune + LD         | 0.00    | 42.23              |
| PubMed  | GCN        | GTA      | None               | 95.00   | 87.29              |
| PubMed  | GCN        | GTA      | Outlier Detection  | 0.00    | 79.33              |
| PubMed  | GCN        | GTA      | Prune              | 60.00   | 84.07              |
| PubMed  | GCN        | GTA      | Prune + LD         | 0.00    | 42.73              |
| PubMed  | GCN        | UGBA     | None               | 95.00   | 87.45              |
| PubMed  | GCN        | UGBA     | Outlier Detection  | 85.00   | 79.63              |
| PubMed  | GCN        | UGBA     | Prune              | 92.50   | 83.69              |
| PubMed  | GCN        | UGBA     | Prune + LD         | 0.00    | 42.53              |
| PubMed  | GCN        | DPGBA    | None               | 95.00   | 87.80              |
| PubMed  | GCN        | DPGBA    | Outlier Detection  | 20.00   | 79.91              |
| PubMed  | GCN        | DPGBA    | Prune              | 92.50   | 83.92              |
| PubMed  | GCN        | DPGBA    | Prune + LD         | 0.00    | 42.73              |
| PubMed  | GraphSage  | SBA-Samp | None               | 95.00   | 86.84              |
| PubMed  | GraphSage  | SBA-Samp | Outlier Detection  | 87.50   | 78.65              |
| PubMed  | GraphSage  | SBA-Samp | Prune              | 92.50   | 82.50              |
| PubMed  | GraphSage  | SBA-Samp | Prune + LD         | 0.00    | 41.47              |
| PubMed  | GraphSage  | SBA-Gen  | None               | 95.00   | 88.66              |
| PubMed  | GraphSage  | SBA-Gen  | Outlier Detection  | 75.00   | 80.65              |
| PubMed  | GraphSage  | SBA-Gen  | Prune              | 92.50   | 83.64              |
| PubMed  | GraphSage  | SBA-Gen  | Prune + LD         | 0.00    | 42.56              |
| PubMed  | GraphSage  | GTA      | None               | 95.00   | 88.79              |
| PubMed  | GraphSage  | GTA      | Outlier Detection  | 0.00    | 80.73              |
| PubMed  | GraphSage  | GTA      | Prune              | 67.50   | 84.48              |
| PubMed  | GraphSage  | GTA      | Prune + LD         | 0.00    | 42.89              |




In [None]:


# Install necessary packages
!pip install torch-geometric
!pip install matplotlib
!pip install scikit-learn

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid, Flickr
from sklearn.decomposition import PCA
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import networkx as nx
from sklearn.neighbors import NearestNeighbors
import pandas as pd

import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed()

# Check if GPU is available and set device
device = torch.device('cpu')  # Use CPU
print(f"Using device: {device}")


# Load datasets
def load_dataset(dataset_name):
    if dataset_name in ["Cora", "PubMed"]:
        dataset = Planetoid(root=f"./data/{dataset_name}", name=dataset_name)
    elif dataset_name == "Flickr":
        dataset = Flickr(root="./data/Flickr")
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    return dataset

# Dataset-specific device check (for large datasets like OGB-arxiv)
def set_dataset_device(dataset_name):
    return torch.device('cpu') if dataset_name == "OGB-arxiv" else device

# Split dataset into train/validation/test
# Updated to randomly mask out 20% of nodes, use 10% for labeled nodes, and 10% for validation
def split_dataset(data, test_size=0.2, val_size=0.1):
    num_nodes = data.num_nodes
    indices = np.arange(num_nodes)
    np.random.shuffle(indices)

    num_test = int(test_size * num_nodes)
    num_val = int(val_size * num_nodes)
    num_train = num_nodes - num_test - num_val

    train_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)

    train_mask[indices[:num_train]] = True
    val_mask[indices[num_train:num_train + num_val]] = True
    test_mask[indices[num_train + num_val:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    # Mask out 20% nodes for attack performance evaluation (half target, half clean test)
    num_target = int(0.1 * num_nodes)  # Half of 20%
    target_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    clean_test_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    target_mask[indices[num_train + num_val:num_train + num_val + num_target]] = True
    clean_test_mask[indices[num_train + num_val + num_target:]] = True

    data.target_mask = target_mask
    data.clean_test_mask = clean_test_mask

    return data

# Define GNN Model with multiple architectures (GCN, GraphSAGE, GAT)
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, model_type='GCN'):
        super(GNN, self).__init__()
        if model_type == 'GCN':
            self.conv1 = GCNConv(input_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, output_dim)
        elif model_type == 'GraphSage':
            self.conv1 = SAGEConv(input_dim, hidden_dim)
            self.conv2 = SAGEConv(hidden_dim, output_dim)
        elif model_type == 'GAT':
            self.conv1 = GATConv(input_dim, hidden_dim, heads=8, concat=True)
            self.conv2 = GATConv(hidden_dim * 8, output_dim, heads=1, concat=False)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Select nodes to poison based on high-centrality (degree centrality) for a stronger impact
def select_high_centrality_nodes(data, num_nodes_to_select):
    graph = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    graph.add_edges_from(edge_index.T)
    centrality = nx.degree_centrality(graph)
    sorted_nodes = sorted(centrality, key=centrality.get, reverse=True)
    return torch.tensor(sorted_nodes[:num_nodes_to_select], dtype=torch.long).to(device)


class TriggerGenerator(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TriggerGenerator, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        return self.mlp(x)

# OOD Detector (Autoencoder) for detecting poisoned nodes
class OODDetector(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(OODDetector, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 16),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(16, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

    def reconstruction_loss(self, x):
        decoded = self.forward(x)
        loss = F.mse_loss(decoded, x, reduction='none').mean(dim=1)
        return loss

# Train the OOD Detector
def train_ood_detector(ood_detector, data, optimizer, epochs=50):
    ood_detector.train()
    for epoch in range(epochs):
        optimizer.zero_grad()  # Clear the gradients
        reconstructed = ood_detector(data.x)
        # Use only the training mask to train the OOD detector
        loss = F.mse_loss(reconstructed[data.train_mask], data.x[data.train_mask])
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Reconstruction Loss: {loss.item():.4f}")

# Training Function with Poisoned Data
def train_with_poisoned_data(model, data, optimizer, poisoned_nodes, trigger_gen, attack, alpha=0.7, early_stopping=False):
    # Apply trigger injection
    data_poisoned = inject_trigger(data, poisoned_nodes, attack, trigger_gen, alpha)

    # Training loop
    model.train()
    for epoch in range(100):
        optimizer.zero_grad()  # Clear the gradients

        # Forward pass
        out = model(data_poisoned.x, data_poisoned.edge_index)

        # Calculate loss
        loss = F.cross_entropy(out[data_poisoned.train_mask], data_poisoned.y[data_poisoned.train_mask])

        # Backward pass
        # Ensure we only retain the graph if we need to perform multiple backward passes
        if epoch < 99:  # In all but the last epoch, retain the graph
            loss.backward(retain_graph=True)
        else:
            loss.backward()

        # Update parameters
        optimizer.step()

        # Optional: Print loss during training for insight
        if early_stopping and epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    return model, data_poisoned

from sklearn.cluster import KMeans

# Clustering-based Node Selection for UGBA
def select_diverse_nodes(data, num_nodes_to_select, num_clusters=10):
    """
    Select nodes using a clustering-based approach to ensure diversity.

    Parameters:
    - data: PyG data object representing the graph.
    - num_nodes_to_select: Number of nodes to select for poisoning.
    - num_clusters: Number of clusters to form for diversity.

    Returns:
    - Tensor containing indices of selected nodes.
    """
    features = data.x.cpu().numpy()

    # Perform K-means clustering to find representative nodes
    kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(features)
    labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_

    # Select nodes closest to the cluster centers
    selected_nodes = []
    for i in range(num_clusters):
        cluster_indices = np.where(labels == i)[0]
        center = cluster_centers[i]
        distances = np.linalg.norm(features[cluster_indices] - center, axis=1)
        closest_node = cluster_indices[np.argmin(distances)]
        selected_nodes.append(closest_node)

    # If we need more nodes than clusters, add random nodes from each cluster
    while len(selected_nodes) < num_nodes_to_select:
        for i in range(num_clusters):
            cluster_indices = np.where(labels == i)[0]
            random_node = np.random.choice(cluster_indices)
            if random_node not in selected_nodes:
                selected_nodes.append(random_node)
            if len(selected_nodes) >= num_nodes_to_select:
                break

    return torch.tensor(selected_nodes[:num_nodes_to_select], dtype=torch.long).to(device)

# Inject Trigger Function for Different Attacks
def inject_trigger(data, poisoned_nodes, attack_type, trigger_gen=None, alpha=0.7, trigger_size=5, trigger_density=0.5, model_type='SW', input_dim=None):
    data_poisoned = data.clone()    # Clone data to avoid overwriting the original graph

    if attack_type == 'SBA-Samp':
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes[:trigger_size]]
        avg_features = torch.stack([data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes])
        natural_features = avg_features + torch.randn_like(avg_features) * 0.02  # Smaller randomness

        # Generate subgraph with realistic density
        G = nx.erdos_renyi_graph(trigger_size, trigger_density)
        trigger_edge_index = torch.tensor(list(G.edges)).t().contiguous()

        poisoned_edges = torch.stack([
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ])

        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index.to(device), poisoned_edges.to(device)], dim=1)
        data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]

    elif attack_type == 'SBA-Gen':
        # Enhanced Subgraph-based Backdoor Attack - Generated (SBA-Gen)
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes[:trigger_size]]
        avg_features = torch.stack([data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes])
        natural_features = avg_features + torch.randn_like(avg_features) * 0.03

        # Generate subgraph based on realistic local clustering (e.g., Watts-Strogatz or PA)
        if model_type == 'SW':
            G = nx.watts_strogatz_graph(trigger_size, k=3, p=0.4)  # Increase k to make more realistic local clusters
        elif model_type == 'PA':
            G = nx.barabasi_albert_graph(trigger_size, m=3)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

        # Add edges from generated subgraph and connect to existing nodes
        trigger_edge_index = torch.tensor(list(G.edges)).t().contiguous()
        poisoned_edges = torch.stack([
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ])
        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index.to(device), poisoned_edges.to(device)], dim=1)
        data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]

    elif attack_type == 'DPGBA':
        # Distribution-Preserving Graph Backdoor Attack (DPGBA) simplified
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes]
        avg_features = torch.stack([data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes])

        # Generate trigger features that preserve the distribution
        trigger_features = trigger_gen(avg_features)

        # Blend trigger features with original features to keep in-distribution
        node_alphas = torch.rand(len(poisoned_nodes)).to(device) * 0.3 + 0.5  # Random alpha between 0.5 and 0.8
        distribution_preserved_features = node_alphas.unsqueeze(1) * data.x[poisoned_nodes] + (1 - node_alphas.unsqueeze(1)) * trigger_features

        # Update poisoned nodes with blended features
        data_poisoned.x[poisoned_nodes] = distribution_preserved_features


    elif attack_type == 'GTA':
        # Enhanced Graph Trojan Attack (GTA)
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes]
        avg_features = torch.stack([data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes])

        trigger_features = avg_features + torch.randn_like(avg_features) * 0.05
        data_poisoned.x[poisoned_nodes] = trigger_features

    elif attack_type == 'UGBA':
        # Unnoticeable Graph Backdoor Attack (UGBA)
        diverse_nodes = select_diverse_nodes(data, len(poisoned_nodes), num_clusters=10)
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in diverse_nodes]
        avg_features = torch.stack([data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes])
        trigger_features = trigger_gen(avg_features) * 0.6 + avg_features * 0.4
        refined_trigger_features = trigger_features + torch.randn_like(avg_features) * 0.02
        data_poisoned.x[diverse_nodes] = refined_trigger_features


    return data_poisoned

# Defense 1: Outlier Detection (OD)
def defense_outlier_detection(ood_detector, data, threshold=0.9):
    """
    Identifies outlier nodes using an OOD detector and removes their influence by modifying labels and features.

    Parameters:
    - ood_detector: Pre-trained OOD detector to identify outliers.
    - data: PyG data object representing the graph.
    - threshold: Quantile threshold for identifying outliers (e.g., 0.9 means top 10% of loss values are outliers).

    Returns:
    - pruned_nodes: Set of nodes identified as outliers.
    - data: Updated PyG data object with modified features and labels for outliers.
    """
    ood_detector.eval()
    with torch.no_grad():
        reconstruction_loss = ood_detector.reconstruction_loss(data.x)
        threshold_loss = torch.quantile(reconstruction_loss, threshold).item()

        # Identify outliers with reconstruction loss higher than the threshold
        outliers = torch.where(reconstruction_loss > threshold_loss)[0]
        pruned_nodes = set(outliers.cpu().numpy())

        # Update data to reflect removal of outlier influence
        if len(pruned_nodes) > 0:
            data.y[outliers] = -1  # Discard labels for outliers
            data.x[outliers] = data.x.mean(dim=0).to(device)  # Replace features with average feature to mitigate impact

    return pruned_nodes, data

# Defense 2: Adaptive Prune Edges based on Cosine Similarity
def defense_prune_edges(data, quantile_threshold=0.9):
    """
    Prunes edges based on adaptive cosine similarity between node features.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile to determine pruning threshold (e.g., 0.9 means pruning edges in the top 10% dissimilar).

    Returns:
    - data: Updated PyG data object with pruned edges.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Adaptive threshold based on quantile of similarity distribution
    similarity_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Keep edges with cosine similarity above the threshold
    pruned_mask = cosine_similarities >= similarity_threshold
    pruned_edges = edge_index[:, pruned_mask]

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    return data


# Defense 3: Prune + LD (Adaptive Prune and Selective Discard Labels)
def defense_prune_and_discard_labels(data, quantile_threshold=0.8):
    """
    Prunes edges based on adaptive cosine similarity and discards labels of nodes connected by pruned edges selectively.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile threshold for cosine similarity pruning (e.g., 0.2 means pruning edges in the bottom 20%).

    Returns:
    - data: Updated PyG data object with pruned edges and selectively discarded labels.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features using PyTorch
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Use quantile to determine adaptive threshold for pruning
    adaptive_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Mask edges with similarity below the adaptive threshold
    pruned_mask = cosine_similarities < adaptive_threshold
    pruned_edges = edge_index[:, ~pruned_mask]  # Retain edges that are above the threshold

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    # Selectively discard labels of nodes connected by many pruned edges
    pruned_src, pruned_dst = edge_index[:, pruned_mask]
    pruned_nodes_count = torch.bincount(torch.cat([pruned_src, pruned_dst]), minlength=data.num_nodes)

    # Only discard labels if the node has a high count of pruned edges
    threshold_count = int(torch.median(pruned_nodes_count).item())  # Use median count as a threshold
    nodes_to_discard = torch.where(pruned_nodes_count > threshold_count)[0]

    data.y[nodes_to_discard] = -1  # Use -1 to represent discarded labels

    return data





# Compute ASR and Clean Accuracy (using .detach() to avoid retaining computation graph)
def compute_metrics(model, data, poisoned_nodes):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index).detach()
        _, pred = out.max(dim=1)
        asr = (pred[poisoned_nodes] == data.y[poisoned_nodes]).sum().item() / len(poisoned_nodes) * 100
        clean_acc = accuracy_score(data.y[data.test_mask].cpu(), pred[data.test_mask].cpu()) * 100
    return asr, clean_acc

# Visualize PCA for Attacks
# Added function to visualize PCA projections of node embeddings for different attacks
def visualize_pca_for_attacks(attack_embeddings_dict):
    pca = PCA(n_components=2)
    plt.figure(figsize=(20, 10))

    for i, (attack, attack_data) in enumerate(attack_embeddings_dict.items(), 1):
        # Extract the node embeddings and poisoned nodes indices
        embeddings = attack_data['data'].cpu().numpy()
        poisoned_nodes = attack_data['poisoned_nodes'].cpu().numpy()

        # Apply PCA to the node embeddings
        pca_result = pca.fit_transform(embeddings)

        # Create masks for clean and poisoned nodes
        clean_mask = np.ones(embeddings.shape[0], dtype=bool)
        clean_mask[poisoned_nodes] = False

        # Extract clean and poisoned node embeddings after PCA
        clean_embeddings = pca_result[clean_mask]
        poisoned_embeddings = pca_result[~clean_mask]

        # Plotting clean and poisoned nodes
        plt.subplot(2, 3, i)
        plt.scatter(clean_embeddings[:, 0], clean_embeddings[:, 1], s=10, alpha=0.5, label='Clean Nodes', c='b')
        plt.scatter(poisoned_embeddings[:, 0], poisoned_embeddings[:, 1], s=10, alpha=0.8, label='Poisoned Nodes', c='r')
        plt.title(f'PCA Visualization for {attack}')
        plt.xlabel('PCA Component 1')
        plt.ylabel('PCA Component 2')
        plt.legend()

    plt.tight_layout()
    plt.show()

import gc

# Run all attacks and apply defenses
def run_all_attacks():
    datasets = ["Cora", "PubMed", "Flickr"]
    results_summary = []

    for dataset_name in datasets:
        dataset = load_dataset(dataset_name)
        data = dataset[0].to(device)
        input_dim = data.num_features
        output_dim = dataset.num_classes if isinstance(dataset.num_classes, int) else dataset.num_classes[0]
        data = split_dataset(data)

        # Dataset-specific poisoning budgets
        dataset_budgets = {
            'Cora': 10,
            'PubMed': 40,
            'Flickr': 160
        }
        poisoned_node_budget = dataset_budgets.get(dataset_name, 10)

        # Define GNN models for experiments
        model_types = ['GCN', 'GraphSage', 'GAT']

        for model_type in model_types:
            # Initialize model, optimizer, and Trigger Generator with smaller hidden sizes
            model = GNN(input_dim=input_dim, hidden_dim=64, output_dim=output_dim, model_type=model_type).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
            trigger_gen = TriggerGenerator(input_dim=input_dim, hidden_dim=64).to(device)
            ood_detector = OODDetector(input_dim=input_dim, hidden_dim=64).to(device)

            # Train OOD detector
            ood_optimizer = torch.optim.Adam(ood_detector.parameters(), lr=0.001)
            train_ood_detector(ood_detector, data, ood_optimizer)

            # Select nodes to poison based on high-centrality (degree centrality) for a stronger impact
            poisoned_nodes = select_high_centrality_nodes(data, poisoned_node_budget)

            # Define different attacks
            attack_methods = ['SBA-Samp', 'SBA-Gen', 'GTA', 'UGBA', 'DPGBA']

            for attack in attack_methods:
                # Train model with poisoned data
                trained_model, data_poisoned = train_with_poisoned_data(
                    model, data, optimizer, poisoned_nodes, trigger_gen, attack, alpha=0.7, early_stopping=True
                )

                # Compute ASR and Clean Accuracy before applying any defense
                asr, clean_acc = compute_metrics(trained_model, data_poisoned, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "None",
                    "ASR": asr,
                    "Clean Accuracy": clean_acc
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: None - ASR: {asr:.2f}%, Clean Accuracy: {clean_acc:.2f}%")

                # Apply defenses
                # Defense 1: Outlier Detection (OD)
                pruned_nodes, data_poisoned_od = defense_outlier_detection(ood_detector, data_poisoned.clone(), threshold=0.9)
                asr_od, clean_acc_od = compute_metrics(trained_model, data_poisoned_od, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "Outlier Detection",
                    "ASR": asr_od,
                    "Clean Accuracy": clean_acc_od
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: Outlier Detection - ASR: {asr_od:.2f}%, Clean Accuracy: {clean_acc_od:.2f}%")

                # Defense 2: Prune
                data_poisoned_prune = defense_prune_edges(data_poisoned.clone(), quantile_threshold=0.8)
                asr_prune, clean_acc_prune = compute_metrics(trained_model, data_poisoned_prune, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "Prune",
                    "ASR": asr_prune,
                    "Clean Accuracy": clean_acc_prune
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: Prune - ASR: {asr_prune:.2f}%, Clean Accuracy: {clean_acc_prune:.2f}%")

                # Defense 3: Prune + LD
                data_poisoned_prune_ld = defense_prune_and_discard_labels(data_poisoned.clone(), quantile_threshold=0.8)
                asr_prune_ld, clean_acc_prune_ld = compute_metrics(trained_model, data_poisoned_prune_ld, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "Prune + LD",
                    "ASR": asr_prune_ld,
                    "Clean Accuracy": clean_acc_prune_ld
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: Prune + LD - ASR: {asr_prune_ld:.2f}%, Clean Accuracy: {clean_acc_prune_ld:.2f}%")

                # Clear memory after each attack-defense cycle
                variables_to_clear = ['model', 'optimizer', 'trigger_gen', 'ood_detector', 'data_poisoned', 'trained_model']
                for var_name in variables_to_clear:
                    if var_name in locals():
                        del locals()[var_name]

                # Release memory
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    # Summarize Results in a Table
    results_df = pd.DataFrame(results_summary)
    print("\nSummary of Attack Success Rate and Clean Accuracy Before and After Defenses:")
    print(results_df)

    # Optionally, save results to CSV
    results_df.to_csv("backdoor_attack_results_summary.csv", index=False)

run_all_attacks()


Using device: cpu


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Epoch 0, Reconstruction Loss: 0.0213
Epoch 10, Reconstruction Loss: 0.0153
Epoch 20, Reconstruction Loss: 0.0132
Epoch 30, Reconstruction Loss: 0.0124
Epoch 40, Reconstruction Loss: 0.0121
Epoch 0, Loss: 1.9662
Epoch 10, Loss: 1.4551
Epoch 20, Loss: 0.9441
Epoch 30, Loss: 0.6230
Epoch 40, Loss: 0.4510
Epoch 50, Loss: 0.3446
Epoch 60, Loss: 0.2920
Epoch 70, Loss: 0.2578
Epoch 80, Loss: 0.2321
Epoch 90, Loss: 0.2144
Dataset: Cora, Model: GCN, Attack: SBA-Samp, Defense: None - ASR: 90.00%, Clean Accuracy: 88.54%
Dataset: Cora, Model: GCN, Attack: SBA-Samp, Defense: Outlier Detection - ASR: 90.00%, Clean Accuracy: 78.93%
Dataset: Cora, Model: GCN, Attack: SBA-Samp, Defense: Prune - ASR: 90.00%, Clean Accuracy: 83.92%
Dataset: Cora, Model: GCN, Attack: SBA-Samp, Defense: Prune + LD - ASR: 0.00%, Clean Accuracy: 46.40%
Epoch 0, Loss: 0.1922
Epoch 10, Loss: 0.1802
Epoch 20, Loss: 0.1651
Epoch 30, Loss: 0.1521
Epoch 40, Loss: 0.1414
Epoch 50, Loss: 0.1413
Epoch 60, Loss: 0.1319
Epoch 70, Loss:

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!


Epoch 0, Reconstruction Loss: 0.0095
Epoch 10, Reconstruction Loss: 0.0036
Epoch 20, Reconstruction Loss: 0.0014
Epoch 30, Reconstruction Loss: 0.0007
Epoch 40, Reconstruction Loss: 0.0004
Epoch 0, Loss: 1.1002
Epoch 10, Loss: 1.0339
Epoch 20, Loss: 0.9595
Epoch 30, Loss: 0.8772
Epoch 40, Loss: 0.7984
Epoch 50, Loss: 0.7202
Epoch 60, Loss: 0.6517
Epoch 70, Loss: 0.5947
Epoch 80, Loss: 0.5494
Epoch 90, Loss: 0.5130
Dataset: PubMed, Model: GCN, Attack: SBA-Samp, Defense: None - ASR: 90.00%, Clean Accuracy: 84.30%
Dataset: PubMed, Model: GCN, Attack: SBA-Samp, Defense: Outlier Detection - ASR: 82.50%, Clean Accuracy: 76.69%
Dataset: PubMed, Model: GCN, Attack: SBA-Samp, Defense: Prune - ASR: 90.00%, Clean Accuracy: 82.17%
Dataset: PubMed, Model: GCN, Attack: SBA-Samp, Defense: Prune + LD - ASR: 0.00%, Clean Accuracy: 40.65%
Epoch 0, Loss: 0.4867
Epoch 10, Loss: 0.4678
Epoch 20, Loss: 0.4488
Epoch 30, Loss: 0.4374
Epoch 40, Loss: 0.4275
Epoch 50, Loss: 0.4186
Epoch 60, Loss: 0.4054
Epoch 7

Downloading https://drive.usercontent.google.com/download?id=1crmsTbd1-2sEXsGwa2IKnIB7Zd3TmUsy&confirm=t
Downloading https://drive.usercontent.google.com/download?id=1join-XdvX3anJU_MLVtick7MgeAQiWIZ&confirm=t
Downloading https://drive.usercontent.google.com/download?id=1uxIkbtg5drHTsKt-PAsZZ4_yJmgFmle9&confirm=t
Downloading https://drive.usercontent.google.com/download?id=1htXCtuktuCW8TR8KiKfrFDAxUgekQoV7&confirm=t
Processing...
Done!


Epoch 0, Reconstruction Loss: 3.6785
Epoch 10, Reconstruction Loss: 3.3606
Epoch 20, Reconstruction Loss: 2.9332
Epoch 30, Reconstruction Loss: 2.7447
Epoch 40, Reconstruction Loss: 2.6318
Epoch 0, Loss: 2.0753
Epoch 10, Loss: 1.6098
Epoch 20, Loss: 1.5804
Epoch 30, Loss: 1.5437
Epoch 40, Loss: 1.5324
Epoch 50, Loss: 1.5266
Epoch 60, Loss: 1.5077
Epoch 70, Loss: 1.4997
Epoch 80, Loss: 1.4894
Epoch 90, Loss: 1.4826
Dataset: Flickr, Model: GCN, Attack: SBA-Samp, Defense: None - ASR: 41.88%, Clean Accuracy: 50.63%
Dataset: Flickr, Model: GCN, Attack: SBA-Samp, Defense: Outlier Detection - ASR: 35.00%, Clean Accuracy: 44.87%
Dataset: Flickr, Model: GCN, Attack: SBA-Samp, Defense: Prune - ASR: 41.88%, Clean Accuracy: 44.24%
Dataset: Flickr, Model: GCN, Attack: SBA-Samp, Defense: Prune + LD - ASR: 0.00%, Clean Accuracy: 26.75%
Epoch 0, Loss: 1.4751
Epoch 10, Loss: 1.4660
Epoch 20, Loss: 1.4602
Epoch 30, Loss: 1.4527
Epoch 40, Loss: 1.4487
Epoch 50, Loss: 1.4472
