In [40]:
import sys
sys.path.append('../')
import torch
import matplotlib.pyplot as plt
from notebook_setup import device, smooth_graph, create_new_set_of_models, train_models_and_get_histories, update_dict
from oslow.models.oslow import OSlow
from oslow.data.synthetic.graph_generator import GraphGenerator
from oslow.data.synthetic.utils import RandomGenerator
from oslow.data.synthetic.parametric import AffineParametericDataset
from oslow.models.normalization import ActNorm
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import numpy as np
from itertools import permutations

%load_ext autoreload
%autoreload 2
if torch.cuda.is_available():
    # Set the device to GPU k
    device = torch.device("cuda:0")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using GPU: NVIDIA A100-SXM4-80GB


Generate a causal graph using the GraphGenerator class. Here, we specify the number of nodes (3) and enforce a specific ordering [1, 0, 2]. This graph will be used as the ground truth for our causal discovery experiment.

In [41]:
num_covariates = 4
true_ordering = [1, 2, 0, 3] # [2, 1, 0] # [1, 3, 0, 2, 4]
graph_generator = GraphGenerator(
    num_nodes=num_covariates,
    seed=0,
    graph_type="full",
    enforce_ordering=true_ordering,
)
graph = graph_generator.generate_dag()

Here, we generate synthetic data based on the causal graph. We create an AffineParametericDataset 
with sinusoidal links between variables. This dataset will be used to train our OSlow models and 
test our causal discovery strategy.

In [42]:
num_samples = 50000
gaussian_noise_generator = RandomGenerator('normal', seed=10, loc=0, scale=1)
link_generator = RandomGenerator('uniform', seed=110, low=1, high=1)

dset_sinusoidal = AffineParametericDataset(
    num_samples=num_samples,
    graph_generator=graph_generator,  # Not graph, requires a GraphGenerator object that generates the DAG
    noise_generator=gaussian_noise_generator,
    link_generator=link_generator,
    link="sinusoid",
    perform_normalization=False,
)

This cell defines the settings for our OSlow models and their training process. We specify the model 
architecture (additive or not, number of transforms, normalization method) and training parameters 
(batch size, learning rate, number of epochs). These settings will be used for all OSlow models we create.

In [43]:
base_model_instantiation_setting = dict(
    additive = False,
    num_transforms = 1,
    normalization = ActNorm,
    base_distribution = torch.distributions.Normal(loc=0, scale=1),
    use_standard_ordering=False,
)

base_training_setting = dict(
    batch_size=512,
    lr=0.005,
    epoch_count=10,
    use_standard_ordering=False,
)
batch_size = base_training_setting['batch_size']
lr = base_training_setting['lr']
epoch_count = base_training_setting['epoch_count']
use_standard_ordering = base_training_setting['use_standard_ordering']

In [44]:
tensor_samples = torch.tensor(dset_sinusoidal.samples.values).float()
torch_dataset = TensorDataset(tensor_samples)
torch_dataloader = DataLoader(torch_dataset, batch_size=base_training_setting['batch_size'], shuffle=True)

# Sanity Check

Check that the a model conditioned on the true ordering of covariates actually corresponds to the highest log-likelihood (lowest loss).

In [45]:
# # comment to select the ones you want to plot
# subset_to_consider = [
#     "sinusoidal", # only show the last 0.1 of the epochs
# ]
# all_models_trained = {}
# for key in subset_to_consider:
#     print("key: ", key)
#     print("dataset: ", dset_sinusoidal)
#     print("all_models_trained: ", all_models_trained)
#     # print("all_models_trained[key]: ", all_models_trained[key])
#     dset = dset_sinusoidal
#     all_models_trained[key] = create_new_set_of_models(**base_model_instantiation_setting)
#     all_histories = train_models_and_get_histories(**base_training_setting, dset=dset, all_models=all_models_trained[key])

