<a href="https://colab.research.google.com/github/AbhiJeet70/The-Power-of-Many-Investigating-Defense-Mechanisms-for-Resilient-Graph-Neural-Networks/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


# Install necessary packages
!pip install torch-geometric
!pip install matplotlib
!pip install scikit-learn

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid, Flickr
from sklearn.decomposition import PCA
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import networkx as nx
from sklearn.neighbors import NearestNeighbors
import pandas as pd
from torch_geometric.utils import to_networkx

import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed()

# Check if GPU is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


# Load datasets
def load_dataset(dataset_name):
    if dataset_name in ["Cora", "PubMed", "CiteSeer"]:
        dataset = Planetoid(root=f"./data/{dataset_name}", name=dataset_name)
    elif dataset_name == "Flickr":
        dataset = Flickr(root="./data/Flickr")
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    return dataset


# Split dataset into train/validation/test
# Updated to randomly mask out 20% of nodes, use 10% for labeled nodes, and 10% for validation
def split_dataset(data, test_size=0.2, val_size=0.1):
    num_nodes = data.num_nodes
    indices = np.arange(num_nodes)
    np.random.shuffle(indices)

    num_test = int(test_size * num_nodes)
    num_val = int(val_size * num_nodes)
    num_train = num_nodes - num_test - num_val

    train_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)

    train_mask[indices[:num_train]] = True
    val_mask[indices[num_train:num_train + num_val]] = True
    test_mask[indices[num_train + num_val:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    # Mask out 20% nodes for attack performance evaluation (half target, half clean test)
    num_target = int(0.1 * num_nodes)  # Half of 20%
    target_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    clean_test_mask = torch.zeros(num_nodes, dtype=torch.bool).to(device)
    target_mask[indices[num_train + num_val:num_train + num_val + num_target]] = True
    clean_test_mask[indices[num_train + num_val + num_target:]] = True

    data.target_mask = target_mask
    data.clean_test_mask = clean_test_mask

    return data

# Define GNN Model with multiple architectures (GCN, GraphSAGE, GAT)
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, model_type='GCN'):
        super(GNN, self).__init__()
        if model_type == 'GCN':
            self.conv1 = GCNConv(input_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, output_dim)
        elif model_type == 'GraphSage':
            self.conv1 = SAGEConv(input_dim, hidden_dim)
            self.conv2 = SAGEConv(hidden_dim, output_dim)
        elif model_type == 'GAT':
            self.conv1 = GATConv(input_dim, hidden_dim, heads=8, concat=True)
            self.conv2 = GATConv(hidden_dim * 8, output_dim, heads=1, concat=False)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Select nodes to poison based on high-centrality (degree centrality) for a stronger impact
def select_high_centrality_nodes(data, num_nodes_to_select):
    graph = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    graph.add_edges_from(edge_index.T)
    centrality = nx.degree_centrality(graph)
    sorted_nodes = sorted(centrality, key=centrality.get, reverse=True)
    return torch.tensor(sorted_nodes[:num_nodes_to_select], dtype=torch.long).to(device)


class TriggerGenerator(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TriggerGenerator, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        return self.mlp(x)

class OODDetector(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(OODDetector, self).__init__()
        # Encoder
        self.encoder = torch.nn.Sequential(
            GCNConv(input_dim, hidden_dim),
            torch.nn.ReLU(),
            GCNConv(hidden_dim, latent_dim),
        )
        # Decoder
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x, edge_index):
        # Encode
        z = self.encoder[0](x, edge_index)
        z = self.encoder[1](z)
        z = self.encoder[2](z, edge_index)

        # Decode
        reconstructed_x = self.decoder(z)
        return reconstructed_x, z

    def reconstruction_loss(self, x, edge_index):
        reconstructed_x, _ = self.forward(x, edge_index)
        loss = F.mse_loss(reconstructed_x, x, reduction='none').mean(dim=1)
        return loss

def train_ood_detector(ood_detector, data, optimizer, epochs=50):
    ood_detector.train()
    for epoch in range(epochs):
        optimizer.zero_grad()  # Clear the gradients

        # Forward pass
        reconstructed_x, _ = ood_detector(data.x, data.edge_index)

        # Use only the training mask to compute reconstruction loss
        loss = F.mse_loss(reconstructed_x[data.train_mask], data.x[data.train_mask])
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Reconstruction Loss: {loss.item():.4f}")


# Training Function with Poisoned Data
def train_with_poisoned_data(model, data, optimizer, poisoned_nodes, trigger_gen, attack, ood_detector=None, alpha=0.7, early_stopping=False):
    # Apply trigger injection
    data_poisoned = inject_trigger(data, poisoned_nodes, attack, trigger_gen, ood_detector=ood_detector, alpha=0.7)

    # Training loop
    model.train()
    for epoch in range(100):
        optimizer.zero_grad()  # Clear the gradients

        # Forward pass
        out = model(data_poisoned.x, data_poisoned.edge_index)

        # Calculate loss
        loss = F.cross_entropy(out[data_poisoned.train_mask], data_poisoned.y[data_poisoned.train_mask])

        # Backward pass
        # Ensure we only retain the graph if we need to perform multiple backward passes
        if epoch < 99:  # In all but the last epoch, retain the graph
            loss.backward(retain_graph=True)
        else:
            loss.backward()

        # Update parameters
        optimizer.step()

        # Optional: Print loss during training for insight
        if early_stopping and epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    return model, data_poisoned

from sklearn.cluster import KMeans

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

def select_diverse_nodes(data, num_nodes_to_select, num_clusters=None):
    """
    Select nodes using a clustering-based approach to ensure diversity, along with high-degree nodes.

    Parameters:
    - data: PyG data object representing the graph.
    - num_nodes_to_select: Number of nodes to select for poisoning.
    - num_clusters: Number of clusters to form for diversity. Defaults to number of classes if not provided.

    Returns:
    - Tensor containing indices of selected nodes.
    """
    if num_clusters is None:
        num_clusters = data.num_features  # Use the number of features as the default number of clusters

    # Node feature embeddings
    node_features = data.x.cpu().numpy()

    # Use GCN encoder to get node embeddings that capture both attribute and structural information
    encoder = GCNEncoder(data.num_features, out_channels=16)  # Assuming out_channels = 16
    encoder.eval()
    with torch.no_grad():
        embeddings = encoder(data.x, data.edge_index).cpu().numpy()

    # Perform K-means clustering to find representative nodes
    kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(embeddings)
    labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_

    # Select nodes closest to the cluster centers
    selected_nodes = []
    for i in range(num_clusters):
        cluster_indices = np.where(labels == i)[0]
        center = cluster_centers[i]
        distances = np.linalg.norm(embeddings[cluster_indices] - center, axis=1)
        closest_node = cluster_indices[np.argmin(distances)]
        selected_nodes.append(closest_node)

    # Calculate node degrees
    degree = torch.bincount(data.edge_index[0])  # Calculate node degrees
    # Select high-degree nodes
    high_degree_nodes = torch.topk(degree, len(selected_nodes) // 2).indices

    # Convert the graph to NetworkX to calculate centrality measures
    G = to_networkx(data, to_undirected=True)
    betweenness_centrality = nx.betweenness_centrality(G)
    central_nodes = sorted(betweenness_centrality, key=betweenness_centrality.get, reverse=True)
    central_nodes_tensor = torch.tensor(central_nodes[:len(selected_nodes) // 2], dtype=torch.long)

    # Combine diverse nodes, high-degree nodes, and central nodes
    combined_nodes = torch.cat([torch.tensor(selected_nodes), high_degree_nodes, central_nodes_tensor])
    # Get unique nodes and limit to num_nodes_to_select
    unique_nodes = torch.unique(combined_nodes)[:num_nodes_to_select]

    return torch.tensor(selected_nodes[:num_nodes_to_select], dtype=torch.long).to(data.x.device)

def inject_trigger(data, poisoned_nodes, attack_type, trigger_gen=None, ood_detector=None, alpha=0.7, trigger_size=5, trigger_density=0.5, input_dim=None):
    data_poisoned = data.clone()

    if poisoned_nodes.numel() < trigger_size:
        raise ValueError(f"Insufficient poisoned nodes: required {trigger_size}, found {poisoned_nodes.numel()}")

    if attack_type == 'SBA-Samp':
        G = nx.erdos_renyi_graph(trigger_size, trigger_density)
        trigger_edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()

        if trigger_edge_index.numel() > 0:
            # Map local indices to global indices using poisoned_nodes
            trigger_edge_index = torch.cat([
                poisoned_nodes[trigger_edge_index[0]].unsqueeze(0),
                poisoned_nodes[trigger_edge_index[1]].unsqueeze(0)
            ], dim=0)

        # Randomly connect poisoned nodes to existing graph nodes
        poisoned_edges = torch.stack([
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ])

        # Update edge index
        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index.to(device), poisoned_edges.to(device)], dim=1)

        # Generate new features for poisoned nodes
        avg_features = torch.stack([
            data.x[data.edge_index[0][data.edge_index[1] == node]].mean(dim=0) if len(data.edge_index[0][data.edge_index[1] == node]) > 0 else data.x.mean(dim=0)
            for node in poisoned_nodes[:trigger_size]
        ])
        data_poisoned.x[poisoned_nodes[:trigger_size]] = avg_features + torch.randn_like(avg_features) * 0.02

    elif attack_type == 'SBA-Gen':
        G = nx.erdos_renyi_graph(trigger_size, trigger_density)
        trigger_edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()

        if trigger_edge_index.numel() > 0:
            # Map local indices to global indices using poisoned_nodes
            trigger_edge_index = torch.cat([
                poisoned_nodes[trigger_edge_index[0]].unsqueeze(0),
                poisoned_nodes[trigger_edge_index[1]].unsqueeze(0)
            ], dim=0)

        # Randomly connect poisoned nodes to existing graph nodes
        poisoned_edges = torch.stack([
            poisoned_nodes[:trigger_size],
            torch.randint(0, data.num_nodes, (trigger_size,), device=device)
        ])

        # Update edge index
        data_poisoned.edge_index = torch.cat([data.edge_index, trigger_edge_index.to(device), poisoned_edges.to(device)], dim=1)

        # Generate Gaussian-distributed features for poisoned nodes
        avg_features = torch.stack([
            data.x[data.edge_index[0][data.edge_index[1] == node]].mean(dim=0) if len(data.edge_index[0][data.edge_index[1] == node]) > 0 else data.x.mean(dim=0)
            for node in poisoned_nodes[:trigger_size]
        ])
        data_poisoned.x[poisoned_nodes[:trigger_size]] = avg_features + torch.normal(mean=0.0, std=0.03, size=avg_features.shape).to(data.x.device)

    elif attack_type == 'DPGBA':
        if ood_detector is None:
            raise ValueError("OODDetector must be provided for DPGBA attack.")

        # Use OODDetector to refine trigger features
        ood_detector.eval()
        with torch.no_grad():
            _, latent_embeddings = ood_detector(data.x, data.edge_index)

        poisoned_latent_embeddings = latent_embeddings[poisoned_nodes]
        refined_trigger_latent = poisoned_latent_embeddings + torch.randn_like(poisoned_latent_embeddings) * 0.1

        # Map latent embeddings back to feature space using the decoder
        refined_trigger_features = ood_detector.decoder(refined_trigger_latent)

        # Interpolate with existing features for in-distribution preservation
        node_alphas = torch.rand(len(poisoned_nodes)).to(data.x.device) * 0.3 + 0.5
        distribution_preserved_features = (
            node_alphas.unsqueeze(1) * data.x[poisoned_nodes]
            + (1 - node_alphas.unsqueeze(1)) * refined_trigger_features
        )

        data_poisoned.x[poisoned_nodes] = distribution_preserved_features

    elif attack_type == 'GTA':
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node] for node in poisoned_nodes]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ])
        trigger_features = avg_features + torch.randn_like(avg_features) * 0.05
        data_poisoned.x[poisoned_nodes] = trigger_features

    elif attack_type == 'UGBA':
        diverse_nodes = select_diverse_nodes(data_poisoned, len(poisoned_nodes))
        connected_nodes = [data_poisoned.edge_index[0][data_poisoned.edge_index[1] == node] for node in diverse_nodes]

        avg_features = torch.stack([
            data_poisoned.x[nodes].mean(dim=0) if len(nodes) > 0 else data_poisoned.x.mean(dim=0) for nodes in connected_nodes
        ])
        refined_trigger_features = avg_features + torch.normal(mean=2.0, std=0.5, size=avg_features.shape).to(data_poisoned.x.device)
        data_poisoned.x[diverse_nodes] = refined_trigger_features

        new_edges = []
        for i in range(len(diverse_nodes)):
            node = diverse_nodes[i]
            neighbor = connected_nodes[i][0] if len(connected_nodes[i]) > 0 else diverse_nodes[(i + 1) % len(diverse_nodes)]
            new_edges.append([node, neighbor])

        new_edges = torch.tensor(new_edges, dtype=torch.long).t().contiguous().to(data_poisoned.edge_index.device)
        data_poisoned.edge_index = torch.cat([data_poisoned.edge_index, new_edges], dim=1)

    return data_poisoned





