In [1]:
import tensorflow as tf
import os
import jax
from jax.lib import xla_bridge

# Set the CUDA_VISIBLE_DEVICES environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = str('0')
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/local/cuda-12.1"
os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"

# Checking for GPU access
print("Device: {}".format(xla_bridge.get_backend().platform))

# Checking the GPU available
gpus = jax.devices("gpu")
print("Number of avaliable devices : {}".format(len(gpus)))

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")

2023-09-11 14:29:29.048975: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-11 14:29:29.096443: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Device: gpu
Number of avaliable devices : 1


In [2]:
import sys
import tensorflow_datasets as tfds

import numpy as np
import jax.numpy as jnp
import optax
import wandb
import logging

from astropy.stats import mad_std
from tensorflow_probability.substrates import jax as tfp
from flax import linen as nn  # Linen API
from jax import random

from tqdm.auto import tqdm

# Loading the dataset and transforming it to NumPy Arrays
train_dset, info = tfds.load(name='hsc_photoz', with_info=True, split="train")

# What's in our dataset:
# info

# Let's collect a few examples to check their distributions
cutouts = []
specz = []
for entry in train_dset.take(1000):
    specz.append(entry["attrs"]["specz_redshift"])
    cutouts.append(entry["image"])

cutouts = np.stack(cutouts)
specz = np.stack(specz)

scaling = []

for i, _ in enumerate(["g", "r", "i", "z", "y"]):
    sigma = mad_std(
        cutouts[..., i].flatten()
    )  # Capturing the std devation of each band
    scaling.append(sigma)

# Using a mapping function to apply preprocessing to our data
def preprocessing(example):
    img = tf.math.asinh(example["image"] / tf.constant(scaling) / 3.0)
    # We return the image as our input and output for a generative model
    return img

def input_fn(mode="train", batch_size=64):
    """
    mode: 'train' or 'test'
    """
    if mode == "train":
        dataset = tfds.load('hsc_photoz', split="train[:80%]")
        dataset = dataset.repeat()
        dataset = dataset.shuffle(10000)
    else:
        dataset = tfds.load('hsc_photoz', split="train[80%:]")

    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.map(preprocessing)  # Apply data preprocessing
    dataset = dataset.prefetch(
        -1
    )  # fetch next batches while training current one (-1 for autotune)
    return dataset

# Dataset as a numpy iterator
dset = input_fn().as_numpy_iterator()

# Generating a random key for JAX
rng = random.PRNGKey(0)
# Size of the input to initialize the encoder parameters
batch_enc = jnp.ones((1, 64, 64, 5))

latent_dim = 64
c_hidden_enc = (64, 128, 256)
num_blocks_enc = (1, 1, 1)
c_hidden_dec = (128, 64, 32, 5)
num_blocks_dec = (1, 1, 1, 1)

# Size of the input to initialize the decoder parameters
batch_dec = jnp.ones((1, 4, 4, 64))

act_fn = nn.leaky_relu

In [3]:
from flax import linen as nn  # Linen API
from tensorflow_probability.substrates import jax as tfp

# Loading distributions and bijectors from TensorFlow Probability (JAX version)
tfd = tfp.distributions
tfb = tfp.bijectors


class ResNetBlock(nn.Module):
    """Creates a block of a CNN with ResNet architecture to encode or decode images."""

    act_fn: callable  # Activation function
    c_out: int  # Output feature size
    subsample: bool = False  # If True, we apply a stride inside F

    @nn.compact
    def __call__(self, x, encode=True):
        # Network representing F
        print("Input X Resnet", x.shape)
        if encode:
            z = nn.Conv(self.c_out, kernel_size=(3, 3), padding="SAME", strides=(2, 2))(
                x
            )
            print("Z Resnet", z.shape)
            z = self.act_fn(z)
            x = nn.Conv(self.c_out, kernel_size=(3, 3), padding="SAME", strides=(2, 2))(
                x
            )
            print("Sum X Resnet", x.shape)

        else:
            z = nn.ConvTranspose(
                self.c_out, kernel_size=(3, 3), padding="SAME", strides=(2, 2)
            )(x)
            print("Z Resnet", z.shape)
            z = self.act_fn(z)
            x = nn.ConvTranspose(
                self.c_out, kernel_size=(3, 3), padding="SAME", strides=(2, 2)
            )(x)
            print("Sum X Resnet", x.shape)

        x_out = self.act_fn(z + x)
        return x_out


