In [1]:
import sys
import warnings
sys.path.append('/path/to/parent/directory/of/structural_gnn_lib')
warnings.filterwarnings("ignore", message="An issue occurred while importing 'torch-scatter'")

In [None]:
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm

from structural_gnn_lib import (
    AdversarialEstimator,
    objective_function,
    linear_in_means_model
)

N_NODES=2000
N_SAMPLES=1000
N_EPOCHS=4
RESOLUTION=5

class TestDataset:
    """Generate test dataset for linear-in-means model."""
    
    def __init__(self, num_nodes=100, true_a=1.0, true_b=2.0, p=0.01, seed=42):
        """
        Initialize test dataset.
        
        Parameters:
        -----------
        num_nodes : int
            Number of nodes in the graph
        true_a : float
            True intercept parameter
        true_b : float
            True slope parameter
        p : float
            Edge probability for Erdos-Renyi graph
        seed : int
            Random seed
        """
        self.num_nodes = num_nodes
        self.true_a = true_a
        self.true_b = true_b
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        
        # Generate the NetworkX graph
        self.G = nx.erdos_renyi_graph(n=num_nodes, p=p, seed=seed)
        
        # Create attributes with the expected names for GroundTruthGenerator
        self.A = nx.adjacency_matrix(self.G).todense()  # Adjacency matrix
        self.X = np.random.randn(num_nodes, 1)          # Exogenous features
        self.N = set(range(num_nodes))                  # Node indices
        
        # Generate outcomes Y based on linear-in-means model
        self.Y = np.zeros((num_nodes, 1))
        for i in range(num_nodes):
            neighbors = list(self.G.neighbors(i))
            if neighbors:
                mean_neighbor_x = np.mean(self.X[neighbors])
            else:
                mean_neighbor_x = 0.0
            self.Y[i] = true_a + true_b * mean_neighbor_x



def visualize_objective_surface(estimator, m, resolution, num_epochs, verbose=False):
    """
    Visualize the objective function as a 2D surface.
    
    Parameters:
    -----------
    estimator : AdversarialEstimator
        The adversarial estimator instance
    m : int
        Number of nodes to sample for each evaluation
    resolution : int
        Resolution of the grid for parameter values
    num_epochs : int
        Number of epochs to train the discriminator for each evaluation
    verbose : bool
        Whether to print progress information
    """

    a_range = np.linspace(-3, 5, resolution)
    b_range = np.linspace(-1, 5, resolution)
    A, B = np.meshgrid(a_range, b_range)

    Z = np.zeros_like(A)
    
    total_evals = resolution * resolution
    with tqdm(total=total_evals, desc="Evaluating objective surface") as pbar:
        for i in range(resolution):
            for j in range(resolution):
                theta = [A[i, j], B[i, j]]
                Z[i, j] = objective_function(
                    theta,
                    estimator.ground_truth_generator,
                    estimator.synthetic_generator,
                    num_epochs=num_epochs,
                    m=m
                    verbose=True
                )
                pbar.update(1)
    
    fig = plt.figure(figsize=(12, 5))
    
    ax1 = fig.add_subplot(121, projection='3d')
    surf = ax1.plot_surface(A, B, Z, cmap='viridis', alpha=0.8)
    ax1.set_xlabel('Parameter a')
    ax1.set_ylabel('Parameter b')
    ax1.set_zlabel('Discriminator Accuracy')
    ax1.set_title('Objective Function Surface')
    
    true_a, true_b = 1.0, 2.0  

    ax1.scatter([true_a], [true_b], [Z.min()], color='red', s=100, marker='*', label='True params')
    ax2 = fig.add_subplot(122)
    contour = ax2.contour(A, B, Z, levels=20, cmap='viridis')
    ax2.clabel(contour, inline=True, fontsize=8)
    ax2.scatter([true_a], [true_b], color='red', s=100, marker='*', label='True params')
    
    if hasattr(estimator, 'estimated_params') and estimator.estimated_params is not None:
        est_a, est_b = estimator.estimated_params
        ax1.scatter([est_a], [est_b], [Z.min()], color='orange', s=100, marker='^', label='Estimated params')
        ax2.scatter([est_a], [est_b], color='orange', s=100, marker='^', label='Estimated params')
    
    ax2.set_xlabel('Parameter a')
    ax2.set_ylabel('Parameter b')
    ax2.set_title('Objective Function Contours')
    ax2.legend()
    
    plt.colorbar(surf, ax=ax1, shrink=0.5)
    plt.tight_layout()
    plt.show()
    
    return Z, (A, B)

