In [1]:
import tensorflow as tf
import os
import jax
from jax.lib import xla_bridge

# Set the CUDA_VISIBLE_DEVICES environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = str('0')
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/local/cuda-12.1"
os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"

# Checking for GPU access
print("Device: {}".format(xla_bridge.get_backend().platform))

# Checking the GPU available
gpus = jax.devices("gpu")
print("Number of avaliable devices : {}".format(len(gpus)))

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")

2023-09-11 14:47:15.546520: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-11 14:47:15.595909: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Device: gpu
Number of avaliable devices : 1


In [2]:
import tensorflow_datasets as tfds
from galsim_jax.datasets import cosmos
import numpy as np
from astropy.stats import mad_std
import jax.numpy as jnp
from flax import linen as nn

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

In [3]:
import galsim as gs
from galsim.bounds import _BoundsI
# indices = [80100, 80101, 80102, 80103, 80104, 80105, 80106, 80107, 80108, 80203, 80252, 81005]
indices = [80101, 80102, 80203] #, 80252, 81005]
cat = gs.COSMOSCatalog()

STAMP_SIZE = 128
PIXEL_SCALE = 0.03

def generate_examples(cat, indices):

    for j in range(len(indices)):
        i = indices[j]
    
        # Drawing galaxy from cosmos
        gal = cat.makeGalaxy(i, gal_type='real', noise_pad_size=0.8*PIXEL_SCALE*STAMP_SIZE)
        psf = gal.original_psf
        real = gs.Convolve(psf, gal)
        
        print(gal.index)
        # cosmos_gal = gs.Convolve(gal[i], gal[i].original_psf)
    
        cosmos_stamp = real.drawImage(
            nx=STAMP_SIZE,
            ny=STAMP_SIZE,
            scale=PIXEL_SCALE,
            method="no_pixel",
        ).array.astype("float32")
    
        interp_factor = 1
        padding_factor = 1
        Nk = STAMP_SIZE * interp_factor * padding_factor
        bounds = _BoundsI(0, Nk // 2, -Nk // 2, Nk // 2 - 1)
    
        imkpsf = gal.original_psf.drawKImage(
            bounds=bounds,
            scale=2.0
            * np.pi
            / (
                STAMP_SIZE
                * padding_factor
                * PIXEL_SCALE
            ),
            recenter=False,
        )
    
        kpsf = np.fft.fftshift(imkpsf.array, 0).astype("complex64")
        kpsf_real = kpsf.real
        kpsf_imag = kpsf.imag
    
        # Pixel noise standard deviation
        noise_std = np.sqrt(real.noise.getVariance())
        
        # Noise power spectrum
        # from
        # https://github.com/ml4astro/galaxy2galaxy/blob/6d8b20722a5545c8c79a19cb67c6131c061763ed/galaxy2galaxy/data_generators/galsim_utils.py#L146
        
        bounds = _BoundsI(0, 
                          STAMP_SIZE//2, 
                          -STAMP_SIZE//2, 
                          STAMP_SIZE//2-1
                          )
        imG = real.drawKImage(bounds=bounds,
                                scale=2.*np.pi/(STAMP_SIZE * PIXEL_SCALE),
                                recenter=False)
        mask = ~(np.fft.fftshift(imG.array, axes=0) == 0)
    
        ps = real.noise._get_update_rootps((STAMP_SIZE, STAMP_SIZE), 
                                    wcs=gs.PixelScale(PIXEL_SCALE))
    
        rt2 = np.sqrt(2.)
        shape = (STAMP_SIZE, STAMP_SIZE)
        ps[0, 0] = rt2 * ps[0, 0]
        # Then make the changes necessary for even sized arrays
        if shape[1] % 2 == 0:  # x dimension even
            ps[0, shape[1] // 2] = rt2 * ps[0, shape[1] // 2]
        if shape[0] % 2 == 0:  # y dimension even
            ps[shape[0] // 2, 0] = rt2 * ps[shape[0] // 2, 0]
            # Both dimensions even
            if shape[1] % 2 == 0:
                ps[shape[0] // 2, shape[1] // 2] = rt2 * \
                    ps[shape[0] // 2, shape[1] // 2]
    
        ps = np.where(mask, np.log(ps**2), 10).astype('float32')
        
        yield "%d" % gal.index, {
            "image": cosmos_stamp,
            "kpsf_real": kpsf_real,
            "kpsf_imag": kpsf_imag,
            "noise_std": noise_std,
            "ps": ps,
        }

In [4]:
examples_fr = generate_examples(cat, indices)

In [5]:
# for ex in examples_fr:
#     print(ex[1]["image"][0][0])

In [6]:
# examples_fr_array = np.asarray(list(examples_fr))

In [7]:
# examples_fr_array = np.fromiter(examples_fr, dtype=float)

In [8]:
galaxies_images = []
kpsf_real_img   = []
kpsf_imag_img   = []
std_list        = []
indexes         = []

def fr_examples(examples_fr):
    for ex in examples_fr:
        print("Index: ", ex[0])
        indexes.append(ex[0])
        
        image_with_channel = np.expand_dims(ex[1]["image"], axis=-1)
        galaxies_images.append(image_with_channel)

        kpsf_real_with_channel = np.expand_dims(ex[1]["kpsf_real"], axis=-1)
        kpsf_real_img.append(kpsf_real_with_channel)

        kpsf_imag_with_channel = np.expand_dims(ex[1]["kpsf_imag"], axis=-1)
        kpsf_imag_img.append(kpsf_imag_with_channel)

        std_list.append(ex[1]["noise_std"])
        
    images = np.stack(galaxies_images, axis=0)
    psf_real = np.stack(kpsf_real_img, axis=0)
    psf_img = np.stack(kpsf_imag_img, axis=0)
    std = np.stack(std_list, axis=-1)

    return images, psf_real, psf_img, std, indexes
    
images, psf_real, psf_img, std, indexes = fr_examples(examples_fr)

86297
Index:  86297
86298
Index:  86298
86405
Index:  86405


In [9]:
# images[0][0][0][0]

In [10]:
psf_real.shape

(3, 128, 65, 1)

In [11]:
import os
import numpy as np

def load_galaxies(folder_path):
    """
    Load all the npy files contained in the specified folder and stack them into a single NumPy array.

    Parameters:
        folder_path (str): The path to the folder containing the npy files.

    Returns:
        np.ndarray: A NumPy array containing all the data from the .npy files.
                    The shape of the array will be (num_files, width, height, 1), where:
                    - num_files: number of .npy files in the folder
                    - width: width of each individual array in the .npy files
                    - height: height of each individual array in the .npy files
    """
    galaxies_data = []
    if not os.path.exists(folder_path):
        raise ValueError("The folder does not exist!")

    # Get a sorted list of .npy files in ascending order
    sorted_files = sorted([filename for filename in os.listdir(folder_path) if filename.endswith(".npy")])

    for filename in sorted_files:
        file_path = os.path.join(folder_path, filename)
        if os.path.isfile(file_path):
            data = np.load(file_path)
            data_with_channel = np.expand_dims(data, axis=-1)
            galaxies_data.append(data_with_channel)

    stacked_data = np.stack(galaxies_data)
    return stacked_data

In [12]:
galaxies_data = load_galaxies('galaxies-pres')

In [13]:
# galaxies_data.shape

In [14]:
# import matplotlib.pyplot as plt

# # Plotting the original, predicted and their differences for 8 examples
# num_rows, num_cols = 3, 4

# plt.figure(figsize=(10, 8))

# fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 8))

# for ax, z_img in zip(axes.flatten(), images):
#     ax.imshow(z_img.mean(axis=-1))
#     ax.axis("off")

# # Add a title to the figure
# fig.suptitle("Original galaxies", fontsize=12, y=0.99)

# # Adjust the layout of the subplots
# fig.tight_layout()

# plt.savefig('Original_FR_galaxies_2.png')
# plt.close(fig)

In [15]:
# import matplotlib.pyplot as plt

# fig = plt.figure(frameon=False)
# fig.set_size_inches(3,3)

# ax = plt.Axes(fig, [0., 0., 1., 1.])
# ax.set_axis_off()
# fig.add_axes(ax)

# ax.imshow(images[9, ...].mean(axis=-1), aspect='auto')

# fig.savefig('Original_galaxy_presentation.png')
# plt.close(fig)

In [16]:
# import matplotlib.pyplot as plt

# fig = plt.figure(frameon=False)
# fig.set_size_inches(3,3)

# ax = plt.Axes(fig, [0., 0., 1., 1.])
# ax.set_axis_off()
# fig.add_axes(ax)

# ax.imshow(images[9, 32:96, 32:96, ...].mean(axis=-1), aspect='auto')

# plt.savefig('Original_galaxy_presentation_2.png')
# plt.close(fig)

In [17]:
# import matplotlib.pyplot as plt

# fig = plt.figure(frameon=False)
# fig.set_size_inches(3,3)

# ax = plt.Axes(fig, [0., 0., 1., 1.])
# ax.set_axis_off()
# fig.add_axes(ax)

# ax.imshow(images[9, 48:80, 48:80, ...].mean(axis=-1), aspect='auto')

# plt.savefig('Original_galaxy_presentation_3.png')
# plt.close(fig)

In [18]:
from galsim_jax.dif_models import AutoencoderKLModule

# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
# Size of the input to initialize the encoder parameters
batch_autoenc = jnp.ones((1, 128, 128, 1))

latent_dim = 128
act_fn = nn.gelu

# Initializing the AutoEncoder
Autoencoder = AutoencoderKLModule(
    ch_mult=(1, 2, 4),
    num_res_blocks=2,
    double_z=True,
    z_channels=1,
    resolution=latent_dim,
    in_channels=1,
    out_ch=1,
    ch=1,
    embed_dim=1,
    act_fn=act_fn,
)

params = Autoencoder.init(rng, x=batch_autoenc, seed=rng_2)

x : (1, 128, 128, 1)
Conv_in : (1, 128, 128, 1)
Down : (1, 32, 32, 4)
Mid : (1, 32, 32, 4)
Conv_out : (1, 32, 32, 2)
Moments shape : (1, 32, 32, 2)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[1, 32, 32], event_shape=[1], dtype=float32)
Working with z of shape (1, 1, 32, 32) = 1024 dimensions.


In [19]:
import wandb
api = wandb.Api()

run = api.run("jonnyytorres/VAE-SD/iebszvhp")
artifact = api.artifact('jonnyytorres/VAE-SD/iebszvhp-checkpoint:best', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [20]:
from flax.serialization import to_state_dict, msgpack_serialize, from_bytes

def load_checkpoint(ckpt_file, state):
    """Loads the best Wandb checkpoint."""
    # artifact = wandb.use_artifact(f"artifacts/{run.id}-checkpoint:v72/")
    artifact_dir = f"artifacts/{run.id}-checkpoint:best/"
    ckpt_path = os.path.join(artifact_dir, ckpt_file)
    with open(ckpt_path, "rb") as data_file:
        byte_data = data_file.read()
    return from_bytes(state, byte_data)

In [21]:
# Loading checkpoint for the best step
# wandb.init()
params = load_checkpoint("checkpoint.msgpack", params)

In [22]:
def save_samples(z, name):
    # Plotting 16 images of the estimated shape of galaxies
    num_rows, num_cols = 3, 4

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 8))

    for ax, z_img in zip(axes.flatten(), z):
        ax.imshow(tf.math.reduce_mean(z_img, axis=-1))
        # ax.imshow(z_img.mean(axis=-1))
        ax.axis("off")

    # Add a title to the figure
    fig.suptitle("Samples of predicted galaxies", fontsize=16)

    # Adjust the layout of the subplots
    fig.tight_layout()

    plt.savefig(name)
    plt.close(fig)

In [23]:
from galsim_jax.convolution import convolve_kpsf

import matplotlib.pyplot as plt
# Predicting over an example of data

x = images
kpsf_real = psf_real 
kpsf_imag = psf_img
kpsf = kpsf_real + 1j*kpsf_imag
std = std.reshape((-1, 1, 1, 1))

# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)

rng, rng_1 = jax.random.split(rng)
# X estimated distribution
q = Autoencoder.apply(params, x=x, seed=rng_1)
z = q

# Saving the samples of the predicted images and their difference from the original images
# save_samples(z, 'Autoencoder_samples.png')

x : (3, 128, 128, 1)
Conv_in : (3, 128, 128, 1)
Down : (3, 32, 32, 4)
Mid : (3, 32, 32, 4)
Conv_out : (3, 32, 32, 2)
Moments shape : (3, 32, 32, 2)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[3, 32, 32], event_shape=[1], dtype=float32)
Working with z of shape (1, 1, 32, 32) = 1024 dimensions.


In [24]:
q[..., 0].shape

(3, 128, 128)

In [25]:
kpsf[..., 0].shape

(3, 128, 65)

In [26]:
p = jax.vmap(convolve_kpsf)(q[..., 0], kpsf[..., 0])

p = tf.expand_dims(p, axis=-1)

z = p

# Saving the samples of the predicted images and their difference from the original images
# save_samples(z, 'Convolve_samples.png')

In [27]:
# Converting array into float32
std = np.float32(std)

p = tfd.MultivariateNormalDiag(loc=p, scale_diag=std)

pred_model_2 = p.sample(seed=rng_1)

# Saving the samples of the predicted images and their difference from the original images
# save_samples(pred_model_2, 'MNDist_samples.png')

In [28]:
galaxies_data.shape

(3, 128, 128, 1)

In [29]:
p = tfd.MultivariateNormalDiag(loc=galaxies_data, scale_diag=std)

pred_model_1 = p.sample(seed=rng_1)

# Saving the samples of the predicted images and their difference from the original images
# save_samples(pred_model_1, 'MNDist_samples_FR.png')

In [30]:
def norm_values_diff(orig, inf1, inf2, num_images=3):
    min_values = []
    max_values = []

    for i in range(num_images):
        
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]
        inf2_img = inf2[i, ...]

        diff_1 = inf1_img - orig_img 
        diff_2 = inf2_img - orig_img

        min_values.append(np.minimum(diff_1.mean(axis=-1).min(), diff_2.mean(axis=-1).min()))
        max_values.append(np.minimum(diff_1.mean(axis=-1).max(), diff_2.mean(axis=-1).max()))

    return [np.min(min_values), np.max(max_values)]

In [31]:
min_value, max_value = norm_values_diff(images, pred_model_1, pred_model_2, num_images=3)

In [32]:
import matplotlib as mpl
def diff_inference(orig, inf1, inf2, name):
     
    # Plotting the original, predicted and their differences for 8 examples
    num_rows, num_cols = 3, 6

    plt.figure(figsize=(18, 9))
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 9))

    for i, (ax1, ax2, ax3, ax4, ax5, ax6) in enumerate(zip(axes[:,0], axes[:,1], axes[:,2], axes[:,3], axes[:,4], axes[:,5])):
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]
        inf2_img = inf2[i, ...]

        diff_1 = inf1_img - orig_img 
        diff_2 = inf2_img - orig_img

        # Plotting original image
        ax1.imshow(orig_img.mean(axis=-1))
        ax1.axis("off")
        
        # Plotting predicted image - Model 1
        ax2.imshow(inf1_img.mean(axis=-1))
        ax2.axis("off")
        
        # Plotting predicted image - Model 2
        ax3.imshow(inf2_img.mean(axis=-1))
        ax3.axis("off")
        
        # Plotting difference between original and predicted image - Model 1
        im = ax4.imshow(diff_1.mean(axis=-1), vmin=min_value, vmax=max_value)
        ax4.axis("off")
        
        # Plotting difference between original and predicted image - Model 2
        im = ax5.imshow(diff_2.mean(axis=-1), vmin=min_value, vmax=max_value)
        ax5.axis("off")

        ax6.axis("off")
        
        if i ==0:
            ax1.text(2, 2, "Original images", verticalalignment='top', fontsize=10, color="white", weight="bold")
            ax2.text(2, 2, "Reference model", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax3.text(2, 2, "State-of-the-art model", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax4.text(2, 2, "Difference original with \n reference model images", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax5.text(2, 2, "Difference original with \n State-of-the-art images", verticalalignment='top', fontsize=10, color="white", weight="bold") 

    # Add a title to the figure
    # fig.suptitle(
    #     "Comparison between original and predicted images for both models", fontsize=12, y=0.99
    # )
            
    # Adjust the layout of the subplots
    fig.tight_layout()

    # cmap = mpl.cm.cool
    # norm = mpl.colors.Normalize(vmin=min_value, vmax=max_value)

    # Create a colorbar for the "im" object
    cb_ax = fig.add_axes([0.835,.015,.015,.965])
    cbar = plt.colorbar(im, ax=[ax4, ax5], label="Pixel Value", orientation="vertical", aspect=120, cax=cb_ax)

    # Adjust the font size of colorbar labels and ticks
    cbar.ax.tick_params(labelsize=12)  # Adjust the fontsize as needed

    # # Save plot as image
    plt.savefig(name)
    plt.close(fig)    

In [33]:
# Saving the differences between the predicted images of both models and their difference from the original images
diff_inference(images, pred_model_1, pred_model_2, "difference_best_model.png")

<Figure size 1800x900 with 0 Axes>

In [34]:
# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
# Size of the input to initialize the encoder parameters
batch_autoenc = jnp.ones((1, 128, 128, 1))

latent_dim = 128
act_fn = nn.gelu

# Initializing the AutoEncoder
Autoencoder2 = AutoencoderKLModule(
    ch_mult=(1, 2, 4, 8),
    num_res_blocks=2,
    double_z=True,
    z_channels=1,
    resolution=latent_dim,
    in_channels=1,
    out_ch=1,
    ch=1,
    embed_dim=1,
    act_fn=act_fn,
)

params2 = Autoencoder2.init(rng, x=batch_autoenc, seed=rng_2)

x : (1, 128, 128, 1)
Conv_in : (1, 128, 128, 1)
Down : (1, 16, 16, 8)
Mid : (1, 16, 16, 8)
Conv_out : (1, 16, 16, 2)
Moments shape : (1, 16, 16, 2)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[1, 16, 16], event_shape=[1], dtype=float32)
Working with z of shape (1, 1, 16, 16) = 256 dimensions.


In [35]:
run = api.run("jonnyytorres/VAE-SD/5rxwcd2b")
artifact = api.artifact('jonnyytorres/VAE-SD/5rxwcd2b-checkpoint:best', type='model')
artifact_dir = artifact.download()

# Loading checkpoint for the best step
params2 = load_checkpoint("checkpoint.msgpack", params2)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [36]:
# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)

rng, rng_1 = jax.random.split(rng)
# X estimated distribution
q2 = Autoencoder2.apply(params2, x=x, seed=rng_1)
z2 = q2

p2 = jax.vmap(convolve_kpsf)(q2[..., 0], kpsf[..., 0])

p2 = tf.expand_dims(p2, axis=-1)

z2 = p2

p2 = tfd.MultivariateNormalDiag(loc=p2, scale_diag=std)

model_bottleneck_16 = p2.sample(seed=rng_1)

x : (3, 128, 128, 1)
Conv_in : (3, 128, 128, 1)
Down : (3, 16, 16, 8)
Mid : (3, 16, 16, 8)
Conv_out : (3, 16, 16, 2)
Moments shape : (3, 16, 16, 2)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[3, 16, 16], event_shape=[1], dtype=float32)
Working with z of shape (1, 1, 16, 16) = 256 dimensions.


In [37]:
# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
# Size of the input to initialize the encoder parameters
batch_autoenc = jnp.ones((1, 128, 128, 1))

latent_dim = 128
act_fn = nn.gelu

# Initializing the AutoEncoder
Autoencoder3 = AutoencoderKLModule(
    ch_mult=(1, 2, 4, 8, 16),
    num_res_blocks=2,
    double_z=True,
    z_channels=1,
    resolution=latent_dim,
    in_channels=1,
    out_ch=1,
    ch=1,
    embed_dim=1,
    act_fn=act_fn,
)

params3 = Autoencoder3.init(rng, x=batch_autoenc, seed=rng_2)

x : (1, 128, 128, 1)
Conv_in : (1, 128, 128, 1)
Down : (1, 8, 8, 16)
Mid : (1, 8, 8, 16)
Conv_out : (1, 8, 8, 2)
Moments shape : (1, 8, 8, 2)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[1, 8, 8], event_shape=[1], dtype=float32)
Working with z of shape (1, 1, 8, 8) = 64 dimensions.


