# Playing around with the split isolation model

In [1]:
import msprime
import numpy as np
import matplotlib.pyplot as plt
import demesdraw
import moments
import dadi


In [2]:
upper_bound_params = {
    "t_split": 5000, 
    "N1": 10000,
    "N2": 10000,
    "Na": 20000
}

lower_bound_params =  {
    "t_split": 100, 
    "N1": 100,
    "N2": 100,
    "Na": 100

}

mutation_rate = 5.7e-9
recombination_rate = 3.386e-9

In [3]:
def sample_params():
    sampled_params = {}
    for key in lower_bound_params:
        lower_bound = lower_bound_params[key]
        upper_bound = upper_bound_params[key]
        sampled_value = np.random.uniform(lower_bound, upper_bound)
        sampled_params[key] = int(sampled_value)


        # Check if the sampled parameter is equal to the mean of the uniform distribution
        mean_value = (lower_bound + upper_bound) / 2
        if sampled_value == mean_value:
            # Add a small random value to avoid exact mean, while keeping within bounds
            adjustment = np.random.uniform(-0.1 * (upper_bound - lower_bound), 0.1 * (upper_bound - lower_bound))
            adjusted_value = sampled_value + adjustment
            
            # Ensure the adjusted value is still within the bounds
            adjusted_value = max(min(adjusted_value, upper_bound), lower_bound)
            sampled_params[key] = int(adjusted_value)

    return sampled_params

In [4]:
sampled_params = sample_params()
print(sampled_params)

{'t_split': 1358, 'N1': 255, 'N2': 5722, 'Na': 9570}


In [5]:
import os
os.chdir('/sietch_colab/akapoor/Demographic_Inference/')

In [6]:
from src.demographic_models import split_isolation_model_simulation

In [7]:
def create_SFS(
    sampled_params, mode, num_samples, demographic_model, length=1e7, mutation_rate=5.7e-9, recombination_rate = 3.386e-9, **kwargs
):
    """
    If we are in pretraining mode we will use a simulated SFS. If we are in inference mode we will use a real SFS.

    """

    if mode == "pretrain":
        # Simulate the demographic model
        g = demographic_model(sampled_params)
        demog = msprime.Demography.from_demes(g)

        # Dynamically define the samples using msprime.SampleSet, based on the sample_sizes dictionary
        samples = [
            msprime.SampleSet(sample_size, population=pop_name, ploidy=1)
            for pop_name, sample_size in num_samples.items()
        ]

        # Simulate ancestry for two populations (joint simulation)
        ts = msprime.sim_ancestry(
            samples=samples,  # Two populations
            demography=demog,
            sequence_length=length,
            recombination_rate=recombination_rate,
            random_seed=295,
        )

        # Check the samples for each population dynamically
        for pop in ts.populations():
            print(f"Population {pop.id} samples:", ts.samples(population=pop.id))

        
        # Simulate mutations over the ancestry tree sequence
        ts = msprime.sim_mutations(ts, rate=mutation_rate)

        # Define sample sets dynamically for the SFS
        sample_sets = [
            ts.samples(population=pop.id) 
            for pop in ts.populations() 
            if len(ts.samples(population=pop.id)) > 0  # Exclude populations with no samples
        ]
        
        # Create the joint allele frequency spectrum
        sfs = ts.allele_frequency_spectrum(sample_sets=sample_sets, mode="site", polarised=True)
        
        # Multiply SFS by the sequence length to adjust scale
        sfs *= length

        # Convert to moments Spectrum for further use
        sfs = moments.Spectrum(sfs)
    
    elif mode == "inference":
        vcf_file = kwargs.get("vcf_file", None)
        pop_file = kwargs.get("pop_file", None)
        popname = kwargs.get("popname", None)

        if vcf_file is None or pop_file is None:
            raise ValueError(
                "vcf_file and pop_file must be provided in inference mode."
            )

        dd = dadi.Misc.make_data_dict_vcf(vcf_file, pop_file)
        sfs = dadi.Spectrum.from_data_dict(
            dd, [popname], projections=[2 * num_samples], polarized=True
        )

    return sfs

In [8]:
sample_sizes = {
    "N1": 15,  # 15 samples from population 1
    "N2": 8    # 8 samples from population 2
}

sfs = create_SFS(
    sampled_params, 'pretrain', sample_sizes, demographic_model = split_isolation_model_simulation, length=1e7, mutation_rate=5.7e-9, recombination_rate = 3.386e-9)

Population 0 samples: []
Population 1 samples: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
Population 2 samples: [15 16 17 18 19 20 21 22]


In [9]:
sfs.shape

(16, 9)

In [10]:
from src.parameter_inference import run_inference_dadi, run_inference_moments
from src.demographic_models import split_isolation_model_dadi, split_isolation_model_moments

In [11]:
p0 = [sampled_params["N1"], sampled_params["N2"], sampled_params["t_split"]]
num_samples = [sample_sizes['N1'], sample_sizes['N2']]
demographic_model = "split_isolation_model"
k = 1

In [12]:
sampled_params

{'t_split': 1358, 'N1': 255, 'N2': 5722, 'Na': 9570}

In [13]:
demographic_model = "split_isolation_model"

