In [2]:
from collections import Counter, defaultdict
from pathlib import Path

import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, fcluster
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

from utils.conversion import ms_to_numpy

# Inferring selection strength

Using the method of [Messer & Neher (2012)](https://doi.org/10.1534/genetics.112.138461).

## Empirical window

We will attempt to get $L(\mu + r)/s \approx 0.1$.

In [3]:
min_s = 0.01
max_s = 1.0
mut_rate = snakemake.params["mut_rate"]
rec_rate = snakemake.params["rec_rate"]

In [4]:
L_min = int((0.1*min_s)/(mut_rate + rec_rate))
L_max = int((0.1*max_s)/(mut_rate + rec_rate))
print(f'To get u/s = 0.1 with s={min_s}, we need L={L_min}')
print(f'To get u/s = 0.1 with s={max_s}, we need L={L_max}')

In [5]:
window_size = int(snakemake.wildcards["window_size"])

In [6]:
u = (window_size*(mut_rate + rec_rate))
print(f"With a window size (L) of {window_size}, we get u={u}")

Get a window in the center of the region:

In [7]:
with open(snakemake.input["ms"]) as f:
    positions, haplotypes = ms_to_numpy(f)
region_center = (positions.min() + positions.max())/2
win_start = region_center - window_size/2
win_end = region_center + window_size/2

Subset the haplotypes to the window:

In [8]:
indices = np.where((positions > win_start) & (positions < win_end))[0]
subpositions = positions[indices]
subwindow = haplotypes[indices, :]

## Haplotype frequency spectrum

In [9]:
hap_strings = ["".join(str(i) for i in hap) for hap in subwindow.T]
hap_counts = Counter(hap_strings)
haps_by_frequency = sorted(hap_counts, key=lambda x: hap_counts[x], reverse=True)

How many unique haplotypes are there?

In [10]:
len(hap_counts)

Plot the haplotype distance matrix, ordered by most common haplotype:

In [11]:
hap_arrays_by_frequency = [list(hap) for hap in haps_by_frequency]
prop_diff_sites = pdist(
    hap_arrays_by_frequency,
    metric="hamming"
)
diff_sites = squareform(prop_diff_sites*len(subpositions))

In [12]:
fig, ax = plt.subplots()
im = ax.imshow(diff_sites)
cbar = plt.colorbar(im)
plt.show()

## Split into sweep components

Perform clustering to see if there are groups of related haplotypes. I moved the parameter $t$ around such that we have sensical clusters of haplotypes.

In [13]:
# Cluster until no single cluster has only 1 haplotype
clustering = linkage(prop_diff_sites, method='single')
min_hap = 1
t = 0
iteration = 0
while min_hap == 1:
    iteration += 1
    t += 0.01
    hap_clusters = fcluster(clustering, t=t, criterion='distance')
    clust_info = np.unique(hap_clusters, return_counts=True)
    min_hap = min(clust_info[1])
    
print(f"Clustering iteration {iteration} with t={t}\n")
for ix, clust in enumerate(clust_info[0]):
    print(f"Cluster {clust}: {clust_info[1][ix]} haplotypes")

Order haplotypes by their cluster, then by their frequency.

In [14]:
ordered = defaultdict(list)
for ix, hap in enumerate(haps_by_frequency):
    cluster = hap_clusters[ix]
    ordered[cluster].append(hap)
    
haps_by_cluster = []
for cluster_ix, haps in ordered.items():
    haps_by_cluster += haps

In [15]:
hap_arrays_by_cluster = [list(hap) for hap in haps_by_cluster]
prop_diff_sites_cluster = pdist(
    hap_arrays_by_cluster,
    metric="hamming"
)
diff_sites = squareform(prop_diff_sites_cluster*len(subpositions))

In [16]:
fig, ax = plt.subplots()
im = ax.imshow(diff_sites)
cbar = plt.colorbar(im)
plt.savefig(snakemake.output["clusters"])
plt.show()

## Estimate selection strength

We will estimate one $s$ per cluster, even though some clusters might just be neutral.

In [17]:
# How many different haplotypes do we consider per cluster?
num_most_abundant_unique_haps = int(snakemake.wildcards["haps_per_cluster"])

u = window_size*(mut_rate + rec_rate)

result_dfs = []

with open(snakemake.output["estimate"], 'w') as f:
    f.write(f"Window size: {window_size}" + '\n')
    f.write(f"Haplotypes per cluster: {num_most_abundant_unique_haps}" + '\n\n')
    for cluster, haps in ordered.items():

        cluster_name = f'Cluster {cluster}'
        f.write(f'{cluster_name}\n' + '-'*len(cluster_name) + '\n\n')

        haps_subset = Counter({hap: hap_counts[hap] for hap in haps})
        most_common = haps_subset.most_common(num_most_abundant_unique_haps)

        for hap, count in most_common:
            hap_string = hap
            if len(hap_string) > 35:
                hap_string = hap_string[:32] + '...'
            f.write(hap_string + '\t' + 'x' + str(count) + '\n')
        f.write('\n')
        
        counts = sorted([duo[1] for duo in most_common], reverse=True)

        n_0 = counts[0]
        n_c = min(counts)
        i_c = len(counts)

        s = (u/i_c)*(n_0/n_c)**(1 + (n_c*i_c)/n_0)
        error = s*(1/np.sqrt(i_c))
        
        f.write(f"i_c = {i_c} haplotypes with an abundance of at least n_c = {n_c}.\n")
        if i_c < num_most_abundant_unique_haps:
            f.write(f"This cluster doesn't have at least {num_most_abundant_unique_haps} haplotypes.")
        elif n_c == 1:
            f.write("This cluster is probably neutral.\n")
        else:
            f.write(f"Estimated s = {s} +/- {error}\n")
            this_df = pd.DataFrame({
                'estimated_s': s,
                'estimated_error': error,
                'cluster': cluster,
                'counts': counts
            })
            result_dfs.append(this_df)
    
        f.write('\n\n')

In [18]:
try:
    df = pd.concat(result_dfs)
except ValueError: # no good clusters at all
    df = pd.DataFrame({'cluster': []})

## Plot haplotype frequency spectrum

In [19]:
u = window_size*(mut_rate + rec_rate)

outfolder = Path(snakemake.output["hfs_folder"])
outfolder.mkdir(exist_ok=True)

for cluster in df.cluster.unique():
    
    data = df.loc[df.cluster == cluster]    
    y = data.counts
    s = data.estimated_s[0]
    i = np.arange(1, len(y))
    beta = 1 - (u/s)
    expected_fracs = [1] + list((u/(i*s))**beta)
    
    fig, ax = plt.subplots()
    ax.plot(range(len(y)), y/y[0], label='Observed')
    ax.scatter(range(len(y)), y/y[0])
    ax.plot(range(len(y)), expected_fracs, label='Expected')
    ax.scatter(range(len(y)), expected_fracs)
    ax.set_yscale('log')
    ax.legend()
    ax.set_xlabel('Haplotype rank i')
    ax.set_ylabel('n_i/n_0')
    ax.set_title(f"HFS for cluster {cluster}")
    
    plt.savefig(outfolder/f"cluster-{cluster}.pdf")
    plt.show()