In [13]:
import msprime
import numpy as np
import matplotlib.pyplot as plt
import demesdraw
import moments
import dadi
import time 
import nlopt
from tqdm import tqdm
import json


In [14]:
upper_bound_params = {
    "t_split": 5000, 
    "m": 1e-4,
    "N1": 10000,
    "N2": 10000,
    "Na": 20000
}

lower_bound_params =  {
    "t_split": 100, 
    "m": 1e-8,
    "N1": 100,
    "N2": 100,
    "Na": 10000

}

In [15]:
def sample_params():
    sampled_params = {}
    for key in lower_bound_params:
        lower_bound = lower_bound_params[key]
        upper_bound = upper_bound_params[key]
        sampled_value = np.random.uniform(lower_bound, upper_bound)

        # Initialize adjusted_value with sampled_value by default
        adjusted_value = sampled_value

        # Check if the sampled parameter is equal to the mean of the uniform distribution
        mean_value = (lower_bound + upper_bound) / 2
        if sampled_value == mean_value:
            # Add a small random value to avoid exact mean, while keeping within bounds
            adjustment = np.random.uniform(-0.1 * (upper_bound - lower_bound), 0.1 * (upper_bound - lower_bound))
            adjusted_value = sampled_value + adjustment
            # Ensure the adjusted value is still within the bounds
            adjusted_value = max(min(adjusted_value, upper_bound), lower_bound)

        # Assign adjusted_value to sampled_params
        if key == "m":
            sampled_params[key] = adjusted_value
        else:
            sampled_params[key] = int(adjusted_value)

    return sampled_params

In [16]:
sampled_params = sample_params()
print(sampled_params)

{'t_split': 2858, 'm': 6.911408461632738e-05, 'N1': 4108, 'N2': 7247, 'Na': 17237}


In [17]:
with open("/sietch_colab/akapoor/Demographic_Inference/experiment_config.json") as f:
    experiment_config = json.load(f)

mutation_rate = experiment_config['mutation_rate']
recombination_rate = experiment_config['recombination_rate']
length = experiment_config['genome_length']

In [18]:
print(mutation_rate)
print(recombination_rate)
print(length)

5.7e-09
3.386e-09
1000000.0


In [19]:
import os
os.chdir('/sietch_colab/akapoor/Demographic_Inference/')

In [20]:
# import src.demographic_models as demographic_models

# demographic_model = demographic_models.split_isolation_model_simulation

In [21]:
if not os.path.isdir("./data/"):
    os.makedirs("./data/")
os.system("rm ./data/*.vcf.gz")
os.system("rm ./data/*.h5")

rm: cannot remove './data/*.vcf.gz': No such file or directory
rm: cannot remove './data/*.h5': No such file or directory


256

In [22]:
def run_msprime_replicates(g,experiment_config, num_reps=100):
    # Set up the demography from demes
    demog = msprime.Demography.from_demes(g)

    # Dynamically define the samples using msprime.SampleSet, based on the sample_sizes dictionary
    samples = [
        msprime.SampleSet(sample_size, population=pop_name, ploidy=1)
        for pop_name, sample_size in experiment_config['num_samples'].items()
    ]

    tree_sequences = msprime.sim_ancestry(
        samples,
        demography=demog,
        sequence_length=experiment_config['genome_length'],
        recombination_rate=experiment_config['recombination_rate'],
        num_replicates=num_reps,
        random_seed=42,
    )
    for ii, ts in enumerate(tree_sequences):
        ts = msprime.sim_mutations(ts, rate=experiment_config['mutation_rate'], random_seed=ii + 1)
        vcf_name = "./data/split_mig.{0}.vcf".format(ii)
        with open(vcf_name, "w+") as fout:
            ts.write_vcf(fout, allow_position_zero=True)
        os.system(f"gzip {vcf_name}")


