<a href="https://colab.research.google.com/github/AbhiJeet70/PowerfulGNNs/blob/main/Attack_ESAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary packages
!pip install torch-geometric
!pip install matplotlib
!pip install scikit-learn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
import networkx as nx
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid, Flickr
from sklearn.decomposition import PCA
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import networkx as nx
from sklearn.neighbors import NearestNeighbors
import pandas as pd
from torch_geometric.utils import to_networkx
from torch_geometric.nn.models.autoencoder import GAE
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import pandas as pd
import gc
from torch_geometric.utils import from_networkx


import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed()

# Check if GPU is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


# Load datasets
def load_dataset(dataset_name):
    if dataset_name in ["Cora", "PubMed", "CiteSeer"]:
        dataset = Planetoid(root=f"./data/{dataset_name}", name=dataset_name)
    elif dataset_name == "Flickr":
        dataset = Flickr(root="./data/Flickr")
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    return dataset



import numpy as np

def split_dataset(data, test_size=0.2, val_size=0.1, seed=None):
    """
    Splits the dataset into train, validation, and test sets.
    Additionally creates target and clean test masks for attack evaluation.

    Args:
        data: PyTorch Geometric Data object.
        test_size: Fraction of nodes to use for the test set.
        val_size: Fraction of nodes to use for the validation set.
        seed: Random seed for reproducibility.

    Returns:
        data: PyTorch Geometric Data object with updated masks as NumPy arrays.
    """
    # Get the total number of nodes in the graph
    num_nodes = int(data.num_nodes)

    # Shuffle indices
    indices = np.arange(num_nodes)
    if seed is not None:
        np.random.seed(seed)
    np.random.shuffle(indices)

    # Calculate split sizes
    num_test = int(test_size * num_nodes)
    num_val = int(val_size * num_nodes)
    num_train = num_nodes - num_test - num_val
    num_target = int(0.1 * num_nodes)

    # Ensure there are enough nodes for the splits
    assert num_train > 0, "Training set size is too small!"
    assert num_val > 0, "Validation set size is too small!"
    assert num_test > 0, "Test set size is too small!"
    assert num_target > 0, "Target mask size is too small!"

    # Initialize boolean masks as NumPy arrays
    train_mask = np.zeros(num_nodes, dtype=bool)
    val_mask = np.zeros(num_nodes, dtype=bool)
    test_mask = np.zeros(num_nodes, dtype=bool)
    target_mask = np.zeros(num_nodes, dtype=bool)
    clean_test_mask = np.zeros(num_nodes, dtype=bool)

    # Assign indices to each mask
    train_mask[indices[:num_train]] = True
    val_mask[indices[num_train:num_train + num_val]] = True
    test_mask[indices[num_train + num_val:]] = True
    target_mask[indices[num_train + num_val:num_train + num_val + num_target]] = True
    clean_test_mask[indices[num_train + num_val + num_target:]] = True

    # Add masks to the data object (convert NumPy to PyTorch tensors if required downstream)
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    data.target_mask = target_mask
    data.clean_test_mask = clean_test_mask

    return data


def select_high_centrality_nodes(data, num_nodes_to_select):
    graph = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    graph.add_edges_from(edge_index.T)
    centrality = nx.degree_centrality(graph)
    sorted_nodes = sorted(centrality, key=centrality.get, reverse=True)
    return torch.tensor(sorted_nodes[:num_nodes_to_select], dtype=torch.long).to(device)


class TriggerGenerator(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TriggerGenerator, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),  # Reduce dimensions
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim)  # Restore to input_dim
        )

    def forward(self, x):
        return self.mlp(x)

class GAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(GAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, x, edge_index):
        return self.encoder(x, edge_index)

    def decode(self, z, edge_index):
        return self.decoder(z, edge_index)

