In [None]:
import dataclasses
from mpi4py import MPI
import numpy as np
import jax.numpy as jnp
import jax

# Diffskyopt Tutorial: Fitting to COSMOS + Extending to New Datasets

This tutorial demonstrates how to use  `diffopt` (kdescent + multigrad) to fit `diffsky` to COSMOS and showing how to extend to new datasets like BGS. Code blocks are not executable to reduce clutter and prioritize learning the important components. We can manually add the missing pieces as an exercise and verify everything executes as expected.

---

## 1. Generating and Distributing Diffsky Halos with MPI


In [None]:
from diffskyopt.diffsky_model import generate_weighted_grid_lc_data, lc_data_slice

SIZE, RANK = MPI.COMM_WORLD.size, MPI.COMM_WORLD.rank

if RANK == 0:
    # Generate the full lightcone only on rank 0
    # (TODO: generate dithered grids on each rank)
    full_lc_data, full_halo_upweights = generate_weighted_grid_lc_data(
        z_min=0.4, z_max=2.0, num_z_grid=100,
        lgmp_min=10.5, lgmp_max=15.0, num_m_grid=100,
        sky_area_degsq=1.21, ran_key=jax.random.key(0)
    )
    indices = np.array_split(np.arange(full_lc_data.z_obs.size), SIZE)
    lc_data_slices = [lc_data_slice(full_lc_data, idx) for idx in indices]
    halo_upweights_slices = [full_halo_upweights[idx] for idx in indices]
else:
    lc_data_slices = None
    halo_upweights_slices = None

# Each rank receives only its slice
lc_data = MPI.COMM_WORLD.scatter(lc_data_slices, root=0)
halo_upweights = MPI.COMM_WORLD.scatter(halo_upweights_slices, root=0)

---

## 2. Defining the COSMOS Loss Function

The core fitting logic is encapsulated in a class, e.g. [`CosmosFit`](../diffskyopt/lossfuncs/cosmos_fit.py). This class loads the data, prepares targets, and defines the loss function using kdescent and multigrad.

In [None]:
from diffopt import kdescent, multigrad
from cosmos20_colors import load_cosmos20
from diffsky.param_utils import diffsky_param_wrapper as dpw

class CosmosFit:
    def __init__(self, ...):
        # Load and mask COSMOS data
        cat = load_cosmos20()
        # Mask out NaNs and select redshift range
        # Prepare data_targets and data_weights
        self.data_targets, self.data_weights = self._prepare_data_targets_and_weights(cat)

        # Distribute model halos as above
        # self.lc_data, self.halo_upweights = ... (see previous section)

        # Pretrain kdescent kernels for KDE and Fourier loss terms
        ktrain = kdescent.KPretrainer.from_training_data(
            self.data_targets, self.data_weights,
            num_eval_kernels=40, num_eval_fourier_positions=20,
            comm=MPI.COMM_WORLD
        )
        self.kcalc = kdescent.KCalc(ktrain)

    def targets_and_weights_from_params(self, params, randkey):
        # Each rank computes photometry for its slice
        targets, weights = compute_targets_and_weights(
            params, self.lc_data,
            ran_key=jax.random.split(randkey, SIZE)[RANK],
            weights=self.halo_upweights
        )
        return targets, weights

    def sumstats_from_params(self, params, randkey):
        keys = jax.random.split(randkey, 3)
        model_targets, model_weights = self.targets_and_weights_from_params(
            params, keys[0])
        model_k, data_k, err_k = self.kcalc.compare_kde_counts(
            keys[1], model_targets, model_weights, return_err=True)
        model_f, data_f, err_f = self.kcalc.compare_fourier_counts(
            keys[2], model_targets, model_weights, return_err=True)
        sumstats = jnp.concatenate([model_k, model_f])
        sumstats_aux = jnp.concatenate([data_k, err_k, data_f, err_f])
        return sumstats, sumstats_aux

    def loss_from_sumstats(self, sumstats, sumstats_aux):
        # Compute reduced chi^2 or similar metric
        normalized_residuals = (sumstats - sumstats_aux[:len(sumstats)]) / sumstats_aux[len(sumstats):]
        return jnp.mean(normalized_residuals**2)

    def get_multi_grad_calc(self):
        return self.MultiGradModel(aux_data=dict(fit_instance=self))

    @dataclasses.dataclass
    class MultiGradModel(multigrad.OnePointModel):
        aux_data: dict
        sumstats_func_has_aux: bool = True

        def calc_partial_sumstats_from_params(self, params, randkey):
            fit_instance = self.aux_data["fit_instance"]
            return fit_instance.sumstats_from_params(params, randkey)

        def calc_loss_from_sumstats(self, sumstats, sumstats_aux, randkey=None):
            fit_instance = self.aux_data["fit_instance"]
            return fit_instance.loss_from_sumstats(sumstats, sumstats_aux)

