In [None]:
from hera_stats.bias_jackknife import bias_jackknife, bandpower
from hera_pspec import UVPSpec
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
from scipy.interpolate import interp1d
from scipy.stats import norm, multivariate_normal, gaussian_kde
from scipy.linalg import block_diag, LinAlgError
import time
from more_itertools import powerset
from itertools import combinations
import networkx as nx
from matplotlib import cm
from copy import deepcopy

default_hist_kwargs = {"histtype": "step", "bins": "auto", "density": True}

In [None]:
bp = bandpower()

### Check that we got gaussian draws about the correct mean

In [None]:
print(f"bandpower mean: {bp.mean}")
print(f"bandpower bias: {bp.bias}")
print(f"bandpower std: {bp.std}")

est_mean = np.mean(bp.bp_draws)
est_std = np.sqrt(np.var(bp.bp_draws))
print(f"bandpower estimated mean: {est_mean}")
print(f"bandpower estimated std: {est_std}")

_, _, _ = plt.hist(bp.bp_draws.flatten(), **default_hist_kwargs)
plt.yscale("log")

### Do a new draw with a bias

In [None]:
bp = bandpower(bias=1)
print(f"bandpower mean: {bp.mean}")
print(f"bandpower bias: {bp.bias}")
print(f"bandpower std: {bp.std}")

est_mean = np.mean(bp.bp_draws)
est_std = np.sqrt(np.var(bp.bp_draws))
print(f"bandpower estimated mean: {est_mean}")
print(f"bandpower estimated std: {est_std}")

_, _, _ = plt.hist(bp.bp_draws.flatten(), **default_hist_kwargs)
plt.yscale("log")

### Do a draw with a few different scales

In [None]:
bp = bandpower(bias=1, std=[1, 1.2, 0.8, 1.5])
print(f"bandpower mean: {bp.mean}")
print(f"bandpower bias: {bp.bias}")
print(f"bandpower std: {bp.std}")

est_mean = np.mean(bp.bp_draws)
est_std = np.sqrt(np.var(bp.bp_draws, axis=0))
print(f"bandpower estimated mean: {est_mean}")
print(f"bandpower estimated std: {est_std}")

for ep_ind in range(4):
    _, _, _ = plt.hist(bp.bp_draws[:, ep_ind], **default_hist_kwargs)
plt.yscale("log")

# Make a jackknife object and check numerical vs. analytic posteriors. Check with a histogram

In [None]:
def hist_check(bias=1, num_draw=int(1e3), std=0.5, check=1, num_pow=2, debug=False, bias_prior_mean=0):
    bp_test = bandpower(bias=bias, num_draw=num_draw, std=std, num_pow=num_pow)
    if check == 1:
        
        title = "Analytic vs. Numerical"
        bjack_1 = bias_jackknife(bp_test, bias_prior_mean=bias_prior_mean, analytic=True)
        bjack_2 = bias_jackknife(bp_test, bias_prior_mean=bias_prior_mean, analytic=False)
    elif check == 2:
        title = "Analytic vs. Analytic"
        bjack_1 = bias_jackknife(bp_test, bias_prior_mean=bias_prior_mean, analytic=True)
        bp_test.std = np.array(num_pow*[std])
        bjack_2 = bias_jackknife(bp_test, bias_prior_mean=bias_prior_mean, analytic=True)
        
        if debug:
            print(f"params for bjack_1 in True case: {bjack_1._get_mod_var_mean_gauss_2(True, debug=True)}")
            print(f"params for bjack_2 in True case: {bjack_2._get_mod_var_mean_gauss_2(True, debug=True)}")
            print(f"params for bjack_1 in False case: {bjack_1._get_mod_var_mean_gauss_2(False, debug=True)}")
            print(f"params for bjack_2 in False case: {bjack_2._get_mod_var_mean_gauss_2(False, debug=True)}")
            return
    else:
        raise ValueError("check keyword can only take values 1 or 2")

    fig, ax = plt.subplots(figsize=(8, 4), ncols=2)
    print(bjack_1.num_hyp)
    
    post_diffs = (bjack_1.post - bjack_2.post)
    frac_post_diffs = post_diffs / (0.5 * (bjack_1.post + bjack_2.post))

    
    _, _, _ = ax[0].hist(post_diffs.flatten(), **default_hist_kwargs)
    _, _, _ = ax[1].hist(frac_post_diffs.flatten(), **default_hist_kwargs)
    ax[0].set_xlabel("Absolute Difference")
    ax[1].set_xlabel("Fractional Difference")
    fig.suptitle(title)

    print(f"max fractional difference is {np.amax(np.abs(frac_post_diffs))}")
    print(f"max absolute difference is {np.amax(np.abs(post_diffs))}")

In [None]:
hist_check(num_pow=3)

In [None]:
hist_check(num_pow=3, bias_prior_mean=1)

In [None]:
hist_check(num_draw=int(1e3), check=2, debug=False)

In [None]:
hist_check(std=[0.5, 0.5, 0.5, 0.5], check=1, num_pow=4)

In [None]:
hist_check(std=[0.5, 0.7, 0.1, 0.5], check=1, num_pow=4)

## Do a bigger draw and histogram the posteriors - compare biased vs. unbiased

In [None]:
bp_big_draw = bandpower(bias=0, num_draw=int(1e7))
bjack_big_draw = bias_jackknife(bp_big_draw, analytic=True)
bp_big_draw_bias = bandpower(bias=1, num_draw=int(1e7))
bjack_big_draw_bias = bias_jackknife(bp_big_draw_bias, analytic=True)



In [None]:
plt.figure(figsize=(12, 6))
plt.hist(bjack_big_draw.post[0], bins=np.linspace(0, 1, num=1001), histtype='step', density=True, label="Unbiased")
plt.hist(bjack_big_draw_bias.post[0], bins=np.linspace(0, 1, num=1001), histtype='step', density=True, label="Biased")
plt.xlabel("$P(n=0|d)$")
plt.ylabel("PDF")
plt.yscale("log")
plt.xscale("log")
plt.legend()

# Library of useful functions

In [None]:
def make_diag_bpc_list(num_pow, var):
    bpc_list = []
    for epoch_on in powerset(np.arange(num_pow, dtype=int)):
        elems = np.zeros(num_pow)
        elems[np.array(epoch_on, dtype=int)] = var
        bpc_list.append(np.diag(elems))
        
    return(bpc_list)

def make_cov_list(mode, bias_prior_std, num_pow):
    bpc_shape = [num_pow, num_pow]
    var = bias_prior_std**2
    
    if mode == 'ternary':
        bpc_list = [np.zeros(bpc_shape), var * np.eye(num_pow), var * np.ones(bpc_shape)]
    else:
        bpc_list = make_diag_bpc_list(num_pow, var)
        
    return(bpc_list)

def make_sim_from_bpc_list(bpc_list, num_pow, num_draw, mean, std, bias_prior_mean):
    sim_list = []
    for bpc_ind, bpc in enumerate(bpc_list):
        if bpc_ind > 0:
            bpc_diags = np.diag(bpc)
            bias_mean_use = np.where(bpc_diags != 0, bias_prior_mean, 0)
        else:
            bias_mean_use = np.zeros(num_pow)
        bias = np.random.multivariate_normal(mean=bias_mean_use, cov=bpc, size=num_draw)
        sim = bandpower(mean=mean, bias=bias, std=std, num_pow=num_pow, num_draw=num_draw)
        sim_list.append(sim)
    return(sim_list)
        

def make_sim_list(mode, num_pow=3, bp_prior_mean=0.1, bp_prior_std=0.05, std=1, bias_prior_std=10,
                  num_draw=int(1e6), bias_prior_mean=0):
    valid_modes = ['ternary', 'diagonal']
    assert mode in valid_modes, f"mode must be in one of {valid_modes}"
    
    num_draw = int(num_draw)
    
    mean = np.random.normal(loc=bp_prior_mean, scale=bp_prior_std, size=(num_draw, num_pow))
    bpc_list = make_cov_list(mode, bias_prior_std, num_pow)

    sim_list = make_sim_from_bpc_list(bpc_list, num_pow, num_draw, mean, std, bias_prior_mean)
    
    
    return(sim_list)

def get_jk_list(sim_list, mode, num_pow=3, bp_prior_mean=0.1, bp_prior_std=0.05, std=1, bias_prior_std=10,
                num_draw=int(1e6), hyp_prior=None, bias_prior_mean=0):

    jk_list = []
    for sim in sim_list:
        jk = bias_jackknife(sim, bp_prior_mean=bp_prior_mean, bp_prior_std=bp_prior_std,
                            bias_prior_mean=bias_prior_mean,
                            bias_prior_std=bias_prior_std, mode=mode, hyp_prior=hyp_prior)
        
        jk_list.append(jk)
    
    return(jk_list)

            