class OODDetector(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(OODDetector, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(input_dim, hidden_dim, latent_dim)
        self.gae = GAE(self.encoder, self.decoder)

    def forward(self, x, edge_index):
        z = self.gae.encode(x, edge_index)
        return z

    def reconstruct(self, z, edge_index):
        return self.decoder(z, edge_index)

    def reconstruction_loss(self, x, edge_index):
        z = self.gae.encode(x, edge_index)
        reconstructed = self.reconstruct(z, edge_index)
        return F.mse_loss(reconstructed, x)

    def detect_ood(self, x, edge_index, threshold):
        loss = self.reconstruction_loss(x, edge_index)
        return loss > threshold

class Encoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, latent_dim)

    def forward(self, x, edge_index):
        z = F.relu(self.conv1(x, edge_index))
        z = self.conv2(z, edge_index)
        return z

class Decoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Decoder, self).__init__()
        self.conv1 = GCNConv(latent_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, input_dim)

    def forward(self, z, edge_index):
        x = F.relu(self.conv1(z, edge_index))
        x = self.conv2(x, edge_index)
        return x



def train_ood_detector(ood_detector, data, optimizer, epochs=50):
    ood_detector.train()
    for epoch in range(epochs):
        optimizer.zero_grad()  # Clear the gradients

        # Forward pass
        z = ood_detector(data.x, data.edge_index)

        # Reconstruct data using the latent embedding
        reconstructed_x = ood_detector.reconstruct(z, data.edge_index)

        # Use only the training mask to compute reconstruction loss
        loss = F.mse_loss(reconstructed_x[data.train_mask], data.x[data.train_mask])
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Reconstruction Loss: {loss.item():.4f}")



def train_with_poisoned_data(model, data, optimizer, poisoned_nodes, trigger_gen, attack, ood_detector=None, alpha=0.7, early_stopping=False):
    # Apply trigger injection
    data_poisoned = inject_trigger(data, poisoned_nodes, attack, trigger_gen, ood_detector, alpha)

    # Ensure data_poisoned is a valid PyG Data object
    if not isinstance(data_poisoned, Data):
        raise TypeError("data_poisoned must be a PyTorch Geometric Data object.")

    # Training loop
    model.train()
    for epoch in range(100):
        optimizer.zero_grad()

        # Forward pass
        out = model(data_poisoned.x, data_poisoned.edge_index)

        # Calculate loss
        loss = F.cross_entropy(out[data_poisoned.train_mask], data_poisoned.y[data_poisoned.train_mask])

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Optional logging
        if early_stopping and epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    return model, data_poisoned



from sklearn.cluster import KMeans

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

