# Testing with multiple datasets

In [9]:
import sys
sys.path.append('../')
import torch
from notebook_setup import device, smooth_graph, create_new_set_of_models, train_models_and_get_histories, update_dict
from oslow.models.oslow import OSlow
from oslow.data.synthetic.graph_generator import GraphGenerator
from oslow.data.synthetic.utils import RandomGenerator
from oslow.data.synthetic.parametric import AffineParametericDataset
from oslow.data.synthetic.nonparametric import AffineNonParametericDataset
from oslow.models.normalization import ActNorm
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import random

%load_ext autoreload
%autoreload 2
if torch.cuda.is_available():
    # Set the device to GPU k
    device = torch.device("cuda:5")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using GPU: NVIDIA A100-SXM4-80GB


In [10]:
# Block 3: Generate datasets
num_covariates = 3
num_samples = 100
true_ordering = [2, 0, 1]
batch_size = 512
lr = 0.005

graph_generator = GraphGenerator(
    num_nodes=num_covariates,
    seed=0,
    graph_type="full",
    enforce_ordering=true_ordering,
)

gaussian_noise_generator = RandomGenerator('normal', seed=10, loc=0, scale=1)
laplace_noise_generator = RandomGenerator('laplace', seed=10, loc=0, scale=1)
link_generator = RandomGenerator('uniform', seed=110, low=1, high=1)

datasets = {
    "sinusoidal": AffineParametericDataset(
        num_samples=num_samples,
        graph_generator=graph_generator,
        noise_generator=gaussian_noise_generator,
        link_generator=link_generator,
        link="sinusoid",
        perform_normalization=False,
    ),
    # "cubic": AffineParametericDataset(
    #     num_samples=num_samples,
    #     graph_generator=graph_generator,
    #     noise_generator=gaussian_noise_generator,
    #     link_generator=link_generator,
    #     link="cubic",
    #     perform_normalization=True,
    # ),
    "laplace_linear": AffineParametericDataset(
        num_samples=num_samples,
        graph_generator=graph_generator,
        noise_generator=laplace_noise_generator,
        link_generator=link_generator,
        link="linear",
        perform_normalization=False,
        additive=True,
    ),
    "nonparametric_affine": AffineNonParametericDataset(
        num_samples=1000,
        graph_generator=graph_generator,
        noise_generator=gaussian_noise_generator,
        invertibility_coefficient=0.0,
        perform_normalization=False,
        additive=False,
    ),
    "nonparametric_additive": AffineNonParametericDataset(
        num_samples=1000,
        graph_generator=graph_generator,
        noise_generator=gaussian_noise_generator,
        invertibility_coefficient=0.0,
        perform_normalization=False,
        additive=True,
    ),
    "nonparametric_almost_invertible": AffineNonParametericDataset(
        num_samples=1000,
        graph_generator=graph_generator,
        noise_generator=gaussian_noise_generator,
        invertibility_coefficient=1.0,
        perform_normalization=False,
        additive=False,
    )
}

print("Datasets generated:")
for name in datasets.keys():
    print(f"- {name}")


Datasets generated:
- sinusoidal
- laplace_linear
- nonparametric_affine
- nonparametric_additive
- nonparametric_almost_invertible


# Recursive Ordering algorithm

See the determined_ordering function. This is a recursive algorithm to determine the causal ordering of covariates.

* For each starting covariate, we will sample the remaining covariates to form complete permutations of length num_total_covariates by sampling uniformly without replacement. We will create and train a unique OSLow model for each permutation that we have sampled.
* Once training has finished, for each starting covariate, calculate the average of the log probabilities of the data over all the OSLow models with that starting covariate.
* The starting covariate used in the OSLow models achieving the lowest log probability on average (in expectation) indicates which covariate comes first.
* Fix the first element and then continue this recursively until all the elements have been ordered.

For example, if we have 3 variables [0, 1, 2], then for the starting covariate 0, we would have the models conditioned on permutations [0, 1, 2] and [0, 2, 1]. We would take the final log probability for the [0, 1, 2] model and the final log probability for the [0, 2, 1] model and find their average final log probability by taking their mean. 