def get_mut_info(mode, num_pow=3, bp_prior_mean=0.1, bp_prior_std=0.05, std=1, bias_prior_std=10,
                 num_draw=int(1e6), hyp_prior=None, bias_prior_mean=0):
    
    num_draw = int(num_draw)
    sim_list = make_sim_list(mode, num_pow=num_pow, bp_prior_mean=bp_prior_mean, bp_prior_std=bp_prior_std,
                             std=std, bias_prior_std=bias_prior_std, num_draw=num_draw,
                             bias_prior_mean=bias_prior_mean)
    if hyp_prior is None:
        num_hyp = len(sim_list)
        hyp_prior = np.ones(num_hyp) / num_hyp
    
    dat_list = [sim.bp_draws.T for sim in sim_list]
    trial = np.random.multinomial(1, hyp_prior, size=num_draw).argmax(axis=1)
    mixdat = np.choose(trial, dat_list).T
    
    mix_sim = deepcopy(sim_list[0])
    mix_sim.bp_draws = np.copy(mixdat)
    
    jk = bias_jackknife(mix_sim, bp_prior_mean=bp_prior_mean, bp_prior_std=bp_prior_std, mode=mode,
                        bias_prior_std=bias_prior_std, hyp_prior=hyp_prior, bias_prior_mean=bias_prior_mean)
    #jk_list = get_jk_list(sim_list, bp_prior_mean=bp_prior_mean, bp_prior_std=bp_prior_std,
                          #bias_prior_std=bias_prior_std, mode=mode, hyp_prior=hyp_prior,
                          #bias_prior_mean=bias_prior_mean)

    #evid_list = [jk.evid for jk in jk_list]
    
    #evid_samps = np.choose(trial, evid_list)
    #logs = np.where(evid_samps > 0, -np.log2(evid_samps), 0)
    
    #
    #Hcond = jk_list[0].sum_entropy
    
    logs = np.where(jk.evid > 0, -np.log2(jk.evid), 0)
    Hd = logs.mean()
    Hcond = jk.sum_entropy
    
    if np.any(np.isnan(logs)):
        print("Some nans in logs")
        
    if np.any(np.isinf(logs)):
        print("Some infs in logs")
    mut_info = Hd-Hcond
    
    return(mut_info)

def mut_info_wrap(mode, num_pow=4, bp_prior_mean=0, bp_prior_std=0, std=1,
                  bias_prior_means=np.linspace(0, 10, num=100),
                  bias_prior_stds=np.logspace(-3, 3, num=100), num_draw=int(1e6), hyp_prior=None):

    mut_info_arr = np.zeros([len(bias_prior_stds), len(bias_prior_means)])
    for bps_ind, bias_prior_std in enumerate(bias_prior_stds):
        print(bps_ind)
        for bpm_ind, bias_prior_mean in enumerate(bias_prior_means):
            mut_info = get_mut_info(mode, num_pow=num_pow, bp_prior_mean=bp_prior_mean, bp_prior_std=bp_prior_std,
                                    bias_prior_std=bias_prior_std, num_draw=num_draw, std=std, hyp_prior=hyp_prior,
                                    bias_prior_mean=bias_prior_mean)
            mut_info_arr[bps_ind, bpm_ind] = mut_info
    return((bias_prior_stds, bpm, mut_info_arr))


_concs = ["Null hypothesis accepted",
          "All epochs likely biased identically"]
_not_concs = ["Null hypothesis rejectied",
              "All epochs likely biased differently"]

def get_odds(jk, odds_thresh):
    max_post_ind = np.argmax(jk.post)
    odds = jk.post / jk.post[max_post_ind] 
    comp_inds = np.logical_and(odds > 1, odds < odds_thresh).flatten()
    
    conc = {"max_post_ind": max_post_ind, "comp_inds": comp_inds,
            "odds": odds[comp_inds]}
    
    return(odds, conc)

def get_max_compet_combos(epochs, **conc):
    combos = list(powerset(epochs))
    max_post_combo = combos[max_post_ind]
    comp_combos = combos[comp_inds]
    
    return(max_post_combo, comp_combos)
    
    
def run_jk(bp_meas, num_pow=3, bp_prior_mean=0.1, bp_prior_std=0.05, bias_prior_std=10, odds_thresh=10,
           std=1, bias_prior_mean=0, jk_mode="diagonal", print_post=False):
    bp = bandpower(bp_meas=bp_meas, simulate=False, num_pow=num_pow, num_draw=1, std=std)

    modes = ["ternary", "diagonal"]

        
    concs = []
    odds_list = []
    for mode_ind, mode in enumerate(modes):
        if (jk_mode == "diag_only") and (mode != "diagonal"):
            continue
        jk = bias_jackknife(bp, bp_prior_mean=bp_prior_mean, bp_prior_std=bp_prior_std,
                            bias_prior_std=bias_prior_std, mode=mode, bias_prior_mean=bias_prior_mean)
        if print_post:
            print(jk.post)
        odds, conc = get_odds(jk, odds_thresh)
        odds_list.append(odds)
        concs.append(conc)
 
    return(odds_list, concs)
        
    

# Intro plot ("intuition builder")

In [None]:
class int_builder_dat:
    
    def __init__(self, s=0.1, n=1, zwant=4, num_pow=4, new_draw=False, save_new=False,
                 fns=['unbiased_int_build.npy', 'bias1_int_build.npy', 'bias2_int_build.npy'],
                 zcut=0.25):
        self.s = s
        self.n = n
        self.zwant= zwant
        self.num_pow = num_pow
        self.new_draw = new_draw
        self.save_new = save_new
        self.fns = fns
        self.zcut = zcut
        
        self.b1, self.b2 = self.get_b1_b2()
        self.bp_meas_unb, self.bp_meas_b1, self.bp_meas_b2 = self.get_draws()
        
    def get_b1_b2(self):
        b1 = self.zwant / np.sqrt(self.num_pow) * self.n
        b2 = self.zwant * np.sqrt(self.num_pow) * self.n
        return(b1, b2)
    
    def get_draws(self):
        if self.new_draw:
            bp_meas_unb, bp_meas_b1, bp_meas_b2 = self.gen_new_draws()
            bps_meas = [bp_meas_unb, bp_meas_b1, bp_meas_b2]
            if self.save_new:
                for fn, bp_meas in zip(self.fns, bps_meas):
                    np.save(fn, bp_meas)
        else:
            bp_meas_unb, bp_meas_b1, bp_meas_b2 = self.load_draws()
        return(bp_meas_unb, bp_meas_b1, bp_meas_b2)

    def gen_new_draws(self):
        typical=False
        while not typical:
            bp_meas_unb = np.random.normal(loc=self.s, scale=self.n, size=self.num_pow)
            bp_meas_b1 = np.random.normal(loc=self.s + self.b1, scale=self.n, size=self.num_pow)
            bp_meas_b2 = np.random.normal(loc=[self.s + self.b2] + (self.num_pow - 1) * [self.s], scale=self.n,
                                          size=self.num_pow)
            zb1 = np.sqrt(self.num_pow) * np.mean(bp_meas_b1 - self.s) / self.n - self.zwant
            zb2 = np.sqrt(self.num_pow) * np.mean(bp_meas_b2 - self.s) / self.n - self.zwant
            if (np.abs(zb1) < self.zcut) and (np.abs(zb2) < self.zcut):
                typical = True

        return(bp_meas_unb, bp_meas_b1, bp_meas_b2)

    def load_draws(self):
        bps_meas = []
        for fn in self.fns:
            bps_meas.append(np.load(fn))
        bps_meas = tuple(bps_meas)

        return(bps_meas)
    
