In [4]:
import normflows as nf

import torch
import numpy as np
from nf_class import SystematicFlow  


# -----------------------------------------------------------------------------
# Dummy Flow Layer (Now Subclassing nn.Module)
# -----------------------------------------------------------------------------
import torch.nn as nn

class DummyFlow(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, z, context=None):
        # Return the input unchanged and zero log determinant
        return z, torch.zeros(z.shape[0], device=z.device)
    
    def inverse(self, x, context=None):
        # Return the input unchanged and zero log determinant
        return x, torch.zeros(x.shape[0], device=x.device)

# -----------------------------------------------------------------------------
# Dummy Base Distribution (q0)
# -----------------------------------------------------------------------------
class DummyBaseDistribution:
    def __init__(self, shape=(2,)):
        self.shape = shape

    def __call__(self, num_samples=1, context=None):
        # Sample standard normal and return zero log probability
        z = torch.randn(num_samples, *self.shape)
        log_q = torch.zeros(num_samples, device=z.device)
        return z, log_q

    def log_prob(self, z, context=None):
        # Return zero log probability
        return torch.zeros(z.shape[0], device=z.device)

# -----------------------------------------------------------------------------
# Dummy Target Distribution (p)
# -----------------------------------------------------------------------------
class DummyTargetDistribution:
    def log_prob(self, i, pedestal=None):
        # Return dummy outputs for z, context, and log_p
        z = torch.randn(4, 2)
        context = torch.zeros(4, 1)
        log_p = torch.zeros(4)
        return z, context, log_p

# -----------------------------------------------------------------------------
# Testing SystematicFlow
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # Define dummy components
    flows = [DummyFlow(), DummyFlow()]  # Two dummy flows
    q0 = DummyBaseDistribution(shape=(2,))
    p = DummyTargetDistribution()

    # Create an instance of SystematicFlow
    flow = SystematicFlow(flows=flows, q0=q0, p=p)

    # Test forward pass
    z = torch.randn(5, 2)  # Latent variable
    x = flow.forward(z)
    print("Forward pass output:", x)

    # Test forward and log determinant
    x, log_det = flow.forward_and_log_det(z)
    print("Forward and log determinant output:", x, log_det)

    # Test inverse pass
    x_rand = torch.randn(5, 2)
    z = flow.inverse(x_rand)
    print("Inverse pass output:", z)

    # Test inverse and log determinant
    z, log_det = flow.inverse_and_log_det(x_rand)
    print("Inverse and log determinant output:", z, log_det)

    # Test sampling
    samples, log_prob = flow.sample(num_samples=3)
    print("Sampled data:", samples)
    print("Log probability of samples:", log_prob)

    # Test log probability
    log_prob = flow.log_prob(x_rand)
    print("Log probability of input:", log_prob)

    # Test symmetric KLD
    sym_kld = flow.symmetric_kld(i=0)
    print("Symmetric KLD:", sym_kld)



Forward pass output: tensor([[ 0.1098, -1.7427],
        [-0.0491, -1.5347],
        [ 0.4971,  0.9044],
        [-0.6229, -0.3093],
        [ 2.0344, -0.3587]])
Forward and log determinant output: tensor([[ 0.1098, -1.7427],
        [-0.0491, -1.5347],
        [ 0.4971,  0.9044],
        [-0.6229, -0.3093],
        [ 2.0344, -0.3587]]) tensor([0., 0., 0., 0., 0.])
Inverse pass output: tensor([[ 0.2862, -0.1005],
        [ 0.2021, -0.7505],
        [ 0.3440, -1.2535],
        [-1.9999,  0.6734],
        [ 1.6624, -1.5152]])
Inverse and log determinant output: tensor([[ 0.2862, -0.1005],
        [ 0.2021, -0.7505],
        [ 0.3440, -1.2535],
        [-1.9999,  0.6734],
        [ 1.6624, -1.5152]]) tensor([0., 0., 0., 0., 0.])
Sampled data: tensor([[-0.2932, -0.5512],
        [ 0.9578,  1.5651],
        [ 0.4992,  0.1731]])
Log probability of samples: tensor([0., 0., 0.])
Log probability of input: tensor([0., 0., 0., 0., 0.])
Symmetric KLD: tensor(0.)
