In [1]:
from censai import RIMSharedUnetv3, PhysicalModelv2, PowerSpectrum, RIMSourceUnetv2, AnalyticalPhysicalModel
from censai.models import SharedUnetModelv4, UnetModelv2
from censai.data.lenses_tng_v3 import decode_results, decode_physical_model_info, decode_all
import tensorflow as tf
import numpy as np
import os, glob, json
import h5py
from censai.definitions import log_10
from tqdm import tqdm
from argparse import Namespace


result_dir = os.path.join(os.getenv("CENSAI_PATH"), "results")
data_path = os.path.join(os.getenv("CENSAI_PATH"), "data")
models_path = os.path.join(os.getenv("CENSAI_PATH"), "models")


In [2]:
model = "RIMSU128hstv4_control_008_RMSP1_TS8_NLtanh_TWuniform_211117121747"
source_model = "RIMSource128hstv3_control_001_A0_L2_FLM0.0_211108220845"
_train_dataset = "lenses128hst_TNG_rau_200k_control_denoised_validated_train"
_val_dataset = "lenses128hst_TNG_rau_200k_control_denoised_validated_val"
_test_dataset = "lenses128hst_TNG_rau_200k_control_denoised_testset_validated"
bins=40

checkpoints_dir = os.path.join(os.getenv("CENSAI_PATH"), "models", model)

with open(os.path.join(checkpoints_dir, "script_params.json"), "r") as f:
    args = json.load(f)
args = Namespace(**args)