def select_diverse_nodes(data, num_nodes_to_select, num_clusters=None):
    if num_clusters is None:
        num_clusters = len(torch.unique(data.y))

    # Use GCN encoder to get node embeddings
    encoder = GCNEncoder(data.num_features, out_channels=16).to(data.x.device)  # Move encoder to the correct device
    encoder.eval()
    with torch.no_grad():
        embeddings = encoder(data.x, data.edge_index).to(data.x.device)  # Ensure embeddings are on the correct device

    # Perform K-means clustering
    kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(embeddings.cpu().numpy())  # Clustering runs on CPU
    labels = kmeans.labels_
    cluster_centers = torch.tensor(kmeans.cluster_centers_, device=data.x.device)  # Move cluster centers to device

    # Select nodes closest to the cluster centers
    selected_nodes = []
    for i in range(num_clusters):
        cluster_indices = torch.where(torch.tensor(labels, device=data.x.device) == i)[0]
        center = cluster_centers[i]
        distances = torch.norm(embeddings[cluster_indices] - center, dim=1)
        closest_node = cluster_indices[torch.argmin(distances)]
        selected_nodes.append(closest_node)

    # Calculate node degrees
    degree = torch.bincount(data.edge_index[0], minlength=data.num_nodes).to(data.x.device)
    high_degree_nodes = torch.topk(degree, len(selected_nodes) // 2).indices

    # Combine diverse nodes and high-degree nodes
    combined_nodes = torch.cat([torch.tensor(selected_nodes, device=data.x.device), high_degree_nodes])
    unique_nodes = torch.unique(combined_nodes)[:num_nodes_to_select]

    return unique_nodes.to(data.x.device)  # Ensure the selected nodes are on the correct device


def inject_trigger(data, poisoned_nodes, attack_type, trigger_gen=None, ood_detector=None, alpha=0.7, trigger_size=5, trigger_density=0.5, input_dim=None):
    # Ensure all attributes in data are tensors
    data_poisoned = Data(
        x=torch.tensor(data.x, dtype=torch.float) if not isinstance(data.x, torch.Tensor) else data.x.clone(),
        edge_index=torch.tensor(data.edge_index, dtype=torch.long) if not isinstance(data.edge_index, torch.Tensor) else data.edge_index.clone(),
        y=torch.tensor(data.y, dtype=torch.long) if not isinstance(data.y, torch.Tensor) else data.y.clone(),
        train_mask=torch.tensor(data.train_mask, dtype=torch.bool) if not isinstance(data.train_mask, torch.Tensor) else data.train_mask.clone(),
        val_mask=torch.tensor(data.val_mask, dtype=torch.bool) if not isinstance(data.val_mask, torch.Tensor) else data.val_mask.clone(),
        test_mask=torch.tensor(data.test_mask, dtype=torch.bool) if not isinstance(data.test_mask, torch.Tensor) else data.test_mask.clone(),
        target_mask=torch.tensor(data.target_mask, dtype=torch.bool) if not isinstance(data.target_mask, torch.Tensor) else data.target_mask.clone(),
        clean_test_mask=torch.tensor(data.clean_test_mask, dtype=torch.bool) if not isinstance(data.clean_test_mask, torch.Tensor) else data.clean_test_mask.clone(),
    ).to(data.x.device if isinstance(data.x, torch.Tensor) else 'cpu')

    device = data_poisoned.x.device

    if len(poisoned_nodes) == 0:
        raise ValueError("No poisoned nodes selected. Ensure 'poisoned_nodes' is populated and non-empty.")

    trigger_size = min(trigger_size, len(poisoned_nodes))

    if attack_type in ['SBA-Samp', 'SBA-Gen', 'GTA']:
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node].to(device) for node in poisoned_nodes[:trigger_size]]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ]).to(device)

        if attack_type == 'SBA-Samp':
            natural_features = avg_features + torch.randn_like(avg_features) * 0.02
            trigger_edges = nx.erdos_renyi_graph(trigger_size, trigger_density).edges
            trigger_edge_index = torch.tensor(list(trigger_edges), dtype=torch.long, device=device).t()
            poisoned_edges = torch.stack((
                poisoned_nodes[:trigger_size],
                torch.randint(0, data.num_nodes, (trigger_size,), dtype=torch.long, device=device)
            ), dim=0)

            data_poisoned.edge_index = torch.cat([data_poisoned.edge_index, trigger_edge_index, poisoned_edges], dim=1)
            data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]

        elif attack_type == 'SBA-Gen':
            natural_features = avg_features + torch.normal(mean=0.0, std=0.03, size=avg_features.shape, device=device)
            trigger_edges = []
            for i in range(trigger_size):
                for j in range(i + 1, trigger_size):
                    similarity = torch.exp(-torch.norm((natural_features[i] - natural_features[j]) / data.x.std(dim=0))**2)
                    if similarity > torch.rand(1).item():
                        trigger_edges.append([i, j])

            if len(trigger_edges) > 0:
                trigger_edge_index = torch.tensor(trigger_edges, dtype=torch.long, device=device).t()
                trigger_edge_index += poisoned_nodes[:trigger_size].unsqueeze(0)
            else:
                trigger_edge_index = torch.empty((2, 0), dtype=torch.long, device=device)

            poisoned_edges = torch.stack((
                poisoned_nodes[:trigger_size],
                torch.randint(0, data.num_nodes, (trigger_size,), dtype=torch.long, device=device)
            ), dim=0)

            data_poisoned.edge_index = torch.cat([data_poisoned.edge_index, trigger_edge_index, poisoned_edges], dim=1)
            data_poisoned.x[poisoned_nodes[:trigger_size]] = natural_features[:trigger_size]

        elif attack_type == 'GTA':
            trigger_features = avg_features + torch.randn_like(avg_features) * 0.05
            data_poisoned.x[poisoned_nodes] = trigger_features

    elif attack_type == 'DPGBA':
        connected_nodes = [data.edge_index[0][data.edge_index[1] == node].to(device) for node in poisoned_nodes]
        avg_features = torch.stack([
            data.x[nodes].mean(dim=0) if len(nodes) > 0 else data.x.mean(dim=0) for nodes in connected_nodes
        ]).to(device)

        if trigger_gen is None:
            raise ValueError("Trigger generator is required for the DPGBA attack.")

        with torch.no_grad():
            trigger_features = trigger_gen(avg_features)

        if trigger_features.size(1) != data.x.size(1):
            raise ValueError(f"Trigger feature dimension mismatch: {trigger_features.size(1)} vs {data.x.size(1)}")

        node_alphas = torch.rand(len(poisoned_nodes), device=device) * 0.3 + 0.5
        distribution_preserved_features = (
            node_alphas.unsqueeze(1) * data.x[poisoned_nodes] +
            (1 - node_alphas.unsqueeze(1)) * trigger_features
        )
        data_poisoned.x[poisoned_nodes] = distribution_preserved_features

    return data_poisoned



