In [9]:
import sys
sys.path.append('../')
import torch
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from notebooks.notebook_setup import device, smooth_graph, create_new_set_of_models, train_models_and_get_histories, update_dict
from oslow.models.oslow import OSlowTest
from oslow.data.synthetic.graph_generator import GraphGenerator
from oslow.data.synthetic.utils import RandomGenerator
from oslow.data.synthetic.parametric import AffineParametericDataset
from oslow.data.synthetic.nonparametric import AffineNonParametericDataset
from oslow.models.normalization import ActNorm
from oslow.training.trainer import Trainer
from oslow.training.permutation_learning.buffered_methods import GumbelTopK
from tqdm import tqdm
import numpy as np

%load_ext autoreload
%autoreload 2
device = 'cuda:0'
print(device)

num_samples = 1000
permutation_batch_size = 128
flow_batch_size = 128
epochs = 500
flow_lr = 0.001
perm_lr = 0.00001
flow_freq = 1
perm_freq = 0
num_nodes = 4

graph_generator = GraphGenerator(
    num_nodes=num_nodes,
    seed=12,
    graph_type="full",
    enforce_ordering=[i for i in range(num_nodes)],
)

# These generators are also needed to generate the data
gaussian_noise_generator = RandomGenerator('normal', seed=30, loc=0, scale=1)
link_generator = RandomGenerator('uniform', seed=1100, low=1, high=1)

# parameteric with sin(x) + x non-linearity and softplus
dset_sinusoidal = AffineParametericDataset(
    num_samples=num_samples,
    graph_generator=graph_generator,
    noise_generator=gaussian_noise_generator,
    link_generator=link_generator,
    link="sinusoid",
    perform_normalization=False,
    standard=True,
)
class CustomTensorDataset(torch.utils.data.Dataset):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, tensor: torch.Tensor) -> None:
        self.tensor = tensor

    def __getitem__(self, index):
        return self.tensor[index]

    def __len__(self):
        return len(self.tensor)


