In [0]:
DEBUG = False
CONFIG_FILE = '/datascope/subaru/data/targeting/dSph/draco/pmap/draco_nb/ga-pmap_20250313213620.config'
OUTPUT_PATH = '/datascope/subaru/data/targeting/dSph/draco/pmap/draco_nb'

# Plot the probability map

In [0]:
import os, sys
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import commentjson as json

In [0]:
plt.rc('font', size=6) #controls default text size

In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
if DEBUG and 'debug' not in globals():
    import debugpy
    debugpy.listen(('0.0.0.0', int(os.environ['PFS_TARGETING_DEBUGPORT'])))
    debug = True

# Imports

In [0]:
import pfs.utils

from pfs.ga.targeting.scripts.pmap.notebooks.notebooks import *
from pfs.ga.targeting.targets.dsph import GALAXIES as DSPH_FIELDS
from pfs.ga.targeting.targets.m31 import M31_FIELDS
from pfs.ga.targeting import ProbabilityMap

# Load the pmap config and args file

In [0]:
# Load the configuration
config = load_pmap_config(CONFIG_FILE)

In [0]:
args_file = os.path.splitext(CONFIG_FILE)[0] + '.args'
with open(args_file) as f:
    args = json.load(f)

In [0]:
args

# Plot definitions

In [0]:
from pfs.ga.targeting.targets.dsph import GALAXIES as DSPH_FIELDS
from pfs.ga.targeting.targets.m31 import M31_FIELDS
from pfs.ga.targeting.instrument import *
from pfs.ga.targeting.diagram import CMD, CCD, FOV, FP, ColorAxis, MagnitudeAxis
from pfs.ga.targeting.photometry import Photometry, Magnitude, Color
from pfs.ga.targeting.projection import Pointing, WcsProjection

In [0]:
if 'dsph' in args and args['dsph'] is not None:
    field = DSPH_FIELDS[args['dsph']]

hsc = field.get_photometry()
cmd = field.get_cmd()
ccd = field.get_ccd()

In [0]:
pointing = field.get_center()
pointing

In [0]:
wcs = WcsProjection(pointing, proj='TAN')
wfc = SubaruWFC(pointing)
fov = FOV(projection=wcs)
fp = FP(wfc)

# Load the simulation

In [0]:
sim = load_simulation(config)

In [0]:
for k in sim.data.keys():
    print(k, sim.data[k].shape)

# Load observations

In [0]:
obs = load_observations(field, config)

In [0]:
cmd.axes

# Plot the observations and the simulation

In [0]:
mask = field.get_selection_mask(obs, observed=True, nb=config.cut_nb, blue=config.keep_blue, probcut=None)

f, axs = plt.subplots(2, 2, figsize=(6, 8), dpi=120)

cmd.plot_catalog(axs[0, 0], obs, observed=True)
ccd.plot_catalog(axs[0, 1], obs, observed=True)

cmd.plot_catalog(axs[1, 0], obs, observed=True)
cmd.plot_catalog(axs[1, 0], obs, observed=True, mask=mask, color='red')
ccd.plot_catalog(axs[1, 1], obs, observed=True)
ccd.plot_catalog(axs[1, 1], obs, observed=True, mask=mask, color='red')

f.tight_layout()

In [0]:
s = np.s_[::10]

mask = field.get_selection_mask(sim, observed=True, nb=config.cut_nb, blue=config.keep_blue)

f, axs = plt.subplots(2, 2, figsize=(6, 8), dpi=120)

cmd.plot_simulation(axs[0, 0], sim, observed=True, s=s, size=0.05)
ccd.plot_simulation(axs[0, 1], sim, observed=True, s=s, size=0.05)

cmd.plot_simulation(axs[1, 0], sim, observed=True, s=s, size=0.05)
cmd.plot_simulation(axs[1, 0], sim, observed=True, s=s, size=0.05, mask=mask, color='red')
ccd.plot_simulation(axs[1, 1], sim, observed=True, s=s, size=0.05)
ccd.plot_simulation(axs[1, 1], sim, observed=True, s=s, size=0.05, mask=mask, color='red')

f.tight_layout()

# Update the population weights

This is basically just manually scaling the Galaxia MW population weights until it 
matches the observations. 

In [0]:
config.population_weights

In [0]:
if config.population_weights is not None:
    s = PMapScript()
    s._config = config
    w1, g1 = s._PMapScript__update_weights(sim)

In [0]:
# Number of objects inside cuts
mask = field.get_selection_mask(obs, nb=config.cut_nb, blue=config.keep_blue, observed=True)
n_obs = mask.sum()
print('obs', n_obs)

mask = field.get_selection_mask(sim, nb=config.cut_nb, blue=config.keep_blue, observed=True)
mask = sim.apply_categories(mask, g=sim.data['g'])
n_sim = mask.sum()
print('sim', n_sim)

mask = field.get_selection_mask(sim, nb=config.cut_nb, blue=config.keep_blue, observed=True)
mask = sim.apply_categories(mask, g=g1)
n_sim = mask.sum()
print('sim', n_sim)

n_sim / n_obs

In [0]:
f, axs = plt.subplots(2, 3, figsize=(6, 6), dpi=120)

s = np.s_[::1]

mask = field.get_selection_mask(obs, nb=config.cut_nb, blue=config.keep_blue, observed=True)
cmd.plot_observation(axs[0, 0], obs, size=0.05, mask=mask, s=s)
ccd.plot_observation(axs[1, 0], obs, size=0.05, mask=mask, s=s)
axs[0, 0].set_title('OBS')