class int_builder_plot:
    
    def __init__(self, ibd, fn="int_builder.pdf", right_xlim=1.47):
        self.ibd = ibd
        self.set_plot_params(right_xlim)
            
        self.fig, self.ax = fig, ax = plt.subplots(figsize=(13, 6), ncols=2)
        
        
        self.plot_dat()
        self.set_ax_labels()
        self.fig.savefig(fn)
        
    
    def set_plot_params(self, right_xlim):  # Just a bunch of hardocded params I know work well
        self.xunb = np.arange(0, self.ibd.num_pow, dtype=int)
        self.xb = self.xunb + 0.1
        self.ebar_params = {"marker":'s', "yerr":self.ibd.n, "linestyle":'', "capsize":10, "elinewidth":1}
        self.ylim = [-4, 10]
        self.xlim = [-0.25, 3.35]
        self.alpha_high = 0.25
        self.alpha_low = 0.2
        self.colors = ["#cc0000", "#3465a4"]
        self.fill_xlims = [[-0.25, 5], [-0.25, right_xlim]]
    
    def set_ax_labels(self):
        
        
        self.ax[0].set_ylabel("Toy Bandpower", fontsize="xx-large")   
        for ax in self.ax:
            ax.set_ylim(self.ylim)
            ax.set_xlim(self.xlim)
            ax.set_xlabel("Epoch", fontsize="xx-large")
    
    def plot_dat(self):
        legend_entries = [r'$O($"at least one bias"$,$"all unbiased"$) = 2.71\cdot10^{-5}$',
                          r'$O($"all biased equally"$,$"all unbiased"$) = 4.98$',
                          r'$O($"$d_0$ biased only"$,$"all unbiased"$) = 94.0$']
        locs = ["upper left", "center left"]
        for ax_ind, (ax, dat, b) in enumerate(zip(self.ax, [self.ibd.bp_meas_b1, self.ibd.bp_meas_b2],
                                                  [self.ibd.b1, self.ibd.b2])):
            
            ax.errorbar(self.xunb, self.ibd.bp_meas_unb, color='black', label=legend_entries[0], **self.ebar_params)
            ax.axhline(y=self.ibd.s, color='black', linestyle='--', linewidth=1)
            self.do_fill(ax, self.fill_xlims[ax_ind], b, self.colors[ax_ind])
            ax.errorbar(self.xb, dat, color=self.colors[ax_ind], label=legend_entries[1 + ax_ind], **self.ebar_params)
            ax.axhline(y=self.ibd.s + b, color=self.colors[ax_ind], linestyle='--', linewidth=1)
            ax.legend(frameon=False, loc="best")
            ax.set_xticks(range(4))
            
    def do_fill(self, ax_ob, xlims, b, color):
        ax_ob.fill_between([-0.25, 5], self.ibd.s + self.ibd.n, y2 = self.ibd.s-self.ibd.n, color='gray',
                           alpha=self.alpha_high)
        ax_ob.fill_between([-0.25, 5], self.ibd.s + 2 * self.ibd.n, y2 = self.ibd.s-2 * self.ibd.n, color='gray',
                           alpha=self.alpha_low)
        ax_ob.fill_between(xlims, self.ibd.s + b + self.ibd.n, y2 = self.ibd.s + b-self.ibd.n, color=color,
                           alpha=self.alpha_high)
        ax_ob.fill_between(xlims, self.ibd.s + b + 2 * self.ibd.n, y2 = self.ibd.s + b-2 * self.ibd.n, color=color,
                           alpha=self.alpha_low)


ibd = int_builder_dat(zwant=1.5, new_draw=False, save_new=False, zcut=0.5)
ibp = int_builder_plot(ibd, right_xlim=0.47)

In [None]:
#bp_meas_unb, bp_meas_b1, bp_meas_b2 = ibd.get_draws(new_draw=False, save_new=False, num_pow=4, zcut=0.25, zwant=4)

odds_list = []
conc_list = []
for bp_meas in [ibd.bp_meas_unb, ibd.bp_meas_b1, ibd.bp_meas_b2]:
    odds, conc = run_jk(bp_meas, num_pow=4, bias_prior_std=5, bias_prior_mean=0 * np.ones(4), print_post=False,
                        bp_prior_mean=0, bp_prior_std=0)
    print(np.sum(odds[1]) - 1)
    print(odds[1])
    print(1 / odds[1][0])

## Check which hypotheses are "most discernible" from another and use that as the decision ordering

We have to construct mutual information plots for Stage 1, Stage 2, and potentially Stage 3

In [None]:
def I2d_plot_wrapper(calc_I=True, save_I=True, num_pow=4, num_draw = int(1e4), save_plot=True,
                     bps = np.logspace(-3, 3, num=100), bpm=np.linspace(0, 10, num=100)):

    prefix = f"mut_info_2d_pow_{num_pow}_draw_{num_draw}"
    npy_fn = f"{prefix}.npy"
    plot_fn = f"{prefix}.pdf"


    if calc_I:
        _, _, I2d = mut_info_wrap("diagonal", num_draw=num_draw, num_pow=num_pow, bias_prior_stds=bps, 
                                      bias_prior_means=bpm)
        if save_I:
            np.save(npy_fn, I2d)
    else:

        I2d = np.load(npy_fn)
    bps_g, bpm_g = np.meshgrid(bps, bpm)
    plt.pcolor(bpm_g, bps_g, I2d.T, cmap='plasma', vmin=0, vmax=4, edgecolors='face')
    plt.colorbar(label="Mutual Information (bits)")
    plt.xlabel("Bias Prior Mean (Error Bar Widths)")
    plt.ylabel("Bias Prior Width (Error Bar Widths)")
    plt.yscale("log")
    if save_plot:
        plt.savefig(plot_fn)
    return(I2d)
I2d = I2d_plot_wrapper(calc_I = False)

In [None]:
plt.plot(I2d[20, :])
plt.xscale("log")

In [None]:
modes = ["ternary", "ternary", "diagonal"]
labels = ["Unbiased vs. All Biased Differently", "Unbiased vs. All Biased Identically",
          "All Combinations of Biases"]
hyp_priors = [np.array([0.5, 0.5, 0]), np.array([0.5, 0, 0.5]), None]
num_draw = int(1e6)
fig, ax = plt.subplots(figsize=(8, 4.5))
lines = []
calc_I = False
save=False
num_pow = 4 
bias_prior_mean=0
for mode, label, hyp_prior in zip(modes, labels, hyp_priors):
    if calc_I:
        bps, I = mut_info_wrap(mode, num_draw=num_draw, hyp_prior=hyp_prior, num_pow=num_pow,
                               bias_prior_mean=bias_prior_mean, bias_prior_stds=np.logspace(-2, 2, num=20))
        if save:
            np.save(f"mut_info_demo_{mode}.npy", I)
    else:
        bps = np.logspace(-3, 3, num=100)
        I = np.load(f"mut_info_demo_{mode}.npy")
    if mode == "ternary":
        line, = ax.plot(bps, I, label=label)
        lines.append(line)
    else:
        twin_ax = ax.twinx()
        line, = twin_ax.plot(bps, I, label=label, color="tab:green")
        lines.append(line)
        twin_ax.set_ylabel("Mutual Information (Bits)", fontsize="xx-large")

ax.legend(lines, labels, loc="upper left", frameon=False)
ax.set_ylabel("Mutual Information (Bits)", fontsize="xx-large")
ax.set_xlabel("Bias Prior Width (in Error Bar Units)", fontsize="xx-large")
ax.set_xscale("log")
ax.axvline(10, color='black', linestyle='--')
fig.savefig("mutual_info_demo_3stage.pdf")

In [None]:
modes = ["diagonal"]
labels = ["All Combinations of Biases"]
hyp_priors = [None]
num_draw = int(1e5)
fig, ax = plt.subplots(figsize=(8, 4.5))
lines = []
calc_I = True
save= True
num_pow = 4 
bias_prior_mean=5
for mode, label, hyp_prior in zip(modes, labels, hyp_priors):
    if calc_I:
        bps, I = mut_info_wrap(mode, num_draw=num_draw, hyp_prior=hyp_prior, num_pow=num_pow,
                               bias_prior_mean=bias_prior_mean, bias_prior_stds=np.logspace(-2, 2, num=20))
        if save:
            np.save(f"mut_info_demo_mean_offset_{mode}.npy", I)
    else:
        bps = np.logspace(-3, 3, num=100)
        I = np.load(f"mut_info_demo_mean_offset_{mode}.npy")
    if mode == "ternary":
        line, = ax.plot(bps, I, label=label)
        lines.append(line)
    else:
        twin_ax = ax.twinx()
        line, = twin_ax.plot(bps, I, label=label, color="tab:green")
        lines.append(line)
        twin_ax.set_ylabel("Mutual Information (Bits)", fontsize="xx-large")

ax.legend(lines, labels, loc="upper left", frameon=False)
ax.set_ylabel("Mutual Information (Bits)", fontsize="xx-large")
ax.set_xlabel("Bias Prior Width (in Error Bar Units)", fontsize="xx-large")
ax.set_xscale("log")
ax.axvline(10, color='black', linestyle='--')
fig.savefig("mutual_info_demo_mean_offset.pdf")

In [None]:
modes = ["stage1", "stage2", "diagonal"]
labels = ["Stage 1", "Stage 2", "Stage 3"]
num_draw = int(1e4)
fig, ax = plt.subplots(figsize=(8, 4.5))
lines = []
for mode, label in zip(modes, labels):
    bps, I = mut_info_wrap(mode, num_pow=4, num_draw=num_draw, std=[1, 1.2, 1.4, 11])
    if mode in ["stage1", "stage2"]:
        line, = ax.plot(bps, I, label=label)
        lines.append(line)
    else:
        twin_ax = ax.twinx()
        line, = twin_ax.plot(bps, I, label=label, color="tab:green")
        lines.append(line)
        twin_ax.set_ylabel("Mutual Information (Bits)", fontsize="xx-large")