In [38]:
run = api.run("jonnyytorres/VAE-SD/gakxodml")
artifact = api.artifact('jonnyytorres/VAE-SD/gakxodml-checkpoint:best', type='model')
artifact_dir = artifact.download()

# Loading checkpoint for the best step
params3 = load_checkpoint("checkpoint.msgpack", params3)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [39]:
# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)

rng, rng_1 = jax.random.split(rng)
# X estimated distribution
q3 = Autoencoder3.apply(params3, x=x, seed=rng_1)
z3 = q3

p3 = jax.vmap(convolve_kpsf)(q3[..., 0], kpsf[..., 0])

p3 = tf.expand_dims(p3, axis=-1)

z3 = p3

p3 = tfd.MultivariateNormalDiag(loc=p3, scale_diag=std)

model_bottleneck_8 = p3.sample(seed=rng_1)

x : (3, 128, 128, 1)
Conv_in : (3, 128, 128, 1)
Down : (3, 8, 8, 16)
Mid : (3, 8, 8, 16)
Conv_out : (3, 8, 8, 2)
Moments shape : (3, 8, 8, 2)
Posterior : tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[3, 8, 8], event_shape=[1], dtype=float32)
Working with z of shape (1, 1, 8, 8) = 64 dimensions.


In [40]:
def norm_values_two_diff(orig, inf1, inf2, inf3, num_images=3):
    min_values = []
    max_values = []

    for i in range(num_images):
        
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]
        inf2_img = inf2[i, ...]
        inf3_img = inf3[i, ...]

        diff_1 = inf1_img - orig_img 
        diff_2 = inf2_img - orig_img
        diff_3 = inf3_img - orig_img

        min_values.append(np.minimum(np.minimum(diff_1.mean(axis=-1).min(), diff_2.mean(axis=-1).min()), diff_3.mean(axis=-1).min())) 
        max_values.append(np.minimum(np.minimum(diff_1.mean(axis=-1).max(), diff_2.mean(axis=-1).max()), diff_3.mean(axis=-1).max()))

    return [np.min(min_values), np.max(max_values)]