import numpy as np
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pandas as pd
import gc
from sklearn.metrics import pairwise_distances
from sklearn.cluster import KMeans


def dominant_set_clustering(data, threshold=0.7, use_pca=True, pca_components=10):
    """
    Applies a robust outlier detection framework using K-Means clustering and distance-based heuristics.

    Parameters:
    - data: PyG data object representing the graph.
    - threshold: Quantile threshold for identifying outliers based on cluster distances.
    - use_pca: Whether to use PCA for dimensionality reduction.
    - pca_components: Number of PCA components to use if PCA is applied.

    Returns:
    - pruned_nodes: Set of nodes identified as outliers.
    - data: Updated PyG data object with modified features and labels for outliers.
    """
    # Step 1: Dimensionality reduction using PCA (optional)
    node_features = data.x.detach().cpu().numpy()
    if use_pca and node_features.shape[1] > pca_components:
        pca = PCA(n_components=pca_components)
        node_features = pca.fit_transform(node_features)

    # Step 2: Determine the number of clusters dynamically
    num_classes = len(torch.unique(data.y).tolist())  # Number of unique classes
    num_nodes = node_features.shape[0]  # Total number of nodes
    n_clusters = min(num_classes, num_nodes)

    # Step 3: K-Means Clustering to identify clusters and potential outliers
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(node_features)
    cluster_labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_

    # Calculate distances to cluster centers
    distances = np.linalg.norm(node_features - cluster_centers[cluster_labels], axis=1)

    # Identify outlier candidates based on distance threshold
    distance_threshold = np.percentile(distances, 100 * threshold)
    outlier_candidates = np.where(distances > distance_threshold)[0]

    # Step 4: Update data to reflect removal of outlier influence
    pruned_nodes = set(outlier_candidates)
    if len(pruned_nodes) > 0:
        outliers = torch.tensor(list(pruned_nodes), dtype=torch.long, device=data.x.device)

        # Assign an invalid label (-1) to outlier nodes to discard them during training
        data.y[outliers] = -1

        # Replace the features of outliers with the average feature value to reduce their impact
        data.x[outliers] = data.x.mean(dim=0).to(data.x.device)

    return pruned_nodes, data




