In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
# draw from the prior
u_1 = np.random.uniform(0,1)
u_2 = np.random.uniform(0,1)
# gamma + alpha = 1 - u_1
gamma = u_2 * (1 - u_1) / 2 # ensure that alpha > gamma
alpha = (1 - u_1) - gamma
beta = u_1

In [3]:
gamma, alpha, beta

(0.07536668596951239, 0.09568824157375029, 0.8289450724567373)

In [4]:
def summary_statistic(cluster_df):
    g = len(cluster_df['Cluster Size'])
    sum_population = (cluster_df['Cluster Size'] * cluster_df['Number of Clusters']).sum()
    y1 = g / sum_population
    y2 = 1 - ((cluster_df['Number of Clusters'] / sum_population) ** 2).sum()
    return np.array([y1, y2])


def subsample_population(population_df, subsample_size=473):
    # Expand the population DataFrame into individual bacteria
    expanded_population = population_df.loc[
        population_df.index.repeat(population_df['count'])
    ].reset_index(drop=True)

    # Randomly subsample
    subsample = expanded_population.sample(n=subsample_size, replace=False)

    # Reconstruct the subsample cluster sizes
    subsample_cluster_sizes = subsample.groupby('genotype').size()

    # Convert back to a DataFrame
    cluster_df = subsample_cluster_sizes.value_counts().reset_index()
    cluster_df.columns = ['Cluster Size', 'Number of Clusters']

    return cluster_df


def simulate_data(alpha, beta, gamma):
    N_target = 1e4  # Target population size

    # Initialize population
    population = {0: 1}  # Start with one bacterium of genotype 0
    N = 1  # Current population size
    time = 0  # Simulation time
    times = [0]  # Record times
    sizes = [1]  # Record total population sizes

    # Simulation loop
    while N < N_target:
        # Compute total rate
        rates = {genotype: (alpha + gamma + beta) * count for genotype, count in population.items()}
        total_rate = sum(rates.values())

        # If total rate is zero, population is extinct
        if total_rate == 0 or N == 0:
            print('extinct!!')
            return np.array([0, 1])  # Return default summary statistic

        # Sample time to next event
        delta_t = np.random.exponential(1 / total_rate)
        time += delta_t

        # Determine which event occurs
        event = np.random.uniform(0, total_rate)
        cumulative_rate = 0
        for genotype, rate in rates.items():
            cumulative_rate += rate
            if event < cumulative_rate:
                selected_genotype = genotype
                break

        # Determine event type (replication, death, mutation)
        event_type = np.random.choice(['replication', 'death', 'mutation'],
                                      p=[alpha, gamma, beta])
        if event_type == 'replication':
            population[selected_genotype] += 1
            N += 1
        elif event_type == 'death':
            population[selected_genotype] -= 1
            if population[selected_genotype] == 0:
                del population[selected_genotype]
            N -= 1
        elif event_type == 'mutation':
            population[selected_genotype] -= 1
            if population[selected_genotype] == 0:
                del population[selected_genotype]
            # Create new genotype
            new_genotype = max(population.keys(), default=0) + 1
            population[new_genotype] = 1

        # Record data
        times.append(time)
        sizes.append(N)

    # Create population DataFrame
    population_df = pd.DataFrame(
        population.items(),
        columns=['genotype', 'count']
    )

    cluster_df = subsample_population(population_df)
    y = summary_statistic(cluster_df)
    
    return y


In [5]:
# Given data
cluster_sizes = [1, 2, 3, 4, 5, 8, 10, 15, 23, 30]
number_of_clusters = [282, 20, 13, 4, 2, 1, 1, 1, 1, 1]


true_cluster_df = pd.DataFrame({
    'Cluster Size': cluster_sizes,
    'Number of Clusters': number_of_clusters
})

In [6]:
true_summary_statistics = summary_statistic(true_cluster_df)

In [7]:
true_summary_statistics

array([0.02114165, 0.64189712])

In [8]:
from scipy.stats import multivariate_normal

def ABC_sampler(epsilon, true_summary_statistics):
    # Multivariate Gaussian parameters
    mean = [0.6, 0.2]
    cov = [[0.007, -0.008], [-0.008, 0.01]]

    while True:
        alpha, gamma = np.random.multivariate_normal(mean, cov)
        if alpha > 0 and gamma > 0 and alpha + gamma < 1 and alpha > gamma :
            break
    beta = 1 - alpha - gamma
    statistic_proposal = simulate_data(alpha, beta, gamma)
    distance = np.linalg.norm(statistic_proposal - true_summary_statistics, ord=2)
    proposal_pdf = multivariate_normal.pdf([alpha, gamma], mean=mean, cov=cov)
    weight = 1 / proposal_pdf
    
    if distance <= epsilon: 
        zero_weight = False
    else:
        zero_weight = True
    return alpha, gamma, weight, zero_weight, statistic_proposal, distance

In [None]:
from joblib import Parallel, delayed
import tqdm

n_samples = 50000
epsilon = 0.2


def worker(_):
    return ABC_sampler(epsilon, true_summary_statistics)

results = Parallel(n_jobs=-1)(delayed(worker)(i) for i in tqdm.tqdm(range(n_samples)))

# Unpack results
alphas, gammas, weights, zero_weights, statistic_proposals, distances = zip(*[(alpha, gamma, weight, zero_weights, statistic_proposal, distance) for alpha, gamma, weight, zero_weights, statistic_proposal, distance in results])

# Convert to lists (optional)
alphas = list(alphas)
gammas = list(gammas)
weights = list(weights)
zero_weights = list(zero_weights)
statistic_proposals = list(statistic_proposals)
distances = list(distances)

print("Sampling completed!")



  0%|          | 180/50000 [00:46<4:46:17,  2.90it/s]

In [None]:
import pickle

results = {
    'alphas': alphas,
    'gammas': gammas,
    'weights': weights,
    'zero_weights': zero_weights,
    'summary_satistics' : statistic_proposals,
    'distance' : distances,
}

# Save to a pkl file
with open('abc_results.pkl', 'wb') as file:
    pickle.dump(results, file)

print("Results saved to abc_results.pkl")