In [None]:
import msprime as msp
import demes
import demesdraw

Ne = 1e4
q_anc = Ne * 2.0
q0 = Ne * 2.0
q1 = Ne * 1.0
m = 0.0001
tau = 1000
demo = msp.Demography()
demo.add_population(initial_size= q_anc, name = "anc")
demo.add_population(initial_size = q0, name = "P0")
demo.add_population(initial_size = q1, name = "P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate= m)
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time = tau, derived=tmp, ancestral="anc")
g = demo.to_demes()
demesdraw.tubes(g)
# print(g)
sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e8, random_seed=42)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed=42)

import jax.random as jr
import numpy as np
import random
key = jr.PRNGKey(0)

In [None]:
from momi3.momi import Momi3
import numpy as np
momi_object = Momi3(g).iicr(2)
# params = [("demes", 0, "epochs", 0, "start_size"), ("demes", 1, "epochs", 0, "start_size")]
params = [("demes", 0, "epochs", 0, "start_size")]
f, x = momi_object.reparameterize(list(params))
parameters = list(x.keys())
# x[parameters[0]] = np.array(8000.0)
print(x)
from momi3.jsfs import JSFS
from momi3.momi import Momi3
import jax
# g is just a demes formatted demographic model 
momi_sfs_object = Momi3(g).sfs({'P0':20, 'P1':20})
afs = ts.allele_frequency_spectrum(sample_sets=[ts.samples([1]), ts.samples([2])], span_normalise=False)
jsfs = JSFS.from_dense(afs, ["P0", "P1"])
# print(momi_sfs_object.loglik(x, jsfs))
# this next line gives me errors, I *think* to obtain the likelihood you feed in # a joint SFS object
# jax.grad(momi_sfs_object.loglik)(x, jsfs)

In [None]:
from jax import vmap
import jax.numpy as jnp
x_values = jnp.linspace(7000, 13000, 20)  # adjust these steps as needed

# Vectorize the likelihood computation over x_values
def compute_likelihood(val):
    updated_x = x.copy()
    updated_x[parameters[0]] = val
    params = updated_x
    return -momi_sfs_object.loglik(params, jsfs)

# Use vmap to compute likelihoods for all x_values
likelihoods = vmap(compute_likelihood)(x_values)

In [None]:
import matplotlib.pyplot as plt
# Plot
plt.figure(figsize=(10, 6))
plt.plot(x_values, likelihoods, label='Likelihood')
plt.xlabel('x (parameter values)')
plt.ylabel('Debugger Likelihood')
plt.title('Debugger likelihood over parameters')
plt.legend()
plt.grid(True)
plt.show()