# Example Script for Running a Sweep with WandB logging

<p> Wandb (Weights and Biases) is a service that evaluates and saves model parameters while providing excellent visualisation tools - ideal for running experiments. Here, you'll see how you can use wandb in tandem with the GNM toolbox </p>

<p><i>Wandb is a seperate service not affiliated with this toolbox - for wandb-specific support, have a look at their documentation.</i></p>

In [None]:
# basic imports from the package with torch and numpy
import numpy as np
import torch
from gnm.fitting.experiment_saving import *
from gnm.fitting.experiment_dataclasses import Experiment
from gnm import defaults, fitting, generative_rules, weight_criteria, evaluation

# import wandb - run 'pip install wandb' if it's not already installed
import wandb

# Use the correct Device - this is CPU or GPU (use GPU if you have one to utilize parallelization
# which will speed things up considerably)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load the basic matrices - here, we'll just use a binary network
distance_matrix = defaults.get_distance_matrix(device=DEVICE)
binary_consensus_network = defaults.get_binary_network(device=DEVICE)

# login using wandb - you'll need to create an account if you
# dont have one already 
wandb.login()

In [None]:
# set the basic parameters - we'll iterate through just 4 combinations 
# here for demonstration purposes, but you can set this to any number

eta_values = torch.Tensor([1, 1.5]) #torch.linspace(-5, -1, 1)
gamma_values = torch.Tensor([-1, -0.5])#torch.linspace(-0.5, 0.5, 1)
num_connections = int( binary_consensus_network.sum().item() / 2 )


In [None]:
# These are the basic and for the most part, default parameters you'd use in a run. 
# Have a look at the other example scripts for an in-depth look at the parameters and how you can 
# use them yourself.

# The binary sweep parameters are the parameters that are used to generate the binary network
binary_sweep_parameters = fitting.BinarySweepParameters(
    eta = eta_values,
    gamma = gamma_values,
    lambdah = torch.Tensor([0.0]),
    distance_relationship_type = ["powerlaw"],
    preferential_relationship_type = ["powerlaw"],
    heterochronicity_relationship_type = ["powerlaw"],
    generative_rule = [generative_rules.MatchingIndex()],
    num_iterations = [num_connections],
)

# The weighted sweep parameters are the parameters that are used to generate the weighted network
weighted_sweep_parameters = fitting.WeightedSweepParameters(
    alpha = [0.01],
    optimisation_criterion = [weight_criteria.DistanceWeightedCommunicability(distance_matrix=distance_matrix) ],
)  


# The sweep config is the object that contains all the parameters for the sweep
# and is used to generate the networks.
sweep_config = fitting.SweepConfig(
    binary_sweep_parameters = binary_sweep_parameters,
    weighted_sweep_parameters = weighted_sweep_parameters,
    num_simulations = 1,
    distance_matrix = [distance_matrix]    
)

# additonal cirteria to evaluate the generative model against a real connectome
criteria = [ evaluation.ClusteringKS(), evaluation.DegreeKS(), evaluation.EdgeLengthKS(distance_matrix) ]
energy = evaluation.MaxCriteria( criteria )
binary_evaluations = [energy]
weighted_evaluations = [ evaluation.WeightedNodeStrengthKS(normalise=True), evaluation.WeightedClusteringKS() ]


In [None]:
#  Run the experment sweep. You should see a wandb link in the terminal. Follow this
# link to see the model parameters visualized in the wandb dashboard. This will make 
# parameter combinations easier to visualize and understand.
experiments = fitting.perform_sweep(sweep_config=sweep_config, 
                                binary_evaluations=binary_evaluations, 
                                real_binary_matrices=binary_consensus_network,
                                weighted_evaluations=weighted_evaluations,
                                save_model = False,
                                save_run_history = False,
                                verbose=True,
                                wandb_logging=True # Set this to true for logging, it's False by default
)