#     smoothed_histories = smooth_graph(all_histories, window_size=100)
#     # create two subplots and unpack the output array immediately
#     plt.figure(figsize=(15, 5))
#     plt.subplot(121)
#     plt.title(f"full loss graph {key}")
#     plt.xlabel("epochs")
#     plt.ylabel("loss")
#     for order in all_histories.keys():
#         plt.plot(all_histories[order], label=order)
#     plt.legend()
#     plt.subplot(122)
#     plt.xlabel("epochs")
#     plt.ylabel("loss")
#     for order in all_histories.keys():
#         ending_portion = int(0.1 * len(smoothed_histories[order]))
#         plt.plot(smoothed_histories[order][-ending_portion:], label=order)
#     plt.title(f"{key} ending portion smoothed")
#     plt.legend()
#     plt.show()

In [46]:
def create_oslow_model_with_ordering(ordering):
    return OSlow(
        in_features=len(ordering),
        layers=[100, 100],
        dropout=None,
        residual=False,
        activation=torch.nn.LeakyReLU(),
        additive=False,
        num_transforms=1,
        normalization=ActNorm,
        base_distribution=torch.distributions.Normal(loc=0, scale=1),
        ordering=torch.tensor(ordering)  # Pass the ordering here
    )

# Recursive Ordering algorithm

See the determined_ordering function. This is a recursive algorithm to determine the causal ordering of covariates.

* For each starting covariate, we will sample the remaining covariates to form complete permutations of length num_total_covariates by sampling uniformly without replacement. We will create and train a unique OSLow model for each permutation that we have sampled.
* Once training has finished, for each starting covariate, calculate the average of the log probabilities of the data over all the OSLow models with that starting covariate.
* The starting covariate used in the OSLow models achieving the lowest log probability on average (in expectation) indicates which covariate comes first.
* Fix the first element and then continue this recursively until all the elements have been ordered.

For example, if we have 3 variables [0, 1, 2], then for the starting covariate 0, we would have the models conditioned on permutations [0, 1, 2] and [0, 2, 1]. We would take the final log probability for the [0, 1, 2] model and the final log probability for the [0, 2, 1] model and find their average final log probability by taking their mean. 

In [49]:
def sample_permutations(remaining):
    """
    From a list or tuple of remaining covariates, return a list of all possible permutations, each of which is a list.

    Returns:
    List[List[int]]: List of all possible permutations of the remaining covariates
    """
    return [list(p) for p in permutations(remaining)]

def initialize_models_and_optimizers(determined_ordering, remaining_covariates, num_total_covariates):
    """
    Initialize models, optimizers, and histories for all possible permutations of remaining covariates.

    Args:
    - determined_ordering: list of covariates that have already been ordered
    - remaining_covariates: list of remaining covariates to consider
    - num_total_covariates: total number of covariates in the dataset

    Returns:
    - models: dictionary of OSlow models for each permutation
    - optimizers: dictionary of optimizers for each model
    - histories: dictionary to store training histories for each model
    """
    all_permutations = sample_permutations(remaining_covariates)
    models = {}
    histories = {}

    for perm in all_permutations:
        full_perm = determined_ordering + perm
        perm_key = tuple(full_perm)  # Use tuple as dictionary key

        # Create model
        model = create_oslow_model_with_ordering(full_perm).to(device)
        models[perm_key] = model

        # Initialize history
        histories[perm_key] = []

        # Create initial permutation matrix
        perm_matrix = torch.zeros((num_total_covariates, num_total_covariates))
        for i, j in enumerate(full_perm):
            perm_matrix[i, j] = 1
        perm_matrix = perm_matrix.to(device)

        # Store permutation matrix
        models[perm_key].perm_matrix = perm_matrix

    return models, histories