def write_samples_and_rec_map(experiment_config):

    # Define the file paths
    samples_file = "./data/samples.txt"
    flat_map_file ="./data/flat_map.txt"

    # Open and write the sample file
    with open(samples_file, "w+") as fout:
        fout.write("sample\tpop\n")

        # Dynamically define samples based on the num_samples dictionary
        sample_idx = 0  # Initialize sample index
        for pop_name, sample_size in experiment_config['num_samples'].items():
            for _ in range(sample_size):
                fout.write(f"tsk_{sample_idx}\t{pop_name}\n")
                sample_idx += 1

    # Write the recombination map file
    with open(flat_map_file, "w+") as fout:
        fout.write("pos\tMap(cM)\n")
        fout.write("0\t0\n")
        fout.write(f"{experiment_config['genome_length']}\t{experiment_config['recombination_rate'] * experiment_config['genome_length'] * 100}\n")




In [23]:
import src.demographic_models as demographic_models
demographic_model = demographic_models.split_isolation_model_simulation

In [24]:
print("running msprime and writing vcfs")
g = demographic_model(sampled_params)

running msprime and writing vcfs


In [25]:
g

Graph(description='', time_units='generations', generation_time=1, doi=[], metadata={}, demes=[Deme(name='ancestral', description='', start_time=inf, ancestors=[], proportions=[], epochs=[Epoch(start_time=inf, end_time=2858, start_size=17237, end_size=17237, size_function='constant', selfing_rate=0, cloning_rate=0)]), Deme(name='N1', description='', start_time=2858, ancestors=['ancestral'], proportions=[1.0], epochs=[Epoch(start_time=2858, end_time=0, start_size=4108, end_size=4108, size_function='constant', selfing_rate=0, cloning_rate=0)]), Deme(name='N2', description='', start_time=2858, ancestors=['ancestral'], proportions=[1.0], epochs=[Epoch(start_time=2858, end_time=0, start_size=7247, end_size=7247, size_function='constant', selfing_rate=0, cloning_rate=0)])], migrations=[AsymmetricMigration(source='N1', dest='N2', start_time=2858, end_time=0, rate=6.911408461632738e-05), AsymmetricMigration(source='N2', dest='N1', start_time=2858, end_time=0, rate=6.911408461632738e-05)], puls

In [26]:
run_msprime_replicates(g,experiment_config)

print("writing samples and recombination map")
write_samples_and_rec_map(experiment_config)


writing samples and recombination map


In [27]:
import ray 

In [28]:
# Initialize Ray
ray.init(ignore_reinit_error=True)

# Define your function with Ray's remote decorator
@ray.remote
def get_LD_stats(rep_ii, r_bins):
    vcf_file = f"./data/split_mig.{rep_ii}.vcf.gz"
    time1 = time.time()
    ld_stats = moments.LD.Parsing.compute_ld_statistics(
        vcf_file,
        rec_map_file="./data/flat_map.txt",
        pop_file="./data/samples.txt",
        pops=["N1", "N2"],
        r_bins=r_bins,
        report=False,
    )
    time2 = time.time()
    print("  finished rep", rep_ii, "in", int(time2 - time1), "seconds")
    return ld_stats


2024-10-14 14:03:07,765	INFO worker.py:1781 -- Started a local Ray instance.


In [29]:
num_reps = 100
# define the bin edges
r_bins = np.array([0, 1e-6, 2e-6, 5e-6, 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3])
# Parallel execution with Ray
print("parsing LD statistics in parallel")
# Submit tasks to Ray in parallel using .remote()
futures = [get_LD_stats.remote(ii, r_bins) for ii in range(num_reps)]
# Gather results with ray.get() to collect them once the tasks are finished
ld_stats = ray.get(futures)
# Optionally, you can convert the list of results into a dictionary with indices
ld_stats_dict = {ii: result for ii, result in enumerate(ld_stats)}


parsing LD statistics in parallel


[36m(get_LD_stats pid=3403749)[0m   finished rep 73 in 18 seconds
[36m(get_LD_stats pid=3400795)[0m   finished rep 90 in 23 seconds[32m [repeated 32x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(get_LD_stats pid=3405658)[0m   finished rep 1 in 29 seconds[32m [repeated 54x across cluster][0m
[36m(get_LD_stats pid=3405252)[0m   finished rep 39 in 33 seconds


In [30]:
print("computing mean and varcov matrix from LD statistics sums")
mv = moments.LD.Parsing.bootstrap_data(ld_stats_dict)

print("running inference")
# Run inference using the parsed data
demo_func = moments.LD.Demographics2D.split_mig
# Set up the initial guess
# The split_mig function takes four parameters (nu0, nu1, T, m), and we append
# the last parameter to fit Ne, which doesn't get passed to the function but
# scales recombination rates so can be simultaneously fit
p_guess = [0.1, 2, 0.075, 0.8, 10000]
p_guess = moments.LD.Util.perturb_params(p_guess, fold=0.1)
opt_params, LL = moments.LD.Inference.optimize_log_lbfgsb(
    p_guess, [mv["means"], mv["varcovs"]], [demo_func], rs=r_bins, verbose = 3
)


computing mean and varcov matrix from LD statistics sums
running inference
3       , -2.17333e+07, array([ 0.100294   ,  2.13672    ,  0.07435    ,  0.778923   ,  10147.8    ])
6       , -2.17341e+07, array([ 0.100294   ,  2.13459    ,  0.07435    ,  0.778923   ,  10158      ])
9       , -192018     , array([ 0.212762   ,  2.18508    ,  0.0386801  ,  0.844859   ,  10319.7    ])
12      , -191876     , array([ 0.212762   ,  2.1829     ,  0.0386801  ,  0.844859   ,  10330      ])
15      , -191555     , array([ 0.213034   ,  2.18513    ,  0.0386502  ,  0.844889   ,  10330.2    ])
18      , -191414     , array([ 0.213034   ,  2.18294    ,  0.0386502  ,  0.844889   ,  10340.6    ])
21      , -189731     , array([ 0.214124   ,  2.1853     ,  0.0385307  ,  0.845007   ,  10372.5    ])
24      , -189590     , array([ 0.214124   ,  2.18312    ,  0.0385307  ,  0.845007   ,  10382.9    ])
27      , -182838     , array([ 0.218542   ,  2.186      ,  0.0380566  ,  0.845477   ,  10543.2    ])
30     

In [31]:
physical_units = moments.LD.Util.rescale_params(
    opt_params, ["nu", "nu", "T", "m", "Ne"]
)

print("Simulated parameters:")
print(f"  N(deme0)         :  {g.demes[1].epochs[0].start_size:.1f}")
print(f"  N(deme1)         :  {g.demes[2].epochs[0].start_size:.1f}")
print(f"  Div. time (gen)  :  {g.demes[1].epochs[0].start_time:.1f}")
print(f"  Migration rate   :  {g.migrations[0].rate:.6f}")
print(f"  N(ancestral)     :  {g.demes[0].epochs[0].start_size:.1f}")

print("best fit parameters:")
print(f"  N(deme0)         :  {physical_units[0]:.1f}")
print(f"  N(deme1)         :  {physical_units[1]:.1f}")
print(f"  Div. time (gen)  :  {physical_units[2]:.1f}")
print(f"  Migration rate   :  {physical_units[3]:.6f}")
print(f"  N(ancestral)     :  {physical_units[4]:.1f}")

Simulated parameters:
  N(deme0)         :  4108.0
  N(deme1)         :  7247.0
  Div. time (gen)  :  2858.0
  Migration rate   :  0.000069
  N(ancestral)     :  17237.0
best fit parameters:
  N(deme0)         :  6108.1
  N(deme1)         :  6715.7
  Div. time (gen)  :  606.7
  Migration rate   :  0.000002
  N(ancestral)     :  356672.2