class ResNetEnc(nn.Module):
    """ "Creates a small convolutional encoder using ResNet blocks as intermediate layers"""

    act_fn: callable
    block_class: nn.Module
    num_blocks: tuple = (1, 1, 1)
    c_hidden: tuple = (64, 128, 256)
    latent_dim: int = 64

    @nn.compact
    def __call__(self, x, encode=True):
        # A first convolution on the original image to scale up the channel size
        print(x.shape)
        x = nn.Conv(
            self.latent_dim, kernel_size=(3, 3), padding="SAME", strides=(2, 2)
        )(x)
        x = self.act_fn(x)
        print(x.shape)

        # Creating the ResNet blocks
        for block_idx, block_count in enumerate(self.num_blocks):
            for bc in range(block_count):
                # Subsample the first block of each group, except the very first one.
                subsample = bc == 0 and block_idx > 0
                # ResNet block
                x = self.block_class(
                    c_out=self.c_hidden[block_idx],
                    act_fn=self.act_fn,
                    subsample=subsample,
                )(x, encode=True)

        net = nn.Dense(features=self.latent_dim * 2)(x)
        # Image is now 4x4x128
        print("Dense shape", net.shape, "\n")

        q = tfd.MultivariateNormalDiag(
            loc=net[..., : self.latent_dim], scale_diag=net[..., self.latent_dim :]
        )

        return q


class ResNetDec(nn.Module):
    """ "Creates a small convolutional decoder using ResNet blocks as intermediate layers"""

    act_fn: callable
    block_class: nn.Module
    num_blocks: tuple = (1, 1, 1, 1)
    c_hidden: tuple = (128, 64, 32, 5)

    @nn.compact
    def __call__(self, x, encode=False):
        # Creating the ResNet blocks
        for block_idx, block_count in enumerate(self.num_blocks):
            for bc in range(block_count):
                # Subsample the first block of each group, except the very first one.
                subsample = bc == 0 and block_idx > 0
                # ResNet block
                x = self.block_class(
                    c_out=self.c_hidden[block_idx],
                    act_fn=self.act_fn,
                    subsample=subsample,
                )(x, encode=False)

        x = nn.activation.softplus(x)
        # Image is now 64x64x5
        r = tfd.MultivariateNormalDiag(loc=x, scale_diag=[0.01, 0.01, 0.01, 0.01, 0.01])

        return r

In [4]:
# Initializing the Encoder
Encoder = ResNetEnc(
    act_fn=act_fn,
    block_class=ResNetBlock,
    latent_dim=latent_dim,
    c_hidden=c_hidden_enc,
    num_blocks=num_blocks_enc,
)
params_enc = Encoder.init(rng, batch_enc)

# Taking 64 images of the dataset
batch_im = next(dset)
# Generating new keys to use them for inference
rng, rng_1 = random.split(rng)

# Initializing the Decoder
Decoder = ResNetDec(
    act_fn=act_fn,
    block_class=ResNetBlock,
    c_hidden=c_hidden_dec,
    num_blocks=num_blocks_dec,
)
params_dec = Decoder.init(rng_1, batch_dec)

# Defining a general list of the parameters
params = [params_enc, params_dec]

(1, 64, 64, 5)
(1, 32, 32, 64)
Input X Resnet (1, 32, 32, 64)
Z Resnet (1, 16, 16, 64)
Sum X Resnet (1, 16, 16, 64)
Input X Resnet (1, 16, 16, 64)
Z Resnet (1, 8, 8, 128)
Sum X Resnet (1, 8, 8, 128)
Input X Resnet (1, 8, 8, 128)
Z Resnet (1, 4, 4, 256)
Sum X Resnet (1, 4, 4, 256)
Dense shape (1, 4, 4, 128) 

Input X Resnet (1, 4, 4, 64)
Z Resnet (1, 8, 8, 128)
Sum X Resnet (1, 8, 8, 128)
Input X Resnet (1, 8, 8, 128)
Z Resnet (1, 16, 16, 64)
Sum X Resnet (1, 16, 16, 64)
Input X Resnet (1, 16, 16, 64)
Z Resnet (1, 32, 32, 32)
Sum X Resnet (1, 32, 32, 32)
Input X Resnet (1, 32, 32, 32)
Z Resnet (1, 64, 64, 5)
Sum X Resnet (1, 64, 64, 5)