s = np.s_[::3]

mask = field.get_selection_mask(sim, nb=config.cut_nb, blue=config.keep_blue, observed=True)
cmd.plot_simulation(axs[0, 1], sim, observed=True, apply_categories=True, mask=mask, g=g1, s=s, size=0.05)
ccd.plot_simulation(axs[1, 1], sim, observed=True, apply_categories=True, mask=mask, g=g1, s=s, size=0.05)
axs[0, 1].set_title('SIM updated weights')

mask = field.get_selection_mask(sim, nb=config.cut_nb, blue=config.keep_blue, observed=True)
cmd.plot_simulation(axs[0, 2], sim, observed=True, apply_categories=True, mask=mask, g=sim.data['g'], s=s, size=0.05)
ccd.plot_simulation(axs[1, 2], sim, observed=True, apply_categories=True, mask=mask, g=sim.data['g'], s=s, size=0.05)
axs[0, 2].set_title('SIM original weights')

for ax in axs.flatten():
    ax.grid()
    ax.set_xlim(-1, 2.2)

f.tight_layout()

# Plot color histograms

In [0]:
def plot_histogram(ax, obs, sim, axis, bins, plot_populations=True):
    ((x, x_err),) = obs.get_diagram_values([axis], observed=True)
    mask = field.get_selection_mask(obs, nb=config.cut_nb, blue=config.keep_blue, observed=True)
    hist, bins = np.histogram(x[mask], bins=bins, density=True)
    ax.step(0.5 * (bins[1:] + bins[:-1]), hist, lw=1, label='OBS')
    print(x.min(), x.max())

    ((x, x_err),) = sim.get_diagram_values([axis], observed=True)
    mask = field.get_selection_mask(sim, nb=config.cut_nb, blue=config.keep_blue, observed=True)
    mask = sim.apply_categories(mask, g=g1)
    x = sim.apply_categories(x, g=g1)
    hist, bins = np.histogram(x[mask], bins=bins, density=True)
    ax.step(0.5 * (bins[1:] + bins[:-1]), hist, lw=1, label='SIM')
    
    if plot_populations:
        for i, name in enumerate(config.population_names):
            # TODO: what if we don't have binaries
            hist, bins = np.histogram(x[mask][(g1[mask[:,0]] == 2 * i) | (g1[mask[:,0]] == 2 * i + 1)], bins=bins, density=True)
            ax.step(0.5 * (bins[1:] + bins[:-1]), (w1[2 * i] + w1[2 * i + 1]) * hist, lw=0.5, label=name)

In [0]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 2.4), dpi=240)

plot_histogram(ax, obs, sim, cmd.axes[0], bins=np.linspace(-1.0, 2.0, 100))

ax.set_xlim(-1, 2.2)
ax.set_xlabel(cmd.axes[0].label)
ax.legend()

In [0]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 2.4), dpi=240)

plot_histogram(ax, obs, sim, cmd.axes[1], bins=np.linspace(16, 23, 100))

ax.set_xlabel(cmd.axes[1].label)
ax.legend()

In [0]:
# Load and plot the probability map

In [0]:
fn = os.path.join(OUTPUT_PATH, 'pmap.h5')
pmap = ProbabilityMap(cmd.axes)
pmap.load(fn)

In [0]:
f, axs = plt.subplots(1, 2, figsize=(6, 4), dpi=120)

l0 = cmd.plot_probability_map(axs[0], pmap, 0)
l1 = cmd.plot_probability_map(axs[1], pmap, 1)

f.tight_layout()

# Membership probability based on the map

In [0]:
lp_member, mask_member = pmap.lookup_lp_member(obs)

lp_member.shape, np.isnan(lp_member).sum(), np.isnan(lp_member[mask_member]).sum(), mask_member.shape, mask_member.sum()

In [0]:
f, axs = plt.subplots(1, 2, figsize=(6, 4), dpi=120)

cmd.plot_observation(axs[0], obs, c=lp_member[...,0])
ccd.plot_observation(axs[1], obs, c=lp_member[...,0])

f.tight_layout()

In [0]:
mask = field.get_selection_mask(obs, nb=config.cut_nb, blue=config.keep_blue, observed=True)

f, axs = plt.subplots(1, 2, figsize=(6, 4), dpi=120)

cmd.plot_observation(axs[0], obs, c=lp_member[...,0][mask], mask=mask)
ccd.plot_observation(axs[1], obs, c=lp_member[...,0][mask], mask=mask)

f.tight_layout()

# Ghost plots

In [0]:
# mask = get_selection_mask(obs, nb=True, blue=False, probcut=probcut)
ghost_mask = pmap.create_random_mask(obs)

ghost_mask.shape

In [0]:
mask = field.get_selection_mask(obs, observed=True, nb=config.cut_nb, blue=config.keep_blue, probcut=None)

f = plt.figure(figsize=(4, 4), dpi=240)
gs = f.add_gridspec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1], wspace=0.35, hspace=0.35)

ax = f.add_subplot(gs[0, 0])
cmd.plot_observation(ax, obs, color='gray', mask=mask)

ax = f.add_subplot(gs[0, 1])
cmd.plot_observation(ax, obs, mask=mask & ~ghost_mask[..., 1], color='gray')

ax = f.add_subplot(gs[1, 0], projection=wcs.wcs)
fov.plot_observation(ax, obs, color='gray', mask=mask)

ax = f.add_subplot(gs[1, 1], projection=wcs.wcs)
fov.plot_observation(ax, obs, mask=mask & ~ghost_mask[..., 1], color='gray')
