# Design document

## Adversarial estimation on graphs 
For ground truth dataset graph dataset $G = (X,Y,N,A) $ where:
- X is a matrix of $n \times k$ exogenous characteristics of individual nodes, i.e. each node is asociated with $k$ dimensinal vector of features
- Y is a matrix of $n \times l$ endogenous outcomes of individual nodes, i.e. each node is asociated with $k$ dimensinal vector of outcomes
- N = \{0,...,n\} is set of node indicises
- A $n\times n$ is and adjacency matrix, symmetric and $A \in \{0,1\}^{n\times n}$

Structural model $m_{\theta}: R^{n \times k } \to R^{n \times l }$, $m$ is parametrized by uknown vector $\theta$.

Synthetic dataset $G(\theta)' = (X,Y',N,A) $ where $Y'=m_{\theta}(X,A, \theta)$

GNN discriminator $D: g_i \to [0,1]$, $g_i$ is a graph.

We search for $\theta*$ such that:
$$ \theta* \in \arg \min_{\theta} \max_{D} L(G'(\theta),G)$$

where the loss $L$ is some classification quality metric we want to minimize (e.g. accuracy).

### Algorithm description 
In general use ego sampling centered around randomly selected node idices with neigbourhood of size h=1.

repeat until convergence of $\theta$:
  - sample ground truth data from $G$ 
  - for fixed $\theta$ use $m_{\theta}$ to generate synthetic data $G'=(X,Y',N,A)$
  - sample synthetic data from $G'$
  - create labeled dataset where subgraps drawn from $G$ have label 0 and synthetic examples form $G'$ have label 1
  - make train test split 
  - train GNN discriminator $D$
  - evaluate performance on a test set 


## Repository structure

   \structural_gnn_lib

   \\ estimator
  
  \\\ __init__.py

  \\\ estimator.py

  \\ generator
  
  \\\ __init__.py

  \\\ generator.py

  \\ utils

  \\\ __init__.py

  \\\ utils.py

\linear_in_means.ipynb

## tools to use
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import random
from skopt import gp_minimize
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import time
from tqdm.notebook import tqdm

## Required objects

Following sections describe specifications of objects required to implement above described algorithm. The demands are very concrete.

## test dataset
example of creating test dataset
        self.num_nodes = num_nodes
        self.true_a = true_a
        self.true_b = true_b
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        
        # Generate a fixed random graph using the Erdos-Rényi model
        self.G = nx.erdos_renyi_graph(n=num_nodes, p=p, seed=seed)
        self.adjacency = nx.adjacency_matrix(self.G).todense()
        
        # Generate univariate node features
        self.x = np.random.randn(num_nodes, 1)
        
        # Generate outcomes y based on neighbors' x values
        self.y = np.zeros((num_nodes, 1))
        for i in range(num_nodes):
            neighbors = list(self.G.neighbors(i))
            if neighbors:
                mean_neighbor_x = np.mean(self.x[neighbors])
            else:
                mean_neighbor_x = 0.0
            self.y[i] = true_a + true_b * mean_neighbor_x

## Generator class

### class GeneratorBase:

This class defines abstract properties of generator classes, general purpose 
of the simulator is to supply data to the adversarial estimator, this class abstracts behaviour
of two child classes:

1) GroundTruthGenerator
    - is instantiated from ground truth data, graph $G = \{X,Y,A,N\}$, where $X$ is matrix containing exogenous covariates for each node ($n \times k$),  $Y$ is a matrix containing node level outcomes,  $A$ is symmetric adjacency matrix (undirected graph, $n \times n$) and $N = \{1,...,n\}$ is a collection of node indices. The data necessary for a creation of ground truth generator instance are supplied externally and are assumed to be wrangled into correct format.

2) SyntheticGenerator
- intherits exogenous data from the GroundTruthGenerator instance
  i.e. from $G = \{X,Y,A,N\}$,  ${X,A,N\}$ are inherited as immutable data members of the SyntheticGenerator class instance. In addition  GroundTruthGenerator needs a structural_model function which will take ${X,A,N\}$ and generates synthetic outcomes $Y'$. 

  For example for linear in means model the structural_model function is: 
  y_i = a + b * mean(x_j) for j in neighbors(i)

  
All generator classes implement indentical sample_subgraphs method:

   def sample_subgraphs(self, node_ids):
        """
        Extract induced subgraphs centered on specified nodes.
        
        For each node in node_ids, creates a subgraph containing the node and all
        its neighbors, with features and outcomes preserved from the original graph.
        
        Parameters:
        -----------
        node_ids : list
            List of node indices to sample subgraphs from
        
        Returns:
        --------
        list
            List of PyTorch Geometric Data objects representing subgraphs
        """

