# Injective model playground

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import logging
from pathlib import Path

import lightning as pl
import normflows as nf
import numpy as np
import torch
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint

from ciflows.datasets.lightning import MultiDistrDataModule, DatasetName
from ciflows.distributions.linear import ClusteredLinearGaussianDistribution
from ciflows.flows import TwoStageTraining, plCausalInjFlowModel
from ciflows.flows.glow import (GlowBlock, InjectiveGlowBlock, ReshapeFlow,
                                Squeeze)


  from .autonotebook import tqdm as notebook_tqdm


In [18]:
input_shape = (3, 64, 64)

def get_inj_model(input_shape):
    use_lu = True
    gamma = 1e-6
    activation = "linear"
    dropout_probability = 0.2

    net_actnorm = False
    # n_hidden_list = [32, 64, 128, 256, 256, 256]
    n_hidden = 32
    n_glow_blocks = 3
    n_mixing_layers = 4
    n_injective_layers = 8
    n_layers = n_mixing_layers + n_injective_layers

    # hidden layers for the AutoregressiveRationalQuadraticSpline
    net_hidden_layers = 2
    net_hidden_dim = 32

    n_channels = input_shape[0]
    img_size = input_shape[1]

    n_chs = n_channels
    flows = []

    debug = False

    n_chs = int(n_channels * 4**n_mixing_layers * (1 / 2) ** n_injective_layers)
    latent_size = int(img_size / (2**n_mixing_layers))
    init_n_chs = n_chs
    init_latent_size = latent_size
    print(
        "Starting at latent representation: ", n_chs, "with latent size: ", latent_size
    )
    q0 = nf.distributions.DiagGaussian(
        (n_chs, latent_size, latent_size), trainable=False
    )

    split_mode = "channel"

    for i in range(n_injective_layers):
        # n_hidden = n_hidden_list[-i]
        if i <= 1:
            split_mode = "checkerboard"
        else:
            split_mode = "channel"

        if i % 1 == 0:
            for j in range(n_glow_blocks):
                flows += [
                    GlowBlock(
                        channels=n_chs,
                        hidden_channels=n_hidden,
                        use_lu=use_lu,
                        scale=True,
                        split_mode=split_mode,
                        net_actnorm=net_actnorm,
                        dropout_probability=dropout_probability,
                    )
                ]
        else:
            flows += [
                ReshapeFlow(
                    shape_in=(n_chs, latent_size, latent_size),
                    shape_out=(n_chs * latent_size * latent_size,),
                )
            ]
            flows += [
                nf.flows.AutoregressiveRationalQuadraticSpline(
                    num_input_channels=n_chs * latent_size * latent_size,
                    num_blocks=net_hidden_layers,
                    num_hidden_channels=net_hidden_dim,
                    permute_mask=True,
                )
            ]
            flows += [
                ReshapeFlow(
                    shape_in=(n_chs * latent_size * latent_size,),
                    shape_out=(n_chs, latent_size, latent_size),
                )
            ]

        # input to inj flow is what is at the X -> V layer
        flows += [
            InjectiveGlowBlock(
                channels=n_chs,
                hidden_channels=n_hidden,
                activation=activation,
                scale=True,
                gamma=gamma,
                debug=debug,
                split_mode=split_mode,
                net_actnorm=net_actnorm,
            )
        ]
        n_chs = n_chs * 2
        latent_size = latent_size
        if debug:
            print(f"On layer {n_layers - i}, n_chs = {n_chs//2} -> {n_chs}")

    # split_mode = "channel_inv"
    for i in range(n_mixing_layers):
        # if i > 0:# n_mixing_layers - 1:
        for j in range(n_glow_blocks):
            flows += [
                GlowBlock(
                    channels=n_chs,
                    hidden_channels=n_hidden,
                    use_lu=use_lu,
                    scale=False,
                    split_mode=split_mode,
                    dropout_probability=dropout_probability,
                )
            ]
        # else:
        #     flows += [
        #         ReshapeFlow(
        #             shape_in=(n_chs, latent_size, latent_size),
        #             shape_out=(n_chs * latent_size * latent_size,),
        #         )
        #     ]
        #     flows += [
        #         nf.flows.AutoregressiveRationalQuadraticSpline(
        #             num_input_channels=n_chs * latent_size * latent_size,
        #             num_blocks=net_hidden_layers,
        #             num_hidden_channels=net_hidden_dim,
        #             permute_mask=True,
        #         )
        #     ]
        #     flows += [
        #         ReshapeFlow(
        #             shape_in=(n_chs * latent_size * latent_size,),
        #             shape_out=(n_chs, latent_size, latent_size),
        #         )
        #     ]

        flows += [Squeeze()]
        n_chs = n_chs // 4
        latent_size *= 2
        if debug:
            print(f"On layer {n_mixing_layers - i}, n_chs = {n_chs}")

    model = nf.NormalizingFlow(q0=q0, flows=flows)
    model.output_n_chs = init_n_chs
    model.output_latent_size = init_latent_size
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(pytorch_total_params)

    return model