ax.legend(lines, labels, loc="upper left", frameon=False)
ax.set_ylabel("Mutual Information (Bits)", fontsize="xx-large")
ax.set_xlabel("Bias Prior Width (in Error Bar Units)", fontsize="xx-large")
ax.set_xscale("log")
ax.axvline(10, color='black', linestyle='--')

# Find out what happens if you use a prior that is too broad

In [None]:
true_std = 2
sims = make_sim_list('diagonal', num_pow=4, bp_prior_mean=0.1, bp_prior_std=0.05, std=1, bias_prior_std=true_std,
                     num_draw=int(1e4))

jk_list = get_jk_list(sims, mode="diagonal", num_pow=4, bp_prior_mean=0.1, bp_prior_std=0.05, std=1,
                                  bias_prior_std=70, num_draw=int(1e4))
def sens_fpr_sig(jk_list, sig=3):

    true_pos_total = 0
    is_pos_total = 0
    false_pos_total = 0
    is_neg_total = 0
    for jk in jk_list:
        where_is_pos = (np.any(np.abs(jk.bp_obj.bias) > sig, axis=1))
        where_not_pos = np.logical_not(where_is_pos)
        where_class_pos = (np.argmax(jk.post, axis=0) > 0)
        
        num_is_pos = where_is_pos.sum()


        num_not_pos = where_not_pos.sum()
        
        where_true_pos = np.logical_and(where_class_pos, where_is_pos)
        where_false_pos = np.logical_and(where_class_pos, where_not_pos)
        
        num_true_pos = where_true_pos.sum()
        num_false_pos = where_false_pos.sum()
        

        true_pos_total += num_true_pos
        is_pos_total += num_is_pos
        false_pos_total += num_false_pos
        is_neg_total += num_not_pos

    sens = true_pos_total / is_pos_total
    fpr = false_pos_total / is_neg_total
    return(sens, fpr)

sens, fpr = sens_fpr_sig(jk_list, sig=6)
print(sens, fpr)

In [None]:
def sens_fpr_ratio_wrapper(bias_prior_stds=np.arange(1, 101), true_std=1, num_draw=int(1e4), sig=3):
    sens_fpr_rats = []
    for bias_prior_std in bias_prior_stds:
        sims = make_sim_list('diagonal', num_pow=4, bp_prior_mean=0.1, bp_prior_std=0.05, std=1, bias_prior_std=true_std,
                             num_draw=num_draw)

        jk_list = get_jk_list(sims, mode="diagonal", num_pow=4, bp_prior_mean=0.1, bp_prior_std=0.05, std=1,
                              bias_prior_std=bias_prior_std, num_draw=num_draw)
        sens, fpr = sens_fpr_sig(jk_list, sig=sig)
        print(sens, fpr)
        sens_fpr_rats.append(sens / fpr)
    return(np.array(sens_fpr_rats))

bias_prior_stds = np.arange(10, 110, 10)
for true_std in range(1, 6):
    rats = sens_fpr_ratio_wrapper(bias_prior_stds=bias_prior_stds, true_std=true_std, num_draw=int(1e4), sig=3)
    plt.plot(bias_prior_stds, rats)
    plt.axhline(10, color='black', linestyle='--')

In [None]:
def hist_odds(jk_list, norm_ind, label, bins=np.logspace(-7, 3, num=101)):
    odds = jk_list[norm_ind].post / jk_list[norm_ind].post[norm_ind]
    if norm_ind == 0:
        #spec = (np.argmax(odds, axis=0) == 0).mean()
        spec = (np.amax(odds[1:], axis=0) < 1).mean()
        print(f"{label} specificity: {spec}")
    plt.hist(odds.flatten(), bins=bins, histtype='step', label=label)
    plt.xscale("log")
    plt.legend()
    
def hist_odds_wrapper(good_jk_list, bad_jk_list, norm_ind):
    jk_lists = [good_jk_list, bad_jk_list]
    labels = ["correct prior", "incorrect prior"]
    for jk_list, label in zip(jk_lists, labels):
        hist_odds(jk_list, norm_ind, label)
        

hist_odds_wrapper(jk_list_good_prior, jk_list_bad_prior, 0)

### Intuition-building Plot

In [None]:
def get_lods(p1, p2):
    return(10 * np.log10(p1 / p2)[30])

num_pow = 4

X = np.linspace(0, 10, num=100)
Y = np.zeros([100, num_pow])
Y[:, 0] = X

lods10 = []
lods12 = []
lods31 = []
stds = np.logspace(0, 2, num=100, base=10)
for std in stds:
    var = std**2
    p0 = multivariate_normal.pdf(Y, mean=np.zeros(num_pow), cov=np.eye(num_pow))
    
    
    p1 = multivariate_normal.pdf(Y, mean=np.zeros(num_pow), cov=np.eye(num_pow)*(var + 1))
    p2 = multivariate_normal.pdf(Y, mean=np.zeros(num_pow), cov=np.eye(num_pow) + var * np.ones([num_pow, num_pow]))
    p3 = multivariate_normal.pdf(Y, mean=np.zeros(num_pow),
                                 cov=var * np.diag([1] + (num_pow - 1) * [0]) + np.eye(num_pow))
    
    lods10.append(get_lods(p3, p0))
    lods12.append(get_lods(p3, p2))
    lods31.append(get_lods(p3, p1))

    
plt.plot(stds, lods10, label="3/0")
plt.plot(stds, lods12, label="3/2")
plt.plot(stds, lods31, label="3/1")
plt.axhline(0, color='black', linestyle='--')
plt.legend()
print(stds[np.argmax(np.array(lods31) + np.array(lods12) + np.array(lods10))])

# Make a plot of the mutual information as a function of bias std and corrrelation coefficient.

In [None]:

def get_H(p):
    return(-np.sum(p * np.log2(p), axis=0))




    


# Memo sampling distribution plot

In [None]:
def set_hist_plots(ax_obj, title):
    ax_obj.set_xlabel(r'$P(\mathcal{H}=$"sys"$ | \hat{P}, S_0)$')
    ax_obj.set_ylabel("Counts")
    ax_obj.legend(frameon=False, loc="upper center")
    ax_obj.set_title(title)
    ax_obj.set_ylim([5e-1, 1e6])

In [None]:
fig, ax = plt.subplots(figsize=(12, 12), ncols=2, nrows=2)

hist_kwargs = {"bins": np.linspace(0, 1, num=1001), "density": False, 'histtype': 'step', 'log': True}

unbiased = {'mean': 1, 'std': 1, 'bias': 0, 'num_draw': int(1e7)}
mod_bias = {'mean': 1, 'std': 1, 'bias': 2, 'num_draw': int(1e7)}
bias = np.repeat(np.array([0, 0, 0, 4])[:, np.newaxis], int(1e7), axis=1)
large_bias = {'mean': np.ones([4, int(1e7)]), 'std': 1, 'bias': bias, 'num_draw': int(1e7)}

uninformative = {'bp_prior_mean': 1,'bp_prior_std': 0.5,'bias_prior_mean': 0,'bias_prior_std': 10}
informative = {'bp_prior_mean': 1,'bp_prior_std': 0.5,'bias_prior_mean': 3,'bias_prior_std': 3}

bias_param_list = [unbiased, mod_bias, large_bias]

labels = ["Unbiased", "2-$\sigma$ Bias on all epochs", "4-$\sigma$ bias on one epoch"]
spec=np.linspace(0.001, 0.999, num=999)

for bias_params, label in zip(bias_param_list, labels):
    bp = bandpower(**bias_params)
    jk = bias_jackknife(bp, **uninformative)
    ax[0, 0].hist(jk.post[0], label=label, **hist_kwargs)
    if "epoch" in label:
        sens, odds, redo_prob, _, _ = get_sens_odds(**bias_params, **uninformative)
        if "one" in label:
            linestyle='--'
        else:
            linestyle="-"
        do_sens_plot(ax[1, 0], spec, sens, odds, redo_prob, linestyle=linestyle, label="")
        
    
set_hist_plots(ax[0, 0], "Less Informative Priors")
    
for bias_params, label in zip(bias_param_list, labels):
    bp = bandpower(**bias_params)
    jk = bias_jackknife(bp, **informative)
    ax[0, 1].hist(jk.post[0], label=label, **hist_kwargs)
    if "epoch" in label:
        sens, odds, redo_prob, _, _ = get_sens_odds(**bias_params, **informative)
        if "one" in label:
            linestyle='--'
        else:
            linestyle="-"
        do_sens_plot(ax[1, 1], spec, sens, odds, redo_prob, linestyle=linestyle, label="")