def defense_prune_edges(data, quantile_threshold=0.9):
    """
    Prunes edges based on adaptive cosine similarity between node features.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile to determine pruning threshold (e.g., 0.9 means pruning edges in the top 10% dissimilar).

    Returns:
    - data: Updated PyG data object with pruned edges.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Adaptive threshold based on quantile of similarity distribution
    similarity_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Keep edges with cosine similarity above the threshold
    pruned_mask = cosine_similarities >= similarity_threshold
    pruned_edges = edge_index[:, pruned_mask]

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    return data



def defense_prune_and_discard_labels(data, quantile_threshold=0.2):
    """
    Prunes edges based on adaptive cosine similarity and discards labels of nodes connected by pruned edges selectively.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile threshold for cosine similarity pruning (e.g., 0.2 means pruning edges in the bottom 20%).

    Returns:
    - data: Updated PyG data object with pruned edges and selectively discarded labels.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features using PyTorch
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Use quantile to determine adaptive threshold for pruning
    adaptive_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Mask edges with similarity below the adaptive threshold
    pruned_mask = cosine_similarities < adaptive_threshold
    pruned_edges = edge_index[:, ~pruned_mask]  # Retain edges that are above the threshold

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    # Selectively discard labels of nodes connected by many pruned edges
    pruned_src, pruned_dst = edge_index[:, pruned_mask]
    pruned_nodes_count = torch.bincount(torch.cat([pruned_src, pruned_dst]), minlength=data.num_nodes)

    # Only discard labels if the node has a high count of pruned edges
    threshold_count = int(torch.median(pruned_nodes_count).item())  # Use median count as a threshold
    nodes_to_discard = torch.where(pruned_nodes_count > threshold_count)[0]

    data.y[nodes_to_discard] = -1  # Use -1 to represent discarded labels

    return data

