In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import multiprocessing as mp

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={mp.cpu_count()}"

In [3]:
import jax
import jax.numpy as jnp
import jax.random as jr

from cohlib.jax.dists import sample_from_gamma, sample_obs, sample_ccn_rank1
from cohlib.jax.observations import add0
from cohlib.jax.simtools import load_gamma, construct_gamma_init

from cohlib.jax.models import ToyModel
from cohlib.utils import pickle_open

In [4]:
jax.config.update('jax_platform_name', 'cpu')
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
print("Platform: ", platform)
print(len(jax.devices()))

Platform:  cpu
28


In [5]:
res = pickle_open('/projectnb/stephenlab/jtauber/cohlib/hydra/gaussian_obs_postfix/batch/outputs/2024-11-25-ch4prelim/12-29-46/res.pickle')
cfg = res['cfg']

In [12]:
lcfg = cfg.latent
ocfg = cfg.obs
mcfg = cfg.model

# NOTE
ocfg.ov2 = -1
lcfg.L = 7

num_devices = len(jax.devices())
print(f"NUM_DEVICES={num_devices}")

gamma_load = load_gamma(cfg)

gamma_full = gamma_load['gamma']
K = gamma_full.shape[-1]
freqs = gamma_load['freqs']
nz_true = gamma_load['nonzero_inds']
nz_target = gamma_load['target_inds']
eigvec = gamma_load['eigvec']
eigval = gamma_load['eigval']


# sample latent and observations according to gamma and observation distribution
print(f"Sampling {lcfg.L} samples from gamma {lcfg.gamma}; seed = {lcfg.seed}; scale = {lcfg.scale}")
if ocfg.obs_type == 'pp_relu' or ocfg.obs_type == 'pp_log':
    print(f'alpha = {ocfg.alpha}')
if ocfg.obs_type == 'gaussian':
    print(f'obs var = {ocfg.ov1}e{ocfg.ov2}')
lrk = jr.key(lcfg.seed)

zs_target = sample_ccn_rank1(lrk, eigvec, eigval, K, lcfg.L)
gamma_full_dummytarget = gamma_full.copy()
gamma_full_dummytarget = gamma_full_dummytarget.at[nz_target,:,:].set(jnp.eye(K, dtype=complex))

zs = sample_from_gamma(lrk, gamma_full_dummytarget, lcfg.L)
zs = zs.at[nz_target,:,:].set(zs_target)

zs_0dc = jnp.apply_along_axis(add0, 0, zs)
xs = jnp.fft.irfft(zs_0dc, axis=0)

obs, obs_params = sample_obs(xs, params)
obs_type = ocfg.obs_type

# initialize gamma
gamma_init, nz_model = construct_gamma_init(cfg, obs, gamma_load)

# instantiate model and run em
print(f"Running EM for {mcfg.emiters} iters. Newton iters = {mcfg.maxiter}")

# TODO: handle m_step options in better way
# if mcfg.m_step_option == 'standard':
#     m_step_params = None
# elif mcfg.m_step_option == 'low-rank':
#     m_step_params = {'rank': mcfg.m_step_rank}

# NOTE
m_step_option = 'standard'
m_step_params = None

model = ToyModel()
model.initialize_latent(gamma_init, freqs, nz_model)
model.initialize_observations(obs_params, obs_type)
# model.fit_em(obs, mcfg.emiters, mcfg.maxiter, m_step_option=mcfg.m_step_option, m_step_params=m_step_params)

NUM_DEVICES=28
Sampling 7 samples from gamma k3-chlg4-gaussian-rank1-nz9; seed = 7; scale = 1
obs var = 1e-1
Sampling Gaussian observations with variance 1e^-1
EM initialization: 'flat-init'
Setting model support to 9 Hz - 9 Hz
Running EM for 50 iters. Newton iters = 10


In [13]:
from cohlib.jax.models import JaxOptim
gamma_inv = jnp.zeros_like(gamma_init)
gamma_inv_nz = jnp.linalg.inv(gamma_init[nz_model,:,:])
data = obs
obs_type = ocfg.obs_type
num_newton_iters = mcfg.maxiter
params = {'obs': obs_params,
          'freqs': freqs,
          'nonzero_inds': nz_model,
          'K': K}

optimizer = JaxOptim(data, gamma_inv, params, obs_type, num_iters=num_newton_iters)

In [14]:
mus_output, upss_output = optimizer.run_e_step_par_ts(obs, num_devices=3)

num_batches = 3
batch shape: (1000, 3, 3)
batch shape: (1000, 3, 3)
batch shape: (1000, 3, 1)


In [22]:
for b in range(len(mus_output)):
    print(mus_output[b].shape)

(3, 1, 3)
(3, 1, 3)
(1, 1, 3)


In [26]:
mus_output[1][0,0,:]

Array([25.432165 +9.978339j,  6.633691+21.721048j, 10.168804+24.531487j],      dtype=complex64)