In [11]:
def sample_permutations(remaining, num_samples=10):
    """
    Generate an unbiased sample of permutations from the remaining covariates.

    Args:
    remaining (list or tuple): List of remaining covariates to permute.
    num_samples (int): Number of permutation samples to generate.

    Returns:
    List[List[int]]: List of sampled permutations of the remaining covariates.
    """
    if len(remaining) == 1:
        return [[remaining[0]]] * num_samples
    sampled_permutations = []
    for _ in range(num_samples):
        # Create a copy of the remaining list to shuffle
        perm = list(remaining)
        # Shuffle the list in-place
        random.shuffle(perm)
        sampled_permutations.append(perm)
    
    return sampled_permutations


def create_oslow_model_with_ordering(ordering, additive=False, base_distribution=None):
    if base_distribution is None:
        base_distribution = torch.distributions.Normal(loc=0, scale=1)
    
    return OSlow(
        in_features=len(ordering),
        layers=[100, 100],
        dropout=None,
        residual=False,
        activation=torch.nn.LeakyReLU(),
        additive=additive,
        num_transforms=1,
        normalization=ActNorm,
        base_distribution=base_distribution,
        ordering=torch.tensor(ordering)
    )

def initialize_models_and_optimizers(determined_ordering, remaining_covariates, num_total_covariates, model_params):
    """
    Initialize models, optimizers, and histories for all possible permutations of remaining covariates.

    Args:
    - determined_ordering: list of covariates that have already been ordered
    - remaining_covariates: list of remaining covariates to consider
    - num_total_covariates: total number of covariates in the dataset (int)
    - dataset: the AffineParametricDataset or AffineNonParametricDataset to be used for training
    - model_params: additional parameters for model creation

    Returns:
    - models: dictionary of OSlow models for each permutation
    - histories: dictionary to store training histories for each model
    - dataloader: DataLoader for the given dataset
    """
    all_permutations = []
    for i in range(len(remaining_covariates)):
        start_covariate = remaining_covariates[i]
        if i == 0:
            permutation_remaining_covariates = remaining_covariates[1:]
        elif i == len(remaining_covariates) - 1:
            permutation_remaining_covariates = remaining_covariates[:i]
        else:
            permutation_remaining_covariates = remaining_covariates[:i] + remaining_covariates[i+1:]
        
        samples = sample_permutations(permutation_remaining_covariates, num_samples=10)
        all_permutations.extend([[start_covariate] + sample for sample in samples])
    # for each starting covariate, sample the permutations and then add this to a list of all the permutations
    
    models = {}
    histories = {}

    for perm in all_permutations:
        full_perm = determined_ordering + perm
        perm_key = tuple(full_perm)  # Use tuple as dictionary key

        # Create model using the provided model_params
        model = create_oslow_model_with_ordering(full_perm, **model_params).to(device)

        if perm_key not in models:
            models[perm_key] = []
            histories[perm_key] = []

        models[perm_key].append(model)
        histories[perm_key].append([])

        # Create initial permutation matrix
        perm_matrix = torch.zeros((num_total_covariates, num_total_covariates))
        for i, j in enumerate(full_perm):
            perm_matrix[i, j] = 1
        perm_matrix = perm_matrix.to(device)

        # Store permutation matrix
        models[perm_key][-1].perm_matrix = perm_matrix

    return models, histories