In [5]:
import wandb
api = wandb.Api()

run = api.run("jonnyytorres/resnet-comp-dim/0swrd6wi")
artifact = api.artifact('jonnyytorres/resnet-comp-dim/0swrd6wi-checkpoint:best', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [6]:
from flax.serialization import to_state_dict, msgpack_serialize, from_bytes

def load_checkpoint(ckpt_file, state):
    """Loads the best Wandb checkpoint."""
    # artifact = wandb.use_artifact(f"artifacts/{run.id}-checkpoint:v72/")
    artifact_dir = f"artifacts/{run.id}-checkpoint:best/"
    ckpt_path = os.path.join(artifact_dir, ckpt_file)
    with open(ckpt_path, "rb") as data_file:
        byte_data = data_file.read()
    return from_bytes(state, byte_data)

In [7]:
 # Loading checkpoint for the best step
params = load_checkpoint("checkpoint.msgpack", params)

# Predicting over an example of data
dataset_eval = input_fn("test")
test_iterator = dataset_eval.as_numpy_iterator()
batch = next(test_iterator)
# Taking 16 images as example
batch = batch[:16, ...]

# Dividing the list of parameters obtained before
params_enc, params_dec = params
# Distribution of latent space calculated using the batch of data
q = ResNetEnc(
    act_fn=act_fn,
    block_class=ResNetBlock,
    latent_dim=latent_dim,
    c_hidden=c_hidden_enc,
    num_blocks=num_blocks_enc,
).apply(params_enc, batch)
# Sampling from the distribution
z = q.sample(seed=rng_1)

# Posterior distribution
p = ResNetDec(
    act_fn=act_fn,
    block_class=ResNetBlock,
    c_hidden=c_hidden_dec,
    num_blocks=num_blocks_dec,
).apply(params_dec, z)
# Sample some variables from the posterior distribution
rng, rng_1 = random.split(rng)
z = p.sample(seed=rng_1)

(16, 64, 64, 5)
(16, 32, 32, 64)
Input X Resnet (16, 32, 32, 64)
Z Resnet (16, 16, 16, 64)
Sum X Resnet (16, 16, 16, 64)
Input X Resnet (16, 16, 16, 64)
Z Resnet (16, 8, 8, 128)
Sum X Resnet (16, 8, 8, 128)
Input X Resnet (16, 8, 8, 128)
Z Resnet (16, 4, 4, 256)
Sum X Resnet (16, 4, 4, 256)
Dense shape (16, 4, 4, 128) 

Input X Resnet (16, 4, 4, 64)
Z Resnet (16, 8, 8, 128)
Sum X Resnet (16, 8, 8, 128)
Input X Resnet (16, 8, 8, 128)
Z Resnet (16, 16, 16, 64)
Sum X Resnet (16, 16, 16, 64)
Input X Resnet (16, 16, 16, 64)
Z Resnet (16, 32, 32, 32)
Sum X Resnet (16, 32, 32, 32)
Input X Resnet (16, 32, 32, 32)
Z Resnet (16, 64, 64, 5)
Sum X Resnet (16, 64, 64, 5)


In [8]:
# TO PLOT ORIGINAL HSC IMAGES IN BLUE

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10,10))
for i in range(16):
    z_img = batch[i, ...]
    plt.subplot(4,4,i+1)
    plt.imshow((z_img.mean(axis=-1)))
    plt.axis('off')

# Adjust the layout of the subplots
fig.tight_layout()

plt.savefig('hsc_examples.png')
plt.close(fig)

In [9]:
# TO PLOT ORIGINAL HSC IMAGES IN FALSE COLOR

from astropy.visualization import make_lupton_rgb

def luptonize(img):
  return make_lupton_rgb(img[:,:,2], img[:,:,1], img[:,:,0],
                         Q=15, stretch=0.5, minimum=0)

fig = plt.figure(figsize=(10,10))
for i, entry in enumerate(train_dset.take(16)):
    plt.subplot(4,4,i+1)
    plt.imshow(luptonize(entry['image']))
    # plt.imshow(luptonize(z_img))
    plt.axis('off')