set_hist_plots(ax[0, 1], "Informative Priors")
ax[0, 1].set_title("More Informative Priors")


fig.savefig("sampling_dist_prior_compare.png")

In [None]:
fig, ax = plt.subplots(figsize=(12, 12), ncols=2, nrows=2)
p_bad = 0.5

hist_kwargs = {"bins": np.linspace(0, 1, num=1001), "density": False, 'histtype': 'step', 'log': True}

unbiased = {'mean': 1, 'std': 1, 'bias': 0, 'num_draw': int(1e7)}
mod_bias = {'mean': 1, 'std': 1, 'bias': 2, 'num_draw': int(1e7)}
bias = np.repeat(np.array([0, 0, 0, 4])[np.newaxis, :], int(1e7), axis=0)
large_bias = {'mean': np.ones([int(1e7), 4]), 'std': 1, 'bias': bias, 'num_draw': int(1e7)}

uninformative = {'bp_prior_mean': 1,'bp_prior_std': 0.5,'bias_prior_mean': 0,'bias_prior_std': 10, 'p_bad': p_bad}
informative = {'bp_prior_mean': 1,'bp_prior_std': 0.5,'bias_prior_mean': 3,'bias_prior_std': 3, 'p_bad': p_bad}

bias_param_list = [unbiased, mod_bias, large_bias]

labels = ["Unbiased", "2-$\sigma$ Bias on all epochs", "4-$\sigma$ bias on one epoch"]
spec=np.linspace(0.001, 0.999, num=999)

for bias_params, label in zip(bias_param_list, labels):
    bp = bandpower(**bias_params)
    jk = bias_jackknife(bp, **uninformative)
    ax[0, 0].hist(jk.post[0], label=label, **hist_kwargs)
    if "epoch" in label:
        sens, odds, redo_prob, _, _ = get_sens_odds(**bias_params, **uninformative)
        if "one" in label:
            linestyle='--'
        else:
            linestyle="-"
        do_sens_plot(ax[1, 0], spec, sens, odds, redo_prob, linestyle=linestyle, label="")
        
    
set_hist_plots(ax[0, 0], "Less Informative Priors")
    
for bias_params, label in zip(bias_param_list, labels):
    bp = bandpower(**bias_params)
    jk = bias_jackknife(bp, **informative)
    ax[0, 1].hist(jk.post[0], label=label, **hist_kwargs)
    if "epoch" in label:
        sens, odds, redo_prob, _, _ = get_sens_odds(**bias_params, **informative)
        if "one" in label:
            linestyle='--'
        else:
            linestyle="-"
        do_sens_plot(ax[1, 1], spec, sens, odds, redo_prob, linestyle=linestyle, label="")

set_hist_plots(ax[0, 1], "Informative Priors")
ax[0, 1].set_title("More Informative Priors")


fig.savefig("sampling_dist_prior_compare_p_bad_mod.png")

# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!
# Beware all ye blinded folks! Spherical bandpowers be shown below!

In [None]:
ps_repo = "/Users/mike_e_dubs/Repositories/H1C_IDR3_Power_Spectra"
epoch_0_results = f"{ps_repo}/Epoch_0_Power_Spectra/results_files"
epoch_res_dirs = [epoch_0_results, ]
for epoch in range(1, 4):
    epoch_results = f"{ps_repo}/SPOILERS/Epoch_{epoch}_Power_Spectra/results_files"
    
    epoch_res_dirs.append(epoch_results)
print(epoch_res_dirs)

In [None]:
def get_epochs_in_band_field(band, field):
    epoch_list = []
    for epoch, epoch_res_dir in enumerate(epoch_res_dirs):
        if os.path.exists(f"{epoch_res_dir}/Pofk_Band_{band}_Field_{field}.h5"):
            epoch_list.append(epoch)
    return(epoch_list)

bands = list('12')
fields = list('ABCDE')

epoch_dict = {}
for band in bands:
    for field in fields:
        band_field_key = (band, field)
        epoch_dict[band_field_key] = get_epochs_in_band_field(band, field)
print(epoch_dict)
epoch_dict[('1', 'C')] = [0, 1, 2]
epoch_dict[('2', 'C')] = [0, 1, 2]
print(epoch_dict)

In [None]:
#spw = 0
#field = "B"
#bp_prior_mean = 300
#bp_prior_std = bp_prior_mean / 2
#bias_prior_std = 1e5

def get_uvp_goodies(spw, field, mode, epochs=np.arange(4, dtype=int), comp='real'):
    bp_meas = []
    stds = []
    if mode == 'dsq':
        modestr = "Deltasq"
    elif mode == 'pk':
        modestr = 'Pofk'
    else:
        raise ValueError("mode must be 'dsq' or 'pk'")
    for epoch in epochs:
        uvp = UVPSpec()
        path = f"{epoch_res_dirs[epoch]}/{modestr}_Band_{spw}_Field_{field}.h5"
        if os.path.exists(path):
            uvp.read_hdf5(path)
            bp_meas.append(getattr(uvp.data_array[0].squeeze(), comp))
            stds.append(np.sqrt(uvp.cov_array_real[0].squeeze().diagonal()))
        else:
            print(f"Could not find results file for epoch {epoch}, Band {spw}, Field {field}")
            continue

    if spw == 1:
        num_k = 29      
    else:
        num_k = 25
    
    # First 3 are all 0, only the next num_k are reported in this band
    bp_meas_k = np.array(bp_meas)[:, 3:3 + num_k]
    stds_k = np.array(stds)[:, 3:3 + num_k]
    ks = uvp.get_kparas(0)[3:3 + num_k] # The kparas store the actual k magnitude in these files?
    
    return(bp_meas_k, stds_k, ks)



In [None]:
def get_odds_list_k(ks, stds_k, bp_meas_k, epochs, bppm=10, use_I=False, bias_mult=10, jk_mode="diag_only",
                    bias_prior_mean=0):
    odds_list_k = []
    conc_list_k = []
    for k_ind, k in enumerate(ks):
        bppm_use = bppm * 2 * np.pi**2 / k**3
        std = stds_k[:, k_ind]
        bp_meas = bp_meas_k[:, k_ind]
        if use_I:
            bias_prior_stds = np.logspace(4, 10, num=100)
            _, I = mut_info_wrap('diagonal', num_pow=len(epochs), bp_prior_mean=bppm_use, bp_prior_std=bppm_use/2,
                                 std=std, bias_prior_stds=bias_prior_stds, num_draw=int(1e4))
            bias_prior_std = bias_prior_stds[np.argmax(I / I.max() > 0.9)] 
            print(bias_prior_std / np.amax(std))
        else:
            bias_prior_std = bias_mult * std
        odds_list, concs = run_jk(bp_meas, bp_prior_mean=bppm, bp_prior_std = bppm / 2,
                                  bias_prior_mean=bias_prior_mean * std,
                                  std=std, bias_prior_std=bias_prior_std, num_pow=len(epochs), jk_mode=jk_mode)
        odds_list_k.append(odds_list)
        conc_list_k.append(concs)
    
    return(odds_list_k, conc_list_k)

def get_stage12_scat(ks, odds_list_k):
    stage1_xscat = ks
    stage1_yscat = [1 / odds_list_k[p][0] for p in range(len(ks))] # Want inverse odds so that plot is primarily violators
    
    stage2_yscat = [1 / odds_list_k[p][1] for p in range(len(ks)) if len(odds_list_k[p]) > 1]
    stage2_xscat = [ks[p] for p in range(len(ks)) if len(odds_list_k[p]) > 1]
    
    stage1_xy = (stage1_xscat, stage1_yscat)
    stage2_xy = (stage2_xscat, stage2_yscat)
    
    return(stage1_xy, stage2_xy)

def plot_stage_12(band_ind, field_ind, ax, stage1_xy, stage2_xy):
    ax.scatter(*stage1_xy, color='blue', alpha=0.5)
    ax.scatter(*stage2_xy, color='orange', alpha=0.5)
    ax.axhline(1, color='black', linestyle='--')
    ax.axhline(10, color='red', linestyle='--')
    
    ax.set_yscale("log")
    if field_ind == 0:
        ax.set_ylabel(f"Band {band_ind + 1}, Odds", fontsize="xx-large")
    if band_ind == 0:
        ax.set_title(f"Field {'BCDE'[field_ind]}", fontsize="xx-large")
    if band_ind == 1:
        ax.set_xlabel("k ($h$ Mpc$^{-1}$)", fontsize="xx-large")
    plt.tight_layout()
    
def get_epoch_combos(epochs):
    return(np.array(list(powerset(epochs)), dtype=object))