In [41]:
min_value, max_value = norm_values_two_diff(images, pred_model_2, model_bottleneck_16, model_bottleneck_8, num_images=3)

In [42]:
def diff_two_inference(orig, inf1, inf2, inf3, name):

    # Plotting the original, predicted and their differences for 8 examples
    num_rows, num_cols = 3, 8

    plt.figure(figsize=(24, 9))
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(24, 9))

    for i, (ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8) in enumerate(zip(axes[:,0], axes[:,1], axes[:,2], axes[:,3], axes[:,4], axes[:,5], axes[:,6], axes[:,7])):
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]
        inf2_img = inf2[i, ...]
        inf3_img = inf3[i, ...]

        diff_1 = inf1_img - orig_img 
        diff_2 = inf2_img - orig_img
        diff_3 = inf3_img - orig_img

        # Plotting original image
        ax1.imshow(orig_img.mean(axis=-1))
        ax1.axis("off")
        
        # Plotting predicted image - Model 1
        ax2.imshow(inf1_img.mean(axis=-1))
        ax2.axis("off")
        
        # Plotting predicted image - Model 2
        ax3.imshow(inf2_img.mean(axis=-1))
        ax3.axis("off")

        # Plotting predicted image - Model 2
        ax4.imshow(inf3_img.mean(axis=-1))
        ax4.axis("off")
        
        # Plotting difference between original and predicted image - Model 1
        im = ax5.imshow(diff_1.mean(axis=-1), vmin = min_value, vmax = max_value)
        ax5.axis("off")
        
        # Plotting difference between original and predicted image - Model 2
        im = ax6.imshow(diff_2.mean(axis=-1), vmin = min_value, vmax = max_value)
        ax6.axis("off")

        # Plotting difference between original and predicted image - Model 2
        im = ax7.imshow(diff_3.mean(axis=-1), vmin = min_value, vmax = max_value)
        ax7.axis("off")

        ax8.axis("off")
        
        if i ==0:
            ax1.text(2, 2, "Original images", verticalalignment='top', fontsize=10, color="white", weight="bold")
            ax2.text(2, 2, "Bottleneck size = $32^2$", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax3.text(2, 2, "Bottleneck size = $16^2$", verticalalignment='top', fontsize=10, color="white", weight="bold")
            ax4.text(2, 2, "Bottleneck size = $8^2$", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax5.text(2, 2, "Difference original with \n$32^2$ Bottleneck size images", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax6.text(2, 2, "Difference original with \n$16^2$ Bottleneck size images", verticalalignment='top', fontsize=10, color="white", weight="bold") 
            ax7.text(2, 2, "Difference original with \n$8^2$ Bottleneck size images", verticalalignment='top', fontsize=10, color="white", weight="bold") 

    # Add a title to the figure
    # fig.suptitle(
    #     "Comparison between original and predicted images for both models", fontsize=12, y=0.99
    # )

    # Adjust the layout of the subplots
    fig.tight_layout()

    # Create a colorbar for the "im" object
    cb_ax = fig.add_axes([0.885,.024,.015,.95])
    cbar = plt.colorbar(im, ax=[ax5, ax6, ax7], label="Pixel Value", orientation="vertical", aspect=120, cax=cb_ax)

    # Adjust the font size of colorbar labels and ticks
    cbar.ax.tick_params(labelsize=12)  # Adjust the fontsize as needed

    # print(min_value, max_value)
    
    # # Save plot as image
    plt.savefig(name)
    plt.close(fig)  

In [43]:
# Saving the differences between the predicted images of the previous model and their difference from the original images
diff_two_inference(images, pred_model_2, model_bottleneck_16, model_bottleneck_8,  "difference_depth_model_presentation.png")

<Figure size 2400x900 with 0 Axes>

In [60]:
def norm_values_one_diff(orig, inf1, num_images=3):
    min_values = []
    max_values = []

    for i in range(num_images):
        
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]

        diff_1 = inf1_img - orig_img 

        min_values.append(diff_1.mean(axis=-1).min())
        max_values.append(diff_1.mean(axis=-1).max())

    return [np.min(min_values), np.max(max_values)]

In [61]:
min_value, max_value = norm_values_one_diff(images, pred_model_1, num_images=3)

In [64]:
def diff_one_inference(orig, inf1, name):
    
    # Plotting the original, predicted and their differences for 8 examples
    num_rows, num_cols = 3, 4

    plt.figure(figsize=(12, 9))
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 9))

    for i, (ax1, ax2, ax3, ax4) in enumerate(zip(axes[:, 0], axes[:, 1], axes[:, 2], axes[:, 3])):
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]

        diff_1 = inf1_img - orig_img 

        # Plotting original image
        ax1.imshow(orig_img.mean(axis=-1))
        ax1.axis("off")
        
        # Plotting predicted image - Model 1
        ax2.imshow(inf1_img.mean(axis=-1))
        ax2.axis("off")
        
       # Plotting difference between original and predicted image - Model 1
        im = ax3.imshow(diff_1.mean(axis=-1), vmin=min_value, vmax=max_value)
        ax3.axis("off")

        ax4.axis("off")
        
        if i ==0:
            ax1.set_title('Original images', fontsize=10)
            ax2.set_title('Previous model', fontsize=10)
            ax3.set_title('Difference previous model', fontsize=10)

    # # Add a title to the figure
    # fig.suptitle(
    #     "Comparison between original and predicted images for the previous model", fontsize=10, y=0.99
    # )

    # Adjust the layout of the subplots
    fig.tight_layout()

    # Create a colorbar for the "im" object
    cb_ax = fig.add_axes([0.755,.02,.015,.94])
    cbar = plt.colorbar(im, ax=ax3, label="Pixel Value", orientation="vertical", aspect=120, cax=cb_ax)

    # Adjust the font size of colorbar labels and ticks
    cbar.ax.tick_params(labelsize=12)  # Adjust the fontsize as needed

    # Save plot as image
    plt.savefig(name)
    plt.close(fig)    