files = glob.glob(os.path.join(os.getenv('CENSAI_PATH'), "data", _train_dataset, "*.tfrecords"))
files = tf.data.Dataset.from_tensor_slices(files)
train_dataset = files.interleave(
    lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
# Read off global parameters from first example in dataset
for physical_params in train_dataset.map(decode_physical_model_info):
    break
train_dataset = train_dataset.map(decode_results).shuffle(buffer_size=args.buffer_size)

files = glob.glob(os.path.join(os.getenv('CENSAI_PATH'), "data", _val_dataset, "*.tfrecords"))
files = tf.data.Dataset.from_tensor_slices(files)
val_dataset = files.interleave(
    lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.map(decode_results).shuffle(buffer_size=args.buffer_size)

files = glob.glob(os.path.join(os.getenv('CENSAI_PATH'), "data", _test_dataset, "*.tfrecords"))
files = tf.data.Dataset.from_tensor_slices(files)
test_dataset = files.interleave(
    lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.map(decode_results).shuffle(buffer_size=args.buffer_size)

ps_lens = PowerSpectrum(bins=bins, pixels=physical_params["pixels"].numpy())
ps_source = PowerSpectrum(bins=bins,  pixels=physical_params["src pixels"].numpy())
ps_kappa = PowerSpectrum(bins=bins,  pixels=physical_params["kappa pixels"].numpy())


In [5]:
phys = PhysicalModelv2(
    pixels=physical_params["pixels"].numpy(),
    kappa_pixels=physical_params["kappa pixels"].numpy(),
    src_pixels=physical_params["src pixels"].numpy(),
    image_fov=physical_params["image fov"].numpy(),
    kappa_fov=physical_params["kappa fov"].numpy(),
    src_fov=physical_params["source fov"].numpy(),
    method="fft",
)

phys_sie = AnalyticalPhysicalModel(
    pixels=physical_params["pixels"].numpy(),
    image_fov=physical_params["image fov"].numpy(),
    src_fov=physical_params["source fov"].numpy()
)

# Load RIM for source only
rim_source_dir = os.path.join(os.getenv('CENSAI_PATH'), "models", source_model)
with open(os.path.join(rim_source_dir, "unet_hparams.json")) as f:
    unet_source_params = json.load(f)
unet_source = UnetModelv2(**unet_source_params)
with open(os.path.join(rim_source_dir, "rim_hparams.json")) as f:
    rim_source_params = json.load(f)
rim_source = RIMSourceUnetv2(phys, unet_source, **rim_source_params)
ckpt_s = tf.train.Checkpoint(net=unet_source)
checkpoint_manager_s = tf.train.CheckpointManager(ckpt_s, rim_source_dir, 1)
checkpoint_manager_s.checkpoint.restore(checkpoint_manager_s.latest_checkpoint).expect_partial()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x2b5068ae62b0>

In [6]:
with open(os.path.join(models_path, model, "unet_hparams.json")) as f:
    unet_params = json.load(f)

unet = SharedUnetModelv4(**unet_params)
ckpt = tf.train.Checkpoint(net=unet)
checkpoint_manager = tf.train.CheckpointManager(ckpt, os.path.join(models_path, model), 1)
checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint).expect_partial()
with open(os.path.join(models_path, model, "rim_hparams.json")) as f:
    rim_params = json.load(f)

rim = RIMSharedUnetv3(phys, unet, **rim_params)

train_size = 1000
val_size = 1000
test_size = 1000
sie_size = 1000
dataset_names = [_train_dataset, _val_dataset, _test_dataset]
dataset_shapes = [train_size, val_size, test_size]

In [7]:
# sie params
max_shift = 0.1
max_theta_e = 2.5
min_theta_e = 0.5
max_ellipticity = 0.6

In [10]:
output_file = os.path.join(os.getenv("CENSAI_PATH"), "results", model + "_" + source_model + "_temp.h5")
with h5py.File(output_file, 'w') as hf:
    for i, dataset in tqdm(enumerate([train_dataset, val_dataset, test_dataset])):
#     for i, dataset in tqdm(enumerate([])):
        g = hf.create_group(f'{dataset_names[i]}')
        data_len = dataset_shapes[i]
        g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32)
        g.create_dataset(name="psf",  shape=[data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1], dtype=np.float32)
        g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        g.create_dataset(name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32)
        g.create_dataset(name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32)
        g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32)
        g.create_dataset(name="lens_pred2", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32)
        g.create_dataset(name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32)
        g.create_dataset(name="source_pred2", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32)
        g.create_dataset(name="kappa_pred", shape=[data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32)
        g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32)
        g.create_dataset(name="chi_squared2", shape=[data_len, rim.steps], dtype=np.float32)
        g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, bins], dtype=np.float32)
        g.create_dataset(name="source_coherence_spectrum",  shape=[data_len, bins], dtype=np.float32)
        g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, bins], dtype=np.float32)
        g.create_dataset(name="source_coherence_spectrum2",  shape=[data_len, bins], dtype=np.float32)
        g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, bins], dtype=np.float32)
        g.create_dataset(name="lens_frequencies", shape=[bins], dtype=np.float32)
        g.create_dataset(name="source_frequencies", shape=[bins], dtype=np.float32)
        g.create_dataset(name="kappa_frequencies", shape=[bins], dtype=np.float32)
        g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        g.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32)

        for batch, (lens, source, kappa, noise_rms, psf, fwhm) in tqdm(enumerate(dataset.take(data_len).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE))):
            batch_size = lens.shape[0]
            # Compute predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(lens, noise_rms, psf)
            lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
            # Re-optimize source with a trained source model
            source_pred2, chi_squared2 = rim_source.predict(lens, kappa_pred[-1], noise_rms, psf)
            lens_pred2 = phys.forward(source_pred2[-1], kappa_pred[-1], psf)
            # Compute Power spectrum of converged predictions
            _ps_lens = ps_lens.cross_correlation_coefficient(lens[..., 0], lens_pred[..., 0])
            _ps_lens2 = ps_lens.cross_correlation_coefficient(lens[..., 0], lens_pred2[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0])
            _ps_source = ps_source.cross_correlation_coefficient(source[..., 0], source_pred[-1][..., 0])
            _ps_source2 = ps_source.cross_correlation_coefficient(source[..., 0], source_pred2[-1][..., 0])

            # save results
            i_begin = batch * args.batch_size
            i_end = i_begin + batch_size
            g["lens"][i_begin:i_end] = lens.numpy().astype(np.float32)
            g["psf"][i_begin:i_end] = psf.numpy().astype(np.float32)
            g["psf_fwhm"][i_begin:i_end] = fwhm.numpy().astype(np.float32)
            g["noise_rms"][i_begin:i_end] = noise_rms.numpy().astype(np.float32)
            g["source"][i_begin:i_end] = source.numpy().astype(np.float32)
            g["kappa"][i_begin:i_end] = kappa.numpy().astype(np.float32)
            g["lens_pred"][i_begin:i_end] = lens_pred.numpy().astype(np.float32)
            g["lens_pred2"][i_begin:i_end] = lens_pred2.numpy().astype(np.float32)
            g["source_pred"][i_begin:i_end] = tf.transpose(source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            g["source_pred2"][i_begin:i_end] = tf.transpose(source_pred2, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            g["kappa_pred"][i_begin:i_end] = tf.transpose(kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            g["chi_squared"][i_begin:i_end] = tf.transpose(chi_squared).numpy().astype(np.float32)
            g["chi_squared2"][i_begin:i_end] = tf.transpose(chi_squared2).numpy().astype(np.float32)
            g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens
            g["lens_coherence_spectrum2"][i_begin:i_end] = _ps_lens2
            g["source_coherence_spectrum"][i_begin:i_end] = _ps_source
            g["source_coherence_spectrum2"][i_begin:i_end] = _ps_source2
            g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens
            g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens
            g["kappa_coherence_spectrum"][i_begin:i_end] = _ps_kappa

            if batch == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels//2], bins=ps_lens.bins)
                f = (f[:-1] + f[1:]) / 2
                g["lens_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(phys.src_pixels)[:phys.src_pixels//2], bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                g["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(phys.kappa_pixels)[:phys.kappa_pixels//2], bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                g["kappa_frequencies"][:] = f
                g["kappa_fov"][0] = phys.kappa_fov
                g["source_fov"][0] = phys.src_fov

    # Create SIE test
#     g = hf.create_group(f'SIE_test')
    g = hf["SIE_test"]
    data_len = sie_size
    sie_dataset = test_dataset.take(data_len)
    g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32)
    g.create_dataset(name="psf",  shape=[data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1], dtype=np.float32)
    g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
    g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
    g.create_dataset(name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32)
    g.create_dataset(name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32)
    g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32)
    g.create_dataset(name="lens_pred2", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32)
    g.create_dataset(name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32)
    g.create_dataset(name="source_pred2", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32)
    g.create_dataset(name="kappa_pred", shape=[data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32)
    g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32)
    g.create_dataset(name="chi_squared2", shape=[data_len, rim.steps], dtype=np.float32)
    g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, bins], dtype=np.float32)
    g.create_dataset(name="source_coherence_spectrum",  shape=[data_len, bins], dtype=np.float32)
    g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, bins], dtype=np.float32)
    g.create_dataset(name="source_coherence_spectrum2",  shape=[data_len, bins], dtype=np.float32)
    g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len,bins], dtype=np.float32)
    g.create_dataset(name="lens_frequencies", shape=[bins], dtype=np.float32)
    g.create_dataset(name="source_frequencies", shape=[bins], dtype=np.float32)
    g.create_dataset(name="kappa_frequencies", shape=[bins], dtype=np.float32)
    g.create_dataset(name="einstein_radius", shape=[data_len], dtype=np.float32)
    g.create_dataset(name="position", shape=[data_len, 2], dtype=np.float32)
    g.create_dataset(name="orientation", shape=[data_len], dtype=np.float32)
    g.create_dataset(name="ellipticity", shape=[data_len], dtype=np.float32)
    g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
    g.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
    g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32)

    for batch, (_, source, _, noise_rms, psf, fwhm) in tqdm(enumerate(sie_dataset.take(data_len).batch(1).prefetch(tf.data.experimental.AUTOTUNE))):
        batch_size = source.shape[0]
        # Create some SIE kappa maps
        _r = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=max_shift)
        _theta = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi)
        x0 = _r * tf.math.cos(_theta)
        y0 = _r * tf.math.sin(_theta)
        ellipticity = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=max_ellipticity)
        phi = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi)
        einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=min_theta_e, maxval=max_theta_e)
        kappa = phys_sie.kappa_field(x0, y0, ellipticity, phi, einstein_radius)
        lens = phys.noisy_forward(source, kappa, noise_rms=noise_rms, psf=psf)

        # Compute predictions for kappa and source
        source_pred, kappa_pred, chi_squared = rim.predict(lens, noise_rms, psf)
        lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
        # Re-optimize source with a trained source model
        source_pred2, chi_squared2 = rim_source.predict(lens, kappa_pred[-1], noise_rms, psf)
        lens_pred2 = phys.forward(source_pred2[-1], kappa_pred[-1], psf)
        # Compute Power spectrum of converged predictions
        _ps_lens = ps_lens.cross_correlation_coefficient(lens[..., 0], lens_pred[..., 0])
        _ps_lens2 = ps_lens.cross_correlation_coefficient(lens[..., 0], lens_pred2[..., 0])
        _ps_kappa = ps_kappa.cross_correlation_coefficient(log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0])
        _ps_source = ps_source.cross_correlation_coefficient(source[..., 0], source_pred[-1][..., 0])
        _ps_source2 = ps_source.cross_correlation_coefficient(source[..., 0], source_pred2[-1][..., 0])

        # save results
        i_begin = batch * args.batch_size
        i_end = i_begin + batch_size
        g["lens"][i_begin:i_end] = lens.numpy().astype(np.float32)
        g["psf"][i_begin:i_end] = psf.numpy().astype(np.float32)
        g["psf_fwhm"][i_begin:i_end] = fwhm.numpy().astype(np.float32)
        g["noise_rms"][i_begin:i_end] = noise_rms.numpy().astype(np.float32)
        g["source"][i_begin:i_end] = source.numpy().astype(np.float32)
        g["kappa"][i_begin:i_end] = kappa.numpy().astype(np.float32)
        g["lens_pred"][i_begin:i_end] = lens_pred.numpy().astype(np.float32)
        g["lens_pred2"][i_begin:i_end] = lens_pred2.numpy().astype(np.float32)
        g["source_pred"][i_begin:i_end] = tf.transpose(source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
        g["source_pred2"][i_begin:i_end] = tf.transpose(source_pred2, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
        g["kappa_pred"][i_begin:i_end] = tf.transpose(kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
        g["chi_squared"][i_begin:i_end] = tf.transpose(chi_squared).numpy().astype(np.float32)
        g["chi_squared2"][i_begin:i_end] = tf.transpose(chi_squared2).numpy().astype(np.float32)
        g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens.numpy().astype(np.float32)
        g["lens_coherence_spectrum2"][i_begin:i_end] = _ps_lens2.numpy().astype(np.float32)
        g["source_coherence_spectrum"][i_begin:i_end] = _ps_source.numpy().astype(np.float32)
        g["source_coherence_spectrum2"][i_begin:i_end] = _ps_source2.numpy().astype(np.float32)
        g["kappa_coherence_spectrum"][i_begin:i_end] = _ps_kappa.numpy().astype(np.float32)
        g["einstein_radius"][i_begin:i_end] = einstein_radius[:, 0, 0, 0].numpy().astype(np.float32)
        g["position"][i_begin:i_end] = tf.stack([x0[:, 0, 0, 0], y0[:, 0, 0, 0]], axis=1).numpy().astype(np.float32)
        g["ellipticity"][i_begin:i_end] = ellipticity[:, 0, 0, 0].numpy().astype(np.float32)
        g["orientation"][i_begin:i_end] = phi[:, 0, 0, 0].numpy().astype(np.float32)

        if batch == 0:
            _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_lens.bins)
            f = (f[:-1] + f[1:]) / 2
            g["lens_frequencies"][:] = f
            _, f = np.histogram(np.fft.fftfreq(phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins)
            f = (f[:-1] + f[1:]) / 2
            g["source_frequencies"][:] = f
            _, f = np.histogram(np.fft.fftfreq(phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins)
            f = (f[:-1] + f[1:]) / 2
            g["kappa_frequencies"][:] = f
            g["kappa_fov"][0] = phys.kappa_fov
            g["source_fov"][0] = phys.src_fov

0it [00:00, ?it/s]
1000it [12:21,  1.35it/s]


In [11]:
output_file

'/home/aadam/scratch/Censai/results/RIMSU128hstv4_control_008_RMSP1_TS8_NLtanh_TWuniform_211117121747_RIMSource128hstv3_control_001_A0_L2_FLM0.0_211108220845_temp.h5'