In [26]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from censai import RIMUnet, PhysicalModel, RIM, RIMSharedUnet, PowerSpectrum
from censai.models import VAE, VAESecondStage, SharedUnetModel, UnetModel
from censai.utils import rim_residual_plot, update
from censai.data.lenses_tng_v2 import decode_train, decode_physical_model_info
from censai.definitions import log_10
import os, glob, re, json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
from argparse import Namespace
import math, json
import matplotlib.pylab as pylab
import h5py
from tqdm import tqdm

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 2000)
pd.set_option('display.max_colwidth', 200)

result_dir = os.path.join(os.getenv("CENSAI_PATH"), "results")
data_path = os.path.join(os.getenv("CENSAI_PATH"), "data")
models_path = os.path.join(os.getenv("CENSAI_PATH"), "models")

params = {'legend.fontsize': 'x-large',
#           'figure.figsize': (10, 10),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)


In [13]:
# model = "RIMSU512_k128_NIEs_019_TI1000_32_B5_210918011431"
# model = "RIMSU512_k128_NIEs_017_TI1000_16_B5_210918010833"
# model = "RIMSU512_k128_NIE2nsvdO_033_TS10_F16_L5_IK11_NLrelu_al0.04_GAplus_42_B10_lr0.0005_dr0.8_ds5000_TWquadratic_210923032150"
# dataset = "lenses512_k128_NIE_10k_verydiffuse"
model = "RIMSU128hst_control_012_TS15_F16_211020033949"
_dataset = "lenses128hst_TNG_VAE_200k_control_validated_val"

checkpoints_dir = os.path.join(os.getenv("CENSAI_PATH"), "models", model)

with open(os.path.join(checkpoints_dir, "script_params.json"), "r") as f:
    args = json.load(f)
args = Namespace(**args)

files = glob.glob(os.path.join(os.getenv("CENSAI_PATH"), "data", _dataset, "*.tfrecords"))
# Read concurrently from multiple records
files = tf.data.Dataset.from_tensor_slices(files)
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"),
                           block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
# Read off global parameters from first example in dataset
for physical_params in dataset.map(decode_physical_model_info):
    break
dataset = dataset.map(decode_train)


In [7]:

phys = PhysicalModel(
    pixels=physical_params["pixels"].numpy(),
    kappa_pixels=physical_params["kappa pixels"].numpy(),
    src_pixels=physical_params["src pixels"].numpy(),
    image_fov=physical_params["image fov"].numpy(),
    kappa_fov=physical_params["kappa fov"].numpy(),
    src_fov=physical_params["source fov"].numpy(),
    method=args.forward_method,
    noise_rms=physical_params["noise rms"].numpy(),
    psf_sigma=physical_params["psf sigma"].numpy()
)

unet = SharedUnetModel(
    filters=args.filters,
    filter_scaling=args.filter_scaling,
    kernel_size=args.kernel_size,
    layers=args.layers,
    block_conv_layers=args.block_conv_layers,
    strides=args.strides,
    bottleneck_kernel_size=args.bottleneck_kernel_size,
    bottleneck_filters=args.bottleneck_filters,
    resampling_kernel_size=args.resampling_kernel_size,
    input_kernel_size=args.input_kernel_size,
    gru_kernel_size=args.gru_kernel_size,
    upsampling_interpolation=args.upsampling_interpolation,
    kernel_l2_amp=args.kernel_l2_amp,
    bias_l2_amp=args.bias_l2_amp,
    kernel_l1_amp=args.kernel_l1_amp,
    bias_l1_amp=args.bias_l1_amp,
    activation=args.activation,
    alpha=args.alpha,
    initializer=args.initializer,
    batch_norm=args.batch_norm,
    dropout_rate=args.dropout_rate
)
rim = RIMSharedUnet(
    physical_model=phys,
    unet=unet,
    steps=args.steps,
    adam=args.adam,
    kappalog=args.kappalog,
    source_link=args.source_link,
    kappa_normalize=args.kappa_normalize,
    kappa_init=args.kappa_init,
    source_init=args.source_init
)
ckpt = tf.train.Checkpoint(net=rim.unet)
checkpoint_manager = tf.train.CheckpointManager(ckpt, checkpoints_dir, max_to_keep=args.max_to_keep)
checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint).expect_partial()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x2b3859ac4040>

In [32]:
k_bins = 40
total_items = 5000
examples_per_shard = 1000
batch_size = 10


output_dir = os.path.join(os.getenv("CENSAI_PATH"), "results", model + "_" + _dataset)
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
ps_lens = PowerSpectrum(bins=k_bins, pixels=physical_params["pixels"].numpy())
ps_x = PowerSpectrum(bins=k_bins,  pixels=physical_params["kappa pixels"].numpy())