In [65]:
# Saving the differences between the predicted images of the previous model and their difference from the original images
diff_one_inference(images, pred_model_1, "difference_prev_model.png")

<Figure size 1200x900 with 0 Axes>

In [48]:
def diff_one_inference_2(orig, inf1, name):
    
    # Plotting the original, predicted and their differences for 8 examples
    num_rows, num_cols = 5, 3

    plt.figure(figsize=(6, 10.5))
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(6, 10.5))

    for i, (ax1, ax2, ax3) in enumerate(zip(axes[:, 0], axes[:, 1], axes[:, 2])):
        orig_img = orig[i, ...]
        inf1_img = inf1[i, ...]

        # Plotting original image
        ax1.imshow(orig_img.mean(axis=-1))
        ax1.axis("off")
        
        # Plotting predicted image - Model 1
        ax2.imshow(inf1_img.mean(axis=-1))
        ax2.axis("off")
        
       # Plotting difference between original and predicted image - Model 1
        ax3.imshow(inf1_img.mean(axis=-1) - orig_img.mean(axis=-1))
        ax3.axis("off")
        
        if i ==0:
            ax1.set_title('Original images', fontsize=8)
            ax2.set_title('Prediction model', fontsize=8)
            ax3.set_title('Difference model', fontsize=8)

    # Add a title to the figure
    fig.suptitle(
        "Comparison between original and predicted images for the prediction model", fontsize=10, y=0.99
    )

    # Adjust the layout of the subplots
    fig.tight_layout()

    # Save plot as image
    plt.savefig(name)
    plt.close(fig)    

