In [None]:
# Data gen

from src import analysis
from phise import Context
import astropy.units as u
import numpy as np

def get_ctx() -> Context:
    ctx = Context.get_VLTI()
    ctx.target.companions[0].c = 1e-2
    ctx.interferometer.chip.σ = np.zeros(14) * u.nm
    ctx.monochromatic = True
    return ctx

N = 10_000
def get_distrib(ctx) -> np.ndarray:
    dists = np.empty((3, N))
    for i in range(N):
        print(f"⌛ Sampling {i+1}/{N} ({(i+1)/N:.2%})", end='\r')
        # observe() returns raw intensities; process to get kernels
        outs = ctx.observe()
        k = ctx.interferometer.chip.process_outputs(outs)
        dists[:, i] = k
    print("✅ Done" + " " * 30)
    return dists

# Full context ----------------------------------------------------------------

ctx_full = get_ctx()
dists_full = get_distrib(ctx_full)

# Star only -------------------------------------------------------------------

ctx_so = get_ctx()
ctx_so.target.companions = []
dists_so = get_distrib(ctx_so)

# Planet only -----------------------------------------------------------------

ctx_po = get_ctx()
print("Before scaling:")
print("    Star flux:", ctx_po.target.f)
print("    Companion flux:", ctx_po.target.companions[0].c * ctx_po.target.f)
scale = 1e12
ctx_po.target.f /= scale
ctx_po.target.companions[0].c *= scale
print("After scaling:")
print("    Star flux:", ctx_po.target.f)
print("    Companion flux:", ctx_po.target.companions[0].c * ctx_po.target.f)
dists_po = get_distrib(ctx_po)

dists_comb = dists_so + dists_po

dists_star_noise = np.empty_like(dists_so)
for k in range(3):

    dists_star_noise[k] = dists_so[k] + np.median(dists_po[k])

In [None]:
# Plot

kmin = 0
kmax = 0
for dist in [dists_so, dists_po, dists_full, dists_comb, dists_star_noise]:
    for k in range(3):
        mi, ma = np.percentile(dist, [5, 95])
        kmin = min(kmin, mi)
        kmax = max(kmax, ma)

for k in range(3):
    # Use 2*sqrt(samples) as number of bins
    bins = np.linspace(kmin, kmax, 2*int(np.sqrt(N)) + 1)

# Plot histograms to compare distributions
import matplotlib.pyplot as plt

fig, axs = plt.subplots(3, 1, figsize=(15, 8), tight_layout=True)
labels = ['k1', 'k2', 'k3']
for i in range(3):
    axs[i].hist(dists_so[i], bins=bins, alpha=0.5, label='Star Only', density=True, log=True)
    axs[i].set_title(f'Distribution on Kernel {i+1}')
    axs[i].set_xlabel('Intensity [photons events]')
    axs[i].set_ylabel('Occurences')
    axs[i].legend()

fig, axs = plt.subplots(3, 1, figsize=(15, 8), tight_layout=True)
labels = ['k1', 'k2', 'k3']
for i in range(3):
    axs[i].hist(dists_po[i], bins=bins, alpha=0.5, label='Planet Only', density=True, log=True)
    axs[i].set_title(f'Distribution on Kernel {i+1}')
    axs[i].set_xlabel('Intensity [photons events]')
    axs[i].set_ylabel('Occurences')
    axs[i].legend()

fig, axs = plt.subplots(3, 1, figsize=(15, 8), tight_layout=True)
labels = ['k1', 'k2', 'k3']
for i in range(3):
    axs[i].hist(dists_full[i], bins=bins, alpha=0.5, label='Full model', density=True, log=True)
    axs[i].hist(dists_comb[i], bins=bins, alpha=0.5, label=r'$x = n_* + S_p + n_p$', histtype='step', linewidth=2, density=True, log=True)
    axs[i].hist(dists_star_noise[i], bins=bins, alpha=0.5, label=r'$x = n_* + S_p$', histtype='step', linewidth=2, density=True, log=True)
    axs[i].set_title(f'Distribution on Kernel {i+1}')
    axs[i].set_xlabel('Intensity [photons events]')
    axs[i].set_ylabel('Occurences')
    axs[i].legend()