example of sampling 

   subgraphs = []
        for node in node_ids:
            # Ensure the subgraph contains the target node and its neighbors
            nodes = [node] + list(self.G.neighbors(node))
            subgraph = self.G.subgraph(nodes).copy()
            
            # Relabel nodes for internal consistency
            mapping = {n: i for i, n in enumerate(nodes)}
            subgraph = nx.relabel_nodes(subgraph, mapping)
            
            # Retrieve features and outcomes from the ground truth
            x_sub = torch.tensor(self.x[nodes], dtype=torch.float)
            y_sub = torch.tensor(self.y[nodes], dtype=torch.float)
            
            # Combine x and y into a single feature vector per node
            features = torch.cat([x_sub, y_sub], dim=1)
            
            # Build edge index (ensuring both directions for an undirected graph)
            edge_index = torch.tensor(list(subgraph.edges), dtype=torch.long).t().contiguous()
            if edge_index.numel() > 0:
                edge_index = torch.cat([edge_index, edge_index[[1, 0], :]], dim=1)
            else:
                edge_index = torch.empty((2, 0), dtype=torch.long)
                
            # Create PyTorch Geometric Data object
            data = Data(x=features, edge_index=edge_index, 
                        original_nodes=nodes,  # Store original node IDs for visualization
                        original_graph=subgraph)  # Store NetworkX graph for visualization
            subgraphs.append(data)
        return subgraphs
            
The Synthetic generator class has generate_outcomes(self, theta) which generates outcomes for fixed $X,A,N$ and variable parameter vector $\theta$


## GraphDiscriminator(torch.nn.Module)

Is a graph neural network whose purpose is to discriminate between sythetic and ground truth data.

Is initialized with: 

    def __init__(self, input_dim, hidden_dim, num_classes):
        """
        Initialize the GNN discriminator.
        
        Parameters:
        -----------
        input_dim : int
            Dimension of input node features
        hidden_dim : int
            Dimension of hidden node representations
        num_classes : int
            Number of output classes (2 for binary classification)
        """

And has a formard(self, data) methods as ususal torch NN

    def forward(self, data):
        """
        Forward pass through the GNN.
        
        Parameters:
        -----------
        data : torch_geometric.data.Data
            Input graph data
        
        Returns:
        --------
        torch.Tensor
            Logits for each class
        """


## Utility functions

- defined in utils subdirectory

### create_dateset

Factory function which will combine ground truth subgraphs sampled from a ground truth generator using and synthetic data from synthetic data generator instance to a labeled dataset, before train test split.

def create_dataset(real_subgraphs, synthetic_subgraphs):
    """
    Create a dataset combining real and synthetic subgraphs with class labels.
    
    Parameters:
    -----------
    real_subgraphs : list
        List of PyTorch Geometric Data objects from the ground truth
    synthetic_subgraphs : list
        List of PyTorch Geometric Data objects from the synthetic simulator
    
    Returns:
    --------
    list
        Combined dataset with class labels (0 for real, 1 for synthetic)
    """
    dataset = []
    for data in real_subgraphs:
        data.label = torch.tensor(0, dtype=torch.long)
        dataset.append(data)
    for data in synthetic_subgraphs:
        data.label = torch.tensor(1, dtype=torch.long)
        dataset.append(data)
    return dataset


### evaluate_discriminator

Helper function to evaluate discriminator on a test set. 

def evaluate_discriminator(model, loader, device):
    """
    Evaluate the discriminator model.
    
    Parameters:
    -----------
    model : GraphDiscriminator
        The GNN discriminator model
    loader : torch_geometric.data.DataLoader
        DataLoader containing evaluation data
    device : torch.device
        Device to run computations on
    
    Returns:
    --------
    float
        Classification accuracy
    """

### objective_function

Objective function for outside loop optimization task. Wraps around
ground truth generator, synthetic generator, discriminator and manages whole synthetic data generation process, returns value of discriminator accuracy on test set to be minimized.

def objective_function(theta, ground_truth_generator, m, num_epochs=20, verbose=False):
    """
    Objective function for parameter estimation.
    
    For candidate parameters theta, generates synthetic outcomes, trains a GNN 
    discriminator to distinguish between real and synthetic data, and returns
    the test accuracy (which we want to minimize).
    
    Parameters:
    -----------
    theta : list or numpy.ndarray
        Candidate parameters theta
    ground_truth_simulator : GroundTruthGenerator
        The ground truth generator
    m : int
        Number of nodes to sample for subgraphs
    num_epochs : int
        Number of epochs to train the discriminator
    verbose : bool
        Whether to print progress information
    
    Returns:
    --------
    float
        Test accuracy of the discriminator (objective to minimize)
    """

## AdversarialEstimator 

Takes advantage of the described utility and helper functions to run the adversarial estimation. Using gp_minimize from scikit optimize. Has a constructor that takes in ground truth data, initial params vector and potentially optimizer.
Within constructor the estimator will build the objective function using proposed utils and classes and run the estimation.

defaul optimizer setting
result = gp_minimize(
        safe_objective,
        space,
        n_calls=500,                # Total evaluations - reduce this for testing
        n_initial_points=400,       # Initial random evaluations
        noise=0.2,                  # Explicitly model noise
        acq_func='EI',              #  Expected Improvement acquisition function
        callback=gp_callback,
        random_state=42,
        n_jobs=-1,                  # Set to 1 for debugging, -1 for production
        verbose=verbose
    )

## misc functions 

Additional function outside of the library meant specifically for visualistation
of test two parameter linear in means model minimization objective as 2d surface.
Is used within the test notebook. 

def visualize_objective_surface(estimator, m, resolution=20, num_epochs=5, verbose=False):
    """
    Visualize the objective function as a 2D surface.
    
    Parameters:
    -----------
    estimator:

    m : int
        Number of nodes to sample for each evaluation
    device : torch.device
        Device to run computations on
    resolution : int
        Resolution of the grid for parameter values
    num_epochs : int
        Number of epochs to train the discriminator for each evaluation
    verbose : bool
        Whether to print progress information
    """