## Visualize

### Plot data

In [None]:
import h5py
import numpy as np
import os

import sys
sys.path.append('..')
from util.plotting import plot_image_rows



def get_path(model_sim, ws, encoded, norm):
    if model_sim == "p21c":
        dp = f"/users/jsolt/data/jsolt/21cmFAST_sims/p21c14/p21c14"
    elif model_sim == "zreion":
        dp = f"/users/jsolt/data/jsolt/zreion_sims/zreion24/zreion24"
    elif model_sim == "ctrpx":
        dp = f"/users/jsolt/data/jsolt/centralpix_sims/centralpix05/centralpix05"
    if encoded:
        dp += f"_norm_encoded_ws{ws}.hdf5" if norm else f"_encoded_ws{ws}.hdf5"
    else:
        dp += f"_subdiv_sliced_ws{ws}.hdf5"
    return dp
 

sim = "ctrpx"
ws = 0.0

unenc_fname = get_path(sim, ws, False, False)
enc_fname = get_path(sim, ws, True, False)
norm_enc_fname = get_path(sim, ws, True, True)

fig_dir = f"../figures/data_figures/{sim}_ws{ws}"
if not os.path.exists(fig_dir): os.mkdir(fig_dir)

start, end = 0,8
ntrvl=4
n = np.arange(start*ntrvl,end*ntrvl,ntrvl)
z = np.linspace(0, 29, 6, dtype=int)



In [18]:

with h5py.File(enc_fname, 'r') as f:
    print(f['lightcones/brightness_temp'].shape)
    print(f['lightcone_params/physparams'].shape)
    sample = f['lightcones/brightness_temp'][n]
    labels = f['lightcone_params/physparams'][n]

sample = sample[:,z]
print(sample.shape)
print(labels.shape)


(3445, 30, 4, 32, 32)
(3445, 3)
(8, 6, 4, 32, 32)
(8, 3)


In [19]:

%matplotlib inline

for c in range(4):
    rowdict = {f"lc {ni} dur {labels[ni,1]:.1f}":sample[ni,:,c,:,:] for ni in range(len(n))}
    plot_image_rows(rowdict, fname=f"{fig_dir}/latent_channel_{c}_{sim}_ws{ws}.jpg", title=f"{sim} ws={ws} encoded channel {c}")


In [20]:
with h5py.File(norm_enc_fname, 'r') as f:
    print(f['lightcones/brightness_temp'].shape)
    print(f['lightcone_params/physparams'].shape)
    sample = f['lightcones/brightness_temp'][n]
    labels = f['lightcone_params/physparams'][n]

sample = sample[:,z]
print(sample.shape)
print(labels.shape)

(3445, 30, 4, 32, 32)
(3445, 3)
(8, 6, 4, 32, 32)
(8, 3)


In [21]:
%matplotlib inline

for c in range(4):
    rowdict = {f"lc {ni} dur {labels[ni,1]:.1f}":sample[ni,:,c,:,:] for ni in range(len(n))}
    plot_image_rows(rowdict, fname=f"{fig_dir}/prenorm_latent_channel_{c}_{sim}_ws{ws}.jpg", title=f"{sim} ws={ws} prenorm encoded channel {c}")

In [22]:

with h5py.File(unenc_fname, 'r') as f:
    print(f['lightcones/brightness_temp'].shape)
    input = f['lightcones/brightness_temp'][n]
    labels = f['lightcone_params/physparams'][n]
input = input[:,z]
print(input.shape)

(3445, 30, 256, 256)
(8, 6, 256, 256)


In [23]:
%matplotlib inline

rowdict = {f"lc {ni} dur {labels[ni,1]:.1f})":input[ni,:,:,:] for ni in range(len(n))}

plot_image_rows(rowdict, fname=f"{fig_dir}/unencoded_bT_{sim}_ws{ws}.jpg", title=f"{sim} ws={ws} unencoded bT")


### Dataset Statistics

In [31]:
import os
import numpy as np
import matplotlib.pyplot as plt

get_name = {
    "p21c" : "p21c14",
    "zreion" : "zreion24",
    "ctrpx" : "centralpix05"
    }

get_dir = {
    "p21c" : "/users/jsolt/data/jsolt/21cmFAST_sims",
    "zreion" : "/users/jsolt/data/jsolt/zreion_sims",
    "ctrpx" : "/users/jsolt/data/jsolt/centralpix_sims"
    }

npz_fname, fig_dir = {}, {}
sims = ["p21c", "ctrpx", "zreion"]
ws = 0.0

npz_fname = {sim:f"{get_dir[sim]}/{get_name[sim]}/{get_name[sim]}_ws{ws}_vae_stats.npz" for sim in sims}

fig_dir = "../figures/data_figures"

In [21]:
npz = {sim: np.load(npz_fname[sim]) for sim in sims}


#### MSE

In [32]:
%matplotlib inline

fig, ax = plt.subplots()
ax.set_title("MSE: Original v. Decoded ")
ax.set_ylabel("Mean MSE")
ax.set_xlabel("z slice index")

for sim in sims:
    stat = npz[sim]['mse'].mean(axis=0)

    ax.grid(True)
    ax.plot(stat, label=sim)

ax.legend()
plt.savefig(f"{fig_dir}/mean_mse_per_z.jpeg")
#plt.show()
plt.clf()

<Figure size 640x480 with 0 Axes>

#### Corrcoef

In [33]:
%matplotlib inline

fig, ax = plt.subplots()
ax.set_title("Cross-Correlation: Original v. Decoded ")
ax.set_ylabel("Mean diag(cc matrix)")
ax.set_xlabel("z slice index")

for sim in sims:
    stat = np.nanmean(npz[sim]['corrcoef'], axis=0)

    ax.grid(True)
    ax.plot(stat, label=sim)

ax.legend()
plt.savefig(f"{fig_dir}/mean_cc_per_z.jpeg")
#plt.show()
plt.clf()

<Figure size 640x480 with 0 Axes>