# Compute ASR and Clean Accuracy (using .detach() to avoid retaining computation graph)
def compute_metrics(model, data, poisoned_nodes):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index).detach()
        _, pred = out.max(dim=1)
        asr = (pred[poisoned_nodes] == data.y[poisoned_nodes]).sum().item() / len(poisoned_nodes) * 100
        clean_acc = accuracy_score(data.y[data.test_mask].cpu(), pred[data.test_mask].cpu()) * 100
    return asr, clean_acc



# Visualization Function
# Visualize PCA for Attacks
# Added function to visualize PCA projections of node embeddings for different attacks
def visualize_pca_for_attacks(attack_embeddings_dict):
    pca = PCA(n_components=2)
    plt.figure(figsize=(20, 10))

    for i, (attack, attack_data) in enumerate(attack_embeddings_dict.items(), 1):
        embeddings = attack_data['data'].detach().cpu().numpy()
        poisoned_nodes = attack_data['poisoned_nodes'].detach().cpu().numpy()

        # Apply PCA to the node embeddings
        pca_result = pca.fit_transform(embeddings)

        # Create masks for clean and poisoned nodes
        clean_mask = np.ones(embeddings.shape[0], dtype=bool)
        clean_mask[poisoned_nodes] = False

        # Extract clean and poisoned node embeddings after PCA
        clean_embeddings = pca_result[clean_mask]
        poisoned_embeddings = pca_result[~clean_mask]

        # Plotting clean and poisoned nodes
        plt.subplot(2, 3, i)
        plt.scatter(clean_embeddings[:, 0], clean_embeddings[:, 1], s=10, alpha=0.5, label='Clean Nodes', c='b')
        plt.scatter(poisoned_embeddings[:, 0], poisoned_embeddings[:, 1], s=10, alpha=0.8, label='Poisoned Nodes', c='r')
        plt.title(f'PCA Visualization for {attack}')
        plt.xlabel('PCA Component 1')
        plt.ylabel('PCA Component 2')
        plt.legend()

    plt.tight_layout()
    plt.show()


