In [4]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"

In [5]:
from functools import partial
from pathlib import Path

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas
import typer
from chainconsumer import Chain, ChainConfig, ChainConsumer, Truth
from chainconsumer.plotting.config import PlotConfig
from tqdm import tqdm

from bpd import DATA_DIR, HOME_DIR
from bpd.draw import draw_exponential_galsim
from bpd.io import load_dataset, save_dataset
from bpd.plotting import (
    get_timing_figure,
    set_rc_params,
)
from bpd.sample import sample_galaxy_params_skew
from bpd.utils import DEFAULT_HYPERPARAMS, get_snr

FIG_DIR = HOME_DIR / "paper"
CHAIN_DIR = DATA_DIR / "cache_chains"

In [49]:
INPUT_PATHS = {
    "timing_results": CHAIN_DIR / "exp23_43" / "timing_results_43.npz",
    "timing_conv": CHAIN_DIR / "exp23_43" / "convergence_results_43.npz",
    "exp70_sp": CHAIN_DIR / "exp70_51" / "g_samples_512_plus.npy",
    "exp70_sm": CHAIN_DIR / "exp70_51" / "g_samples_512_minus.npy",
    "exp70_errs": CHAIN_DIR / "exp70_51" / "g_samples_514_errs.npz",
    "exp71_sp": CHAIN_DIR / "exp71_51" / "shear_samples_512_plus.npz",
    "exp71_sm": CHAIN_DIR / "exp71_51" / "shear_samples_512_minus.npz",
    "exp71_errs": CHAIN_DIR / "exp71_51" / "g_samples_514_errs.npz",
    "exp72_sp": CHAIN_DIR / "exp72_51" / "g_samples_512_plus.npy",
    "exp72_sm": CHAIN_DIR / "exp72_51" / "g_samples_512_minus.npy",
    "exp72_errs": CHAIN_DIR / "exp72_51" / "g_samples_514_errs.npz",
    "exp73_sp": CHAIN_DIR / "exp73_51" / "shear_samples_512_plus.npz",
    "exp73_sm": CHAIN_DIR / "exp73_51" / "shear_samples_512_minus.npz",
    "exp73_errs": CHAIN_DIR / "exp73_51" / "g_samples_514_errs_514.npz",
}

# Use medians instead of means

In [58]:
gp1 = np.load(INPUT_PATHS["exp70_sp"])
gm1 = np.load(INPUT_PATHS["exp70_sm"])

gps1 = load_dataset(INPUT_PATHS["exp70_errs"])["g_plus"]
gms1 = load_dataset(INPUT_PATHS["exp70_errs"])["g_minus"]

m1_mean = np.median(gp1[:, 0] - gm1[:, 0]) / 2 / 0.02 - 1
m1s = (np.median(gps1[:, :, 0], axis=1) - np.median(gms1[:, :, 0], axis=1)) / 2 / 0.02 - 1
m1_std = m1s.std() / np.sqrt(len(m1s))

print(m1_mean / 1e-3, m1_std / 1e-3 * 3)


-0.3437093517237866 0.948180672083316


In [59]:
gp1 = np.load(INPUT_PATHS["exp72_sp"])
gm1 = np.load(INPUT_PATHS["exp72_sm"])

gps1 = load_dataset(INPUT_PATHS["exp72_errs"])["gp"]
gms1 = load_dataset(INPUT_PATHS["exp72_errs"])["gm"]

m1_mean = np.median(gp1[:, 0] - gm1[:, 0]) / 2 / 0.02 - 1
m1s = (np.median(gps1[:, :, 0], axis=1) - np.median(gms1[:, :, 0], axis=1)) / 2 / 0.02 - 1
m1_std = m1s.std() / np.sqrt(len(m1s))

print(m1_mean / 1e-3, m1_std / 1e-3 * 3)


-0.19600598010849346 0.9483421903300215


In [62]:
dsp2 = load_dataset(INPUT_PATHS["exp71_sp"])
dsm2 = load_dataset(INPUT_PATHS["exp71_sm"])
gp1 = jnp.stack([dsp2["samples"]["g1"], dsp2["samples"]["g2"]], axis=-1)
gm1 = jnp.stack([dsm2["samples"]["g1"], dsm2["samples"]["g2"]], axis=-1)

gps1 = load_dataset(INPUT_PATHS["exp71_errs"])["plus"]['g']
gms1 = load_dataset(INPUT_PATHS["exp71_errs"])["minus"]['g']

m1_mean = np.median(gp1[:, 0] - gm1[:, 0]) / 2 / 0.02 - 1
m1s = (np.median(gps1[:, :, 0], axis=1) - np.median(gms1[:, :, 0], axis=1)) / 2 / 0.02 - 1
m1_std = m1s.std() / np.sqrt(len(m1s))

print(m1_mean / 1e-3, m1_std / 1e-3 * 3)


-0.315995486336651 1.1676360604036442


In [63]:
dsp2 = load_dataset(INPUT_PATHS["exp73_sp"])
dsm2 = load_dataset(INPUT_PATHS["exp73_sm"])
gp1 = jnp.stack([dsp2["samples"]["g1"], dsp2["samples"]["g2"]], axis=-1)
gm1 = jnp.stack([dsm2["samples"]["g1"], dsm2["samples"]["g2"]], axis=-1)

gps1 = load_dataset(INPUT_PATHS["exp73_errs"])["plus"]['g']
gms1 = load_dataset(INPUT_PATHS["exp73_errs"])["minus"]['g']

m1_mean = np.median(gp1[:, 0] - gm1[:, 0]) / 2 / 0.02 - 1
m1s = (np.median(gps1[:, :, 0], axis=1) - np.median(gms1[:, :, 0], axis=1)) / 2 / 0.02 - 1
m1_std = m1s.std() / np.sqrt(len(m1s))

print(m1_mean / 1e-3, m1_std / 1e-3 * 3)


-0.22854954440398867 1.6791922446479193


# Galaxy distributions