# GNM Experiment: Sweeping, Saving, and Querying

This notebook demonstrates the process of setting up and running a parameter sweep for a generative network model (GNM) using the `gnm` library. We will:

1.  **Configure Environment**: Set up the PyTorch device and load necessary data.
2.  **Define Parameters**: Specify the parameter space for both binary and weighted network generation.
3.  **Set Evaluation Criteria**: Define the metrics to evaluate the similarity between generated networks and a real-world consensus network.
4.  **Run the Sweep**: Execute the experiment using `fitting.perform_sweep`.
5.  **Save and Query**: Use the `ExperimentEvaluation` class to save the results and demonstrate how to query them based on specific parameters.

## 1. Imports

First, we import all the necessary modules from the `gnm` library and `torch`.

In [None]:
from gnm.fitting.experiment_saving import ExperimentEvaluation
from gnm.fitting.experiment_dataclasses import Experiment
from gnm import defaults, fitting, generative_rules, weight_criteria, evaluation
import torch

## 2. Environment Setup and Data Loading

We'll set the device to a GPU if available, otherwise, it will default to the CPU. We then load a pre-defined distance matrix and a binary consensus network, which will serve as the ground truth for our evaluations.

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

distance_matrix = defaults.get_distance_matrix(device=DEVICE)
binary_consensus_network = defaults.get_binary_network(device=DEVICE)

print(f"Distance matrix shape: {distance_matrix.shape}")
print(f"Binary consensus network shape: {binary_consensus_network.shape}")

## 3. Defining Sweep Parameters

Here, we define the parameter space for our sweep. This is broken down into two parts:

* `BinarySweepParameters`: Defines the parameters for generating the network topology (the binary connections). We specify values for `eta` and `gamma`, the generative rule (`MatchingIndex`), and relationship types.
* `WeightedSweepParameters`: Defines the parameters for assigning weights to the connections, including the `alpha` parameter and the optimization criterion.

In [None]:
# For this demo, we use single values. To perform a wider sweep, 
# you could use torch.linspace or a list of values.
eta_values = torch.Tensor([1])
gamma_values = torch.Tensor([-1])

# Calculate the number of connections to generate
num_connections = int(binary_consensus_network.sum().item() / 2)

binary_sweep_parameters = fitting.BinarySweepParameters(
    eta=eta_values,
    gamma=gamma_values,
    lambdah=torch.Tensor([0.0]),
    distance_relationship_type=["powerlaw"],
    preferential_relationship_type=["powerlaw"],
    heterochronicity_relationship_type=["powerlaw"],
    generative_rule=[generative_rules.MatchingIndex()],
    num_iterations=[num_connections],
)

weighted_sweep_parameters = fitting.WeightedSweepParameters(
    alpha=[0.01],
    optimisation_criterion=[weight_criteria.DistanceWeightedCommunicability(distance_matrix=distance_matrix)],
)

print("Binary Sweep Parameters:")
print(binary_sweep_parameters)
print("\nWeighted Sweep Parameters:")
print(weighted_sweep_parameters)

## 4. Full Sweep Configuration

The `SweepConfig` object combines the binary and weighted parameters with general simulation settings, such as the number of simulations to run for each parameter combination.

In [None]:
num_simulations = 1

sweep_config = fitting.SweepConfig(
    binary_sweep_parameters=binary_sweep_parameters,
    weighted_sweep_parameters=weighted_sweep_parameters,
    num_simulations=num_simulations,
    distance_matrix=[distance_matrix]
)

## 5. Defining Evaluation Criteria

We need to define how to score the generated networks. We use Kolmogorov-Smirnov (KS) tests to compare the distributions of various network properties (clustering, degree, edge length) against the real network. The `MaxCriteria` function is used to select the maximum (worst) KS statistic among these as a single energy score for the binary model.

In [None]:
criteria = [evaluation.ClusteringKS(), evaluation.DegreeKS(), evaluation.EdgeLengthKS(distance_matrix)]
energy = evaluation.MaxCriteria(criteria)
binary_evaluations = [energy]

weighted_evaluations = [evaluation.WeightedNodeStrengthKS(normalise=True), evaluation.WeightedClusteringKS()]

## 6. Running the Experiment Sweep

With all configurations in place, we call `fitting.perform_sweep` to run the simulations. This function iterates through all parameter combinations, generates networks, evaluates them, and returns a list of `Experiment` objects containing the results. We set `verbose=True` to see progress.

In [None]:
experiments = fitting.perform_sweep(
    sweep_config=sweep_config, 
    binary_evaluations=binary_evaluations, 
    real_binary_matrices=binary_consensus_network,
    weighted_evaluations=weighted_evaluations,
    save_model=False, # Set to True to save the model instances
    save_run_history=False,
    verbose=True
)

## 7. Saving and Querying Experiment Results

Finally, we demonstrate how to manage the results. 

1.  Instantiate `ExperimentEvaluation`.
2.  Use `.save_experiments()` to save the list of experiment objects. This will typically save to a file for later access.
3.  Use `.query_experiments()` to filter and retrieve specific experiments from the saved set based on their parameters.

In [None]:
# The ExperimentEvaluation class handles saving and loading
eval_handler = ExperimentEvaluation()

# Save the list of experiments
eval_handler.save_experiments(experiments)
print(f"Saved {len(experiments)} experiment(s).")

# Query the experiments by a binary parameter
print("\nQuerying for generative_rule = 'MatchingIndex'...")
query_by_rule = eval_handler.query_experiments(by='generative_rule', value='MatchingIndex')
print(f"Found {len(query_by_rule)} experiment(s).")
print(query_by_rule)

# Query the experiments by a weighted parameter
print("\nQuerying for alpha = 0.01...")
query_by_alpha = eval_handler.query_experiments(by='alpha', value=0.01)
print(f"Found {len(query_by_alpha)} experiment(s).")
print(query_by_alpha)