In [None]:
# Saving the differences between the predicted images of the new model and their difference from the original images
diff_one_inference_2(images, pred_model_2, "difference_new_model.png")

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(frameon=False)
fig.set_size_inches(3,3)

ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)

ax.imshow(pred_model_2[2, ...].mean(axis=-1), aspect='auto')

fig.savefig('Pred_galaxy_presentation.png')
plt.close(fig)

In [None]:
# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
# Size of the input to initialize the encoder parameters
batch_autoenc = jnp.ones((1, 128, 128, 1))

latent_dim = 128
act_fn = nn.gelu

# Initializing the AutoEncoder
Autoencoder = AutoencoderKLModule(
    ch_mult=(1, 2, 4),
    num_res_blocks=2,
    double_z=True,
    z_channels=1,
    resolution=latent_dim,
    in_channels=1,
    out_ch=1,
    ch=1,
    embed_dim=1,
    act_fn=act_fn,
)

params = Autoencoder.init(rng, x=batch_autoenc, seed=rng_2)

In [None]:
run = api.run("jonnyytorres/VAE-SD/iebszvhp")
artifact = api.artifact('jonnyytorres/VAE-SD/iebszvhp-checkpoint:best', type='model')
artifact_dir = artifact.download()

# Loading checkpoint for the best step
params = load_checkpoint("checkpoint.msgpack", params)