def stoplight_plot(ax, epoch_combos, ks, conc_list_k):
    num_combo = len(epoch_combos)
    image = np.zeros([num_combo, len(ks)])
    for k_ind, k in enumerate(ks):
        if len(conc_list_k[k_ind]) < 3:
            pass
        else:
            compet_inds = conc_list_k[k_ind][2]['comp_inds']
            image[np.array(conc_list_k[k_ind][2]['max_post_ind']), k_ind] = 2
            if len(compet_inds):
                image[compet_inds, k_ind] = 1
    ax.imshow(image, vmin=0, vmax=2, aspect='auto', interpolation='none')
    ax.set_yticks(range(len(epoch_combos)))
    ax.set_yticklabels(epoch_combos, fontsize="xx-large")
    plt.tight_layout()
    
def odds_plot_2d(ax, epoch_combos, ks, odds_list_k, band_ind, field_ind, vmin=0.1, vmax=1, log=False):
    num_combo = len(epoch_combos)
    num_k = len(ks)
    image = np.zeros([num_combo, num_k])
    for k_ind in  range(num_k):
        if log:
            image[:, k_ind] = np.log10(odds_list_k[k_ind][0].T)
        else:
            image[:, k_ind] = odds_list_k[k_ind][0].T
            
    cax = ax.imshow(image, aspect='auto', interpolation='none', cmap='plasma', vmax=vmax, vmin=vmin,
                    extent=[ks[0], ks[-1], num_combo - 0.5, -0.5])
    ax.set_yticks(range(len(epoch_combos)))
    ax.set_yticklabels(epoch_combos, fontsize="xx-large")
    do_multipanel_labels(ax, band_ind, field_ind, "Bias Configurations")
    ax.tick_params(labelsize="xx-large")
    
    plt.tight_layout()
    return(cax)
    


def get_stage_inds(stage, odds_list_k):
    
    num_ks = len(odds_list_k)
    bools = [(len(odds_list_k[p]) == stage) for p in range(num_ks)]   
    inds = np.where(bools)
    
    return(inds)

def do_multipanel_labels(ax, band_ind, field_ind, label):
    fields = 'BCDE'
    if field_ind == 0:
        if band_ind == 0:
            ax.set_ylabel(f"Band 1, {label}", fontsize="xx-large")
        else:
            ax.set_ylabel(f"Band 2, {label}", fontsize="xx-large")
    if band_ind == 1:
        ax.set_xlabel("$k$ ($h$ Mpc$^{-1}$)", fontsize="xx-large")
    if band_ind == 0:
        ax.set_title(f"Field {fields[field_ind]}", fontsize="xx-large")

def do_error_plot_stage3(bp_ax, dat, ks, stds, fade_alpha, label, color, odds_list_k, conc_list_k, epoch,
                         epoch_combos, band_ind, field_ind, stagger=0, jk_mode="diag_only", plot_z=True):

        if jk_mode == "full":
            alphas = [fade_alpha, 1, 1]
            markers = ['.', '^', 'X']
            for stage in range(1,4):
                stage_inds = get_stage_inds(stage, odds_list_k)
                k_use = np.array(ks)[stage_inds]
                dat_use = dat[stage_inds]
                std_use = stds[stage_inds]
                alpha = alphas[stage - 1]
                marker = markers[stage - 1]

                
                bp_ax.errorbar((k_use + stagger), dat_use, alpha=alpha,
                               yerr=std_use, color=color,
                               marker=marker, linestyle='', markersize=8, label=label)
        else:
            outlying_bools = [epoch in epoch_combos[conc_list_k[p][0]["max_post_ind"]] for p in range(len(ks))]
            alphas = [1, 0.15]
            for inds_ind, inds_use in enumerate([np.array(outlying_bools), np.logical_not(outlying_bools)]):
                k_use = np.array(ks)[inds_use]
                dat_use = dat[inds_use]
                std_use = stds[inds_use]
                if not inds_ind:
                    label_use = label
                else:
                    label_use=None
                if plot_z:
                    bp_ax.scatter((k_use + stagger), dat_use / std_use, alpha=alphas[inds_ind],
                               color=color, s=100,
                               marker='.', label=label_use)
                else:
                    bp_ax.errorbar((k_use + stagger), dat_use, alpha=alphas[inds_ind],
                                   yerr=std_use, color=color,
                                   marker='.', linestyle='', markersize=8, label=label_use)
            if plot_z:
                bp_label = "$z$-score"
            else:
                bp_label = "Bandpower (mK$^2$ $h^3$ Mpc$^{-3}$)"
            do_multipanel_labels(bp_ax, band_ind, field_ind, bp_label)
    
def bp_plot(bp_ax, bp_meas_k, stds_k, ks, odds_list_k, epochs, conc_list_k, epoch_combos, band_ind, field_ind,
            stagger=0.004, bppm=10, plot_z=True,
            fade_alpha=0.2, jk_mode="diag_only"):
    

    mu0 = 2 * np.pi**2 * bppm / np.array(ks)**3
    
    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"]
    for eind, epoch in enumerate(epochs):
        stagger = 0.004 * eind
        do_error_plot_stage3(bp_ax, bp_meas_k[eind], ks, stds_k[eind], fade_alpha, f"epoch {epoch}",
                             colors[epoch], odds_list_k, conc_list_k, epoch, epoch_combos, band_ind, 
                             field_ind, stagger=stagger,
                             jk_mode=jk_mode, plot_z=plot_z)
    
    bp_ax.axhline(0, color='black', linestyle='--')
    bp_ax.legend(frameon=False, fontsize="xx-large")
    bp_ax.set_xlim(ks[0] - 10 * stagger, ks[-1] + 10 * stagger)
    bp_ax.tick_params(labelsize="xx-large")
    
def H_plot(ax, odds_list_k, ks):
    Hs = []
    for k_ind in range(len(ks)):
        odds = odds_list_k[k_ind][0][:, 0]
        p = 1 / (odds * np.sum(1 / odds))
        H = -p @ np.log2(p) 
        Hs.append(H)
    ax.plot(ks, Hs)



def bias_jackknife_wrapper(comp="real", bppm=10, stagger=0.004, use_I=False, fade_alpha=0.2,
                           bias_mult=[20, 10, 7, 10], jk_mode="diag_only", fields='BCDE', bias_prior_mean=0,
                           dual_mean=False, plot_z=True):
    
    fig, ax = plt.subplots(figsize=(32, 8), nrows=2, ncols=4)
    conc_fig, conc_ax = plt.subplots(figsize=(32, 16), nrows=2, ncols=4)
    bp_fig, bp_ax = plt.subplots(figsize=(32, 16), nrows=2, ncols=4)
    for band_ind, band in enumerate(list('12')):
        for field_ind, field in enumerate(list(fields)):
            epochs = epoch_dict[(band, field)]
            bp_meas_k, stds_k, ks = get_uvp_goodies(int(band), field, 'pk', epochs=epochs, comp=comp)
            odds_list_k, conc_list_k = get_odds_list_k(ks, stds_k, bp_meas_k, epochs, bppm=bppm, use_I=use_I,
                                                       bias_mult=bias_mult[field_ind], jk_mode=jk_mode,
                                                       bias_prior_mean=bias_prior_mean)

         
            # Do stage12 plot
            epoch_combos = get_epoch_combos(epochs)
            if jk_mode == 'full':
                stage1_xy, stage2_xy = get_stage12_scat(ks, odds_list_k)
                plot_stage_12(band_ind, field_ind, ax[band_ind, field_ind], stage1_xy, stage2_xy)

                # Do stage 3 stuff
                
                ks_stage3 = [ks[p] for p in range(len(ks)) if (len(odds_list_k[p]) ==3)]
                concs_use_k = [conc_list_k[p][2] for p in range(len(ks)) if (len(odds_list_k[p]) == 3 )]
                stoplight_plot(conc_ax[band_ind, field_ind], epoch_combos, ks, conc_list_k)
            else:
                H_plot(ax[band_ind, field_ind], odds_list_k, ks)
                cax = odds_plot_2d(conc_ax[band_ind, field_ind], epoch_combos, ks, odds_list_k, band_ind, field_ind)
            bp_plot(bp_ax[band_ind, field_ind], bp_meas_k, stds_k, ks, odds_list_k, epochs, conc_list_k,
                    epoch_combos, band_ind, field_ind, bppm=bppm, stagger=stagger, fade_alpha=fade_alpha,
                    plot_z=plot_z)
      
            #bp_fig.legend(fontsize="xx-large", frameon=False)
    norm = mpl.colors.Normalize(vmin=0.1, vmax=1)
    conc_fig.tight_layout(h_pad=1, w_pad=1)
    cbar = conc_fig.colorbar(cm.ScalarMappable(norm=norm, cmap='plasma'), ax=conc_ax)
    cbar.set_label("Odds against Maximum a Posteriori Hypothesis", fontsize=40)
    cbar.ax.tick_params(labelsize=20)
    
    fig.savefig(f"stage12_{comp}.pdf")
    conc_fig.savefig(f"stoplight_{comp}.png")
    bp_fig.savefig(f"bandpowers_{comp}.pdf")