# Adjust the layout of the subplots
fig.tight_layout()

plt.savefig('hsc_examples_color.png')
plt.close(fig)

In [10]:
import matplotlib.pyplot as plt

def save_samples(z, name):
    # Plotting 16 images of the estimated shape of galaxies
    num_rows, num_cols = 3, 8

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(25, 9))

    for i, (ax1, ax2, ax3) in enumerate(zip(axes[0, :], axes[1, :], axes[2, :])):
        batch_img = batch[i, ...]
        z_img = z[i, ...]

        # Plotting original image
        ax1.imshow(batch_img.mean(axis=-1))
        ax1.axis("off")
        # Plotting predicted image
        ax2.imshow(z_img.mean(axis=-1))
        ax2.axis("off")
        # Plotting difference between original and predicted image
        ax3.imshow(z_img.mean(axis=-1) - batch_img.mean(axis=-1))
        ax3.axis("off")

    # Add a title to the figure
    fig.suptitle("Samples of predicted galaxies", fontsize=16)

    # Adjust the layout of the subplots
    fig.tight_layout()

    plt.savefig(name)
    plt.close(fig)

In [11]:
# Saving the samples of the predicted images and their difference from the original images
save_samples(z, 'HSC_First_Model.png')

wandb.finish()

In [12]:
# Dataset as a numpy iterator
dset = input_fn().as_numpy_iterator()

# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
# Size of the input to initialize the encoder parameters
batch_autoenc = jnp.ones((1, 64, 64, 5))

latent_dim = 64
act_fn = nn.gelu

In [13]:
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from tensorflow_probability.substrates import jax as tfp

# Loading distributions from TensorFlow Probability (JAX version)
tfd = tfp.distributions


def Normalize(num_groups=10):
    return nn.GroupNorm(num_groups=num_groups, epsilon=1e-6, use_scale=True)


class Downsample(nn.Module):
    in_channels: int

    def setup(self):
        self.conv = nn.Conv(
            self.in_channels,
            kernel_size=(3, 3),
            strides=(2, 2),
            padding=((0, 1), (0, 1)),
        )

    def __call__(self, x):
        pad = ((0, 0), (0, 1), (0, 0), (0, 1))
        x = jnp.pad(x, pad, mode="constant", constant_values=0)
        x = self.conv(x)
        return x


