In [1]:
import sys
import momi
import dill

In [3]:
import logging

# Be sure to log output
logging.basicConfig(level=logging.INFO, filename="log-AllAdmix-RandStarts-InitFit_Continued.log")


In [None]:
# THIS IS THE SAME AS THE OTHER MODEL 1, BUT WITH RANDOM STARTING VALUES TO 
# SEE IF THERE ARE LOCAL OPTIMA THAT WE'RE GETTING TRAPPED 
full_model = momi.DemographicModel(
    N_e=3500000, gen_time=1.5, muts_per_gen=3.5e-9)


# First specify the parameters we will be using when adding to the model
# Since we have no idea about divergence times yet, we're going to begin by 
# assuming constant population sizes because we want to avoid specification 
# of times for size changes, as these could be incompatible with divergence times
# and will make the number of free parameters in the model explode

# Pop sizes - ordered from past to present - will have plausible ranges specified from MSMC Ne estimates
# Each species will have one population size during the time prior to the next split/species divergence
# This is during the "waiting time" til specieation. Then, within each will have three unique population 
# sizes estimated. This is to limit computational complexity. Still a lot of free parameters!

# random initial (origination) value for pop sizes, with a specified plausible range
# These can be considered pop sizes immediately post-divergence
full_model.add_size_param("n_mm", lower = 50000, upper = 50000000)
full_model.add_size_param("n_crm", lower = 50000, upper = 50000000)
full_model.add_size_param("n_mongo", lower = 50, upper = 5000000)
full_model.add_size_param("n_pungu", lower = 50, upper = 5000000)
full_model.add_size_param("n_dikume", lower = 50, upper = 5000000)
full_model.add_size_param("n_caroli", lower = 50, upper = 5000000)

# One pop size change that occurs following some time after divergence for all species
full_model.add_size_param("n_mm_1", lower = 100, upper = 50000000)
full_model.add_size_param("n_crm_1", lower = 100, upper = 50000000)
full_model.add_size_param("n_mongo_1", lower = 100, upper = 5000000)
full_model.add_size_param("n_pungu_1", lower = 100, upper = 5000000)
full_model.add_size_param("n_dikume_1", lower = 100, upper = 5000000)
full_model.add_size_param("n_caroli_1", lower = 100, upper = 5000000)

# Another one for the 'longer' persisting branches
# Again, caroli gets one here because of the 'waiting time' till speciation in Barombi mbo post colonization
full_model.add_size_param("n_mm_2", lower = 100, upper = 50000000)
full_model.add_size_param("n_crm_2", lower = 100, upper = 50000000)
full_model.add_size_param("n_caroli_2", lower = 100, upper = 5000000)
full_model.add_size_param("n_mongo_2", lower = 100, upper = 5000000)
full_model.add_size_param("n_pungu_2", lower = 100, upper = 5000000)
full_model.add_size_param("n_dikume_2", lower = 100, upper = 5000000)

# And then one last one for the riverine pops, since we want to estimate a more recent pop size for them
# in addition to the more ancient pop sizes. 
full_model.add_size_param("n_mm_3", lower = 100, upper = 50000000)
full_model.add_size_param("n_crm_3", lower = 100, upper = 50000000)
full_model.add_size_param("n_caroli_3", lower = 100, upper = 5000000)
full_model.add_size_param("n_mongo_3", lower = 100, upper = 5000000)
full_model.add_size_param("n_pungu_3", lower = 100, upper = 5000000)

# random initial value for divergence times, with plausible ranges
# The lake formed 2 million years ago, so that will serve as upper bound 
# for the initial divergence of the BMbo clade, as well as for everything 
# within the lake. We will conservatively set 1000 years ago as the lower bound
full_model.add_time_param("t_mm_crm", lower = 100000, upper = 2000000)
# Bmbo_anc is time at which the lake was colonized/the group diverged from riverine, 
# whereas mongo_myaka is the time at which mongo/myaka split from this ancestral pop
#full_model.add_time_param("t_bmbo_anc", lower = 10000, upper = 2000000, t0 = 1500000, upper_constraints = ["t_mm_crm"])
full_model.add_time_param("t_crm_caroli", lower = 1000, upper = 1000000, upper_constraints = ["t_mm_crm"])
full_model.add_time_param("t_mongo_caroli", lower = 100, upper = 500000, upper_constraints = ["t_crm_caroli"])
full_model.add_time_param("t_mongo_pungu", lower = 100, upper = 500000, upper_constraints = ["t_mongo_caroli"])
full_model.add_time_param("t_pungu_dikume", lower = 100, upper = 500000, upper_constraints = ["t_mongo_pungu"])

