In [21]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from censai import PhysicalModelv2, RIMSourceUnetv2, PowerSpectrum
from censai.models import UnetModelv2, RayTracer
from censai.utils import nullwriter, rim_residual_plot as residual_plot, plot_to_image
from censai.data.lenses_tng_v3 import decode_train, decode_physical_model_info
import os, glob, re, json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
from argparse import Namespace
import math, json
import matplotlib.pylab as pylab
import h5py
from tqdm import tqdm

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 2000)
pd.set_option('display.max_colwidth', 200)

result_dir = os.path.join(os.getenv("CENSAI_PATH"), "results")
data_path = os.path.join(os.getenv("CENSAI_PATH"), "data")
models_path = os.path.join(os.getenv("CENSAI_PATH"), "models")

params = {'legend.fontsize': 'x-large',
#           'figure.figsize': (10, 10),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)


In [17]:
model = "RIMSource128hstv3_control_002_A0_L3_FLM1.0_211108220845"

checkpoints_dir = os.path.join(os.getenv("CENSAI_PATH"), "models", model)
with open(os.path.join(checkpoints_dir, "script_params.json"), "r") as f:
    args = json.load(f)
args = Namespace(**args)
_dataset = os.path.split(args.val_datasets[0])[-1]


files = glob.glob(os.path.join(os.getenv("CENSAI_PATH"), "data", _dataset, "*.tfrecords"))
files = tf.data.Dataset.from_tensor_slices(files)
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"),
                           block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
for physical_params in dataset.map(decode_physical_model_info):
    break
dataset = dataset.map(decode_train)

total_items = int(np.loadtxt(os.path.join(os.getenv("CENSAI_PATH"), "data", _dataset, "dataset_size.txt")))
print(total_items)

10138


In [18]:
phys = PhysicalModelv2(
    pixels=physical_params["pixels"].numpy(),
    kappa_pixels=physical_params["kappa pixels"].numpy(),
    src_pixels=physical_params["src pixels"].numpy(),
    image_fov=physical_params["image fov"].numpy(),
    kappa_fov=physical_params["kappa fov"].numpy(),
    src_fov=physical_params["source fov"].numpy(),
    method="fft"
)

unet = UnetModelv2(
    filters=args.filters,
    filter_scaling=args.filter_scaling,
    kernel_size=args.kernel_size,
    layers=args.layers,
    block_conv_layers=args.block_conv_layers,
    strides=args.strides,
    bottleneck_kernel_size=args.bottleneck_kernel_size,
    resampling_kernel_size=args.resampling_kernel_size,
    input_kernel_size=args.input_kernel_size,
    gru_kernel_size=args.gru_kernel_size,
    upsampling_interpolation=args.upsampling_interpolation,
    kernel_l2_amp=args.kernel_l2_amp,
    bias_l2_amp=args.bias_l2_amp,
    kernel_l1_amp=args.kernel_l1_amp,
    bias_l1_amp=args.bias_l1_amp,
    activation=args.activation,
    initializer=args.initializer,
    batch_norm=args.batch_norm,
    dropout_rate=args.dropout_rate
)
rim = RIMSourceUnetv2(
    physical_model=phys,
    unet=unet,
    steps=args.steps,
    adam=args.adam,
    source_link=args.source_link,
    source_init=args.source_init,
    flux_lagrange_multiplier=args.flux_lagrange_multiplier
)
ckpt = tf.train.Checkpoint(net=rim.unet)
checkpoint_manager = tf.train.CheckpointManager(ckpt, checkpoints_dir, max_to_keep=args.max_to_keep)
checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint).expect_partial()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x2b131d5a8310>

In [27]:
k_bins = 40
# total_items = 10000
examples_per_shard = 1000
batch_size = 10


output_dir = os.path.join(os.getenv("CENSAI_PATH"), "results", model + "_" + _dataset)
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
ps_lens = PowerSpectrum(bins=k_bins, pixels=physical_params["pixels"].numpy())
ps_x = PowerSpectrum(bins=k_bins,  pixels=physical_params["src pixels"].numpy())

shards = total_items // examples_per_shard + 1 * (total_items % examples_per_shard > 0)
k = 0
for shard in tqdm(range(shards)):
    hf = h5py.File(os.path.join(output_dir, f"predictions_{shard:02d}.h5"), 'w')
    data = dataset.skip(shard * examples_per_shard).take(examples_per_shard).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    for batch, (lens, source, kappa, noise_rms, psf) in enumerate(data):
        source_pred, chi_squared = rim.predict(lens, kappa, noise_rms, psf)
        lens_pred = phys.forward(source_pred[-1], kappa, psf)

        # remove channel dimension because power spectrum expect only [batch, pixels, pixels] shaped tensor.
        _ps_lens = ps_lens.cross_correlation_coefficient(lens[..., 0], lens_pred[..., 0])
        _ps_source = ps_x.cross_correlation_coefficient(source[..., 0], source_pred[-1][..., 0])

        batch_size = lens.shape[0]
        for b in range(batch_size):
            g = hf.create_group(f'data_{k:d}')
            g.create_dataset("lens",        data=lens[b])
            g.create_dataset("source",      data=source[b])
            g.create_dataset("kappa",       data=kappa[b])
            g.create_dataset("lens_pred",   data=lens_pred[b])
            g.create_dataset("source_pred", data=source_pred[:, b])
            g.create_dataset("chi_squared", data=chi_squared[:, b])
            g.create_dataset("ps_lens",     data=_ps_lens[b])
            g.create_dataset("ps_source",   data=_ps_source[b])
            k += 1
    hf.close()

100%|██████████| 11/11 [10:47<00:00, 58.83s/it]


In [28]:
output_dir

'/home/aadam/scratch/Censai/results/RIMSource128hstv3_control_002_A0_L3_FLM1.0_211108220845_lenses128hst_TNG_rau_200k_control_denoised_validated_val'