def determine_ordering(remaining_covariates, determined_ordering=None, num_total_covariates=None, true_ordering=None, true_model=None, depth=0):
    """
    Input:
    - remaining_covariates: list of remaining covariates to consider
    - determined_ordering: list of covariates that have already been ordered
    - num_total_covariates: total number of covariates in the dataset (should be 3)
    - true_ordering: list of true ordering of covariates
    - true_model: a model that is conditioned on the true ordering
    - depth: current depth of recursion

    Determine the correct causal ordering greedily by choosing the best covariate to come next.
    Do this by searching over the remaining covariates and for each covariate, run the following steps:
    For each covariate, sample the remaining covariates to form complete permutations of length num_total_covariates by sampling uniformly without replacement. 
    We will create and train a unique OSLow model for each permutation that we have sampled.
    Once training has finished, for each starting covariate, calculate the average of the log probabilities of the data over all the OSLow models with that starting covariate.
    
    The starting covariate used in the OSLow models achieving the lowest log probability on average (in expectation) indicates which covariate comes first.
    Fix the first element and then continue this recursively until all the elements have been ordered.

    Output:
    - determined_ordering: list of covariates in the determined ordering
    """
    if determined_ordering is None:
        determined_ordering = []
    if num_total_covariates is None:
        num_total_covariates = len(remaining_covariates)

    if len(remaining_covariates) == 1:
        return determined_ordering + remaining_covariates

    print(f"\nCurrent stage of recursion:")
    print(f"  Ground truth ordering: {true_ordering}")
    print(f"  Fixed ordering so far: {determined_ordering}")
    print(f"  Remaining covariates: {remaining_covariates}")

    # Initialize models and optimizers for the remaining covariates
    models, histories = initialize_models_and_optimizers(determined_ordering, remaining_covariates, num_total_covariates)
    
    # Create true model only once at the start of recursion
    if true_ordering is not None and true_model is None:
        true_model = create_oslow_model_with_ordering(true_ordering).to(device)
        true_optimizer = torch.optim.Adam(true_model.parameters(), lr=lr)

    i = 0
    print("models.keys(): ", models.keys())
    for perm_key in tqdm(models, desc=f"Training models for {len(remaining_covariates)} remaining covariates"):
        model = models[perm_key]
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        for epoch in range(epoch_count):
            for batch, in torch_dataloader:
                batch = batch.to(device)
                log_prob = model.log_prob(batch).mean()
                loss = -log_prob
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                histories[perm_key].append(log_prob.item())
            
                # Train true ordering model only once at the start
                if true_ordering is not None and true_model is not None and determined_ordering == [] and i == 0:
                    true_log_prob = true_model.log_prob(batch).mean()
                    true_loss = -true_log_prob
                    
                    true_optimizer.zero_grad()
                    true_loss.backward()
                    true_optimizer.step()
        i += 1
        
    if determined_ordering == [] and true_model is not None:
        print("Final log probability of true model after training: ", true_log_prob.item())

    # final_log_probs = {perm_key: histories[perm_key][-1] for perm_key in models}
    # best_perm_key = max(final_log_probs, key=final_log_probs.get)
    # next_covariate = best_perm_key[len(determined_ordering)]

    # Calculate average log probabilities for each starting covariate
    starting_covariate_avg_log_probs = {}
    for covariate in remaining_covariates:
        relevant_perms = [perm for perm in models.keys() if perm[len(determined_ordering)] == covariate]
        print(f"  Covariate {covariate} has {len(relevant_perms)} relevant permutations")
        print(f"  Relevant permutations: {relevant_perms}")
        avg_log_prob = sum(histories[perm][-1] for perm in relevant_perms) / len(relevant_perms)
        starting_covariate_avg_log_probs[covariate] = avg_log_prob

    # Choose the best starting covariate
    next_covariate = max(starting_covariate_avg_log_probs, key=starting_covariate_avg_log_probs.get)
    
    # Sanity check
    if true_ordering is not None and true_model is not None:
        print(f"\nSanity Check Results at depth {depth}:")
        print(f"  Best found average log probability: {starting_covariate_avg_log_probs[next_covariate]}")
        print(f"  Average log probabilities for each starting covariate:")
        for covariate, avg_log_prob in starting_covariate_avg_log_probs.items():
            print(f"    Covariate {covariate}: {avg_log_prob}")
        # print(f"  Best found permutation: {best_perm_key}")
        print(f"  Best covariate at this stage: {next_covariate}")
        print(f"  Current ordering so far (including new covariate): {determined_ordering + [next_covariate]}")
        print(f"  True ordering so far (including new covariate): {true_ordering[:depth+1]}")
    
    remaining_covariates = [i for i in remaining_covariates if i != next_covariate]
    return determine_ordering(remaining_covariates, determined_ordering + [next_covariate], num_total_covariates, true_ordering, true_model, depth+1)

In [48]:
discovered_ordering = determine_ordering(list(range(num_covariates)), true_ordering=true_ordering)
print(f"The full inferred causal ordering is: {discovered_ordering}")
print(f"The true causal ordering is: {true_ordering}")



