This code should make a separate sub-folder for each pair of half-lives and run one set of simulations for each.
Sampling will be done based on sampling times and numbers in the data.

In [39]:
import numpy as np
import random
import math
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16})
from matplotlib import gridspec
from csv import reader
from collections import defaultdict
import os
import sys

In [40]:
def get_strains_num(num_cells, strains_prop):
    strains_num = {}
    strain_with_max_prop = ''  # In case of rounding issues, we will add extra cells to this population
    max_prop = max(strains_prop.values())
    for strain, prop in strains_prop.items():
        strains_num[strain] = round(num_cells * prop)
        if prop == max_prop:
            strain_with_max_prop = strain
        
    # Check that total number of cells is equal to target num_cells
    total_cells = sum(strains_num.values())
    if total_cells < num_cells:
        strains_num[strain_with_max_prop] += (num_cells - total_cells)
        # print("Adding %s cells to strain %s" % ((num_cells - total_cells), strain_with_max_prop))
        
    return strains_num

In [41]:
def get_strains_num_to_decay(strains_num, cells_to_decay):
    # strains_num_to_decay = defaultdict(lambda: 0)
    strains = list(strains_num.keys())
    strains_num_to_decay = {}
    for strain in strains:
        strains_num_to_decay[strain] = 0
    strain_counts = np.cumsum(list(strains_num.values()))
    increment = 0.1
    
    # print(strain_counts, ": strains cumulative sum")
    # print(cells_to_decay, ": cells_to_decay")
    
    for cell_index in cells_to_decay:
        # print("Considering cell index ", cell_index)
        for index, count in enumerate(strain_counts):
            if count == 0:  # No cells of this strain
                # print("Count 0. Skipping strain ", strains[index])
                continue
            if index == 0:
                if cell_index < count-increment:
                    strains_num_to_decay[strains[index]] += 1
                    # print(cell_index, " is of strain ", strains[index])
                    break
            else:
                if cell_index >= strain_counts[index-1]-increment and cell_index < count-increment:
                    strains_num_to_decay[strains[index]] += 1
                    # print(cell_index, " is of strain ", strains[index])
                    break
        # print("Done with cell index ", cell_index, "--------------")
                    
    if sum(strains_num_to_decay.values()) != len(cells_to_decay):
        # print("ERRORRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRR")
        sys.exit(1)
        
    return strains_num_to_decay

In [42]:
def get_divergence(muts, strain_counts):
    divergence = 0.0
    for index, mut in enumerate(muts):
        divergence += mut * strain_counts[index]
    return divergence/sum(strain_counts)

In [43]:
def get_inputs(monkey_ID):
    # Figure out inputs
    slow_filename = os.path.join(results_folder, monkey_ID, monkey_ID+slow_filename_suffix)
    fast_filename = os.path.join(results_folder, monkey_ID, monkey_ID+fast_filename_suffix)
    thalf_filename = os.path.join(data_folder, half_lives_subfolder, monkey_ID+thalf_filename_suffix)
    
    # Read counts of slow-decaying cells
    strains_num_slow = {}  # key -> bin number (0-indexed), value -> number of cells in that bin
    bin_num = 0
    with open(slow_filename, 'r') as read_obj:
        csv_reader = reader(read_obj, delimiter = ' ')
        next(csv_reader) # skip header
        for row in csv_reader:
            strains_num_slow[bin_num] = round(float(row[2]))
            if strains_num_slow[bin_num] < 0:
                strains_num_slow[bin_num] = 0
            bin_num += 1
            
    # Read counts of fast-decaying cells
    strains_num_fast = {}  # key -> bin number (0-indexed), value -> number of cells in that bin
    bin_num = 0
    with open(fast_filename, 'r') as read_obj:
        csv_reader = reader(read_obj, delimiter = ' ')
        next(csv_reader) # skip header
        for row in csv_reader:
            strains_num_fast[bin_num] = round(float(row[2]))
            if strains_num_fast[bin_num] < 0:
                strains_num_fast[bin_num] = 0
            bin_num += 1
    
    # Read the various half lives and their bounds
    decay_rates_slow = []
    decay_rates_fast = []
    with open(thalf_filename, 'r') as read_obj:
        csv_reader = reader(read_obj, delimiter = '\t')
        for row in csv_reader:
            if row[0][:4] == "slow":
                decay_rates_slow.append(math.log(2)/(float(row[1])*7))  # cells/generation => t1/2 is in weeks
            else:
                decay_rates_fast.append(math.log(2)/(float(row[1])))  # cells/generation => t1/2 is in days
    
    return strains_num_slow, strains_num_fast, decay_rates_slow, decay_rates_fast