---

## 3. Fitting the Model with Adam Optimization

To support parallelization, use the multigrad optimizer to minimize the loss function. This is typically run in an MPI job (see [`fit.pbs`](/home/apearl/jobs/diffskyopt/cosmos_fit/fit.pbs)) but in principle you can run it on a single rank in this notebook.

In [None]:
cosmos_fit = CosmosFit(...)
calc = cosmos_fit.get_multi_grad_calc()
key = jax.random.key(1)
params, losses = calc.run_adam(
    guess=cosmos_fit.default_u_param_arr,
    nsteps=1000, learning_rate=0.1,
    randkey=key, progress=True
)
# Save results for validation
np.savez("cosmos_fit_results.npz", params=np.asarray(params), losses=np.asarray(losses))

---

## 4. Validating the Fit

After fitting, generate diagnostic plots using the validation scripts.

In [None]:
from diffskyopt.scripts import cosmos_validation

cosmos_validation.n_of_z_plot(cosmos_fit, model_params=params[-1])
cosmos_validation.n_of_ithresh_plot(cosmos_fit, model_params=params[-1])
cosmos_validation.smhm_drift_plot(model_params=params[-1])

---

## 5. Extending to More Datasets: BGS Example

Let's try simultaneously fitting to COSMOS and BGS data. You can create a new class `BGSFit` with the same structure as `CosmosFit`, but loading BGS data from a dummy `load_bgs` function.

In [None]:
class BGSFit:
    def __init__(self, ...):
        # Load and mask BGS data
        cat = load_bgs()
        self.data_targets, self.data_weights = self._prepare_data_targets_and_weights(cat)
        # Distribute model halos as above
        # Pretrain kdescent kernels
        ktrain = kdescent.KPretrainer.from_training_data(
            self.data_targets, self.data_weights,
            num_eval_kernels=40, num_eval_fourier_positions=20,
            comm=MPI.COMM_WORLD
        )
        self.kcalc = kdescent.KCalc(ktrain)
        # ...

    def _prepare_data_targets_and_weights(self, cat):
        # Return data_targets and data_weights in the same format as CosmosFit
        ...

    # All other methods can be inherited from CosmosFit

---

## 6. Joint Fitting: Combining Losses

To fit both datasets simultaneously, simply sum their loss terms.

In [None]:
cosmos_fit = CosmosFit(...)
bgs_fit = BGSFit(...)


def combined_loss(params, randkey):
    cosmos_sumstats, cosmos_aux = cosmos_fit.sumstats_from_params(
        params, randkey)
    bgs_sumstats, bgs_aux = bgs_fit.sumstats_from_params(params, randkey)
    loss_cosmos = cosmos_fit.loss_from_sumstats(cosmos_sumstats, cosmos_aux)
    loss_bgs = bgs_fit.loss_from_sumstats(bgs_sumstats, bgs_aux)
    return loss_cosmos + loss_bgs


In [None]:
# In parallel, we will need to make use of MultiGrad
cosmos_multigrad = cosmos_fit.get_multi_grad_calc()
bgs_multigrad = bgs_fit.get_multi_grad_calc()

# experimental
multigroup = multigrad.OnePointGroup(
    (cosmos_multigrad, bgs_multigrad))

# computes the sum of the loss of each multigrad model in the group
multigroup.calc_loss_and_grad_from_params(cosmos_fit.default_u_param_arr)

# run gradient descent
multigroup.run_adam(
    guess=cosmos_fit.default_u_param_arr,
    nsteps=1000, learning_rate=0.1,
    randkey=key, progress=True
)