# Start adding leaves to the model
# And set pop sizes, for a time-range constrained by divergence
full_model.add_leaf("SgalMM", N="n_mm_3", t=0)
full_model.set_size("SgalMM", N="n_mm", t=2000000)
full_model.set_size("SgalMM", N="n_mm", t="t_mm_crm")
full_model.set_size("SgalMM", N="n_mm_1", t=lambda params: params.t_mm_crm * 0.5)
full_model.set_size("SgalMM", N="n_mm_2", t=lambda params: params.t_mm_crm * 0.1)
full_model.set_size("SgalMM", N="n_mm_3", t=lambda params: params.t_mm_crm * 0.05)

full_model.add_leaf("SgalCRM", N="n_crm_3", t=0)
full_model.set_size("SgalCRM", N="n_crm", t="t_mm_crm")
full_model.set_size("SgalCRM", N="n_crm_1", t="t_crm_caroli")
full_model.set_size("SgalCRM", N="n_crm_1", t=lambda params: params.t_crm_caroli * 0.5)
full_model.set_size("SgalCRM", N="n_crm_2", t=lambda params: params.t_crm_caroli * 0.1)
full_model.set_size("SgalCRM", N="n_crm_3", t=lambda params: params.t_crm_caroli * 0.05)

full_model.add_leaf("caroli", N="n_caroli_3", t=0)
full_model.set_size("caroli", N="n_caroli", t="t_crm_caroli")
full_model.set_size("caroli", N="n_caroli", t="t_mongo_caroli")
full_model.set_size("caroli", N="n_caroli_1", t=lambda params: params.t_crm_caroli * 0.5)
full_model.set_size("caroli", N="n_caroli_2", t=lambda params: params.t_crm_caroli * 0.1)
full_model.set_size("caroli", N="n_caroli_3", t=lambda params: params.t_crm_caroli * 0.05)

full_model.add_leaf("mongo", N="n_mongo_3", t=0)
full_model.set_size("mongo", N="n_mongo", t="t_mongo_caroli")
full_model.set_size("mongo", N="n_mongo", t="t_mongo_pungu")
full_model.set_size("mongo", N="n_mongo_1", t=lambda params: params.t_mongo_pungu * 0.5)
full_model.set_size("mongo", N="n_mongo_2", t=lambda params: params.t_mongo_pungu * 0.1)
full_model.set_size("mongo", N="n_mongo_3", t=lambda params: params.t_mongo_pungu * 0.05)

full_model.add_leaf("pungu", N="n_pungu_3", t=0)
full_model.set_size("pungu", N="n_pungu", t="t_mongo_pungu")
full_model.set_size("pungu", N="n_pungu", t="t_pungu_dikume")
full_model.set_size("pungu", N="n_pungu_1", t=lambda params: params.t_pungu_dikume * 0.5)
full_model.set_size("pungu", N="n_pungu_2", t=lambda params: params.t_pungu_dikume * 0.1)
full_model.set_size("pungu", N="n_pungu_3", t=lambda params: params.t_pungu_dikume * 0.05)

full_model.add_leaf("dikume", N="n_dikume_2", t=0)
full_model.set_size("dikume", N="n_dikume", t="t_pungu_dikume")
full_model.set_size("dikume", N="n_dikume", t=lambda params: params.t_pungu_dikume * 0.5)
full_model.set_size("dikume", N="n_dikume_1", t=lambda params: params.t_pungu_dikume * 0.1)
full_model.set_size("dikume", N="n_dikume_2", t=lambda params: params.t_pungu_dikume * 0.05)


# And now divergences, with free time and assuming pop size is inherited 
full_model.move_lineages("SgalCRM", "SgalMM", t="t_mm_crm", N="n_crm")
full_model.move_lineages("caroli", "SgalCRM", t="t_crm_caroli", N="n_caroli")
full_model.move_lineages("mongo", "caroli", t="t_mongo_caroli", N="n_mongo")
full_model.move_lineages("pungu", "mongo", t="t_mongo_pungu", N="n_pungu")
full_model.move_lineages("dikume", "pungu", t="t_pungu_dikume", N="n_dikume")

