In [4]:
import sys
sys.path.append('../')
import torch
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from notebooks.notebook_setup import device, smooth_graph, create_new_set_of_models, train_models_and_get_histories, update_dict
from oslow.models.oslow import OSlowTest
from oslow.data.synthetic.graph_generator import GraphGenerator
from oslow.data.synthetic.utils import RandomGenerator
from oslow.data.synthetic.parametric import AffineParametericDataset
from oslow.data.synthetic.nonparametric import AffineNonParametericDataset
from oslow.models.normalization import ActNorm
from oslow.training.trainer import Trainer
from oslow.config import GumbelTopKConfig, BirkhoffConfig, SoftSortConfig
from tqdm import tqdm
import numpy as np

%load_ext autoreload
%autoreload 2
device = 'cuda:4'
print(device)

num_samples = 128
permutation_batch_size = 128
flow_batch_size = 128
epochs = 20000
flow_lr = 0.005
perm_lr = 0.005
flow_freq = 1
perm_freq = 4
num_nodes = 10

graph_generator = GraphGenerator(
    num_nodes=num_nodes,
    seed=12,
    graph_type="full",
    enforce_ordering=[i for i in range(num_nodes)],
)
graph = graph_generator.generate_dag()

# These generators are also needed to generate the data
gaussian_noise_generator = RandomGenerator('normal', seed=30, loc=0, scale=1)
link_generator = RandomGenerator('uniform', seed=1100, low=1, high=1)

# parameteric with sin(x) + x non-linearity and softplus
dset_sinusoidal = AffineParametericDataset(
    num_samples=num_samples,
    graph=graph,
    noise_generator=gaussian_noise_generator,
    link_generator=link_generator,
    link="sinusoid",
    perform_normalization=False,
)
class CustomTensorDataset(torch.utils.data.Dataset):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, tensor: torch.Tensor) -> None:
        self.tensor = tensor

    def __getitem__(self, index):
        return self.tensor[index]

    def __len__(self):
        return len(self.tensor)


dataset = CustomTensorDataset(torch.tensor(dset_sinusoidal.samples.values).float())
flow_dataloader = DataLoader(dataset, batch_size=flow_batch_size, shuffle=True)
permutation_dataloader = DataLoader(dataset, batch_size=permutation_batch_size, shuffle=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cuda:2


  return np.where(x > threshold, x, np.log(1 + np.exp(x)))
  return np.where(x > threshold, x, np.log(1 + np.exp(x)))
  return np.where(x > threshold, x, np.log(1 + np.exp(x)))
  return np.where(x > threshold, x, np.log(1 + np.exp(x)))


In [5]:
import wandb

torch.random.manual_seed(42)
model = OSlowTest(
    in_features=num_nodes,
    base_matrix=torch.eye(num_nodes),
)


def flow_optimizer(params): return torch.optim.Adam(params, lr=flow_lr)
def perm_optimizer(params): return torch.optim.Adam(params, lr=perm_lr)


permutation_learning_config = SoftSortConfig(temp=0.1, iters=20)


# permutation_learning_config = GumbelSinkhornStraightThroughConfig(temp=0.1, iters=20)
for temperature_scheduler in ['linear', 'constant']:
    temperature = 1.

    birkhoff_config = None if num_nodes > 4 else BirkhoffConfig(
        num_samples=100, frequency=1, print_legend=False)
    trainer = Trainer(model=model,
                      dag=graph,
                      flow_dataloader=flow_dataloader,
                      perm_dataloader=permutation_dataloader,
                      flow_optimizer=flow_optimizer,
                      permutation_optimizer=perm_optimizer,
                      flow_frequency=flow_freq,
                      temperature=temperature,
                      temperature_scheduler=temperature_scheduler,
                      permutation_frequency=perm_freq,
                      max_epochs=epochs,
                      flow_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
                      permutation_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
                      permutation_learning_config=permutation_learning_config,
                      birkhoff_config=birkhoff_config,
                      device=device)
    wandb.init(project="notebooks", entity="ordered-causal-discovery",
               tags=[
                   permutation_learning_config.method,
                   f"num_nodes-{num_nodes}",
                   f"epochs-{epochs}",
                   f"base-temperature-{temperature}",
                   f"temperature-scheduling-{temperature_scheduler}",
                   "no-sigmoid",
               ],)
    trainer.train()
    wandb.finish()



VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
flow/loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
flow/step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
permutation/backward_penalty,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
permutation/loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
permutation/step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
permutation/temperature,███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
epoch,4000.0
flow/loss,-10.0
flow/step,4001.0
permutation/backward_penalty,0.0
permutation/loss,-10.0
permutation/step,16004.0
permutation/temperature,0.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112366689162122, max=1.0…

In [None]:
import wandb


wandb.init(project="notebooks", entity="ordered-causal-discovery",
           tags=[
               permutation_learning_config.method,
               f"num_nodes-{num_nodes}",
               f"epochs-{epochs}",
               f"base-temperature-{temperature}",
               f"temperature-scheduling-{temperature_scheduler}",
               "no-sigmoid",
           ],)
trainer.train()

2024-02-07 18:02:14 ERROR    Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhamidrezakamkari[0m ([33mordered-causal-discovery[0m). Use [1m`wandb login --relogin`[0m to force relogin