dataset = CustomTensorDataset(torch.tensor(dset_sinusoidal.samples.values).float())
flow_dataloader = DataLoader(dataset, batch_size=flow_batch_size, shuffle=True)
permutation_dataloader = DataLoader(dataset, batch_size=permutation_batch_size, shuffle=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cuda:0


In [10]:
%autoreload 2

from oslow.visualization.birkhoff import get_all_permutation_matrices
from oslow.models.oslow import OSlow
import wandb
from oslow.training.permutation_learning.initialization import uniform_gamma_init_func


torch.random.manual_seed(101)
model = OSlow(in_features=num_nodes,
              layers=[128, 64, 128],
              dropout=None,
              residual=False,
              activation=torch.nn.LeakyReLU(),
              additive=False,
              num_transforms=1,
              normalization=ActNorm,
              base_distribution=torch.distributions.Normal(loc=0, scale=1),
              ordering=None)


def flow_optimizer(params): return torch.optim.AdamW(params, lr=flow_lr)
def perm_optimizer(params): return torch.optim.AdamW(params, lr=perm_lr)


# permutation_learning_config = GumbelTopKConfig(
#     num_samples=num_samples,
#     buffer_size=10,
#     buffer_update=10,
#     set_gamma_uniform=True,
# )


# permutation_learning_config = GumbelSinkhornStraightThroughConfig(temp=0.1, iters=20)
temperature_scheduler = 'constant'
temperature = 1.0

perm_module = lambda in_features: GumbelTopK(in_features, num_samples=num_samples, buffer_size=10, buffer_update=10, initialization_function=uniform_gamma_init_func)

birkhoff_config = None 
trainer = Trainer(model=model,
                  dag=dset_sinusoidal.dag,
                  flow_dataloader=flow_dataloader,
                  perm_dataloader=permutation_dataloader,
                  flow_optimizer=flow_optimizer,
                  permutation_optimizer=perm_optimizer,
                  flow_frequency=flow_freq,
                  temperature=temperature,
                  temperature_scheduler=temperature_scheduler,
                  permutation_frequency=perm_freq,
                  max_epochs=epochs,
                  flow_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
                  permutation_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
                  permutation_learning_module=perm_module,
                  device=device)
wandb.init(project="notebooks", entity="ordered-causal-discovery",
           name='oslow-ensemble',
            tags=[
                f"num_nodes-{num_nodes}",
                f"epochs-{epochs}",
                f"base-temperature-{temperature}",
                f"temperature-scheduling-{temperature_scheduler}",
                "no-sigmoid",
            ],)
trainer.train()
wandb.finish()

ensemble_model = trainer.model

No final phase!




0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
evaluation/avg_backward_penalty,▆▅▁▃▅▅▅▅▃▄▃▅▅▃▆▆▄▅▃▃▆▃▆▆▅▆▅▇▄▃▅▃▅▆▆▅▆▃█▆
evaluation/best_backward_penalty,▁▂▄▇▇▄▇▁▅▇▄▄▄▄▂▇▁▅▂▄▇▂▄▇▄▇▂▁▇▅▅▂▂▇▇▇▁▅█▅
flow_ensemble/loss,█▇▃▇▄▄▅▃▃▃▃▅▄▄▆▃▅▄▃▃▄▄▂▃▃▄▂▅▃▄▄▂▄▄▃▃▁▅▁▆
flow_ensemble/step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
permutation/temperature,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,499.0
evaluation/avg_backward_penalty,0.0
evaluation/best_backward_penalty,0.0
flow_ensemble/loss,4.55708
flow_ensemble/step,4000.0
permutation/temperature,1.0


In [11]:
# %autoreload 2

from oslow.visualization.birkhoff import get_all_permutation_matrices
from oslow.evaluation import backward_relative_penalty

# model_fixed = {}
# ensemble_model = trainer.model
# # iterate over all the 24 permutations of [0, 1, 2, 3]
# for perm in get_all_permutation_matrices(num_nodes):
#     perm_list = torch.argmax(perm, dim=-1).cpu().numpy().tolist()
#     perm_list_name = ''.join([str(i) for i in perm_list])

#     torch.random.manual_seed(101)
#     model_fixed[perm_list_name] = OSlow(in_features=num_nodes,
#                 layers=[128, 64, 128],
#                 dropout=None,
#                 residual=False,
#                 activation=torch.nn.LeakyReLU(),
#                 additive=False,
#                 num_transforms=1,
#                 normalization=ActNorm,
#                 base_distribution=torch.distributions.Normal(loc=0, scale=1),
#                 ordering=None)


#     def flow_optimizer(params): return torch.optim.AdamW(params, lr=flow_lr)
#     def perm_optimizer(params): return torch.optim.AdamW(params, lr=perm_lr)


#     permutation_learning_config = GumbelTopKConfig(
#         num_samples=num_samples,
#         buffer_size=10,
#         buffer_update=10,
#         set_gamma_custom=[
#             perm_list
#         ]
#     )


#     # permutation_learning_config = GumbelSinkhornStraightThroughConfig(temp=0.1, iters=20)
#     temperature_scheduler = 'constant'
#     temperature = 0.00000001

#     birkhoff_config = None
#     trainer = Trainer(model=model_fixed[perm_list_name],
#                     dag=graph,
#                     flow_dataloader=flow_dataloader,
#                     perm_dataloader=permutation_dataloader,
#                     flow_optimizer=flow_optimizer,
#                     permutation_optimizer=perm_optimizer,
#                     flow_frequency=flow_freq,
#                     temperature=temperature,
#                     temperature_scheduler=temperature_scheduler,
#                     permutation_frequency=perm_freq,
#                     max_epochs=epochs,
#                     flow_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
#                     permutation_lr_scheduler=torch.optim.lr_scheduler.ConstantLR,
#                     permutation_learning_config=permutation_learning_config,
#                     birkhoff_config=birkhoff_config,
#                     device=device)
#     wandb.init(project="notebooks", entity="ordered-causal-discovery",
#                name=f"perm-{perm_list_name}",
#                 tags=[
#                     permutation_learning_config.method,
#                     f"num_nodes-{num_nodes}",
#                     f"epochs-{epochs}",
#                     f"base-temperature-{temperature}",
#                     f"temperature-scheduling-{temperature_scheduler}",
#                     "no-sigmoid",
#                 ],)
#     trainer.train()
#     wandb.finish()

In [12]:
model.eval()
all_perms = get_all_permutation_matrices(4)

all_scores = []
all_backward = []
with torch.no_grad():
    for perm in all_perms:
        perm = perm.float().to(device)
        all_log_probs = []
        for x in flow_dataloader:
            x = x.to(device)
            all_log_probs.append(
                model.log_prob(x, perm.unsqueeze(0).repeat(x.shape[0], 1, 1))
                .mean()
                .item()
            )
        perm_list = [x for x in torch.argmax(perm, dim=1).cpu().numpy().tolist()]
        perm_list_formatted = "".join([str(x) for x in perm_list])
        score = sum(all_log_probs) / len(all_log_probs)
        backward = backward_relative_penalty(perm_list, dset_sinusoidal.dag)
        all_scores.append(-score)
        all_backward.append(backward)
        print(
            f"Permutation: {perm_list_formatted}\nLog Prob: {score}\nBackward count: {backward}\n"
        )
        print("-----")

Permutation: 0123
Log Prob: -4.12761914730072
Backward count: 0.0

-----
Permutation: 0132
Log Prob: -4.20219761133194
Backward count: 0.16666666666666666

-----
Permutation: 0213
Log Prob: -4.242488324642181
Backward count: 0.16666666666666666

-----
Permutation: 0231
Log Prob: -4.463877320289612
Backward count: 0.3333333333333333

-----
Permutation: 0312
Log Prob: -4.284694015979767
Backward count: 0.3333333333333333

-----
Permutation: 0321
Log Prob: -4.509818196296692
Backward count: 0.5

-----
Permutation: 1023
Log Prob: -4.245047152042389
Backward count: 0.16666666666666666

-----
Permutation: 1032
Log Prob: -4.329052746295929
Backward count: 0.3333333333333333

-----
Permutation: 1203
Log Prob: -4.331453859806061
Backward count: 0.3333333333333333

-----
Permutation: 1230
Log Prob: -4.49785715341568
Backward count: 0.5

-----
Permutation: 1302
Log Prob: -4.390943646430969
Backward count: 0.5

-----
Permutation: 1320
Log Prob: -4.574305713176727
Backward count: 0.6666666666666666

In [13]:
import matplotlib as mpl
import matplotlib.pyplot as plt


def latexify(fig_width, fig_height, font_size=7, legend_size=5, labelsize=7):
    """Set up matplotlib's RC params for LaTeX plotting."""
    params = {
        "backend": "ps",
        "text.latex.preamble": "\\usepackage{amsmath,amsfonts,amssymb,amsthm, mathtools,times}",
        "axes.labelsize": font_size,
        "axes.titlesize": font_size,
        "legend.fontsize": legend_size,
        "xtick.labelsize": labelsize,
        "ytick.labelsize": labelsize,
        "text.usetex": True,
        "figure.figsize": [fig_width, fig_height],
        "font.family": "serif",
        "xtick.minor.size": 0.5,
        "xtick.major.pad": 3,
        "xtick.minor.pad": 3,
        "xtick.major.size": 1,
        "ytick.minor.size": 0.5,
        "ytick.major.pad": 1.5,
        "ytick.major.size": 1,
    }

    mpl.rcParams.update(params)
    plt.rcParams.update(params)


COLORS = {
    "green": "#12f913",
    "blue": "#0000ff",
    "red": "#ff0000",
    "pink": "#fb87c4",
    "black": "#000000",
}

LIGHT_COLORS = {
    "blue": (0.237808, 0.688745, 1.0),
    "red": (1.0, 0.519599, 0.309677),
    "green": (0.0, 0.790412, 0.705117),
    "pink": (0.936386, 0.506537, 0.981107),
    "yellow": (0.686959, 0.690574, 0.0577502),
    "black": "#535154",
}

DARK_COLORS = {
    "green": "#3E9651",
    "red": "#CC2529",
    "blue": "#396AB1",
    "black": "#535154",
}

GOLDEN_RATIO = (np.sqrt(5) - 1.0) / 2

cm = 1 / 2.54
FIG_WIDTH = 17 * cm
FONT_SIZE = 10
LEGEND_SIZE = 8

In [14]:
latexify(
    FIG_WIDTH / 3,
    FIG_WIDTH * 0.25,
    font_size=LEGEND_SIZE,
    legend_size=LEGEND_SIZE,
    labelsize=6,
)
fig, ax = plt.subplots(1, 1, figsize=(FIG_WIDTH / 3, FIG_WIDTH * 0.25))
color_list = ["red", "blue", "green", "pink"]
fill_colors = ["red", "blue", "green", "pink"]
linestyles = ["-"] * 4

ax.scatter(all_backward, all_scores, s=4, marker="*", c=LIGHT_COLORS["blue"])
ax.set_xlabel(r"CBC")
ax.set_ylabel(r"Negative log-likelihood")
# hs = axes[0, 0].get_legend_handles_labels()[0]

# fig.legend(hs, names, loc="upper center", ncol=4)
plt.subplots_adjust(left=0.15, bottom=0.17, right=0.95, top=0.95)

fig.savefig("ensemble_four.pdf")
plt.close()

*c* argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with *x* & *y*.  Please use the *color* keyword-argument or provide a 2D array with a single row if you intend to specify the same RGB or RGBA value for all points.