class ResnetBlock(nn.Module):
    in_channels: int
    out_channels: int
    act_fn: callable = nn.gelu  # Activation function

    def setup(self):
        self.norm1 = Normalize(num_groups=5)
        self.conv1 = nn.Conv(
            self.out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )
        self.norm2 = Normalize(num_groups=5)
        self.conv2 = nn.Conv(
            self.out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

        self.nin_shortcut = nn.Conv(
            self.out_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding=((0, 0), (0, 0)),
        )

    def __call__(self, x):
        h = x
        h = self.norm1(h)
        h = self.act_fn(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = self.act_fn(h)
        h = self.conv2(h)
        x = self.nin_shortcut(x)
        x_ = x + h

        return x + h


class DownsamplingBlock(nn.Module):
    ch: int
    ch_mult: tuple
    num_res_blocks: int
    resolution: int
    block_idx: int
    act_fn: callable = nn.gelu  # Activation function

    def setup(self):
        self.ch_mult_ = self.ch_mult
        self.num_resolutions = len(self.ch_mult_)
        in_ch_mult = (1,) + tuple(self.ch_mult_)
        block_in = self.ch * in_ch_mult[self.block_idx]
        block_out = self.ch * self.ch_mult_[self.block_idx]

        res_blocks = []
        for _ in range(self.num_res_blocks):
            res_blocks.append(ResnetBlock(block_in, block_out, self.act_fn))
        block_in = block_out
        self.block = res_blocks

        self.downsample = None
        if self.block_idx != self.num_resolutions - 1:
            self.downsample = Downsample(block_in)

    def __call__(self, h):
        for i, res_block in enumerate(self.block):
            h = res_block(h)

        if self.downsample is not None:
            h = self.downsample(h)

        return h


class MidBlock(nn.Module):
    in_channels: int

    def setup(self):
        self.block_1 = ResnetBlock(
            self.in_channels,
            self.in_channels,
        )
        self.block_2 = ResnetBlock(
            self.in_channels,
            self.in_channels,
        )

    def __call__(self, h):
        h = self.block_1(h)
        h = self.block_2(h)

        return h


class Encoder(nn.Module):
    ch: int
    out_ch: int
    ch_mult: tuple
    num_res_blocks: int
    in_channels: int
    resolution: int
    z_channels: int
    double_z: bool
    act_fn: callable = nn.gelu  # Activation function

    def setup(self):
        self.num_resolutions = len(self.ch_mult)

        # downsampling
        self.conv_in = nn.Conv(
            self.ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

        curr_res = self.resolution
        downsample_blocks = []

        for i_level in range(self.num_resolutions):
            downsample_blocks.append(
                DownsamplingBlock(
                    ch=self.ch,
                    ch_mult=self.ch_mult,
                    num_res_blocks=self.num_res_blocks,
                    resolution=self.resolution,
                    block_idx=i_level,
                    act_fn=self.act_fn,
                )
            )
            if i_level != self.num_resolutions - 1:
                curr_res = curr_res // 2

        self.down = downsample_blocks

        # middle
        mid_channels = self.ch * self.ch_mult[-1]
        self.mid = MidBlock(mid_channels)
        # end
        self.norm_out = Normalize()
        self.conv_out = nn.Conv(
            self.z_channels * 2 if self.double_z else self.z_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

    def __call__(self, x):
        # downsampling
        print("x :", x.shape)
        hs = self.conv_in(x)
        print("Conv_in :", hs.shape)
        for block in self.down:
            hs = block(hs)
        print("Down :", hs.shape)

        # middle
        hs = self.mid(hs)
        print("Mid :", hs.shape)

        # end
        hs = self.norm_out(hs)
        hs = self.act_fn(hs)
        hs = self.conv_out(hs)
        print("Conv_out :", hs.shape)

        return hs


class Upsample(nn.Module):
    in_channels: int

    def setup(self):
        self.conv = nn.Conv(
            self.in_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

    def __call__(self, hs):
        batch, height, width, channels = hs.shape
        hs = jax.image.resize(
            hs,
            shape=(batch, height * 2, width * 2, channels),
            method="bicubic",
        )
        hs = self.conv(hs)
        return hs


class UpsamplingBlock(nn.Module):
    ch: int
    ch_mult: tuple
    num_res_blocks: int
    resolution: int
    block_idx: int
    act_fn: callable = nn.gelu  # Activation function

    def setup(self):
        self.ch_mult_ = self.ch_mult
        self.num_resolutions = len(self.ch_mult_)

        if self.block_idx == self.num_resolutions - 1:
            block_in = self.ch * self.ch_mult_[-1]
        else:
            block_in = self.ch * self.ch_mult_[self.block_idx + 1]

        block_out = self.ch * self.ch_mult_[self.block_idx]

        res_blocks = []
        for _ in range(self.num_res_blocks + 1):
            res_blocks.append(ResnetBlock(block_in, block_out, self.act_fn))

        block_in = block_out

        self.block = res_blocks

        self.upsample = None
        if self.block_idx != 0:
            self.upsample = Upsample(block_in)

    def __call__(self, h):
        for i, res_block in enumerate(self.block):
            h = res_block(h)

        if self.upsample is not None:
            h = self.upsample(h)

        return h


class Decoder(nn.Module):
    ch: int
    out_ch: int
    ch_mult: tuple
    num_res_blocks: int
    in_channels: int
    resolution: int
    z_channels: int
    double_z: bool
    act_fn: callable = nn.gelu  # Activation function

    def setup(self):
        self.num_resolutions = len(self.ch_mult)

        block_in = self.ch * self.ch_mult[self.num_resolutions - 1]
        curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
        self.z_shape = (1, self.z_channels, curr_res, curr_res)

        # z to block_in
        self.conv_in = nn.Conv(
            self.ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

        print(
            "Working with z of shape {} = {} dimensions.".format(
                self.z_shape, np.prod(self.z_shape)
            )
        )

        # middle
        self.mid = MidBlock(block_in)

        # upsampling
        upsample_blocks = []

        for i_level in reversed(range(self.num_resolutions)):
            upsample_blocks.append(
                UpsamplingBlock(
                    ch=self.ch,
                    ch_mult=self.ch_mult,
                    num_res_blocks=self.num_res_blocks,
                    resolution=self.resolution,
                    block_idx=i_level,
                    act_fn=self.act_fn,
                )
            )
            if i_level != 0:
                curr_res = curr_res * 2
        self.up = list(reversed(upsample_blocks))  # reverse to get consistent order

        # end
        self.norm_out = Normalize(num_groups=5)
        self.conv_out = nn.Conv(
            self.out_ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
        )

    def __call__(self, z):
        # z to block_in
        hs = self.conv_in(z)

        # middle
        hs = self.mid(hs)

        # upsampling
        for block in reversed(self.up):
            hs = block(hs)

        # end
        hs = self.norm_out(hs)
        hs = self.act_fn(hs)
        hs = self.conv_out(hs)

        return hs


class AutoencoderKLModule(nn.Module):
    ch: int
    out_ch: int
    ch_mult: tuple
    num_res_blocks: int
    in_channels: int
    resolution: int
    z_channels: int
    double_z: bool
    embed_dim: int
    act_fn: callable = nn.gelu  # Activation function

    def setup(self):
        self.encoder = Encoder(
            self.ch,
            self.out_ch,
            self.ch_mult,
            self.num_res_blocks,
            self.in_channels,
            self.resolution,
            self.z_channels,
            self.double_z,
            self.act_fn,
        )
        self.decoder = Decoder(
            self.ch,
            self.out_ch,
            self.ch_mult,
            self.num_res_blocks,
            self.in_channels,
            self.resolution,
            self.z_channels,
            self.double_z,
            self.act_fn,
        )
        self.quant_conv = nn.Conv(
            2 * self.embed_dim,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
        )
        self.post_quant_conv = nn.Conv(
            self.z_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
        )

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        print("Moments shape :", moments.shape)
        posterior = tfd.MultivariateNormalDiag(
            loc=moments[..., : self.z_channels],
            scale_diag=moments[..., self.z_channels :],
        )
        print("Posterior :", posterior)

        return posterior

    def decode(self, h):
        h = self.post_quant_conv(h)
        h = self.decoder(h)
        # Image is now 64x64x5
        q = tfd.MultivariateNormalDiag(loc=h, scale_diag=[0.01, 0.01, 0.01, 0.01, 0.01])
        return q

    def __call__(self, x, seed):
        posterior = self.encode(x)
        h = posterior.sample(seed=seed)
        q = self.decode(h)

        return q  # , posterior

In [14]:
# Initializing the AutoEncoder
Autoencoder = AutoencoderKLModule(
    ch_mult=(1, 2, 4),
    num_res_blocks=1,
    double_z=True,
    z_channels=5,
    resolution=latent_dim,
    in_channels=5,
    out_ch=5,
    ch=5,
    embed_dim=5,
    act_fn=act_fn,
)

params = Autoencoder.init(rng, x=batch_autoenc, seed=rng_2)

# Taking 64 images of the dataset
batch_im = next(dset)
# Generating new keys to use them for inference
rng_1, rng_2 = jax.random.split(rng_2)

x : (1, 64, 64, 5)
Conv_in : (1, 64, 64, 5)
Down : (1, 16, 16, 20)
Mid : (1, 16, 16, 20)
Conv_out : (1, 16, 16, 10)
Moments shape : (1, 16, 16, 10)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[1, 16, 16], event_shape=[5], dtype=float32)
Working with z of shape (1, 5, 16, 16) = 1280 dimensions.


In [15]:
import wandb
api = wandb.Api()

run = api.run("jonnyytorres/VAE-SD/0jld85s1")
artifact = api.artifact('jonnyytorres/VAE-SD/0jld85s1-checkpoint:best', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [16]:
# Loading checkpoint for the best step
params2 = load_checkpoint("checkpoint.msgpack", params)

In [17]:
# Predicting over an example of data
dataset_eval = input_fn("test")
test_iterator = dataset_eval.as_numpy_iterator()
batch = next(test_iterator)
# Taking 16 images as example
batch = batch[:16, ...]

rng, rng_1 = random.split(rng)
# X estimated distribution
q2 = Autoencoder.apply(params2, x=batch, seed=rng_1)
# Sample some variables from the posterior distribution
rng, rng_1 = random.split(rng)
z2 = q2.sample(seed=rng_1)

x : (16, 64, 64, 5)
Conv_in : (16, 64, 64, 5)
Down : (16, 16, 16, 20)
Mid : (16, 16, 16, 20)
Conv_out : (16, 16, 16, 10)
Moments shape : (16, 16, 16, 10)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[16, 16, 16], event_shape=[5], dtype=float32)
Working with z of shape (1, 5, 16, 16) = 1280 dimensions.


In [18]:
# Saving the samples of the predicted images and their difference from the original images
save_samples(z2, 'HSC_Second_Model.png')

wandb.finish()

In [25]:
def norm_values_diff(orig, inf1, inf2, num_images=3):
    min_values = []
    max_values = []

    for i in range(num_images):
        
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]
        inf2_img = inf2[i, ...]

        diff_1 = inf1_img - orig_img 
        diff_2 = inf2_img - orig_img

        min_values.append(np.minimum(diff_1.mean(axis=-1).min(), diff_2.mean(axis=-1).min()))
        max_values.append(np.minimum(diff_1.mean(axis=-1).max(), diff_2.mean(axis=-1).max()))

    return [np.min(min_values), np.max(max_values)]

In [26]:
min_value, max_value = norm_values_diff(batch, z, z2, num_images=3)

In [37]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

def diff_inference(orig, inf1, inf2, name):
    
    # Plotting the original, predicted and their differences for 8 examples
    num_rows, num_cols = 3, 6

    plt.figure(figsize=(18, 9))
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 9))

    # divider = make_axes_locatable(axes)
    # cax = divider.append_axes('right', size='5%', pad=0.05)

    for i, (ax1, ax2, ax3, ax4, ax5, ax6) in enumerate(zip(axes[:,0], axes[:,1], axes[:,2], axes[:,3], axes[:,4], axes[:,5])):
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]
        inf2_img = inf2[i, ...]

        diff_1 = inf1_img - orig_img 
        diff_2 = inf2_img - orig_img

        # Plotting original image
        ax1.imshow(orig_img.mean(axis=-1))
        ax1.axis("off")
        
        # Plotting predicted image - Model 1
        ax2.imshow(inf1_img.mean(axis=-1))
        ax2.axis("off")
        
        # Plotting predicted image - Model 2
        ax3.imshow(inf2_img.mean(axis=-1))
        ax3.axis("off")
        
        # Plotting difference between original and predicted image - Model 1
        im = ax4.imshow(diff_1.mean(axis=-1), vmin=min_value, vmax=max_value)
        ax4.axis("off")
        
        # Plotting difference between original and predicted image - Model 2
        im = ax5.imshow(diff_2.mean(axis=-1), vmin=min_value, vmax=max_value)
        ax5.axis("off")

        ax6.axis("off")
        
        if i ==0:
            ax1.text(2, 2, "Original images", verticalalignment='top', fontsize=10, color="white", weight="bold")
            ax2.text(2, 2, "Reference model", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax3.text(2, 2, "Second tested model", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax4.text(2, 2, "Difference original with \n reference model images", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax5.text(2, 2, "Difference original with \n second model images", verticalalignment='top', fontsize=10, color="white", weight="bold") 

    # # Add a title to the figure
    # fig.suptitle(
    #     "Comparison between original and sampled images for both models", fontsize=10, y=0.99
    # )
    
    # Adding colorbar
    # fig.colorbar(ax5.imshow(inf2_img.mean(axis=-1) - orig_img.mean(axis=-1), cax=cax, orientation='vertical'))
        
    # Adjust the layout of the subplots
    fig.tight_layout()

    # Create a colorbar for the "im" object
    cb_ax = fig.add_axes([0.835,.02,.015,.965])
    cbar = plt.colorbar(im, ax=[ax4, ax5], label="Pixel Value", orientation="vertical", aspect=120, cax=cb_ax)

    # Adjust the font size of colorbar labels and ticks
    cbar.ax.tick_params(labelsize=12)  # Adjust the fontsize as needed
    
    # Save plot as image
    plt.savefig(name)
    plt.close(fig)


In [38]:
# Saving the differences between the predicted images of both models and their difference from the original images
diff_inference(batch, z, z2, "difference_pred_models.png")

<Figure size 1800x900 with 0 Axes>