In [None]:
import warnings
warnings.filterwarnings("ignore", message="An issue occurred while importing 'torch-sparse'")
warnings.filterwarnings("ignore", message="An issue occurred while importing 'torch-cluster'")

import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

from adversarial_nets import AdversarialEstimator, GraphDataset


def build_peer_operator(A: np.ndarray, row_normalize: bool = True) -> np.ndarray:
    """
    Make a peer operator P from adjacency A with zero diagonal.
    Optionally row-normalize by degree; isolates pure peer effects and avoids self-loops.
    """
    P = A.copy().astype(float)
    np.fill_diagonal(P, 0.0)
    if row_normalize:
        deg = P.sum(axis=1, keepdims=True)
        # Safe normalization: rows with zero degree remain zeros
        with np.errstate(divide="ignore", invalid="ignore"):
            P = np.divide(P, deg, out=np.zeros_like(P), where=(deg > 0))
    return P

def simulate_linear_in_means(X: np.ndarray,
                             A: np.ndarray,
                             theta: np.ndarray,
                             noise_std: float = 0.1) -> np.ndarray:
    """
    Ground-truth simulator: y = (I - rho P)^{-1} (alpha*1 + beta*X + gamma*P X + eps)
    X: (n, 1) single scalar feature per node (for simplicity)
    A: (n, n) adjacency
    theta: [alpha, beta, gamma, rho]
    """
    alpha, beta, gamma, rho = map(float, theta)
    n = X.shape[0]
    P = build_peer_operator(A, row_normalize=True)
    I = np.eye(n)

    # Right-hand side: alpha*1 + beta*x + gamma*P*x + eps
    rhs = alpha * np.ones(n) + beta * X[:, 0] + gamma * (P @ X[:, 0]) + np.random.normal(0.0, noise_std, size=n)

    # Solve (I - rho P) y = rhs
    y = np.linalg.solve(I - rho * P, rhs)
    return y

def structural_model(x: np.ndarray, adjacency: np.ndarray, y0: np.ndarray, theta: np.ndarray) -> np.ndarray:
    """
    Structural mapping required by AdversarialEstimator.
    Ignores y0 here (kept for API compatibility).
    """
    # x expected (n, k). We assume k=1 for simplicity; if k>1, use x[:,0] or adapt the formula.
    return simulate_linear_in_means(x[:, [0]], adjacency, theta, noise_std=0.0)

def discriminator_factory(input_dim: int, hidden_dim: int = 32, num_classes: int = 1) -> nn.Module:
    """
    Very small GNN discriminator: one GCN layer, mean pool, linear head.
    """
    class SimpleGNN(nn.Module):
        def __init__(self, in_dim, hid_dim, out_dim):
            super().__init__()
            self.conv = GCNConv(in_dim, hid_dim)
            self.classifier = nn.Linear(hid_dim, out_dim)

        def forward(self, x, edge_index, batch):
            x = F.relu(self.conv(x, edge_index))
            x = F.dropout(x, p=0.2, training=self.training)
            x = global_mean_pool(x, batch)
            return self.classifier(x)  # BCEWithLogitsLoss inside estimator

    return SimpleGNN(input_dim, hidden_dim, num_classes)



def create_test_graph_dataset(n_nodes: int = 150,
                              p_edge: float = 0.05,
                              theta_true = (0.7, 1.2, 0.8, 0.3),
                              seed: int = 42) -> GraphDataset:
    """
    Make a toy graph with one exogenous feature per node and outcomes y from the true theta.
    """
    rng = np.random.default_rng(seed)
    G = nx.erdos_renyi_graph(n_nodes, p_edge, seed=seed)
    A = nx.to_numpy_array(G, dtype=float)

    # One scalar feature per node
    X = rng.normal(0.0, 1.0, size=(n_nodes, 1))

    # Ensure safe rho (optional; true theta already chosen benignly)
    P = build_peer_operator(A, row_normalize=True)
    eigvals = np.linalg.eigvals(P)
    lambda_max = float(np.max(np.abs(eigvals))) if eigvals.size else 0.0
    if lambda_max > 0 and abs(theta_true[3]) >= 0.99 / lambda_max:
        rho_safe = 0.8 / lambda_max
        theta_true = (theta_true[0], theta_true[1], theta_true[2], rho_safe)

    Y = simulate_linear_in_means(X, A, np.array(theta_true), noise_std=0.1)
    N = list(range(n_nodes))
    return GraphDataset(X=X, Y=Y, A=A, N=N)

if __name__ == "__main__":
    # True structural parameters: (alpha, beta, gamma, rho)
    TRUE_THETA = (0.7, 1.2, 0.8, 0.3)

    # 1) Data
    data = create_test_graph_dataset(n_nodes=180, p_edge=0.06, theta_true=TRUE_THETA, seed=123)

    # 2) Bounds: loose for (alpha, beta, gamma); rho constrained by spectral radius of P
    P = build_peer_operator(data.A, row_normalize=True)
    eigvals = np.linalg.eigvals(P)
    lam_max = float(np.max(np.abs(eigvals))) if eigvals.size else 1.0
    rho_cap = (0.99 / lam_max) if lam_max > 0 else 0.99  # |rho| < 1 / lambda_max
    bounds = [(-5.0, 5.0), (-5.0, 5.0), (-5.0, 5.0), (-rho_cap, rho_cap)]

    # 3) Estimator (small settings; bump up for real runs)
    estimator = AdversarialEstimator(
        ground_truth_data=data,
        structural_model=structural_model,
        initial_params=[0.0, 0.0, 0.0, 0.0],
        bounds=bounds,
        discriminator_factory=discriminator_factory,
        gp_params=dict(
            initial_point_generator="sobol",
            n_initial_points=64,
            noise=0.10,
        ),
        # You can also pass sampler options here if your API exposes them.
    )

    # 4) Estimate
    # m = subgraphs per objective eval; num_epochs = discriminator training per eval
    result = estimator.estimate(m=128, num_epochs=8, verbose=True)
    theta_hat = result["x"] if isinstance(result, dict) else result.x

    # 5) Report
    names = ["alpha", "beta", "gamma", "rho"]
    print("\n=== Estimated parameters (linear-in-means, 4 params) ===")
    for nm, t, th in zip(names, TRUE_THETA, theta_hat):
        print(f"{nm:>6s}: true = {t: .4f},  est = {th: .4f},  abs.err = {abs(th - t): .4f}")