# F-statistics indicate some evidentce of hybridization. 
# So, let's include these in the model.
# Specifically, let's include:
# 1) a 5% pulse of migration from mongo to dikume
# 2) a 4% pulse of migration from mongo to pungu
# 3) a 7% pulse of migration from kidume to Sgal CRM
# 4) a 12% pulse of migration from mongo to Sgal CRM
# But let's set as free parameters, with magnitude bounded by reasonable priors. 
# first the pulse (between which Spp and what magnitude)
# For now, let's just model the instances of hybridization from riverine into bmbo
full_model.add_pulse_param("p_dikume_mongo", p0 = 0.05, upper = 0.5)
full_model.add_pulse_param("p_pungu_mongo", p0 = 0.04, upper = 0.5)
full_model.add_pulse_param("p_caroli_mongo", p0 = 0.03, upper = 0.5)
full_model.add_pulse_param("p_SgalCRM_dikume", p0 = 0.07, upper = 0.5)
full_model.add_pulse_param("p_SgalCRM_mongo", p0 = 0.12, upper = 0.5)
full_model.add_pulse_param("p_SgalCRM_caroli", p0 = 0.12, upper = 0.5)

# and then the time
full_model.add_time_param("t_p_dikume_mongo", upper_constraints=["t_pungu_dikume"])
full_model.add_time_param("t_p_pungu_mongo", upper_constraints=["t_mongo_pungu"])
full_model.add_time_param("t_p_caroli_mongo", upper_constraints=["t_mongo_caroli"])
full_model.add_time_param("t_p_SgalCRM_dikume", upper_constraints=["t_pungu_dikume"])
full_model.add_time_param("t_p_SgalCRM_mongo", upper_constraints=["t_mongo_caroli"])
full_model.add_time_param("t_p_SgalCRM_caroli", upper_constraints=["t_mongo_caroli"])


full_model.move_lineages("mongo", "dikume", t="t_p_dikume_mongo", p="p_dikume_mongo")
full_model.move_lineages("pungu", "mongo", t="t_p_pungu_mongo", p="p_pungu_mongo")
full_model.move_lineages("caroli", "mongo", t="t_p_caroli_mongo", p="p_caroli_mongo")
full_model.move_lineages("dikume", "SgalCRM", t="t_p_SgalCRM_dikume", p="p_SgalCRM_dikume")
full_model.move_lineages("mongo", "SgalCRM", t="t_p_SgalCRM_mongo", p="p_SgalCRM_mongo")
full_model.move_lineages("caroli", "SgalCRM", t="t_p_SgalCRM_caroli", p="p_SgalCRM_caroli")


# Now plot to see if this is coherent

yticks = [500,10000,25000,50000,100000,250000,500000,1000000]
# linthreshy will set to the divergence time between caroli and mongo - we pull out this value below:
thresh = list(dict(full_model.get_params()).values())[24]

fig = momi.DemographyPlot(
    full_model, ["dikume", "pungu", "mongo", "caroli", "SgalCRM", "SgalMM"],
    figsize=(10,12), linthreshy=thresh, 
    major_yticks=yticks)

from matplotlib import pyplot as plt
plt.savefig("./AllAdmix-PreFit.png")
plt.savefig("./AllAdmix-PreFit.pdf")

In [None]:
# Now, it seems like including all samples in this admixture model is just too much for momi to handle
# - it runs out fo memory. 
# As a trial, see if we can reduce the problem by downsampling allele counts to 5 inds per pop

# The next few cells only need to be run once.

# First read in the allele count data
ac = momi.SnpAlleleCounts.load("/global/scratch/users/austinhpatton/cichlids/cameroon/Onil_UMD/momi/data/BM-momi-FinalSamps.snpAlleleCounts.gz")

In [7]:
# Then downsample, and save
pop_dict = {
    "dikume" : 5,
    "pungu" : 5,
    "mongo" : 5,
    "caroli" : 5,
    "SgalCRM" : 5,
    "SgalMM": 5
}

downsamp_ac = ac.down_sample(pop_dict)
downsamp_ac.dump("/global/scratch/users/austinhpatton/cichlids/cameroon/Onil_UMD/momi/data/BM-momi-FinalSamps-Downsamp5.snpAlleleCounts.gz")

