In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from tensorflow_probability.substrates import jax as tfp

# Loading distributions from TensorFlow Probability (JAX version)
tfd = tfp.distributions


def Normalize(num_groups=10):
    return nn.GroupNorm(num_groups=num_groups, epsilon=1e-6, use_scale=True)


class Downsample(nn.Module):
    in_channels: int

    def setup(self):
        self.conv = nn.Conv(
            self.in_channels,
            kernel_size=(3, 3),
            strides=(2, 2),
            padding=((0, 1), (0, 1)),
        )

    def __call__(self, x):
        pad = ((0, 0), (0, 1), (0, 0), (0, 1))
        x = jnp.pad(x, pad, mode="constant", constant_values=0)
        x = self.conv(x)
        return x


class ResnetBlock(nn.Module):
    in_channels: int
    out_channels: int

    def setup(self):
        self.norm1 = Normalize(num_groups=5)
        self.conv1 = nn.Conv(
            self.out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )
        self.norm2 = Normalize(num_groups=5)
        self.conv2 = nn.Conv(
            self.out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

        self.nin_shortcut = nn.Conv(
            self.out_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding=((0, 0), (0, 0)),
        )

    def __call__(self, x):
        h = x
        h = self.norm1(h)
        h = nn.swish(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = nn.swish(h)
        h = self.conv2(h)
        x = self.nin_shortcut(x)
        x_ = x + h

        return x + h


class DownsamplingBlock(nn.Module):
    ch: int
    ch_mult: tuple
    num_res_blocks: int
    resolution: int
    block_idx: int

    def setup(self):
        self.ch_mult_ = self.ch_mult
        self.num_resolutions = len(self.ch_mult_)
        in_ch_mult = (1,) + tuple(self.ch_mult_)
        block_in = self.ch * in_ch_mult[self.block_idx]
        block_out = self.ch * self.ch_mult_[self.block_idx]

        res_blocks = []
        for _ in range(self.num_res_blocks):
            res_blocks.append(
                ResnetBlock(
                    block_in,
                    block_out,
                )
            )
        block_in = block_out
        self.block = res_blocks

        self.downsample = None
        if self.block_idx != self.num_resolutions - 1:
            self.downsample = Downsample(block_in)

    def __call__(self, h):
        for i, res_block in enumerate(self.block):
            h = res_block(h)

        if self.downsample is not None:
            h = self.downsample(h)

        return h


class MidBlock(nn.Module):
    in_channels: int

    def setup(self):
        self.block_1 = ResnetBlock(
            self.in_channels,
            self.in_channels,
        )
        self.block_2 = ResnetBlock(
            self.in_channels,
            self.in_channels,
        )

    def __call__(self, h):
        h = self.block_1(h)
        h = self.block_2(h)

        return h


class Encoder(nn.Module):
    ch: int
    out_ch: int
    ch_mult: tuple
    num_res_blocks: int
    in_channels: int
    resolution: int
    z_channels: int
    double_z: bool

    def setup(self):
        self.num_resolutions = len(self.ch_mult)

        # downsampling
        self.conv_in = nn.Conv(
            self.ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

        curr_res = self.resolution
        downsample_blocks = []

        for i_level in range(self.num_resolutions):
            downsample_blocks.append(
                DownsamplingBlock(
                    ch=self.ch,
                    ch_mult=self.ch_mult,
                    num_res_blocks=self.num_res_blocks,
                    resolution=self.resolution,
                    block_idx=i_level,
                )
            )
            if i_level != self.num_resolutions - 1:
                curr_res = curr_res // 2

        self.down = downsample_blocks

        # middle
        mid_channels = self.ch * self.ch_mult[-1]
        self.mid = MidBlock(mid_channels)
        # end
        self.norm_out = Normalize()
        self.conv_out = nn.Conv(
            self.z_channels * 2 if self.double_z else self.z_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

    def __call__(self, x):
        # downsampling
        print("x :", x.shape)
        hs = self.conv_in(x)
        print("Conv_in :", hs.shape)
        for block in self.down:
            hs = block(hs)
        print("Down :", hs.shape)

        # middle
        hs = self.mid(hs)
        print("Mid :", hs.shape)

        # end
        hs = self.norm_out(hs)
        hs = nn.swish(hs)
        hs = self.conv_out(hs)
        print("Conv_out :", hs.shape)
        # Gaussian Distribution
        # q = tfd.MultivariateNormalDiag(
        #     loc=hs[..., : self.z_channels], scale_diag=hs[..., self.z_channels :]
        # )
        # print("Gauss distribution :", q)
        return hs



In [2]:
# Example usage
input_shape = (1, 64, 64, 5)  # (batch_size, height, width, channels)

# Create an instance of Downsample
downsample = Downsample(in_channels=5)

# Generate random input
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, input_shape)

# Initialize the module and apply it to the input
params = downsample.init(rng, x)
output = downsample.apply(params, x)

# Print the output shape
print(output.shape)

print(Downsample(in_channels=5).tabulate(rng, x))

(1, 32, 32, 5)

[3m                               Downsample Summary                               [0m
┏━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath[0m[1m [0m┃[1m [0m[1mmodule    [0m[1m [0m┃[1m [0m[1minputs           [0m[1m [0m┃[1m [0m[1moutputs         [0m[1m [0m┃[1m [0m[1mparams           [0m[1m [0m┃
┡━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│      │ Downsample │ [2mfloat32[0m[1,64,64,… │ [2mfloat32[0m[1,32,32… │                   │
├──────┼────────────┼───────────────────┼──────────────────┼───────────────────┤
│ conv │ Conv       │ [2mfloat32[0m[1,65,64,… │ [2mfloat32[0m[1,32,32… │ bias: [2mfloat32[0m[5]  │
│      │            │                   │                  │ kernel:           │
│      │            │                   │                  │ [2mfloat32[0m[3,3,6,5]  │
│      │            │                   │                  │                   

In [3]:
from clu import parameter_overview

# Generate random input tensors
x = jnp.zeros((1, 64, 64, 5))
rng = jax.random.PRNGKey(0)

# Initialize DownsamplingBlock
down_block = DownsamplingBlock(
    ch=5,
    ch_mult=(1, 2, 4),
    num_res_blocks=2,
    resolution=64,
    block_idx=1,
)

params = down_block.init(rng, x)
# Apply DownsamplingBlock using the apply method
output = down_block.apply({"params": params["params"]}, x)
print(output.shape)

print(down_block.tabulate(rng, x))

print(parameter_overview.get_parameter_overview(params))

(1, 32, 32, 10)

[3m                           DownsamplingBlock Summary                            [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs      [0m[1m [0m┃[1m [0m[1mparams      [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│               │ Downsampling… │ [2mfloat32[0m[1,64… │ [2mfloat32[0m[1,32… │              │
├───────────────┼───────────────┼───────────────┼───────────────┼──────────────┤
│ block_0       │ ResnetBlock   │ [2mfloat32[0m[1,64… │ [2mfloat32[0m[1,64… │              │
├───────────────┼───────────────┼───────────────┼───────────────┼──────────────┤
│ block_0/norm1 │ GroupNorm     │ [2mfloat32[0m[1,64… │ [2mfloat32[0m[1,64… │ bias:        │
│               │               │               │               │ [2mfloat32

In [4]:
# Generate random input tensors
x = jnp.zeros((1, 64, 64, 5))
rng = jax.random.PRNGKey(0)

# Initialize ResnetBlock
resnet_block = ResnetBlock(
    in_channels=5,
    out_channels=5,
)

params = resnet_block.init(rng, x)
# Apply ResnetBlock using the apply method
output = resnet_block.apply({"params": params["params"]}, x)
print(output.shape)

print(resnet_block.tabulate(rng, x))

print(parameter_overview.get_parameter_overview(params))

(1, 64, 64, 5)

[3m                              ResnetBlock Summary                               [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath        [0m[1m [0m┃[1m [0m[1mmodule     [0m[1m [0m┃[1m [0m[1minputs        [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│              │ ResnetBlock │ [2mfloat32[0m[1,64,… │ [2mfloat32[0m[1,64,… │               │
├──────────────┼─────────────┼────────────────┼────────────────┼───────────────┤
│ norm1        │ GroupNorm   │ [2mfloat32[0m[1,64,… │ [2mfloat32[0m[1,64,… │ bias:         │
│              │             │                │                │ [2mfloat32[0m[5]    │
│              │             │                │                │ scale:        │
│              │             │                │                │ [2mfloat32[0m[5]    

In [5]:
# Example usage
rng = jax.random.PRNGKey(0)
input_shape = (1, 64, 64, 5)

# Initialize Encoder
encoder = Encoder(
    ch_mult=(1, 2, 4),
    num_res_blocks=2,
    double_z=True,
    z_channels=5,
    resolution=64,
    in_channels=5,
    out_ch=5,
    ch=5,
)

# Apply Encoder
params = encoder.init(rng, x=jnp.ones(input_shape))
output = encoder.apply(params, x=jnp.ones(input_shape))

# Print the output shape
# print("Encoder output :", output.shape)  # Shape: (1, 16, 16, 5)

x : (1, 64, 64, 5)
Conv_in : (1, 64, 64, 5)
Down : (1, 16, 16, 20)
Mid : (1, 16, 16, 20)
Conv_out : (1, 16, 16, 10)
x : (1, 64, 64, 5)
Conv_in : (1, 64, 64, 5)
Down : (1, 16, 16, 20)
Mid : (1, 16, 16, 20)
Conv_out : (1, 16, 16, 10)


In [6]:
class Upsample(nn.Module):
    in_channels: int

    def setup(self):
        self.conv = nn.Conv(
            self.in_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

    def __call__(self, hs):
        batch, height, width, channels = hs.shape
        hs = jax.image.resize(
            hs,
            shape=(batch, height * 2, width * 2, channels),
            method="nearest",
        )
        hs = self.conv(hs)
        return hs


class UpsamplingBlock(nn.Module):
    ch: int
    ch_mult: tuple
    num_res_blocks: int
    resolution: int
    block_idx: int

    def setup(self):
        self.ch_mult_ = self.ch_mult
        self.num_resolutions = len(self.ch_mult_)

        if self.block_idx == self.num_resolutions - 1:
            block_in = self.ch * self.ch_mult_[-1]
        else:
            block_in = self.ch * self.ch_mult_[self.block_idx + 1]

        block_out = self.ch * self.ch_mult_[self.block_idx]

        res_blocks = []
        for _ in range(self.num_res_blocks + 1):
            res_blocks.append(ResnetBlock(block_in, block_out))

        block_in = block_out

        self.block = res_blocks

        self.upsample = None
        if self.block_idx != 0:
            self.upsample = Upsample(block_in)

    def __call__(self, h):
        for i, res_block in enumerate(self.block):
            h = res_block(h)

        if self.upsample is not None:
            h = self.upsample(h)

        return h

In [7]:
from clu import parameter_overview

# Generate random input tensors
x = jnp.zeros((1, 16, 16, 5))
rng = jax.random.PRNGKey(0)

# Initialize UpsamplingBlock
up_block = UpsamplingBlock(
    ch=5,
    ch_mult=(4, 2, 1),
    num_res_blocks=2,
    resolution=64,
    block_idx=2,
)

params = up_block.init(rng, x)
# Apply UpsamplingBlock using the apply method
output = up_block.apply({"params": params["params"]}, x)
print(output.shape)

print(up_block.tabulate(rng, x))

print(parameter_overview.get_parameter_overview(params))

(1, 32, 32, 5)

[3m                            UpsamplingBlock Summary                             [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs      [0m[1m [0m┃[1m [0m[1mparams      [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│               │ UpsamplingBl… │ [2mfloat32[0m[1,16… │ [2mfloat32[0m[1,32… │              │
├───────────────┼───────────────┼───────────────┼───────────────┼──────────────┤
│ block_0       │ ResnetBlock   │ [2mfloat32[0m[1,16… │ [2mfloat32[0m[1,16… │              │
├───────────────┼───────────────┼───────────────┼───────────────┼──────────────┤
│ block_0/norm1 │ GroupNorm     │ [2mfloat32[0m[1,16… │ [2mfloat32[0m[1,16… │ bias:        │
│               │               │               │               │ [2mfloat32[

In [8]:
import numpy as np


class Decoder(nn.Module):
    ch: int
    out_ch: int
    ch_mult: tuple
    num_res_blocks: int
    in_channels: int
    resolution: int
    z_channels: int
    double_z: bool

    def setup(self):
        self.num_resolutions = len(self.ch_mult)

        block_in = self.ch * self.ch_mult[self.num_resolutions - 1]
        curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
        self.z_shape = (1, self.z_channels, curr_res, curr_res)

        # z to block_in
        self.conv_in = nn.Conv(
            self.ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

        print(
            "Working with z of shape {} = {} dimensions.".format(
                self.z_shape, np.prod(self.z_shape)
            )
        )

        # middle
        self.mid = MidBlock(block_in)

        # upsampling
        upsample_blocks = []

        for i_level in reversed(range(self.num_resolutions)):
            upsample_blocks.append(
                UpsamplingBlock(
                    ch=self.ch,
                    ch_mult=self.ch_mult,
                    num_res_blocks=self.num_res_blocks,
                    resolution=self.resolution,
                    block_idx=i_level,
                )
            )
            if i_level != 0:
                curr_res = curr_res * 2
        self.up = list(reversed(upsample_blocks))  # reverse to get consistent order

        # end
        self.norm_out = Normalize(num_groups=5)
        self.conv_out = nn.Conv(
            self.out_ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

    def __call__(self, z):
        # z to block_in
        hs = self.conv_in(z)

        # middle
        hs = self.mid(hs)

        # upsampling
        for block in reversed(self.up):
            hs = block(hs)

        # end
        hs = self.norm_out(hs)
        hs = nn.swish(hs)
        hs = self.conv_out(hs)

        return hs

In [9]:
# Example usage
rng = jax.random.PRNGKey(0)
input_shape = (1, 16, 16, 5)

# Initialize ResnetBlock
decoder = Decoder(
    ch_mult=(1, 2, 4),
    num_res_blocks=2,
    double_z=False,
    z_channels=5,
    resolution=64,
    in_channels=5,
    out_ch=5,
    ch=5,
)

# Apply Resnet
params = decoder.init(rng, z=jnp.ones(input_shape))
output = decoder.apply(params, z=jnp.ones(input_shape))

# Print the output shape
print("Decoder output :", output.shape)  # Shape: (1, 64, 64, 5)

Working with z of shape (1, 5, 16, 16) = 1280 dimensions.
Working with z of shape (1, 5, 16, 16) = 1280 dimensions.
Decoder output : (1, 64, 64, 5)


In [10]:
class AutoencoderKLModule(nn.Module):
    ch: int
    out_ch: int
    ch_mult: tuple
    num_res_blocks: int
    in_channels: int
    resolution: int
    z_channels: int
    double_z: bool
    embed_dim: int

    def setup(self):
        self.encoder = Encoder(
            self.ch,
            self.out_ch,
            self.ch_mult,
            self.num_res_blocks,
            self.in_channels,
            self.resolution,
            self.z_channels,
            self.double_z,
        )
        self.decoder = Decoder(
            self.ch,
            self.out_ch,
            self.ch_mult,
            self.num_res_blocks,
            self.in_channels,
            self.resolution,
            self.z_channels,
            self.double_z,
        )
        self.quant_conv = nn.Conv(
            2 * self.embed_dim,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
        )
        self.post_quant_conv = nn.Conv(
            self.z_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
        )

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        print("Moments shape :", moments.shape)
        posterior = tfd.MultivariateNormalDiag(
            loc=moments[..., : self.z_channels],
            scale_diag=moments[..., self.z_channels :],
        )
        print("Posterior :", posterior)
        # posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, h):
        h = self.post_quant_conv(h)
        h = self.decoder(h)
        # Image is now 64x64x5
        x_ = tfd.MultivariateNormalDiag(
            loc=h, scale_diag=[0.01, 0.01, 0.01, 0.01, 0.01]
        )
        return x_

    def __call__(self, x, seed):
        posterior = self.encode(x)
        # rng = self.make_rng('gaussian')
        # key, subkey = jax.random.split(rng)
        h = posterior.sample(seed=seed)
        x_ = self.decode(h)
        return x_, posterior

In [11]:
# Example usage
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
# rngs = {'params': rng, 'seed': rng_2}
input_shape = (1, 64, 64, 5)

# Initialize ResnetBlock
autoencoder = AutoencoderKLModule(
    ch_mult=(1, 2, 4),
    num_res_blocks=2,
    double_z=True,
    z_channels=5,
    resolution=64,
    in_channels=5,
    out_ch=5,
    ch=5,
    embed_dim=5,
)

# Apply Resnet
params = autoencoder.init(rng, x=jnp.ones(input_shape), seed=rng_2)

# rngs = {'params': params, 'gaussian': jax.random.PRNGKey(1)}
print(rng_2)

rng_1, rng_2 = jax.random.split(rng_2)

# params
output = autoencoder.apply(params, x=jnp.ones(input_shape), seed=rng_2)
print(rng_2)

# Print the output shape
print("Decoder output :", output[0])  # Shape: (1, 64, 64, 5)

x : (1, 64, 64, 5)
Conv_in : (1, 64, 64, 5)
Down : (1, 16, 16, 20)
Mid : (1, 16, 16, 20)
Conv_out : (1, 16, 16, 10)
Moments shape : (1, 16, 16, 10)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[1, 16, 16], event_shape=[5], dtype=float32)
Working with z of shape (1, 5, 16, 16) = 1280 dimensions.
[0 1]
x : (1, 64, 64, 5)
Conv_in : (1, 64, 64, 5)
Down : (1, 16, 16, 20)
Mid : (1, 16, 16, 20)
Conv_out : (1, 16, 16, 10)
Moments shape : (1, 16, 16, 10)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[1, 16, 16], event_shape=[5], dtype=float32)
Working with z of shape (1, 5, 16, 16) = 1280 dimensions.
[3819641963 2025898573]
Decoder output : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[1, 64, 64], event_shape=[5], dtype=float32)
