In [None]:
import time
import torch
import pickle
import os
import datetime
import numpy as np
import networkx as nx
import yaml
import types
import wandb

In [None]:
# toolkit
from gTDR.models import DAG_GNN
import gTDR.utils.DAG_GNN as utils 
from gTDR.trainers.DAG_GNN_trainer import Trainer

## Arguments & Parameters

Specify the setup in config, including:
* `seed`: (int) Random seed for reproducibility.
* `use_cuda`: (bool) Whether to use CUDA for training. If True and a GPU is available, the model will be trained on the GPU.
* `save_results`: (bool) Whether to save the training checkpoints.
* `save_folder`: (str) Directory to save the trainer's checkpoints.

In [None]:
config_filename = "../configs/DAG_GNN_synthetic_parameters.yaml"
with open(config_filename) as f:
    configs = yaml.load(f, Loader=yaml.SafeLoader)
args = types.SimpleNamespace(**configs)

Set seed for reproducibility.

In [None]:
seed = args.seed
np.random.seed(seed)
torch.manual_seed(seed)
if args.use_cuda:
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

Start `wandb` for monitoring experiment (nll_train, kl_train, mse_train, shd_train, ELBO_loss).

In [None]:
run = wandb.init(project="DAG-GNN", name="synthetic")

## Data

In this demo, we use synthetic data.

In [None]:
train_loader, valid_loader, test_loader, ground_truth_G = utils.load_synthetic_data(seed=args.seed)

## Model

You may specify these model parameters in config:

* `encoder`: (str) This determines the type of encoder used in the model. It can be either `mlp` for a multi-layer perceptron or `sem` for a structural equation model. This affects how the data is processed and transformed in the initial phase of the model.

* `decoder`: (str) This determines the type of decoder used in the model. Similar to the encoder, it can be either `mlp` or `sem`. This affects how the latent variables are transformed back to the original data space.

* `data_variable_size`: (int) This is the number of nodes in your data.

* `x_dims`: (int) This is the dimensionality of the input data.

* `z_dims`: (int) This is the dimensionality of the latent space

* `encoder_hidden`: (int) This is the size of the hidden layer in the encoder.

* `decoder_hidden`: (int) This is the size of the hidden layer in the decoder.

* `encoder_dropout`: (float) This is the dropout rate applied in the encoder.

* `decoder_dropout`: (float) This is the dropout rate applied in the decoder.

In [None]:
model = DAG_GNN(args)

## Training

You may specify these training parameters in config:

* `lr` (float): This is the learning rate for the optimizer. It determines the step size at each iteration while moving towards a minimum of a loss function.

* `lr_decay` (int): This is the step size for the learning rate scheduler. The learning rate will be reduced every `lr_decay` number of epochs.

* `gamma` (float): This is the factor by which the learning rate will be reduced at each step of the learning rate scheduler. A `gamma` of 1.0 means the learning rate will stay the same.

* `tau_A` (float): This is a regularization parameter for the adjacency matrix. It controls the degree of sparsity in the learned graph structure. A higher value of `tau_A` would enforce more sparsity, i.e., it would encourage the learned graph to have fewer edges. Conversely, a lower value would result in a graph with more edges.

* `lambda_A` (float): This is another regularization parameter for the adjacency matrix. It is used in the computation of the loss function. More specifically, it controls the contribution of the graph complexity to the overall loss. A higher value of `lambda_A` would mean that you are penalizing complex graphs more heavily. On the other hand, a lower value means you are more tolerant of complex graphs.

* `c_A` (int): This is a parameter used in the adaptive learning rate mechanism. It scales the learning rate based on the change in the graph structure. A higher value of `c_A` would lead to a more drastic reduction in the learning rate when the graph structure changes, which could help with stability but might slow down learning. A lower value of `c_A`, on the other hand, means that the learning rate stays more constant even when the graph structure changes, which could speed up learning but might lead to instability.

* `graph_threshold` (float): This is the threshold used to binarize the learned adjacency matrix. Edges with weights below this threshold are removed.

* `h_tol` (str): This is the tolerance for the stopping criterion. The learning process stops when the computed `h(A)` (a measure of the complexity of the graph structure) is below this tolerance.

* `k_max_iter` (int): This is the maximum number of iterations for the outer loop of the learning process.

* `epochs` (int): This is the number of epochs for which the model will be trained. An epoch is a complete pass over the entire training dataset.

* `optimizer` (str): This is the type of optimizer used for the learning process. It can be 
    * `Adam`, 
    * `LBFGS`, 
    * `SGD`.

In [None]:
trainer = Trainer(model, args, ground_truth_G, report_log=True)
best_ELBO_graph, best_NLL_graph, best_MSE_graph = trainer.train(train_loader, use_wandb=True)

## Evaluation (when ground truth is avaliable)

In [None]:
"""
# fdr: false discovery rate, lower the better
# tpr: true positive rate, higher the better
# fpr: false positive rate, lower the better
# shd: symetric hamming distance, lower the better
# nnz: number of nonzeros, closer to the ground truth the better
"""
best_graph = {'ELBO':best_ELBO_graph, 'NLL':best_NLL_graph, 'MSE':best_MSE_graph}
for key in best_graph:
    fdr, tpr, fpr, shd, nnz = utils.count_accuracy(ground_truth_G, nx.DiGraph(best_graph[key]))
    print('Best %s Graph Accuracy: fdr'%(key), fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

graph = trainer.origin_A.data.cpu().clone().numpy()
# various graph threshold
for thres in [0.1, 0.2, 0.3]:
    graph[np.abs(graph) < thres] = 0
    fdr, tpr, fpr, shd, nnz = utils.count_accuracy(ground_truth_G, nx.DiGraph(graph))
    print('threshold %.1f, Accuracy: fdr'%(thres), fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

In [None]:
wandb.finish()