In [1]:
import sys
sys.path.append('../')
import torch
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from notebooks.notebook_setup import device, smooth_graph, create_new_set_of_models, train_models_and_get_histories, update_dict
from oslow.models.oslow import OSlowTest
from oslow.data.synthetic.graph_generator import GraphGenerator
from oslow.data.synthetic.utils import RandomGenerator
from oslow.data.synthetic.parametric import AffineParametericDataset
from oslow.data.synthetic.nonparametric import AffineNonParametericDataset
from oslow.models.normalization import ActNorm
from oslow.training.trainer import Trainer
from oslow.config import GumbelTopKConfig, BirkhoffConfig, GumbelSinkhornStraightThroughConfig, ContrastiveDivergenceConfig
from tqdm import tqdm
import numpy as np

%load_ext autoreload
%autoreload 2
device = 'cuda:0'
print(device)

num_samples = 5000
num_samples_topk = 128
permutation_batch_size = 128
flow_batch_size = 128
epochs = 50
flow_lr = 0.0001
perm_lr = 0.000001
flow_freq = 4
perm_freq = 4
num_nodes = 5

graph_generator = GraphGenerator(
    num_nodes=num_nodes,
    seed=12,
    graph_type="full",
    enforce_ordering=[i for i in range(num_nodes)],
)
graph = graph_generator.generate_dag()

# These generators are also needed to generate the data
gaussian_noise_generator = RandomGenerator('normal', seed=30, loc=0, scale=1)
link_generator = RandomGenerator('uniform', seed=1100, low=0.5, high=1.5)

# parameteric with sin(x) + x non-linearity and softplus
dset_sinusoidal = AffineParametericDataset(
    num_samples=num_samples,
    graph=graph,
    noise_generator=gaussian_noise_generator,
    link_generator=link_generator,
    link="sinusoid",
    perform_normalization=False,
    standard=True,
)
class CustomTensorDataset(torch.utils.data.Dataset):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, tensor: torch.Tensor) -> None:
        self.tensor = tensor

    def __getitem__(self, index):
        return self.tensor[index]

    def __len__(self):
        return len(self.tensor)


dataset = CustomTensorDataset(torch.tensor(dset_sinusoidal.samples.values).float())
flow_dataloader = DataLoader(dataset, batch_size=flow_batch_size, shuffle=True)
permutation_dataloader = DataLoader(dataset, batch_size=permutation_batch_size, shuffle=True)

cuda:0


In [2]:
%autoreload 2

from oslow.models.oslow import OSlow
import wandb

torch.random.manual_seed(101)
model = OSlow(in_features=num_nodes,
              layers=[128, 64, 128],
              dropout=None,
              residual=False,
              activation=torch.nn.LeakyReLU(),
              additive=False,
              num_transforms=1,
              normalization=ActNorm,
              base_distribution=torch.distributions.Normal(loc=0, scale=1),
              ordering=None)


def flow_optimizer(params): return torch.optim.AdamW(params, lr=flow_lr)
def perm_optimizer(params): return torch.optim.AdamW(params, lr=perm_lr)


permutation_learning_config = GumbelTopKConfig(
    num_samples=num_samples_topk,
    buffer_size=10,
    buffer_update=10,
    set_gamma_uniform=True,
)


# permutation_learning_config = GumbelSinkhornStraightThroughConfig(temp=0.1, iters=20)
temperature_scheduler = 'linear'
temperature = 1.0

birkhoff_config = None if num_nodes > 4 else BirkhoffConfig(
    num_samples=100, frequency=1, print_legend=False)
trainer = Trainer(model=model,
                  dag=graph,
                  flow_dataloader=flow_dataloader,
                  perm_dataloader=permutation_dataloader,
                  flow_optimizer=flow_optimizer,
                  permutation_optimizer=perm_optimizer,
                  flow_frequency=flow_freq,
                  temperature=temperature,
                  temperature_scheduler=temperature_scheduler,
                  permutation_frequency=perm_freq,
                  max_epochs=epochs,
                  flow_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
                  permutation_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
                  permutation_learning_config=permutation_learning_config,
                  birkhoff_config=birkhoff_config,
                  device=device,
                  perform_final_buffer_search=True,)
wandb.init(project="notebooks", entity="ordered-causal-discovery",
            tags=[
                permutation_learning_config.method,
                f"num_nodes-{num_nodes}",
                f"epochs-{epochs}",
                f"base-temperature-{temperature}",
                f"temperature-scheduling-{temperature_scheduler}",
                "no-sigmoid",
            ],)
trainer.train()
wandb.finish()

2024-02-08 20:30:23 ERROR    Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhamidrezakamkari[0m ([33mordered-causal-discovery[0m). Use [1m`wandb login --relogin`[0m to force relogin


KeyboardInterrupt: 