# Run all attacks and apply defenses
def run_all_attacks():
    datasets = ["Cora", "PubMed", "CiteSeer"]
    results_summary = []
    attack_embeddings_dict = {}

    for dataset_name in datasets:
        dataset = load_dataset(dataset_name)
        data = dataset[0].to(device)
        input_dim = data.num_features
        output_dim = dataset.num_classes if isinstance(dataset.num_classes, int) else dataset.num_classes[0]
        data = split_dataset(data)

        # Dataset-specific poisoning budgets
        dataset_budgets = {
            'Cora': 10,
            'PubMed': 40,
            'CiteSeer': 30
        }
        poisoned_node_budget = dataset_budgets.get(dataset_name, 10)

        # Define GNN models for experiments
        model_types = ['GCN', 'GraphSage', 'GAT']

        for model_type in model_types:
            # Initialize model and optimizer
            model = GNN(input_dim=input_dim, hidden_dim=64, output_dim=output_dim, model_type=model_type).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

            # Evaluate baseline accuracy without attack or defense
            model.train()
            baseline_epochs = 200  # Number of epochs for baseline training
            for epoch in range(baseline_epochs):
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)
                loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()

            # Evaluate baseline accuracy
            model.eval()
            with torch.no_grad():
                out = model(data.x, data.edge_index)
                predictions = out.argmax(dim=1)
                baseline_acc = (predictions[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
                baseline_acc_percentage = baseline_acc * 100
                print(f"Dataset: {dataset_name}, Model: {model_type}, Baseline Accuracy: {baseline_acc_percentage:.2f}%")

                # Record baseline accuracy in results
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": "None",
                    "Defense": "None",
                    "ASR": "N/A",
                    "Clean Accuracy": baseline_acc_percentage
                })

            # Initialize Trigger Generator and OOD Detector for attacks
            trigger_gen = TriggerGenerator(input_dim=input_dim, hidden_dim=64).to(device)
            ood_detector = OODDetector(input_dim=input_dim, hidden_dim=64, latent_dim=16).to(device)

            # Train OOD detector
            ood_optimizer = torch.optim.Adam(ood_detector.parameters(), lr=0.001)
            train_ood_detector(ood_detector, data, ood_optimizer)

            # Select nodes to poison based on high-centrality (degree centrality) for a stronger impact
            poisoned_nodes = select_high_centrality_nodes(data, poisoned_node_budget)

            # Define different attacks
            attack_methods = ['SBA-Samp', 'SBA-Gen', 'GTA', 'UGBA', 'DPGBA']

            for attack in attack_methods:
                # Train model with poisoned data
                if attack == 'DPGBA':
                    trained_model, data_poisoned = train_with_poisoned_data(
                        model=model,
                        data=data,
                        optimizer=optimizer,
                        poisoned_nodes=poisoned_nodes,
                        trigger_gen=trigger_gen,
                        attack=attack,
                        ood_detector=ood_detector,
                        alpha=0.7,
                        early_stopping=True
                    )
                else:
                    trained_model, data_poisoned = train_with_poisoned_data(
                        model=model,
                        data=data,
                        optimizer=optimizer,
                        poisoned_nodes=poisoned_nodes,
                        trigger_gen=trigger_gen,
                        attack=attack,
                        alpha=0.7,
                        early_stopping=True
                    )

                # Compute ASR and Clean Accuracy before applying any defense
                asr, clean_acc = compute_metrics(trained_model, data_poisoned, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "None",
                    "ASR": asr,
                    "Clean Accuracy": clean_acc
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: None - ASR: {asr:.2f}%, Clean Accuracy: {clean_acc:.2f}%")

                # Apply defenses
                # Defense 1: Dominant Set Outlier Detection (DSOD)
                pruned_nodes, data_poisoned_dsod = dominant_set_clustering(data_poisoned.clone(), threshold=0.9, use_pca=True, pca_components=10)
                asr_dsod, clean_acc_dsod = compute_metrics(trained_model, data_poisoned_dsod, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "Dominant Set Outlier Detection",
                    "ASR": asr_dsod,
                    "Clean Accuracy": clean_acc_dsod
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: Dominant Set Outlier Detection - ASR: {asr_dsod:.2f}%, Clean Accuracy: {clean_acc_dsod:.2f}%")

                # Defense 2: Prune
                data_poisoned_prune = defense_prune_edges(data_poisoned.clone(), quantile_threshold=0.8)
                asr_prune, clean_acc_prune = compute_metrics(trained_model, data_poisoned_prune, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "Prune",
                    "ASR": asr_prune,
                    "Clean Accuracy": clean_acc_prune
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: Prune - ASR: {asr_prune:.2f}%, Clean Accuracy: {clean_acc_prune:.2f}%")

                # Defense 3: Prune + LD
                data_poisoned_prune_ld = defense_prune_and_discard_labels(data_poisoned.clone(), quantile_threshold=0.8)
                asr_prune_ld, clean_acc_prune_ld = compute_metrics(trained_model, data_poisoned_prune_ld, poisoned_nodes)
                results_summary.append({
                    "Dataset": dataset_name,
                    "Model": model_type,
                    "Attack": attack,
                    "Defense": "Prune + LD",
                    "ASR": asr_prune_ld,
                    "Clean Accuracy": clean_acc_prune_ld
                })
                print(f"Dataset: {dataset_name}, Model: {model_type}, Attack: {attack}, Defense: Prune + LD - ASR: {asr_prune_ld:.2f}%, Clean Accuracy: {clean_acc_prune_ld:.2f}%")

                # Store embeddings for visualization
                attack_embeddings_dict[f"{dataset_name}-{model_type}-{attack}"] = {
                    'data': data_poisoned_dsod.x,
                    'poisoned_nodes': poisoned_nodes
                }

                # Clear memory after each attack-defense cycle
                variables_to_clear = ['model', 'optimizer', 'trigger_gen', 'ood_detector', 'data_poisoned', 'trained_model']
                for var_name in variables_to_clear:
                    if var_name in locals():
                        del locals()[var_name]

                # Release memory
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    # Summarize Results in a Table
    results_df = pd.DataFrame(results_summary)
    print("\nSummary of Attack Success Rate and Clean Accuracy Before and After Defenses:")
    print(results_df)

    results_df.to_csv("backdoor_attack_results_summary.csv", index=False)

    # Visualize PCA projections for different attacks
    visualize_pca_for_attacks(attack_embeddings_dict)

# Run the function
run_all_attacks()

# Download the results
from google.colab import files
files.download("backdoor_attack_results_summary.csv")