if __name__ == "__main__":
    print("Testing Adversarial Estimation for Linear-in-Means Model")
    print("=" * 60)
    
    print("\n1. Generating test dataset...")
    test_data = TestDataset(num_nodes=N_NODES, true_a=1.0, true_b=2.0, p=0.1)
    print(f"   - Number of nodes: {test_data.num_nodes}")
    print(f"   - True parameters: a={test_data.true_a}, b={test_data.true_b}")
    print(f"   - Number of edges: {test_data.G.number_of_edges()}")
    
    print("\n2. Creating adversarial estimator...")
    estimator = AdversarialEstimator(
        ground_truth_data=test_data,
        structural_model=linear_in_means_model,
        initial_params=[0.0, 1.0],
        bounds=[(-5.0, 5.0), (-5.0, 5.0)]
    )
    
    print("\n3. Visualizing objective function surface...")
    visualize_objective_surface(estimator, m=N_SAMPLES, resolution=15, num_epochs=N_EPOCHS)
    
    print("\n4. Running adversarial estimation...")
    estimated_params = estimator.estimate(m=N_SAMPLES, num_epochs=N_EPOCHS, verbose=True)
    
    print("\n5. Results:")
    print(f"   - True parameters: a={test_data.true_a}, b={test_data.true_b}")
    print(f"   - Estimated parameters: a={estimated_params[0]:.4f}, b={estimated_params[1]:.4f}")
    print(f"   - Estimation error: a_error={abs(estimated_params[0] - test_data.true_a):.4f}, "
          f"b_error={abs(estimated_params[1] - test_data.true_b):.4f}")
    
    
    plt.tight_layout()
    plt.show()
    
    print("\n7. Final objective surface with estimated parameters...")
    visualize_objective_surface(estimator, m=N_SAMPLES, resolution=15, num_epochs=5)



Testing Adversarial Estimation for Linear-in-Means Model

1. Generating test dataset...
   - Number of nodes: 1000
   - True parameters: a=1.0, b=2.0
   - Number of edges: 49929

2. Creating adversarial estimator...

3. Visualizing objective function surface...


Evaluating objective surface:   0%|          | 0/225 [00:00<?, ?it/s]

Epoch 0, Loss: 0.3246
Test accuracy for theta=[-3.0, -1.0]: 1.0000
Epoch 0, Loss: 0.2417
Test accuracy for theta=[-2.428571428571429, -1.0]: 1.0000
Epoch 0, Loss: 0.3018
Test accuracy for theta=[-1.8571428571428572, -1.0]: 1.0000
Epoch 0, Loss: 0.3564
Test accuracy for theta=[-1.2857142857142858, -1.0]: 1.0000
Epoch 0, Loss: 0.4400
Test accuracy for theta=[-0.7142857142857144, -1.0]: 1.0000
Epoch 0, Loss: 0.4136
Test accuracy for theta=[-0.14285714285714324, -1.0]: 1.0000
Epoch 0, Loss: 0.6353
Test accuracy for theta=[0.4285714285714284, -1.0]: 1.0000
Epoch 0, Loss: 0.6926
Test accuracy for theta=[1.0, -1.0]: 0.4800
Epoch 0, Loss: 0.6600
Test accuracy for theta=[1.5714285714285712, -1.0]: 1.0000
Epoch 0, Loss: 0.5245
Test accuracy for theta=[2.1428571428571423, -1.0]: 1.0000
Epoch 0, Loss: 0.6491
Test accuracy for theta=[2.7142857142857135, -1.0]: 1.0000
Epoch 0, Loss: 0.5496
Test accuracy for theta=[3.2857142857142856, -1.0]: 1.0000
Epoch 0, Loss: 0.5261
Test accuracy for theta=[3.857