In [17]:
from typing import Callable, List, Tuple

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions.distribution import Distribution

# coupling layer
based on

https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/09-normalizing-flows.html (VPN)

https://sebastiancallh.github.io/post/affine-normalizing-flows/

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial11/NF_image_modeling.html

In [18]:
class AffineCouplingLayer(nn.Module):
    def __init__(
        self,
        theta: nn.Module,
        split: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
        swap: int
    ):
        super().__init__()
        self.theta = theta
        self.split = split
        self.swap=swap

    def f(self, x: torch.Tensor) -> torch.Tensor:
        """f : x -> z. The inverse of g."""
        x2, x1 = self.split(x,self.swap)
        t, s = self.theta(x1)
        z1, z2 = x1, x2 * torch.exp(s) + t 
        """
        z1 = x1
        z2 = x2 * torch.exp(s) + t
        """
        log_det = s.sum(-1) #-1 ???
        return torch.cat((z1, z2), dim=-1), log_det

    def g(self, z: torch.Tensor) -> torch.Tensor:
        """g : z -> x. The inverse of f."""
        z1, z2 = self.split(z,self.swap)
        t, s = self.theta(z1)
        x1, x2 = z1, (z2 - t) * torch.exp(-s)
        return torch.cat((x2, x1), dim=-1)

In [19]:
class NormalizingFlow(nn.Module):
    def __init__(self, latent: Distribution, flows: List[nn.Module]):
        super().__init__()
        self.latent = latent
        self.flows = flows

    def latent_log_prob(self, z: torch.Tensor) -> torch.Tensor:
        return torch.sum(self.latent.log_prob(z))

    def latent_sample(self, num_samples: int = 1) -> torch.Tensor:
        return self.latent.sample((num_samples,))                  

    def sample(self, num_samples: int = 1) -> torch.Tensor:
        """Sample a new observation x by sampling z from
        the latent distribution and pass through g."""
        return self.g(self.latent_sample(num_samples)) 
    

    def f(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:  #forward
        """Maps observation x to latent variable z.
        Additionally, computes the log determinant
        of the Jacobian for this transformation.
        Inveres of g."""
        z, sum_log_abs_det = x, torch.ones(x.size(0)).to(x.device)
        for flow in self.flows:
            z, log_abs_det = flow.f(z)
            sum_log_abs_det += log_abs_det

        return z, sum_log_abs_det

    def g(self, z: torch.Tensor) -> torch.Tensor:
        """Maps latent variable z to observation x.
        Inverse of f."""
        with torch.no_grad():
            x = z
            for flow in reversed(self.flows):
                x = flow.g(x)

            return x

    def g_steps(self, z: torch.Tensor) -> List[torch.Tensor]:
        """Maps latent variable z to observation x
        and stores intermediate results."""
        xs = [z]
        for flow in reversed(self.flows):
            xs.append(flow.g(xs[-1]))

        return xs

    def log_prob(self, x: torch.Tensor) -> torch.Tensor:
        """Computes log p(x) using the change of variable formula."""
        z, log_abs_det = self.f(x)
        return self.latent_log_prob(z) + log_abs_det

    def __len__(self) -> int:
        return len(self.flows)

In [20]:
class ThetaNetwork(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        num_hidden: int,
        hidden_dim: int,
        num_params: int,
    ):
        super().__init__()
        self.input = nn.Linear(in_dim, hidden_dim)
        self.hidden = nn.ModuleList(
            [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden)]
        )

        self.num_params = num_params
        self.out_dim = out_dim
        self.dims = nn.Linear(hidden_dim, out_dim * num_params)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.leaky_relu(self.input(x))
        for h in self.hidden:
            x = F.leaky_relu(h(x))

        batch_params = self.dims(x).reshape(x.size(0), self.out_dim, -1) #???
        params = batch_params.chunk(self.num_params, dim=-1) #???
        return [p.squeeze(-1) for p in params]

In [21]:
def SplitFunc(x: torch.Tensor,swap: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if swap==0:
        return x[:,::2], x[:,1::2]
    else: 
        return x[:,1::2], x[:,::2] 

In [22]:
from constants import N_nod
def configure_theta():
    theta=ThetaNetwork(
                in_dim = N_nod//2,
                out_dim = N_nod//2,
                num_hidden = 4,  #2 to 6
                hidden_dim =100 , #100-1024
                num_params = 2)
    return theta
def configure_flows(n_flows):  # n_flows=8,...,12
    flows=[]
    for i in range(n_flows):
        flows.append(AffineCouplingLayer(configure_theta(),split=SplitFunc,swap=i%2))
    flows = nn.ModuleList(flows)
    return flows 
nf=NormalizingFlow(latent=torch.distributions.Normal(loc=0.0, scale=1.),
                  flows=configure_flows(8))
                  

In [26]:
class Pipeline(pl.LightningModule):
    def __init__(
        self,
        model,
        criterion,
        optimizer_class=torch.optim.Adam,
        optimizer_kwargs={"lr": 0.001},
    ) -> None:
        super().__init__()
        self.model = model
        self.loss = criterion
        self.optimizer_class = optimizer_class
        self.optimizer_kwargs = optimizer_kwargs


    def configure_optimizers(self):
        optimizer = self.optimizer_class(
            self.model.parameters(), **self.optimizer_kwargs
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        print("step")
        z = batch
        x = self.model.g(z)
        print("x:",x.shape)
        LP=self.model.log_prob(x)
        loss = self.loss(x,LP)
        self.log('train_loss', loss)
        return loss


    def on_train_epoch_end(self):
        print("epoch end")
        pass
    

In [27]:
from LOSS import KL_osc
from Data import train_loader
from pytorch_lightning.loggers import TensorBoardLogger
pipeline=Pipeline(model=nf,criterion=KL_osc)
trainer = pl.Trainer(
    max_epochs=5,
    logger=TensorBoardLogger(save_dir=f"logs/nf"),
    num_sanity_val_steps=0,
)

trainer.fit(model=pipeline, train_dataloaders=train_loader)
torch.save(nf.state_dict(), "model_weights1.pth")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type            | Params
------------------------------------------
0 | model | NormalizingFlow | 336 K 
1 | loss  | KL_with_S       | 0     
------------------------------------------
336 K     Trainable params
0         Non-trainable params
336 K     Total params
1.344     Total estimated model params size (MB)
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


A Jupyter Widget

step
x: torch.Size([256, 10])
epoch end
step
x: torch.Size([256, 10])
epoch end
step
x: torch.Size([256, 10])
epoch end
step
x: torch.Size([256, 10])
epoch end
step
x: torch.Size([256, 10])
epoch end