In [44]:
def modify_initial_hist(strains_num_slow, strains_num_fast):
    
    # Multiply all values by 100
    for strain, num in strains_num_slow.items():
        # strains_num_slow[strain] += 1
        strains_num_slow[strain] *= 100
    for strain, num in strains_num_fast.items():
        # strains_num_fast[strain] += 1
        strains_num_fast[strain] *= 100
    
    return strains_num_slow, strains_num_fast

In [45]:
def get_sampling_details(monkey_ID):
    # Figure out inputs
    sampling_filename = os.path.join(data_folder, sample_times_subfolder, monkey_ID+"_"+type_aorc+sampling_filename_suffix)
    # print(sampling_filename)
    
    # Read sampling times and number of sequences to be drawn
    sampling_details = {} # key -> time, value -> number of sequences
    with open(sampling_filename, 'r') as read_obj:
        csv_reader = reader(read_obj, delimiter = ' ')
        next(csv_reader) # skip header
        for row in csv_reader:
            sampling_week = int(row[1])
            sampling_details[sampling_week*7] = int(row[2]) # convert sampling time to days
            
    return sampling_details

In [46]:
# Initialize random number generator
generator = np.random.default_rng()

In [47]:
# Set up overall parameters
num_generations = 1050  # run for 150 weeks (~3 years). 1 generation = 1 day
recording_interval = 1
num_runs = 1000

data_folder = "C:\\Users\\361998\\Documents\\Narmada_projects\\SIV_divergence_decline_again_2024\\data\\"
# data_folder = "/Users/nsambaturu/Documents/Narmada_projects/SIV_divergence_decline_again_2024\data\"
half_lives_subfolder = "half_lives"
sample_times_subfolder = "sample_times"
muts_filename = "muts.txt"
monkey_IDs = ["T523", "T530", "T537", "T544", "T545", "T623", "T624", "T625", "T627", "T628"]
# monkey_IDs = ["T530", "T537", "T544", "T545", "T623", "T624", "T625", "T627", "T628"]
# monkey_IDs = ["T523"]
type_aorc = "DNA0_DNArest"
slow_filename_suffix = "_"+type_aorc+"_slow_hist_3years.txt"
fast_filename_suffix = "_"+type_aorc+"_fast_hist_3years.txt"
thalf_filename_suffix = "_half_lives.txt"
sampling_filename_suffix = "_sample_times.txt"

results_folder = "C:\\Users\\361998\\Documents\\Narmada_projects\\SIV_divergence_decline_again_2024\\results"
# results_folder = "/Users/nsambaturu/Documents/Narmada_projects/SIV_divergence_decline_again_2024/results"

In [48]:
for monkey_ID in monkey_IDs:
    # Create the output folder if it doesn't exist
    if not os.path.exists(os.path.join(results_folder, monkey_ID+"_"+type_aorc)):
        os.makedirs(os.path.join(results_folder, monkey_ID+"_"+type_aorc))

In [49]:
# Get list of mutations (or substitutions per site)
muts = []
with open(os.path.join(data_folder, muts_filename), 'r') as read_obj:
    csv_reader = reader(read_obj, delimiter = '\t')
    # next(csv_reader)
    for row in csv_reader:
        muts.append(float(row[1]))
print(muts)

[0.0005, 0.0015, 0.0025, 0.0035, 0.0045, 0.0055, 0.0065, 0.0075, 0.0085, 0.0095, 0.0105, 0.0115, 0.0125, 0.0135, 0.0145, 0.0155, 0.0165, 0.0175, 0.0185, 0.0195]