bias_jackknife_wrapper(fade_alpha=0.15, use_I=False, bias_mult=10 * np.ones(4), bppm=0, jk_mode="diag_only",
                       fields='BCDE')

In [None]:
bias_jackknife_wrapper(fade_alpha=0.15, use_I=False, bias_mult=1 * np.ones(4), bppm=0, jk_mode="diag_only",
                       fields='BCDE', bias_prior_mean=10)

In [None]:
bias_jackknife_wrapper(fade_alpha=0.15, use_I=False, bias_mult=1 * np.ones(4), bppm=1, jk_mode="diag_only",
                       fields='BCDE', bias_prior_mean=-6)

In [None]:
bias_jackknife_wrapper(fade_alpha=0.15, use_I=False, bias_mult=1 * np.ones(4), bppm=0, jk_mode="diag_only",
                       fields='BCDE', bias_prior_mean=6, comp='imag')

In [None]:
bias_jackknife_wrapper(fade_alpha=0.15, use_I=False, bias_mult=1 * np.ones(4), bppm=0, jk_mode="diag_only",
                       fields='BCDE', bias_prior_mean=-6, comp='imag')

In [None]:
bias_jackknife_wrapper(fade_alpha=0.15, use_I=False, bias_mult=70 * np.ones(4), bppm=0, jk_mode="diag_only",
                       fields='BCDE', comp='imag')

In [None]:
def get_p_seq(L, dom=True):
    p = np.ones(L)
    p[0] = 10
    if not dom:
        p[1] = 5
    p = p / np.sum(p)
    H = - p @ np.log2(p)
    return(H)
print(get_p_seq(4) - get_p_seq(4, dom=False))
print(get_p_seq(16) - get_p_seq(16, dom=False))


In [None]:
plot_post_quant(epochs=np.array([0, 2]))
plot_post_quant(epochs=np.array([0, 1]))
plot_post_quant(epochs=np.array([1, 2]))

In [None]:
plot_post_quant(hyp_prior=np.array([0.5, 0.25, 0.25]), epochs=np.arange(1, 4))

In [None]:
plot_post_quant(hyp_prior=np.array([0.5, 0.5, 0]), two_hyp=True)

In [None]:
def mode_test(spw, field, k_inds=None, epochs=np.arange(4), bpw=6):

    bps, stds, ks = get_uvp_goodies(spw, field, 'pk', epochs=epochs)
    bppms = np.logspace(9, 9)
    bppms_use = 2 * np.pi**2 * np.outer(bppms, 1 / ks**3)
    if k_inds is None:
        k_inds = range(len(ks))
    for k_ind in k_inds:
        print(f"k = {ks[k_ind]}")
        bp = bandpower(simulate=False, bp_meas=bps[:, k_ind], num_draw=1, num_pow=len(epochs),
                                       std=stds[:, k_ind])
        bjk = bias_jackknife(bp, bp_prior_mean=bppms_use[0, k_ind], bp_prior_std=bppms_use[0, k_ind] / 2,
                             bias_prior_std=bpw * np.mean(stds[:, k_ind]), hyp_prior=[0.5, 0.5], mode="binary")
        print(f"Binary test: {bjk.post}")
        bjk = bias_jackknife(bp, bp_prior_mean=bppms_use[0, k_ind], bp_prior_std=bppms_use[0, k_ind] / 2,
                             bias_prior_std=bpw * np.mean(stds[:, k_ind]), hyp_prior=[0, 0.5, 0.5], mode="ternary")
        print(f"Ternary test: {bjk.post}")
        max_ternary_ind = np.argmax(bjk.post)
        bjk = bias_jackknife(bp, bp_prior_mean=bppms_use[0, k_ind], bp_prior_std=bppms_use[0, k_ind] / 2,
                             bias_prior_std=bpw * np.mean(stds[:, k_ind]),
                             hyp_prior=np.ones(2**len(epochs)) / 2**len(epochs), mode="diagonal")
        max_diag_ind = np.argmax(bjk.post)
        if max_ternary_ind == 1:
            max_diag_post = np.amax(bjk.post)
            lodds = 10 * np.log10(max_diag_post / bjk.post)
            print(f"Diagonal test result: {np.diag(bjk.bias_prior.cov[max_diag_ind]).nonzero()}")
            
            comps = np.logical_and(1 < lodds, lodds < 10)
            comp_matrs = bjk.bias_prior.cov[comps.flatten()]
            comp_h = [np.diag(matr).nonzero() for matr in comp_matrs]
            print(f"competing hypotheses: {comp_h}")
            print(f"lodds: {lodds[comps]}")

mode_test(2, 'D', [19, 24], bpw=10)
    

In [None]:
mode_test(1, 'C', epochs=np.arange(3), bpw=10)

In [None]:
mode_test(2, 'B', epochs=np.arange(2), bpw=10)

# First run the binary test b/w null and all four, then all four correlated, then disect if necessary

## Run the diagonal test, with the prior chosen to optimize the likelihood ratio for biases of a certain size and type, print the lods. If the lods are greater than 10 for all competing hypotheses, write "conclusive," otherwise list the nearest competing hypotheses. 



# Scratch Paper

In [None]:
def get_num_hyp_max(M):
    return(1 + M + M * (M - 1) * 2**(M - 2))
dim = 2
num_hyp = get_num_hyp_max(dim)

mean = np.zeros(dim)
cov = np.zeros([num_hyp, dim, dim])
for i in range(dim):
    cov[i + 1, i, i] = 1
cov[dim + 1] = np.diag(np.ones(dim))
cov[dim + 2] = np.ones([dim, dim])

ebar = 0.0001 * np.diag(np.ones(dim))
for i in range(num_hyp):
    cov += ebar
cov_block = block_diag(*cov)
mean = np.zeros(num_hyp * dim)

In [None]:
num_draw = int(1e6)
dat = np.random.multivariate_normal(mean=mean, cov=cov_block, size=num_draw).reshape([num_draw, num_hyp, dim])
likes = np.zeros([num_hyp, num_draw, num_hyp])
for i in range(num_hyp):
    likes[i] = multivariate_normal.pdf(dat, mean=np.zeros(dim), cov=cov[i])

def get_posts(prior, likes):
    num = (prior * likes.T).T
    evid = np.sum(num, axis=0)
    posts = num / evid
    return(posts)

def get_Ixy(posts, prior, num_draw):
    pulls = np.random.multinomial(1, prior, size=num_draw).argmax(axis=1)
    new_posts = np.zeros([len(prior), num_draw])
    for i in range(len(prior)):
        new_posts[i] = posts[i, :, i]
    post_use_final = np.choose(pulls, list(new_posts))
    H = -np.nanmean(np.log2(post_use_final))
    Hmax = -np.nansum(prior * np.log2(prior))
    Ixy = Hmax - H
    return(Ixy, Ixy/Hmax)



In [None]:
def get_performance(prior, likes, special=False, num_draw=int(1e6)):
    posts = get_posts(prior, likes)
    if special:
        max_post = np.where(posts[0] > 0.5, 0, np.argmax(posts[1:], axis=0) + 1)
    else:
        # MAP rule for each datum, second axis is true hypothesis
        max_post = np.argmax(posts, axis=0)
    # Fraction of data ruled as decision on second axis, hyp on first axis
    # So this is P(D=H_i|H=H_i)
    class_rate = (max_post[:, :, np.newaxis] == np.arange(5)).mean(axis=0)
    Ixy = get_Ixy(posts, prior, num_draw)

    return(class_rate, Ixy[1])

priors = [np.array([0.5, 0, 0, 0.5, 0]),
          np.array([0.5, 0, 0, 0.25, 0.25]),
          np.array([0.5, 0.16, 0.17, 0.17, 0]),
          np.array([0.5, 0.125, 0.125, 0.125, 0.125]),
          np.array([0.2, 0.2, 0.2, 0.2, 0.2])]
for prior in priors:
    class_rate, Ixy_frac = get_performance(prior, likes)
    print(class_rate)
    print(f"Ixy_frac: {Ixy_frac}")


In [None]:
sens = np.sum(max_post == np.arange(5)) / np.prod(max_post.shape)
print(sens)

In [None]:
from itertools import combinations, chain, permutations
def powerset(iterable):  # From the "more-itertools package"
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)

    return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))