In [20]:
model_list, opt_theta_list, opt_params_final_list, ll_list = run_inference_dadi(
    sfs,
    p0,
    num_samples = 100,
    demographic_model = demographic_model,
    k = k,
    lower_bound=[100, 100, 100], # [N1, N2 , t_split]
    upper_bound=[10000, 10000, 5000], # [N1, N2, t_split]
    mutation_rate=mutation_rate,
    length=1e8,
    top_values_k = 1
)



OPT DADI PARAMETER: [ 156.99711776 1357.48807327 1126.3159604 ]


In [26]:
opt_params_final_list[0]

{'N1': 325.9846424190158,
 'N2': 7757.592331981406,
 't_split': 771.4647505541604}

In [27]:
sampled_params

{'t_split': 1358, 'N1': 255, 'N2': 5722, 'Na': 9570}

In [28]:
model_list, opt_theta_list, opt_params_final_list, ll_list = run_inference_moments(
    sfs,
    p0,
    demographic_model = demographic_model,
    k = k,
    lower_bound=[100, 100, 100], # [N1, N2 , t_split]
    upper_bound=[10000, 10000, 5000], # [N1, N2, t_split]
    mutation_rate=mutation_rate,
    length=1e8,
    top_values_k = 1
)

OPT MOMENTS PARAMETER: [ 406.73383597 5812.7052273  1181.78608151]


In [29]:
opt_params_final_list[0]

{'N1': 406.73383597251353,
 'N2': 5812.7052273046365,
 't_split': 1181.7860815133977}

In [31]:
sampled_params

{'t_split': 1358, 'N1': 255, 'N2': 5722, 'Na': 9570}

In [33]:
model_list[0].shape

(16, 9)

In [5]:
import pickle 

with open('/sietch_colab/akapoor/Demographic_Inference/split_isolation_model_seed_42/models/sims_pretrain_100_sims_inference_5_seed_42_num_replicates_10_top_values_5/num_hidden_neurons_1000_num_hidden_layers_3_num_epochs_500_dropout_value_0_weight_decay_0_batch_size_35_EarlyStopping_False/linear_mdl_obj.pkl', 'rb') as f:
    linear_mdl_obj = pickle.load(f) 
    

In [6]:
linear_mdl_obj.keys()

dict_keys(['model', 'training', 'validation', 'param_names'])

In [7]:
linear_mdl_obj['param_names']

['N1', 'N2', 't_split']

In [9]:
linear_mdl_obj['training']['predictions']

array([[ 0.0225102 ,  0.2844953 , -0.38913348],
       [ 0.32076653, -1.26143853,  0.26610066],
       [ 0.30600883,  0.91358245, -0.92541244],
       [ 0.02548777, -1.20893355,  0.0407554 ],
       [ 0.78636933, -0.4630963 , -1.2049197 ],
       [ 0.57094902, -0.05928905,  0.80462959],
       [-0.02463037,  0.3215308 , -0.36589899],
       [-0.23332101, -0.86222919,  1.38787751],
       [ 0.15606663,  0.26391918,  0.62638861],
       [-1.85172169, -1.78494932, -0.12609008],
       [-0.88617999,  0.39155847, -0.52275136],
       [ 1.05669731,  0.26791219, -0.63639187],
       [-0.43703385,  0.80994754, -0.06623194],
       [-0.53928024,  0.09651836,  0.86194908],
       [-0.82681804,  0.48456305,  0.19771553],
       [-0.61022718,  0.34964982,  0.33698961],
       [ 0.48473722,  0.71827473, -0.31718151],
       [-0.38209656, -0.01369259,  0.31282726],
       [ 0.08990455,  1.592333  , -0.75574536],
       [ 0.63406058, -0.28052289,  0.61035364],
       [ 0.27016008, -0.01976338, -1.290

In [10]:
linear_mdl_obj['training']['targets']

array([[-0.96609945,  0.83033466, -0.93672136],
       [ 0.86707513, -0.77050018,  1.32625605],
       [-0.1210686 ,  0.68582214, -1.28454544],
       [-0.44648421, -1.42343085, -0.91763345],
       [ 1.35239927, -0.33941198, -1.17072495],
       [ 1.26422213, -0.39434773,  1.30999598],
       [ 0.23758838, -0.08992668, -0.99045028],
       [-0.91606243, -1.20473756,  1.53339518],
       [-1.24077821, -0.05108675, -0.56980937],
       [-1.72505262, -1.71490525, -0.576172  ],
       [-0.84678039,  0.71871361, -0.93318656],
       [ 1.47381778,  0.54200943, -0.83916094],
       [ 0.4573314 ,  0.57630054,  0.59596687],
       [-0.72676152, -0.42443993,  1.32342821],
       [-0.975547  , -0.36845444, -0.74089357],
       [ 0.01784537,  1.32195716,  0.65747235],
       [ 1.55184754,  1.49866134, -0.14775454],
       [-1.0511274 , -0.51541633, -0.41357132],
       [-0.01224682,  0.51086751, -0.1116996 ],
       [-0.49862069, -0.6427833 ,  0.10887177],
       [-0.79429401, -0.79744319, -1.466