In [None]:
# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)

rng, rng_1 = jax.random.split(rng)
# X estimated distribution
q = Autoencoder.apply(params, x=x, seed=rng_1)
z = q

# Saving the samples of the predicted images and their difference from the original images
save_samples(z, 'Autoencoder_samples_best.png')

In [None]:
p = jax.vmap(convolve_kpsf)(q[..., 0], kpsf[..., 0])

p = tf.expand_dims(p, axis=-1)

z = p

# Saving the samples of the predicted images and their difference from the original images
save_samples(z, 'Convolve_samples_best.png')

In [None]:
p = tfd.MultivariateNormalDiag(loc=p, scale_diag=std)

pred_model_best = p.sample(seed=rng_1)

# Saving the samples of the predicted images and their difference from the original images
save_samples(pred_model_best, 'MNDist_samples_best.png')

In [None]:
# Saving the differences between the predicted images of both models and their difference from the original images
diff_inference(images, pred_model_1, pred_model_best, "difference_pred_models_best.png")

In [None]:
# Initializing the AutoEncoder
Autoencoder = AutoencoderKLModule(
    ch_mult=(1, 2, 4, 8, 16),
    num_res_blocks=2,
    double_z=True,
    z_channels=1,
    resolution=latent_dim,
    in_channels=1,
    out_ch=1,
    ch=1,
    embed_dim=1,
    act_fn=act_fn,
)