In [None]:
        bias_prior_std = np.array([2, 2.1, 2.4, 2.3])
        bias_cov_shape = [2**16, 3, 3]
        bias_cov = np.zeros(bias_cov_shape)
        hyp_ind = 0
        for diag_on in powerset(range(3)):
            N_on = len(diag_on)
            if N_on == 0:  # Null hypothesis - all 0 cov. matrix
                hyp_ind += 1
            elif N_on == 1:
                bias_cov[hyp_ind, diag_on, diag_on] = bias_prior_std[np.array(diag_on)]**2
                hyp_ind += 1
            else:
                for k in range(2, N_on + 1):
                    combos = combinations(diag_on, k)
                    for combo in combos:
                        all_pairs = combinations(combo, 2)
                        print(diag_on)
                        print(list(combo))
                        print(f"subcombo: {list(all_pairs)}")

In [None]:
def partitions(n, I=1):
    yield (n,)
    for i in range(I, n//2 + 1):
        for p in partitions(n-i, i):
            yield (i,) + p
            
def partition(number):
     answer = set()
     answer.add((number, ))
     for x in range(1, number):
         for y in partition(number - x):
             answer.add((x, ) + y)
     return answer

In [None]:
for part in partition(3):
    print(part)

In [None]:
for part in partitions(3):
    num_blocks = len(part)
    

In [None]:
from scipy.sparse import block_diag
size = 3
matr_set = set()
parts = partitions(size)
matr_list = []
print(f"parts: {parts}")
for part in parts:
    num_blocks = len(part)
    for dec in range(2**num_blocks):
        bin_str = bin(dec)[2:].zfill(num_blocks)
        sublist = []
        for block_ind, block_size in enumerate(part):
            shape = [block_size, block_size]
            matr = np.full(shape, int(bin_str[block_ind]))
            sublist.append(matr)
        work_matr = block_diag(sublist).toarray()

        matr_list.append(work_matr)

cov = np.unique(matr_list, axis=0)
print(cov)
        

In [None]:
matr_set = set(((1, 2, 3), (4, 2,3)))
for item in np.unique(matr_set):
    print(item)

In [None]:
print(len(partition(10)))

In [None]:
st = set((np.zeros([2,2]),))

In [None]:
st.update((2, ))
print(st)

In [None]:
for item in combinations([0, 1], 4):
    print(item)

In [None]:
block_diag([np.array([[0]]), np.array([[1]])]).toarray()

In [None]:
for item in powerset(combinations([0, 2, 4, 6], 2)):
    if len(item) > 1:
        print(f"item: {item}")
        for it in item:
            st = set(it)
            for it2 in item:
                diff = st.symmetric_difference(set(it2))
                st.update_diff
            

In [None]:
st = set((0,2))
st.symmetric_difference_update(set((0, 4)))
print(st)

For N diagonals that are on, find the adjacency matrices for complete graphs of each partition

Do you just shift the partitions diagonally?

There's something very important about **conjugate partitions**

On/off, corr/uncorr

"Write down adjacency matrices for all possible cluster graphs of N points"

In [None]:
print(list(powerset(range(4))))

In [None]:
from scipy.special import comb
print(comb(2*2, 2))

# Some code I would like preserved



In [None]:
from more_itertools import set_partitions

In [None]:
parts = set_partitions(np.arange(4))
for part in parts:
    print(part)

In [None]:
from itertools import combinations
list(combinations([0], 2))

In [None]:
from scipy.special import comb
def get_num_hyp(N):
        """
        Fun fact: these are called bell numbers. For N bandpowers, we actually
        want the N+1th bell number, which is the total number of ways of
        partitioning a set with N+1 elements.
        """
        M = N + 1
        B = np.zeros(M + 1)
        B[0] = 1  # NEED THE SEED
        for n in range(M):
            for k in range(n + 1):
                B[n + 1] += comb(n, k, exact=True) * B[k]


        return(B)

In [None]:
print(get_num_hyp(5))

In [None]:
M = 0.1 * np.eye(2) + np.ones([2, 2])
inv = np.linalg.inv(M)
print(inv)
print(np.linalg.eig(inv))
x = np.full(2, 1 / np.sqrt(2))
print(x.T@inv@x)

In [None]:
M = 0.1 * np.eye(2) + np.array([[1, 0], [0, 1]])
inv = np.linalg.inv(M)
print(inv)
print(np.linalg.eig(inv))
x = np.full(2, 1 / np.sqrt(2))
print(x.T@inv@x)

In [None]:
np.random.multinomial(100, pvals=[0.33, 0.34, 0.33], size=1)

In [None]:
np.arange(10).reshape((5, 2))

In [None]:
N = 16

P = np.ones(N)
P[0] = 10
P = P / P.sum()
print(P)
-P@np.log(P) / np.log(N)

In [None]:
np.ones(3) @ np.linalg.inv(np.eye(3) + 100 * np.ones([3, 3])) @ np.ones(3)

In [None]:
np.ones(3) @ np.linalg.inv(np.eye(3) + 100 * np.eye(3)) @ np.ones(3)

In [None]:
x= np.random.normal(size=2)
y = np.random.normal(loc=[5, 0], size=2)
sig_p = 10
mu_p = np.array([0, 0])
sig_post = np.sqrt(1 / (1 + 1/sig_p**2))

mu = sig_post**2 * (y + mu_p/sig_p**2)
print(mu)
print(y)
print(sig_post)

In [None]:
print(sig_like)
print(1/ (1 + 1/1e4))

In [None]:
np.linalg.inv([[1, 0.9], [0.9, 1]])

In [None]:
plt.close()

In [None]:
mean1 = np.zeros(4)
std1 = 1
mean2 = np.repeat(1, 4)
I = []
stds = np.logspace(-2, 2, num=100)
num_draw = int(1e4)
for std2 in stds:
    std_use = np.sqrt(std1**2 + std2**2)
    dat1 = np.random.multivariate_normal(mean=mean1, cov=std1 * np.eye(4), size=num_draw)
    dat2 = np.random.multivariate_normal(mean=mean2, cov=std_use * np.eye(4), size=num_draw)
    trial = np.random.multinomial(1, [0.5, 0.5], size=num_draw).argmax(axis=1)
    dat = np.choose(trial, [dat1.T, dat2.T]).T
    

    like1 = multivariate_normal(mean=mean1, cov=std1).pdf(dat)
    like2 = multivariate_normal(mean=mean2, cov=std_use).pdf(dat)

    
    P = 0.5 * (like1 + like2)
    Hd = -np.log2(P).mean()
    
    Hcond = 0.5 * (multivariate_normal(mean=mean1, cov=std1 * np.eye(4)).entropy() + multivariate_normal(mean=mean2, cov=std_use*np.eye(4)).entropy()) / np.log(2)
    
    Inow = (Hd - Hcond) #/ np.log(2)
    I.append(Inow)

plt.plot(stds, I)
plt.xscale("log")

In [None]:
from scipy.spatial.distance import jensenshannon as js

x = np.linspace(-10, 10, num=100)
y1 = norm.pdf(x)
scales = np.logspace(-2, 2, base=10)
JS = []
for scale in scales:
    y2 = norm(loc=10, scale=scale).pdf(x)
    JS.append(js(y1, y2, base=2))

plt.plot(scales, np.array(JS)**2)
plt.xscale("log")



In [None]:
from scipy.spatial.distance import jensenshannon as js

x = np.linspace(-10, 10, num=100)
X, Y, Z = np.meshgrid(x, x, x)
R = np.array([X.flatten(), Y.flatten(), Z.flatten()]).T
y1 = multivariate_normal(mean=np.zeros(3)).pdf(R)
scales = np.logspace(-2, 2, base=10)
JS = []
for scale in scales:
    y2 = multivariate_normal(mean=scale*np.ones(3), cov=6*np.eye(3)).pdf(R)
    HdA = 0.5 * multivariate_normal(mean=np.zeros(3)).entropy() + 0.5 * multivariate_normal(mean=scale*np.ones(3), cov=6*np.eye(3)).entropy()
    Hd = 
    JS.append(js(y1, y2, base=2))





In [None]:
plt.plot(scales, np.array(JS)**2)
plt.xscale("log")

In [None]:
num_pow = 4
num_draw = int(1e6)
bias = np.random.multivariate_normal(mean = np.zeros(num_pow), cov = 0.2 * np.eye(num_pow), size=num_draw)
mean = 0
bp = bandpower(num_pow=4, bias=bias, mean=mean, std=1)
jk = bias_jackknife(bp, bias_prior_mean=0, bias_prior_std=10, bp_prior_mean=0, bp_prior_std=0, mode="ternary",
                    hyp_prior=[0.5, 0.5, 0])
H = (-jk.post[:2] * np.log2(jk.post[:2])).sum(axis=0)

_, _, _ = plt.hist(H, bins=np.linspace(0, 1, num=100), histtype='step')
print(np.nanmean(H))

In [None]:
_, _, _ = plt.hist(jk.post[1], bins=np.linspace(0, 1, num=100), histtype='step')

In [None]:
print(H.mean())