In [8]:
# Lastly, recalculate the site frequency spectrum from these downsampled allele counts. 
downsamp_sfs = downsamp_ac.extract_sfs(100)

# And save
downsamp_sfs.dump("/global/scratch/users/austinhpatton/cichlids/cameroon/Onil_UMD/momi/data/BM-momi-FinalSamps-Downsamp5.sfs.gz")

In [5]:
# Great, so that seems to be what we want.
# Now, let's read in the SFS obtained outside of this script. 
sfs = momi.Sfs.load("/global/scratch/users/austinhpatton/cichlids/cameroon/Onil_UMD/momi/data/BM-momi-FinalSamps-Downsamp5.sfs.gz")


In [6]:
# Set this as the data to be used by the model
full_model.set_data(sfs, mem_chunk_size=1000)



In [None]:
# And go ahead and infer!
# Because of the size of the model, we're going to do some exploration of parameter space
# We will loop through several times, doing a few rounds of stochastic optimization, 
# and then we will continue this from the best starting point
results = []
n_runs = 10
for i in range(n_runs):
    print(f"Starting run {i+1} out of {n_runs}...")
    
     #Copy the model, randomizing the parameter values
    full_rand_model = full_model.copy()
    full_rand_model.set_params(
        randomize=True)
    
    checks = "AllAdmix-RandStart-StochOptim-Checkpoints-InitFit.txt"
    results.append(full_rand_model.stochastic_optimize(snps_per_minibatch=1000, 
                                                       printfreq=100,
                                                       num_iters=500, 
                                                       svrg_epoch=100,
                                                       save_to_checkpoint = checks))
    dill.dump_session('AllAdmix-StochOptim-tmp_dill.pkl')
    


# sort results according to log likelihood, pick the best (largest) one
best_result = sorted(results, key=lambda r: r.log_likelihood)[-1]

full_model.set_params(best_result.parameters)

# Okay, that was a lot of work. Now, save the workspace to an intermediate. 
dill.dump_session('AllAdmix-RandStarts_dill_InitFit.pkl')


In [None]:
# Plot it!

# linthreshy will set to the divergence time between caroli and mongo - we pull out this value below:
thresh = list(dict(full_model.get_params()).values())[24]

fig = momi.DemographyPlot(
    full_model, ["dikume", "pungu", "mongo", "caroli", "SgalCRM", "SgalMM"],
    figsize=(10,12),
    major_yticks=yticks,
    linthreshy=thresh)

plt.savefig("AllAdmix-RandStarts-StochastOptim-Fitted.png")
plt.savefig("AllAdmix-RandStarts-StochastOptim-Fitted.pdf")

full_model.get_params()
    

In [None]:
# NOTE:
# This fitting procedure (including the earlier stochastic optimizization) takes which a long time. 
# Ended up hitting the 72 hour time limit on the bigmem savio nodes (also quite memory intensive),
# so we're going to update the parameters of the full model here to be the same as what was reached
# in the last iteration of the L-BFGS-B optimization procedure before the job was cancelled. 
# Last step: INFO:momi.demo_model:{it: 208, KLDivergence: Autograd ArrayBox with value 0.0908904874112342}

dill.load_session('AllAdmix-RandStarts_dill_InitFit.pkl')