def get_bij_model(
    n_chs,
    latent_size,
    adj_mat,
    cluster_sizes,
    intervention_targets,
    confounded_variables,
):
    use_lu = True
    net_actnorm = False
    n_hidden = 128
    n_glow_blocks = 6

    flows = []

    debug = False

    print("Starting at latent representation: ", n_chs, latent_size, latent_size)
    # q0 = nf.distributions.DiagGaussian(
    #     (n_chs, latent_size, latent_size), trainable=False
    # )
    q0 = nf.distributions.DiagGaussian(
        (n_chs * latent_size * latent_size,), trainable=False
    )

    # q0 = ClusteredLinearGaussianDistribution(
    #     adjacency_matrix=adj_mat,
    #     cluster_sizes=cluster_sizes,
    #     intervention_targets_per_distr=intervention_targets,
    #     hard_interventions_per_distr=None,
    #     confounded_variables=confounded_variables,
    # )

    split_mode = "checkerboard"

    net_hidden_layers = 2
    net_hidden_dim = 32

    # flows += [
    #     ReshapeFlow(
    #         shape_in=(n_chs, latent_size, latent_size),
    #         shape_out=(n_chs * latent_size * latent_size,),
    #     )
    # ]
    # n_chs = n_chs * latent_size * latent_size

    # using glow blocks
    # n_chs = int(n_chs * 4**n_glow_blocks)
    for i in range(n_glow_blocks):
        # Neural network with two hidden layers having 64 units each
        # Last layer is initialized by zeros making training more stable
        # n_chs *= 4
        # param_map = nf.nets.MLP([n_chs, net_hidden_dim, n_chs*2], init_zeros=True)
        # # # Add flow layer
        # flows.append(nf.flows.AffineCouplingBlock(param_map, split_mode='channel'))
        # flows.append(nf.flows.Permute(n_chs, mode='swap'))

        # Autoregressive Neural Spline flow
        # Swap dimensions
        flows += [
            nf.flows.AutoregressiveRationalQuadraticSpline(
                num_input_channels=n_chs * latent_size * latent_size,
                num_blocks=net_hidden_layers,
                num_hidden_channels=net_hidden_dim,
                permute_mask=True,
            )
        ]
        if debug:
            print(f"On layer {n_glow_blocks - i}, n_chs = {n_chs//2} -> {n_chs}")

    # maps x to (n_chs * latent_size * latent_size), while v is mapped to (n_chs, latent_size, latent_size)
    flows += [
        ReshapeFlow(
            shape_in=(n_chs * latent_size * latent_size,),
            shape_out=(n_chs, latent_size, latent_size),
        )
    ]
    model = nf.NormalizingFlow(q0=q0, flows=flows)

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(pytorch_total_params)
    return model

In [28]:
inj_model = get_inj_model(input_shape)
init_chs = inj_model.output_n_chs
init_latent_size = inj_model.output_latent_size
print(init_chs, init_latent_size)

Starting at latent representation:  3 with latent size:  4
17151759
3 4


In [21]:
bij_model = get_bij_model(init_chs, init_latent_size, None, None, None, None)

Starting at latent representation:  3 4 4
253344


In [None]:
test_sample, _ = inj_model.sample(2)
print(test_sample.shape)

test_tensor = torch.randn(2, *input_shape)
print(test_tensor.shape)
test_sample = inj_model.inverse(test_tensor)
print(test_sample.shape)

test_latent_tensor = torch.randn(2, init_chs, init_latent_size, init_latent_size)
print(test_latent_tensor.shape)
test_sample = inj_model.forward(test_latent_tensor)
print(test_sample.shape)

torch.Size([2, 3, 64, 64])
torch.Size([2, 3, 64, 64])
torch.Size([2, 3, 4, 4])
torch.Size([2, 3, 4, 4])
torch.Size([2, 3, 64, 64])


In [None]:
test_tensor = torch.randn(2, *input_shape)
test_sample = bij_model.inverse(inj_model.inverse(test_tensor))
print(test_sample.shape)

test_latent_tensor = torch.randn(2, init_chs * init_latent_size * init_latent_size)
print(test_latent_tensor.shape)
test_sample = inj_model.forward(bij_model.forward(test_latent_tensor))
print(test_sample.shape)

torch.Size([2, 48])
torch.Size([2, 48])
torch.Size([2, 3, 64, 64])


In [29]:
n_chs = inj_model.output_n_chs
latent_size = inj_model.output_latent_size

example_input_array = [
    torch.randn(2, n_chs * latent_size * latent_size),
    torch.randn(2, 1),
]
model = plCausalInjFlowModel(
    inj_model=inj_model,
    bij_model=bij_model,
    example_input_array=example_input_array,
)

/Users/adam2392/miniforge3/envs/ciflows/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'inj_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['inj_model'])`.
/Users/adam2392/miniforge3/envs/ciflows/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'bij_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['bij_model'])`.