shards = total_items // examples_per_shard + 1 * (total_items % examples_per_shard > 0)
k = 0
for shard in tqdm(range(shards)):
    hf = h5py.File(os.path.join(output_dir, f"predictions_{shard:02d}.h5"), 'w')
    data = dataset.skip(shard * examples_per_shard).take(examples_per_shard).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    for batch, (lens, source, kappa) in enumerate(data):
        source_pred, kappa_pred, chi_squared = rim.predict(lens)
        lens_pred = phys.forward(source_pred[-1], kappa_pred[-1])
        lam = phys.lagrange_multiplier(y_true=lens, y_pred=lens_pred)

        # remove channel dimension because power spectrum expect only [batch, pixels, pixels] shaped tensor.
        _ps_lens = ps_lens.cross_correlation_coefficient(lens[..., 0], lens_pred[..., 0])
        _ps_kappa = ps_x.cross_correlation_coefficient(log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0])
        _ps_source = ps_x.cross_correlation_coefficient(source[..., 0], source_pred[-1][..., 0])

        batch_size = lens.shape[0]
        for b in range(batch_size):
            g = hf.create_group(f'data_{k:d}')
            g.create_dataset("lens",        data=lens[b])
            g.create_dataset("source",      data=source[b])
            g.create_dataset("kappa",       data=kappa[b])
            g.create_dataset("lens_pred",   data=lens_pred[b])
            g.create_dataset("source_pred", data=source_pred[:, b])
            g.create_dataset("kappa_pred",  data=kappa_pred[:, b])
            g.create_dataset("chi_squared", data=chi_squared[:, b])
            g.create_dataset("lambda",      data=lam[b])
            g.create_dataset("ps_lens",     data=_ps_lens[b])
            g.create_dataset("ps_kappa",    data=_ps_kappa[b])
            g.create_dataset("ps_source",   data=_ps_source[b])
            k += 1
    hf.close()

 20%|██        | 1/5 [02:51<11:27, 171.87s/it]

Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x2b389fd54b20>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3458, in run_code
    self.showtraceback(running_compiled_code=True)  File "<ipython-input-28-433ba2cef1a5>", line 19, in <module>
    source_pred, kappa_pred, chi_squared = rim.predict(lens)  File "/lustre04/scratch/aadam/Censai/censai/rim_shared_unet.py", line 196, in predict
    source, kappa, states = self.time_step(source, kappa, source_grad, kappa_grad, states)  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/tensorflow/python/util/tf_should_use.py", line 247, in wrapped


Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x2b389fd54b20>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3458, in run_code
    self.showtraceback(running_compiled_code=True)  File "<ipython-input-28-433ba2cef1a5>", line 19, in <module>
    source_pred, kappa_pred, chi_squared = rim.predict(lens)  File "/lustre04/scratch/aadam/Censai/censai/rim_shared_unet.py", line 196, in predict
    source, kappa, states = self.time_step(source, kappa, source_grad, kappa_grad, states)  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/tensorflow/python/util/tf_should_use.py", line 247, in wrapped


Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x2b389fd54eb0>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3458, in run_code
    self.showtraceback(running_compiled_code=True)  File "<ipython-input-28-433ba2cef1a5>", line 19, in <module>
    source_pred, kappa_pred, chi_squared = rim.predict(lens)  File "/lustre04/scratch/aadam/Censai/censai/rim_shared_unet.py", line 196, in predict
    source, kappa, states = self.time_step(source, kappa, source_grad, kappa_grad, states)  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/tensorflow/python/util/tf_should_use.py", line 247, in wrapped


Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x2b389fd54eb0>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3458, in run_code
    self.showtraceback(running_compiled_code=True)  File "<ipython-input-28-433ba2cef1a5>", line 19, in <module>
    source_pred, kappa_pred, chi_squared = rim.predict(lens)  File "/lustre04/scratch/aadam/Censai/censai/rim_shared_unet.py", line 196, in predict
    source, kappa, states = self.time_step(source, kappa, source_grad, kappa_grad, states)  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/tensorflow/python/util/tf_should_use.py", line 247, in wrapped


Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x2b3859c38c40>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3458, in run_code
    self.showtraceback(running_compiled_code=True)  File "<ipython-input-28-433ba2cef1a5>", line 19, in <module>
    source_pred, kappa_pred, chi_squared = rim.predict(lens)  File "/lustre04/scratch/aadam/Censai/censai/rim_shared_unet.py", line 196, in predict
    source, kappa, states = self.time_step(source, kappa, source_grad, kappa_grad, states)  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/tensorflow/python/util/tf_should_use.py", line 247, in wrapped


Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x2b3859c38c40>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3458, in run_code
    self.showtraceback(running_compiled_code=True)  File "<ipython-input-28-433ba2cef1a5>", line 19, in <module>
    source_pred, kappa_pred, chi_squared = rim.predict(lens)  File "/lustre04/scratch/aadam/Censai/censai/rim_shared_unet.py", line 196, in predict
    source, kappa, states = self.time_step(source, kappa, source_grad, kappa_grad, states)  File "/home/aadam/environments/censai3.8/lib/python3.8/site-packages/tensorflow/python/util/tf_should_use.py", line 247, in wrapped
100%|██████████| 5/5 [14:20<00:00, 172.09s/it]