def dominant_set_clustering(data, threshold=0.7, use_pca=True, pca_components=10):
    """
    Applies a simplified outlier detection framework using a combination of K-Means clustering and distance-based heuristics.

    Parameters:
    - data: PyG data object representing the graph.
    - threshold: Quantile threshold for identifying outliers based on cluster distances.
    - use_pca: Whether to use PCA for dimensionality reduction.
    - pca_components: Number of PCA components to use if PCA is applied.

    Returns:
    - pruned_nodes: Set of nodes identified as outliers.
    - data: Updated PyG data object with modified features and labels for outliers.
    """
    # Step 1: Determine the number of clusters based on the number of classes
    n_clusters = len(data.y.unique())  # Number of unique classes in the dataset

    # Step 2: Dimensionality reduction using PCA (optional)
    node_features = data.x.detach().cpu().numpy()
    if use_pca and node_features.shape[1] > pca_components:
        pca = PCA(n_components=pca_components)
        node_features = pca.fit_transform(node_features)

    # Step 3: K-Means Clustering to identify clusters and potential outliers
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(node_features)
    cluster_labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_

    # Calculate distances to cluster centers
    distances = np.linalg.norm(node_features - cluster_centers[cluster_labels], axis=1)

    # Identify outlier candidates based on distance threshold
    distance_threshold = np.percentile(distances, 100 * threshold)
    outlier_candidates = np.where(distances > distance_threshold)[0]

    # Step 4: Update data to reflect removal of outlier influence
    pruned_nodes = set(outlier_candidates)
    if len(pruned_nodes) > 0:
        outliers = torch.tensor(list(pruned_nodes), dtype=torch.long, device=data.x.device)

        # Assign an invalid label (-1) to outlier nodes to discard them during training
        data.y[outliers] = -1

        # Replace the features of outliers with the average feature value to reduce their impact
        data.x[outliers] = data.x.mean(dim=0).to(data.x.device)

    return pruned_nodes, data