def determine_ordering(remaining_covariates, 
                       dataloader, 
                       model_params, 
                       training_params, 
                       determined_ordering=None, 
                       num_total_covariates=None, 
                       true_ordering=None, 
                       depth=0):
    """
    Input:
    - remaining_covariates: list of remaining covariates to consider
    - dataloader: the AffineParametricDataset or AffineNonParametricDataset to be used for training
    - model_params: additional parameters for model creation
    - determined_ordering: list of covariates that have already been ordered
    - num_total_covariates: total number of covariates in the dataset
    - true_ordering: list of true ordering of covariates
    - depth: current depth of recursion

    Determine the correct causal ordering greedily by choosing the best covariate to come next.

    Returns:
    - determined_ordering: list of covariates in the determined ordering
    """
    if determined_ordering is None:
        determined_ordering = []
    if num_total_covariates is None:
        num_total_covariates = len(remaining_covariates)

    if len(remaining_covariates) == 1:
        return determined_ordering + remaining_covariates

    print(f"\nCurrent stage of recursion at depth {depth}:")
    print(f"  Ground truth ordering: {true_ordering}")
    print(f"  Fixed ordering so far: {determined_ordering}")
    print(f"  Remaining covariates: {remaining_covariates}")

    # Initialize models and optimizers for the remaining covariates
    models, histories = initialize_models_and_optimizers(
        determined_ordering, remaining_covariates, num_total_covariates, model_params
    )

    # Train each model sequentially
    for perm_key in tqdm(models, desc=f"Training models for {len(remaining_covariates)} remaining covariates"):
        for model_index in range(len(models[perm_key])):
            model = models[perm_key][model_index]
            optimizer = torch.optim.Adam(model.parameters(), lr=training_params['lr'])
            
            for epoch in range(training_params['epoch_count']):
                for batch, in dataloader:
                    batch = batch.to(device)
                    log_prob = model.log_prob(batch).mean()
                    loss = -log_prob
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    histories[perm_key][model_index].append(log_prob.item())
        
    # Calculate average log probabilities for each starting covariate
    starting_covariate_avg_log_probs = {}
    for covariate in remaining_covariates:
        relevant_perms = [perm for perm in models.keys() if perm[len(determined_ordering)] == covariate]
        total_log_prob = 0
        total_models = 0
        
        for perm in relevant_perms:
            for model_history in histories[perm]:
                # Add the final log probability of this model
                total_log_prob += model_history[-1]
                total_models += 1
        
        # Calculate the average log probability for this starting covariate
        if total_models > 0:
            avg_log_prob = total_log_prob / total_models
        else:
            avg_log_prob = float('-inf')  # or some other appropriate value
        
        starting_covariate_avg_log_probs[covariate] = avg_log_prob

    # Choose the best starting covariate
    next_covariate = max(starting_covariate_avg_log_probs, key=starting_covariate_avg_log_probs.get)
    
    # Sanity check
    print(f"\nSanity Check Results at depth {depth}:")
    print(f"  Average log probabilities for each starting covariate:")
    for covariate, avg_log_prob in starting_covariate_avg_log_probs.items():
        print(f"    Covariate {covariate}: {avg_log_prob}")
    print(f"  Best covariate at this stage: {next_covariate}")
    print(f"  Current ordering so far (including new covariate): {determined_ordering + [next_covariate]}")
    print(f"  True ordering so far (including new covariate): {true_ordering[:depth+1]}")
    
    remaining_covariates = [i for i in remaining_covariates if i != next_covariate]
    return determine_ordering(remaining_covariates, dataloader, model_params, training_params, determined_ordering + [next_covariate], num_total_covariates, true_ordering, depth+1)

Run the tests for various datasets.

In [14]:
# Block 6: Function to run test on a single dataset

def run_single_test(dataset_name):
    dataset = datasets[dataset_name]
    print(f"\nTesting on {dataset_name} dataset:")
    
    if dataset_name == "laplace_linear":
        model_params = {
            "additive": True,
            "base_distribution": torch.distributions.Laplace(loc=0, scale=1)
        }
    elif "nonparametric_additive" in dataset_name:
        model_params = {"additive": True}
    else:
        model_params = {}
    
    training_params = {
        "batch_size": 512,
        "lr": 0.005,
        "epoch_count": 30
    }

    # Create dataloader from the dataset
    tensor_samples = torch.tensor(dataset.samples.values).float().clone().detach()
    torch_dataset = TensorDataset(tensor_samples)
    dataloader = DataLoader(
        torch_dataset, 
        batch_size=training_params['batch_size'], 
        shuffle=True,
        pin_memory=True,
        num_workers=4
    )
    
    discovered_ordering = determine_ordering(
        remaining_covariates=list(range(num_covariates)),
        dataloader=dataloader,
        model_params=model_params,
        training_params=training_params,
        true_ordering=true_ordering,
        num_total_covariates=num_covariates
    )
    
    print(f"True ordering: {true_ordering}")
    print(f"Discovered ordering: {discovered_ordering}")
    
    return {
        "true_ordering": true_ordering,
        "discovered_ordering": discovered_ordering
    }

# You can now run tests on individual datasets like this:
# result = run_single_test("sinusoidal")

# Run the tests
test_results = {}
for dataset_name in datasets.keys():
    test_results[dataset_name] = run_single_test(dataset_name)


Testing on sinusoidal dataset:

Current stage of recursion at depth 0:
  Ground truth ordering: [2, 0, 1]
  Fixed ordering so far: []
  Remaining covariates: [0, 1, 2]


TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/amli/miniconda3/envs/oslow/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/amli/miniconda3/envs/oslow/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_3968828/3802794552.py", line 11, in forward
    return torch.stack([model(x) for model in self.models])
TypeError: expected Tensor as element 0 in argument 0, but got tuple