Current stage of recursion:
  Fixed ordering so far: []
  Remaining covariates: [0, 1, 2, 3]
models.keys():  dict_keys([(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), (0, 3, 1, 2), (0, 3, 2, 1), (1, 0, 2, 3), (1, 0, 3, 2), (1, 2, 0, 3), (1, 2, 3, 0), (1, 3, 0, 2), (1, 3, 2, 0), (2, 0, 1, 3), (2, 0, 3, 1), (2, 1, 0, 3), (2, 1, 3, 0), (2, 3, 0, 1), (2, 3, 1, 0), (3, 0, 1, 2), (3, 0, 2, 1), (3, 1, 0, 2), (3, 1, 2, 0), (3, 2, 0, 1), (3, 2, 1, 0)])


Training models for 4 remaining covariates: 100%|██████████| 24/24 [06:01<00:00, 15.06s/it]


Final log probability of true model after training:  -10.455864906311035
  Covariate 0 has 6 relevant permutations
  Relevant permutations: [(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), (0, 3, 1, 2), (0, 3, 2, 1)]
  Covariate 1 has 6 relevant permutations
  Relevant permutations: [(1, 0, 2, 3), (1, 0, 3, 2), (1, 2, 0, 3), (1, 2, 3, 0), (1, 3, 0, 2), (1, 3, 2, 0)]
  Covariate 2 has 6 relevant permutations
  Relevant permutations: [(2, 0, 1, 3), (2, 0, 3, 1), (2, 1, 0, 3), (2, 1, 3, 0), (2, 3, 0, 1), (2, 3, 1, 0)]
  Covariate 3 has 6 relevant permutations
  Relevant permutations: [(3, 0, 1, 2), (3, 0, 2, 1), (3, 1, 0, 2), (3, 1, 2, 0), (3, 2, 0, 1), (3, 2, 1, 0)]

Sanity Check Results at depth 0:
  Best found average log probability: -10.748207569122314
  Average log probabilities for each starting covariate:
    Covariate 0: -10.938713391621908
    Covariate 1: -10.748207569122314
    Covariate 2: -10.842563470204672
    Covariate 3: -11.98186206817627
  Best covariate at thi

Training models for 3 remaining covariates: 100%|██████████| 6/6 [01:14<00:00, 12.34s/it]


  Covariate 0 has 2 relevant permutations
  Relevant permutations: [(1, 0, 2, 3), (1, 0, 3, 2)]
  Covariate 2 has 2 relevant permutations
  Relevant permutations: [(1, 2, 0, 3), (1, 2, 3, 0)]
  Covariate 3 has 2 relevant permutations
  Relevant permutations: [(1, 3, 0, 2), (1, 3, 2, 0)]

Sanity Check Results at depth 1:
  Best found average log probability: -10.623709678649902
  Average log probabilities for each starting covariate:
    Covariate 0: -10.669552326202393
    Covariate 2: -10.623709678649902
    Covariate 3: -10.976052284240723
  Best covariate at this stage: 2
  Current ordering so far (including new covariate): [1, 2]
  True ordering so far (including new covariate): [1, 2]

Current stage of recursion:
  Fixed ordering so far: [1, 2]
  Remaining covariates: [0, 3]
models.keys():  dict_keys([(1, 2, 0, 3), (1, 2, 3, 0)])


Training models for 2 remaining covariates: 100%|██████████| 2/2 [00:24<00:00, 12.41s/it]

  Covariate 0 has 1 relevant permutations
  Relevant permutations: [(1, 2, 0, 3)]
  Covariate 3 has 1 relevant permutations
  Relevant permutations: [(1, 2, 3, 0)]

Sanity Check Results at depth 2:
  Best found average log probability: -10.442317008972168
  Average log probabilities for each starting covariate:
    Covariate 0: -10.442317008972168
    Covariate 3: -10.660240173339844
  Best covariate at this stage: 0
  Current ordering so far (including new covariate): [1, 2, 0]
  True ordering so far (including new covariate): [1, 2, 0]
The full inferred causal ordering is: [1, 2, 0, 3]
The true causal ordering is: [1, 2, 0, 3]