def defense_prune_edges(data, quantile_threshold=0.9):
    """
    Prunes edges based on adaptive cosine similarity between node features.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile to determine pruning threshold (e.g., 0.9 means pruning edges in the top 10% dissimilar).

    Returns:
    - data: Updated PyG data object with pruned edges.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Adaptive threshold based on quantile of similarity distribution
    similarity_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Keep edges with cosine similarity above the threshold
    pruned_mask = cosine_similarities >= similarity_threshold
    pruned_edges = edge_index[:, pruned_mask]

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    return data



def defense_prune_and_discard_labels(data, quantile_threshold=0.2):
    """
    Prunes edges based on adaptive cosine similarity and discards labels of nodes connected by pruned edges selectively.

    Parameters:
    - data: PyG data object representing the graph.
    - quantile_threshold: Quantile threshold for cosine similarity pruning (e.g., 0.2 means pruning edges in the bottom 20%).

    Returns:
    - data: Updated PyG data object with pruned edges and selectively discarded labels.
    """
    features = data.x
    norm_features = F.normalize(features, p=2, dim=1)  # Normalize features using PyTorch
    edge_index = data.edge_index

    # Calculate cosine similarity for each edge
    src, dst = edge_index[0], edge_index[1]
    cosine_similarities = torch.sum(norm_features[src] * norm_features[dst], dim=1)

    # Use quantile to determine adaptive threshold for pruning
    adaptive_threshold = torch.quantile(cosine_similarities, quantile_threshold).item()

    # Mask edges with similarity below the adaptive threshold
    pruned_mask = cosine_similarities < adaptive_threshold
    pruned_edges = edge_index[:, ~pruned_mask]  # Retain edges that are above the threshold

    # Update edge index with pruned edges
    data.edge_index = pruned_edges

    # Selectively discard labels of nodes connected by many pruned edges
    pruned_src, pruned_dst = edge_index[:, pruned_mask]
    pruned_nodes_count = torch.bincount(torch.cat([pruned_src, pruned_dst]), minlength=data.num_nodes)

    # Only discard labels if the node has a high count of pruned edges
    threshold_count = int(torch.median(pruned_nodes_count).item())  # Use median count as a threshold
    nodes_to_discard = torch.where(pruned_nodes_count > threshold_count)[0]

    data.y[nodes_to_discard] = -1  # Use -1 to represent discarded labels

    return data

# Compute ASR and Clean Accuracy (using .detach() to avoid retaining computation graph)
def compute_metrics(model, data, poisoned_nodes):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index).detach()
        _, pred = out.max(dim=1)
        asr = (pred[poisoned_nodes] == data.y[poisoned_nodes]).sum().item() / len(poisoned_nodes) * 100
        clean_acc = accuracy_score(data.y[data.test_mask].cpu(), pred[data.test_mask].cpu()) * 100
    return asr, clean_acc



# Visualization Function
# Visualize PCA for Attacks
# Added function to visualize PCA projections of node embeddings for different attacks
def visualize_pca_for_attacks(attack_embeddings_dict):
    pca = PCA(n_components=2)
    plt.figure(figsize=(20, 10))

    for i, (attack, attack_data) in enumerate(attack_embeddings_dict.items(), 1):
        embeddings = attack_data['data'].detach().cpu().numpy()
        poisoned_nodes = attack_data['poisoned_nodes'].detach().cpu().numpy()

        # Apply PCA to the node embeddings
        pca_result = pca.fit_transform(embeddings)

        # Create masks for clean and poisoned nodes
        clean_mask = np.ones(embeddings.shape[0], dtype=bool)
        clean_mask[poisoned_nodes] = False

        # Extract clean and poisoned node embeddings after PCA
        clean_embeddings = pca_result[clean_mask]
        poisoned_embeddings = pca_result[~clean_mask]

        # Plotting clean and poisoned nodes
        plt.subplot(2, 3, i)
        plt.scatter(clean_embeddings[:, 0], clean_embeddings[:, 1], s=10, alpha=0.5, label='Clean Nodes', c='b')
        plt.scatter(poisoned_embeddings[:, 0], poisoned_embeddings[:, 1], s=10, alpha=0.8, label='Poisoned Nodes', c='r')
        plt.title(f'PCA Visualization for {attack}')
        plt.xlabel('PCA Component 1')
        plt.ylabel('PCA Component 2')
        plt.legend()

    plt.tight_layout()
    plt.show()

!pip install torch torch-geometric

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from itertools import combinations
import networkx as nx
from torch_geometric.utils import to_networkx, from_networkx

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_dataset(name):
    dataset = Planetoid(root=f"./data/{name}", name=name)
    data = dataset[0]

    # Ensure node features are initialized
    if data.x is None:
        num_nodes = data.num_nodes
        data.x = torch.eye(num_nodes)  # One-hot encoding for nodes

    return dataset

# Define the ESAN Model
class ESAN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ESAN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.shared_aggregator = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, subgraphs, num_nodes, batch_size=10):
        # Initialize tensors to store node predictions and counts
        device = next(self.parameters()).device
        node_predictions = torch.zeros((num_nodes, self.shared_aggregator.out_features), device=device)
        node_counts = torch.zeros(num_nodes, device=device)

        # Process subgraphs in batches
        for i in range(0, len(subgraphs), batch_size):
            batch = subgraphs[i:i + batch_size]
            for subgraph in batch:
                x, edge_index = subgraph.x.to(device), subgraph.edge_index.to(device)
                x = self.conv1(x, edge_index)
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
                x = self.conv2(x, edge_index)

                # Map to output dimension (num_classes)
                x = self.shared_aggregator(x)

                # Aggregate features for nodes in the subgraph
                node_predictions[subgraph.n_id] += x
                node_counts[subgraph.n_id] += 1

        # Average predictions for nodes that appear in multiple subgraphs
        node_predictions = node_predictions / node_counts.unsqueeze(1).clamp(min=1)
        return F.log_softmax(node_predictions, dim=1)



# Train the ESAN model
def train_model(model, subgraphs, data, optimizer, epochs=50):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(subgraphs, data.num_nodes)  # Process subgraphs through the model
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        # Print training progress
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            train_acc = (out[data.train_mask].argmax(dim=1) == data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}, Train Accuracy: {train_acc:.4f}")

    return model

# Test the ESAN model
def test_model(model, subgraphs, data):
    model.eval()
    logits = model(subgraphs, data.num_nodes)  # Process subgraphs through the model
    accs = []
    for mask_name, mask in zip(["Train", "Validation", "Test"], [data.train_mask, data.val_mask, data.test_mask]):
        pred = logits[mask].argmax(dim=1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
        print(f"{mask_name} Accuracy: {acc:.4f}")
    return accs

def generate_subgraphs(data, policy="edge_deleted", max_subgraphs=300):
    graph = to_networkx(data, to_undirected=True)
    subgraphs = []

    if policy == "edge_deleted":
        for i, edge in enumerate(graph.edges):
            if len(subgraphs) >= max_subgraphs:
                break
            subgraph = graph.copy()
            subgraph.remove_edge(*edge)
            pyg_subgraph = from_networkx(subgraph)
            pyg_subgraph.n_id = torch.tensor(list(subgraph.nodes))  # Map subgraph nodes to original graph nodes
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[pyg_subgraph.n_id]  # Use features from the original graph
            subgraphs.append(pyg_subgraph)

    elif policy == "node_deleted":
        for i, node in enumerate(graph.nodes):
            if len(subgraphs) >= max_subgraphs:
                break
            subgraph = graph.copy()
            subgraph.remove_node(node)
            pyg_subgraph = from_networkx(subgraph)
            pyg_subgraph.n_id = torch.tensor(list(subgraph.nodes))  # Map subgraph nodes to original graph nodes
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[pyg_subgraph.n_id]  # Use features from the original graph
            subgraphs.append(pyg_subgraph)

    elif policy == "ego":
        radius = 2
        for i, node in enumerate(graph.nodes):
            if len(subgraphs) >= max_subgraphs:
                break
            # Generate ego graph for the node with the specified radius
            subgraph = nx.ego_graph(graph, node, radius=radius)

            # Convert the subgraph to PyTorch Geometric format
            pyg_subgraph = from_networkx(subgraph)

            # Add mapping of subgraph nodes to original graph nodes
            pyg_subgraph.n_id = torch.tensor(list(subgraph.nodes))  # Map subgraph nodes

            # Add central node feature
            central_node_feature = torch.zeros(len(subgraph.nodes), 1)
            central_node_idx = list(subgraph.nodes).index(node)  # Index of the central node
            central_node_feature[central_node_idx] = 1

            # Combine central node feature with original features
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[pyg_subgraph.n_id]  # Use original features from the graph
            pyg_subgraph.x = torch.cat([pyg_subgraph.x, central_node_feature], dim=1)  # Add central node feature

            # Normalize the features (optional)
            pyg_subgraph.x = F.normalize(pyg_subgraph.x, p=2, dim=1)

            subgraphs.append(pyg_subgraph)


    return subgraphs

def implement_attacks_on_esan():
    datasets = ["Cora", "CiteSeer", "PubMed"]
    policies = ["ego", "edge_deleted", "node_deleted"]  # ESAN-specific policies
    dataset_budgets = {'Cora': 10, 'CiteSeer': 20, 'PubMed': 30}  # Poisoning budgets
    results_summary = []

    for dataset_name in datasets:
        try:
            # Load dataset and split into train/val/test
            dataset = load_dataset(dataset_name)
            data = dataset[0].to(device)
            input_dim = data.num_features
            output_dim = dataset.num_classes if isinstance(dataset.num_classes, int) else dataset.num_classes[0]
            data = split_dataset(data)
            poisoned_node_budget = dataset_budgets.get(dataset_name, 10)

            for policy in policies:
                try:
                    # Generate subgraphs and train the ESAN model
                    subgraphs = generate_subgraphs(data, policy=policy, max_subgraphs=100)
                    adjusted_input_dim = input_dim + 1 if policy == "ego" else input_dim
                    model = ESAN(adjusted_input_dim, hidden_dim=64, output_dim=output_dim).to(device)
                    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

                    print(f"\nTraining ESAN model ({policy} policy) for dataset {dataset_name}.")
                    model = train_model(model, subgraphs, data, optimizer, epochs=50)

                    # Test baseline model accuracy
                    baseline_accs = test_model(model, subgraphs, data)
                    baseline_test_acc = baseline_accs[2]  # Test accuracy
                    results_summary.append(f"Dataset: {dataset_name}, Model: ESAN, Policy: {policy}, Attack: None, Clean Accuracy: {baseline_test_acc * 100:.2f}%")
                    print(f"Dataset: {dataset_name}, Policy: {policy}, Baseline Test Accuracy: {baseline_test_acc * 100:.2f}%")

                    # Perform attack evaluations
                    for attack in ['SBA-Samp', 'SBA-Gen', 'GTA', 'UGBA', 'DPGBA']:
                        try:
                            # Select nodes to poison and initialize TriggerGenerator
                            poisoned_nodes = select_high_centrality_nodes(data, poisoned_node_budget)
                            trigger_gen = TriggerGenerator(input_dim=adjusted_input_dim, hidden_dim=64).to(device)

                            # Apply attack and train with poisoned data
                            trained_model, data_poisoned = train_with_poisoned_data(
                                model=model,
                                data=data,
                                optimizer=optimizer,
                                poisoned_nodes=poisoned_nodes,
                                trigger_gen=trigger_gen,
                                attack=attack,
                                alpha=0.7
                            )

                            # Compute metrics (ASR and Clean Accuracy)
                            asr, clean_acc = compute_metrics(trained_model, data_poisoned, poisoned_nodes)
                            results_summary.append(f"Dataset: {dataset_name}, Policy: {policy}, Attack: {attack}, ASR: {asr:.2f}%, Clean Accuracy: {clean_acc:.2f}%")
                            print(f"Dataset: {dataset_name}, Policy: {policy}, Attack: {attack}, ASR: {asr:.2f}%, Clean Accuracy: {clean_acc:.2f}%")
                        except Exception as e:
                            print(f"Error during attack {attack} on {dataset_name} with ESAN ({policy} policy): {e}")
                except Exception as e:
                    print(f"Error during ESAN policy {policy} on {dataset_name}: {e}")
        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {e}")

    # Save results to a CSV
    results_df = pd.DataFrame(results_summary, columns=["Summary"])
    print("\nSummary of Results:")
    for result in results_summary:
        print(result)
    results_df.to_csv("esan_attack_results_summary_fixed.csv", index=False)


# Run the fixed implementation
implement_attacks_on_esan()


Using device: cpu

Training ESAN model (ego policy) for dataset Cora.
Epoch 10/50, Loss: 1.4901, Train Accuracy: 0.3917
Epoch 20/50, Loss: 1.1218, Train Accuracy: 0.4955
Epoch 30/50, Loss: 1.0236, Train Accuracy: 0.5245
Epoch 40/50, Loss: 0.9774, Train Accuracy: 0.5424
Epoch 50/50, Loss: 0.9571, Train Accuracy: 0.5530
Train Accuracy: 0.5561
Validation Accuracy: 0.4556
Test Accuracy: 0.4898
Dataset: Cora, Policy: ego, Baseline Test Accuracy: 48.98%
Error during attack SBA-Samp on Cora with ESAN (ego policy): zeros() received an invalid combination of arguments - got (tuple, device=torch.device), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

Error durin