params = Autoencoder.init(rng, x=batch_autoenc, seed=rng_2)

In [None]:
def save_samples(z, batch, name):
    # Plotting 16 images of the estimated shape of galaxies
    num_rows, num_cols = 3, 4

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 8))

    for ax, z_img in zip(axes.flatten(), z):
        ax.imshow(tf.math.reduce_mean(z_img, axis=-1))
        # ax.imshow(z_img.mean(axis=-1))
        ax.axis("off")

    # Add a title to the figure
    fig.suptitle("Samples of predicted galaxies", fontsize=16)

    # Adjust the layout of the subplots
    fig.tight_layout()

    plt.savefig(name)
    plt.close(fig)

In [None]:
# Predicting over an example of data
# Generating a random key for JAX
rng, rng_2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)

rng, rng_1 = jax.random.split(rng)
# X estimated distribution
q = Autoencoder.apply(params, x=galaxies_data, seed=rng_1)

z = q

# Saving the samples of the predicted images and their difference from the original images
save_samples(z, galaxies_data, 'Autoencoder_FR_samples.png')

In [None]:
def loss_fn(params, rng_key, batch, reg_term):  # state, rng_key, batch):
    """Function to define the loss function"""

    x = batch["image"]
    kpsf_real = batch["kpsf_real"]
    kpsf_imag = batch["kpsf_imag"]
    kpsf = kpsf_real + 1j*kpsf_imag
    # std = 0.005 * np.ones(x.shape[0], dtype=np.float32).reshape((-1, 1, 1, 1))
    std = batch["noise_std"].reshape((-1, 1, 1, 1))

    # Autoencode an example
    q = Autoencoder.apply(params, x=x, seed=rng_key)

    p = jax.vmap(convolve_kpsf)(q[..., 0], kpsf[..., 0])

    p = jnp.expand_dims(p, axis=-1)
    
    p = tfd.MultivariateNormalDiag(loc=p, scale_diag=std)

    # KL divergence between the prior distribution and p
    kl = tfd.kl_divergence(p, tfd.MultivariateNormalDiag(jnp.zeros((1, 128, 128, 1))))

    # Compute log-likelihood
    log_likelihood = p.log_prob(x)

    # Calculating the ELBO value
    elbo = (
        log_likelihood - reg_term * kl
    )  # Here we apply a regularization factor on the KL term

    loss = -jnp.mean(elbo)
    return loss, -jnp.mean(log_likelihood)

# Veryfing that the 'value_and_grad' works fine

kl_reg_w = 1e-3
# (loss, log_likelihood), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, rng, batch_im, kl_reg_w)
(loss, log_likelihood), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, rng, batch_im, kl_reg_w)

In [None]:
loss

In [None]:
log_likelihood

In [None]:
grads;