newest_params = {'n_mm': 49999999.99999999, 'n_crm': 126906.62813382912, 
                 'n_mongo': 15760.384948624753, 'n_pungu': 90049.20151010553, 
                 'n_dikume': 26907.48641417023, 'n_caroli': 33281.10328474059, 
                 'n_mm_1': 517260.3009338876, 'n_crm_1': 5729470.402328302, 
                 'n_mongo_1': 134743.47637289506, 'n_pungu_1': 106290.26766388012, 
                 'n_dikume_1': 2606.092988805131, 'n_caroli_1': 273.57350258168077, 
                 'n_mm_2': 18753.540161922345, 'n_crm_2': 995800.1672198448, 
                 'n_caroli_2': 252561.5357660461, 'n_mongo_2': 83418.58307945068, 
                 'n_pungu_2': 1316.5796224710584, 'n_dikume_2': 14564.644991469302, 
                 'n_mm_3': 49999999.99999999, 'n_crm_3': 11359.467327588545, 
                 'n_caroli_3': 137722.9948566717, 'n_mongo_3': 568.9146821018952, 
                 'n_pungu_3': 27368.18687032608, 't_mm_crm': 280645.3195825911, 
                 't_crm_caroli': 87617.63823326302, 't_mongo_caroli': 44094.57796405377, 
                 't_mongo_pungu': 20621.129004557275, 't_pungu_dikume': 15790.53300240068, 
                 'p_dikume_mongo': 0.14150837101350486, 'p_pungu_mongo': 0.17991833540889018, 
                 'p_caroli_mongo': 0.03352510290085085, 'p_SgalCRM_dikume': 0.0040362924721478784, 
                 'p_SgalCRM_mongo': 0.05088657962989328, 'p_SgalCRM_caroli': 0.029539104631394753, 
                 't_p_dikume_mongo': 15622.390313997917, 't_p_pungu_mongo': 2415.8507430564878, 
                 't_p_caroli_mongo': 42482.39212759363, 't_p_SgalCRM_dikume': 15681.610358939119, 
                 't_p_SgalCRM_mongo': 23073.959129754574, 't_p_SgalCRM_caroli': 35866.74974733347}

full_model.set_params(newest_params)

In [None]:
# Now do a full optimization of the model from this starting point.
full_model.optimize(method="L-BFGS-B")

# And again, we really want to make sure we save this for later bootstrapping
# Importantly, we need to save these intermediate models - let's try doing so using the pickle module
dill.dump_session('AllAdmix-RandStarts_dill_InitFit.pkl')

In [None]:
# Load the session
dill.load_session('AllAdmix-RandStarts_dill_InitFit.pkl')

# Plot the fully optimized model!
# linthreshy will set to the divergence time between caroli and mongo - we pull out this value below:
thresh = list(dict(full_model.get_params()).values())[24]

fig = momi.DemographyPlot(
    full_model, ["dikume", "pungu", "mongo", "caroli", "SgalCRM", "SgalMM"],
    figsize=(10,12),
    major_yticks=yticks,
    linthreshy=thresh)

plt.savefig("AllAdmix-RandStarts_L-BFGS-B-Fitted.png")
plt.savefig("AllAdmix-RandStarts_L-BFGS-B-Fitted.pdf")

ests = dict(full_model.get_params())

with open('AllAdmix-RandStarts_L-BFGS-B_FittedParams.txt', "w") as f:
    grp = "Observed"
    print("Set", "Parameter", "Estimate", file=f)
    for parameter, estimate in ests.items():
        print(grp, '{} {}'.format(parameter, estimate), file=f)

full_model.log_likelihood()

In [None]:
# Lets assess the extent to which the fitted model adequately explains our observed data
full_model_stats = momi.SfsModelFitStats(full_model)

# Using the f4 statistics
print("Computing f4(dikume, pungu, mongo, SgalCRM)")
f4 = full_model_stats.f4("dikume", "pungu", "mongo", "SgalCRM")

print("Expected = {}".format(f4.expected))
print("Observed = {}".format(f4.observed))
print("SD = {}".format(f4.sd))
print("Z(Expected-Observed) = {}".format(f4.z_score))

# an alternative including all pops that CRM supposedly has hybridized with
print("Computing f4(dikume, mongo, caroli, SgalCRM)")
f4 = full_model_stats.f4("dikume", "mongo", "caroli", "SgalCRM")

print("Expected = {}".format(f4.expected))
print("Observed = {}".format(f4.observed))
print("SD = {}".format(f4.sd))
print("Z(Expected-Observed) = {}".format(f4.z_score))

# Only within lake
print("Computing f4(dikume, pungu, mongo, caroli)")
f4 = full_model_stats.f4("dikume", "pungu", "mongo", "caroli")

print("Expected = {}".format(f4.expected))
print("Observed = {}".format(f4.observed))
print("SD = {}".format(f4.sd))
print("Z(Expected-Observed) = {}".format(f4.z_score))

# Pairwise IBS
full_model_stats.all_pairs_ibs()

# And again, save the session, to a new pickle file in case anything gets messed up and overwritten
dill.dump_session('AllAdmix-RandStarts-Final-Fitted_dill.pkl')

In [None]:
# Out of curiosity, let's see what the per-generation mutation rate is estimated as per population?
full_model.fit_within_pop_diversity()