In [50]:
for monkey_ID in monkey_IDs:
    print(monkey_ID)
    
    strains_num_slow, strains_num_fast, decay_rates_slow, decay_rates_fast = get_inputs(monkey_ID)
    print("Decay rates")
    print(decay_rates_slow)
    print(decay_rates_fast)
    strains_num_slow, strains_num_fast = modify_initial_hist(strains_num_slow, strains_num_fast)
    print("After modifications")
    print(strains_num_slow)
    print(strains_num_fast)
    sampling_details = get_sampling_details(monkey_ID)
    
    num_cd4_total = sum(strains_num_slow.values()) + sum(strains_num_fast.values())
    print("Total = ", num_cd4_total)
    
    # Plot modified distributions
    fig = plt.figure(figsize=(8, 6), dpi=80)
    plt.bar(range(len(strains_num_slow)), strains_num_slow.values(), color = '#B8E0D2')
    plt.title(monkey_ID + "_slow")
    figure_filename = monkey_ID + "_" + type_aorc + "_slow.png"
    plt.savefig(os.path.join(results_folder, monkey_ID+"_"+type_aorc, figure_filename), bbox_inches = "tight")
    print("Figure saved as ", figure_filename)
    plt.close()
    
    fig = plt.figure(figsize=(8, 6), dpi=80)
    plt.bar(range(len(strains_num_fast)), strains_num_fast.values(), color = '#FFC09F')
    plt.title(monkey_ID + "_fast")
    figure_filename = monkey_ID + "_" + type_aorc + "_fast.png"
    plt.savefig(os.path.join(results_folder, monkey_ID+"_"+type_aorc, figure_filename), bbox_inches = "tight")
    print("Figure saved as ", figure_filename)
    plt.close()
    
    # Run one set of simulations for each pair of decay rates
    for rate_index in range(len(decay_rates_slow)):
        decay_rate_slow = decay_rates_slow[rate_index]
        decay_rate_fast = decay_rates_fast[rate_index]
        print("Rate index ", rate_index)
        print("Decay rate slow = ", decay_rate_slow)
        print("Decay rate fast = ", decay_rate_fast)
        print("")
        
        sub_folder = "slow_"+str("%.4f" % decay_rate_slow)+"_fast_"+str("%.4f" % decay_rate_fast)
        print_folder = os.path.join(results_folder, monkey_ID+"_"+type_aorc, sub_folder)
        if not os.path.exists(print_folder):
            os.makedirs(print_folder)
        
        for run in range(num_runs):
            output_filename = os.path.join(print_folder, monkey_ID+"_"+type_aorc+"_"+str(run)+".txt")

            strains_num_slow, strains_num_fast, decay_rates_slow, decay_rates_fast = get_inputs(monkey_ID)
            decay_rate_slow = decay_rates_slow[rate_index]
            decay_rate_fast = decay_rates_fast[rate_index]
            # print("Decay rate slow = ", decay_rate_slow)
            # print("Decay rate fast = ", decay_rate_fast)
            strains_num_slow, strains_num_fast = modify_initial_hist(strains_num_slow, strains_num_fast)

            num_cd4_total = sum(strains_num_slow.values()) + sum(strains_num_fast.values())
            #num_to_sample = prop_to_sample * num_cd4_total

            with open(output_filename, 'w') as out_f:
                # Write headers
                out_f.write("Generation" + "\t")
                for strain in strains_num_slow.keys():
                    out_f.write(str(strain)+"_slow" + "\t")
                out_f.write("total_slow" + "\t")
                out_f.write("Divergence_slow" + "\t")
                for strain in strains_num_slow.keys():
                    out_f.write(str(strain)+"_fast" + "\t")
                out_f.write("total_fast" + "\t")
                out_f.write("Divergence_fast" + "\t")
                for strain in strains_num_slow.keys():
                    out_f.write(str(strain)+"_total" + "\t")
                out_f.write("total" + "\t")
                out_f.write("Divergence_total" + "\t")
                for strain in strains_num_slow.keys():
                    out_f.write(str(strain)+"_sampled" + "\t")
                out_f.write("total_sampled" + "\t")
                out_f.write("Divergence_sample" + "\n")

                # Write initial values
                out_f.write(str(0) + "\t")

                out_f.write("\t".join([str(count) for count in strains_num_slow.values()]) + "\t")
                total_slow = sum(strains_num_slow.values())
                out_f.write(str(total_slow) + "\t")
                if total_slow == 0:
                    out_f.write("0.0" + "\t")
                else:
                    out_f.write(str(get_divergence(muts, list(strains_num_slow.values()))) + "\t")

                out_f.write("\t".join([str(count) for count in strains_num_fast.values()]) + "\t")
                total_fast = sum(strains_num_fast.values())
                out_f.write(str(total_fast) + "\t")
                if total_fast == 0:
                    out_f.write("0.0" + "\t")
                else:
                    out_f.write(str(get_divergence(muts, list(strains_num_fast.values()))) + "\t")

                # Gather fast and slow cells of the same strain together
                strain_counts = {}
                for strain in strains_num_slow.keys():
                    strain_counts[strain] = strains_num_slow[strain] + strains_num_fast[strain]
                total_cells = sum(strain_counts.values())
                if total_cells == 0:
                    # print("All cells have decayed.")
                    break
                out_f.write("\t".join([str(count) for count in strain_counts.values()]) + "\t")
                out_f.write(str(total_cells) + "\t")
                if total_cells == 0:
                    out_f.write("0.0" + "\t")
                else:
                    out_f.write(str(get_divergence(muts, list(strain_counts.values()))) + "\t")

                # Sample some cells from the total pool and figure out the strain
                # num_to_sample = round(prop_to_sample * total_cells)
                # if num_to_sample >= total_cells:
                #    num_to_sample = total_cells
                # cells_to_sample = np.sort(random.sample(range(total_cells), num_to_sample))
                # sampled_strains = get_strains_num_to_decay(strain_counts, cells_to_sample)
                # out_f.write("\t".join([str(count) for count in sampled_strains.values()]) + "\t")
                # total_sampled = sum(sampled_strains.values())
                # out_f.write(str(total_sampled) + "\t")
                # if total_sampled == 0:
                #    out_f.write("0.0" + "\t")
                # else:
                #    out_f.write(str(get_divergence(muts, list(sampled_strains.values()))) + "\t")
                
                # Sample cells based on experimental data from the total pool and figure out the strain
                # There is always a sample drawn at time 0
                num_to_sample = sampling_details[0]
                if num_to_sample >= total_cells:
                    num_to_sample = total_cells
                cells_to_sample = np.sort(random.sample(range(total_cells), num_to_sample))
                sampled_strains = get_strains_num_to_decay(strain_counts, cells_to_sample)
                out_f.write("\t".join([str(count) for count in sampled_strains.values()]) + "\t")
                total_sampled = sum(sampled_strains.values())
                out_f.write(str(total_sampled) + "\t")
                if total_sampled == 0:
                    out_f.write("0.0" + "\t")
                else:
                    out_f.write(str(get_divergence(muts, list(sampled_strains.values()))) + "\t")

                out_f.write("\n")

                for gen in range(1, num_generations+1):  # gen 0 is the inital condition
                    if sum(strains_num_slow.values()) <=0 and sum(strains_num_fast.values()) <= 0:
                        print("No more cells to decay")
                        break

                    ############################ CD4 cells with slow decay rate
                    # Identify number of cells which decay
                    # print("Slow cells")
                    num_slow_cd4 = sum(strains_num_slow.values())
                    num_slow_decays = generator.binomial(num_slow_cd4, decay_rate_slow)
                    # Distribute the decays across the strains
                    cells_to_decay = np.sort(random.sample(range(num_slow_cd4), num_slow_decays))
                    strains_num_to_decay = get_strains_num_to_decay(strains_num_slow, cells_to_decay)
                    # print(strains_num_to_decay, ": strains_num_to_decay")
                    # Update strain counts
                    for strain, num_to_decay in strains_num_to_decay.items():
                        strains_num_slow[strain] -= num_to_decay
                        if strains_num_slow[strain] < 0:
                            # print("Slow cell numbers going to negative")
                            sys.exit(1)
                    # print(strains_num_slow, ": updated strains num slow")
                    # print("*********")

                    ############################ CD4 cells with fast decay rate
                    # Identify number of cells which decay
                    # print("Fast cells")
                    num_fast_cd4 = sum(strains_num_fast.values())
                    num_fast_decays = generator.binomial(num_fast_cd4, decay_rate_fast)
                    # Distribute the decays across the strains
                    cells_to_decay = np.sort(random.sample(range(num_fast_cd4), num_fast_decays))
                    strains_num_to_decay = get_strains_num_to_decay(strains_num_fast, cells_to_decay)
                    # print(strains_num_to_decay, ": strains_num_to_decay")
                    # Update strain counts
                    for strain, num_to_decay in strains_num_to_decay.items():
                        strains_num_fast[strain] -= num_to_decay
                        if strains_num_fast[strain] < 0:
                            # print("Fast cell numbers going to negative")
                            sys.exit(1)
                    # print("strains_num_fast", ": updated strains num fast")
                    # print("*********")

                    # Gather fast and slow cells of the same strain together
                    strain_counts = {}
                    for strain in strains_num_slow.keys():
                        strain_counts[strain] = strains_num_slow[strain] + strains_num_fast[strain]
                    total_cells = sum(strain_counts.values())
                    if total_cells == 0:
                        # print("All cells have decayed.")
                        break

                    # Sample cells based on experimental data from the total pool and figure out the strain
                    # num_to_sample = round(prop_to_sample * total_cells)
                    if gen in sampling_details.keys():
                        num_to_sample = sampling_details[gen]
                    else:
                        num_to_sample = 0
                    if num_to_sample >= total_cells:
                        num_to_sample = total_cells
                    cells_to_sample = np.sort(random.sample(range(total_cells), num_to_sample))
                    sampled_strains = get_strains_num_to_decay(strain_counts, cells_to_sample)
                    # print(sampled_strains, ": sampled_strains")
                    # print("*********")

                    if gen % recording_interval == 0:
                        # Write counts to file
                        out_f.write(str(gen) + "\t")

                        out_f.write("\t".join([str(count) for count in strains_num_slow.values()]) + "\t")
                        total_slow = sum(strains_num_slow.values())
                        out_f.write(str(total_slow) + "\t")
                        if total_slow == 0:
                            out_f.write("0.0" + "\t")
                        else:
                            out_f.write(str(get_divergence(muts, list(strains_num_slow.values()))) + "\t")

                        out_f.write("\t".join([str(count) for count in strains_num_fast.values()]) + "\t")
                        total_fast = sum(strains_num_fast.values())
                        out_f.write(str(total_fast) + "\t")
                        if total_fast == 0:
                            out_f.write("0.0" + "\t")
                        else:
                            out_f.write(str(get_divergence(muts, list(strains_num_fast.values()))) + "\t")

                        out_f.write("\t".join([str(count) for count in strain_counts.values()]) + "\t")
                        out_f.write(str(total_cells) + "\t")
                        if total_cells == 0:
                            out_f.write("0.0" + "\t")
                        else:
                            out_f.write(str(get_divergence(muts, list(strain_counts.values()))) + "\t")

                        out_f.write("\t".join([str(count) for count in sampled_strains.values()]) + "\t")
                        total_sampled = sum(sampled_strains.values())
                        out_f.write(str(total_sampled) + "\t")
                        if total_sampled == 0:
                            out_f.write("0.0" + "\t")
                        else:
                            out_f.write(str(get_divergence(muts, list(sampled_strains.values()))) + "\t")

                        out_f.write("\n")

                # print("\n")

T523
Decay rates
[0.002829172165550797, 0.003414518130837169]
[0.20752909597603153, 0.2567211779851649]
After modifications
{0: 0, 1: 0, 2: 300, 3: 400, 4: 300, 5: 100, 6: 200, 7: 200, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0}
{0: 0, 1: 100, 2: 0, 3: 0, 4: 0, 5: 400, 6: 500, 7: 400, 8: 1300, 9: 500, 10: 100, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0}
Total =  4800
Figure saved as  T523_DNA0_DNArest_slow.png
Figure saved as  T523_DNA0_DNArest_fast.png
Rate index  0
Decay rate slow =  0.002829172165550797
Decay rate fast =  0.20752909597603153

Rate index  1
Decay rate slow =  0.003414518130837169
Decay rate fast =  0.2567211779851649

T530
Decay rates
[0.002829172165550797, 0.0021526309955277804]
[0.20752909597603153, 0.5776226504666211]
After modifications
{0: 0, 1: 200, 2: 400, 3: 800, 4: 800, 5: 200, 6: 800, 7: 0, 8: 200, 9: 600, 10: 400, 11: 200, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0}
{0: 0, 1: 100, 2: 0, 3