Inferring demography and mutation spectrum history for human populations
==
Run mushi on 3-SFS computed from 1000 Genome Project data

In [None]:
import mushi

import msprime
import stdpopsim

import numpy as np
import scipy

import matplotlib as mpl
from collections import OrderedDict
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import pandas as pd
import seaborn as sns
import pickle
import copy
import glob
import os

import sklearn
import umap
import tensorly.tenalg as ta
import tensorly.decomposition as td

In [None]:
# path to 3-SFS data previously computed with 1KG pipeline
data_dir = '../example_data'

# enable latex in plots
mpl.rc('text', usetex=True)
mpl.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

## Load data
### Parse population manifest

In [None]:
pops = {}
for ksfs_file in glob.glob(f'{data_dir}/3-SFS.*.tsv'):
    super_pop, pop = os.path.basename(ksfs_file).split('.')[1:3]
    if super_pop not in pops:
        pops[super_pop] = []        
    if pop not in pops[super_pop]:
        pops[super_pop].append(pop)

pops = OrderedDict((pop, sorted(pops[pop])) for pop in sorted(pops))
pops

### Load 1KG 3-SFSs

In [None]:
# we use this ordering for triplet mutation types
sorted_triplets = [f'{a5}{a}{a3}>{a5}{d}{a3}' for a in 'AC' for d in 'ACGT' if d != a for a5 in 'ACGT' for a3 in 'ACGT']

ksfs_dict = {}
plt.figure(figsize=(3, 3))
for idx, super_pop in enumerate(pops):
    for idx2, pop in enumerate(pops[super_pop]):
        ksfs_dict[pop] = mushi.kSFS(file=f'{data_dir}/3-SFS.{super_pop}.{pop}.tsv')
        ksfs_dict[pop].plot_total(kwargs=dict(ls='', alpha=0.75, marker='o', ms=5, mfc='none',
                                  label=super_pop if idx2 ==0 else None, c=f'C{idx}'))
        plt.xscale('log')
        plt.yscale('log')
        # this sorts the columns of the ksfs according to sorted_triplets
        foo, bar = ksfs_dict[pop].mutation_types.reindex(sorted_triplets)
        ksfs_dict[pop].mutation_types = foo
        ksfs_dict[pop].X = ksfs_dict[pop].X[:, bar]
plt.legend();

#### Number of segregating variants in each population

In [None]:
df = pd.DataFrame([(super_pop, pop, ksfs_dict[pop].X.sum()) for super_pop in pops for pop in pops[super_pop]], columns=('superpop', 'pop', 'segregating sites'))
plt.figure(figsize=(12, 5))
sns.barplot(x='pop', y='segregating sites', hue='superpop', data=df, dodge=False);

### Triplet target sizes
Load masked genome size (also previously computed with 1KG pipeline)

In [None]:
masked_genome_size = pd.read_csv(f'{data_dir}/masked_size.tsv', sep='\t', header=None, index_col=0, names=('count',))

plt.figure(figsize=(6, 3))
sns.barplot(data=masked_genome_size.T)
plt.xticks(rotation=90);

## Define a few parameters

### Frequency masking
clip high frequencies due to ancestral state misidentification

In [None]:
clip_low = 0
clip_high = 10
# we need a different mask vector for each population because the number of haplotypes n
# (length of SFS vector) varies
freq_mask = {}
for super_pop in pops:
    for pop in pops[super_pop]:
        freq_mask[pop] = np.array([True if (clip_low <= i < ksfs_dict[pop].n - clip_high - 1) else False
                                   for i in range(ksfs_dict[pop].n - 1)])

### Time discretization
time grid of epoch boundaries (measured in generations)

In [None]:
change_points = np.logspace(np.log10(1), np.log10(200000), 200)

### Total mutation rate
Mutation rate per site per generation (Scally, 2016)

In [None]:
u = 1.25e-8

Mutation rate per masked genome per generation

In [None]:
mu0 = u * masked_genome_size['count'].sum()
mu0

### Generation time
Generation time for time calibration (Fenner, 2005)

In [None]:
t_gen = 29

## Infer effective population size history $\eta(t)\equiv 2N(t)$

### Regularization parameters and convergence criteria

In [None]:
regularization_eta = dict(alpha_tv=1e2, alpha_spline=3e3)
convergence = dict(tol=1e-12, max_iter=2000)

### Loop over populations, inferring history for each
- For the first population (YRI) we set `eta_ref = None` to use the default constant MLE as the reference history.
- For the other populations, we use the YRI history as the reference

In [None]:
eta_ref = None

for idx_super, super_pop in enumerate(sorted(pops)):
    for idx, pop in enumerate(reversed(pops[super_pop])):
        print(super_pop, pop)
        # clear solutions, in case rerunning this cell
        ksfs_dict[pop].clear_eta()
        ksfs_dict[pop].clear_mu()
        if idx_super == idx == 0:
            alpha_ridge = 1e-10
        else:
            alpha_ridge = 1e3
        ksfs_dict[pop].infer_history(change_points, mu0, eta_ref=eta_ref, alpha_ridge=alpha_ridge,
                                     infer_mu=False,
                                     loss='prf', **regularization_eta,
                                     **convergence, mask=freq_mask[pop])
        if idx_super == idx == 0:
            eta_ref = ksfs_dict[pop].eta

### Plot histories
Plot results separately for each superpopulation

In [None]:
fig, axes = plt.subplots(len(pops), 2, sharex='col', figsize=(6, 1.7 * len(pops)), squeeze=False)
for idx_super, super_pop in enumerate(pops):
    for idx, pop in enumerate(pops[super_pop]):
        plt.sca(axes[idx_super, 0])
        ksfs_dict[pop].plot_total(kwargs=dict(ls='', alpha=0.5, marker='o', ms=5, mfc='none', c=f'C{idx}', label=pop, rasterized=True),
                                  line_kwargs=dict(c=f'C{idx}', ls=':', marker='.', ms=3, alpha=0.5, lw=1, rasterized=True),
                                  fill_kwargs=dict(color=f'C{idx}', alpha=0))
        plt.xscale('log')
        plt.yscale('log')
        plt.legend(fontsize=6, loc='upper right')
        if idx_super < len(pops) - 1:
            plt.xlabel(None)
        plt.sca(axes[idx_super, 1])
        ksfs_dict[pop].eta.plot(t_gen=t_gen, lw=2, label=pop, alpha=0.75, c=f'C{idx}')
        plt.xlim([1e3, 1e6])
        plt.legend(fontsize=6, loc='upper right')
        if idx_super < len(pops) - 1:
            plt.xlabel(None)
plt.tight_layout();

Plot all demographies on the same axes

In [None]:
fig = plt.figure(figsize=(6, 1.7 * len(pops)))
for idx_super, super_pop in enumerate(sorted(pops)):
    for idx, pop in enumerate(reversed(pops[super_pop])):
        ksfs_dict[pop].eta.plot(t_gen=t_gen, lw=3, label=super_pop if idx == 0 else None, alpha=0.3, c=f'C{idx_super}')
        plt.xlim([1e3, 1e6])
plt.legend()
plt.tight_layout();

### TMRCA

As a diagnostic to make sure the time range we used is consistent with the histories we inferred, plot the CDF of the sample TMRCA for each population

In [None]:
fig, axes = plt.subplots(len(pops), 1, sharex=True, figsize=(4, 10))
for idx, super_pop in enumerate(pops):
    plt.sca(axes[idx])
    for idx2, pop in enumerate(pops[super_pop]):
        plt.plot(t_gen * ksfs_dict[pop].η.change_points, ksfs_dict[pop].tmrca_cdf(ksfs_dict[pop].eta), label=pop)
    plt.ylabel('TMRCA CDF')
    plt.ylim([0, 1])
    plt.legend(title=super_pop, fontsize=9)
plt.xlabel('$t$')
plt.xscale('log')
plt.tight_layout();

## Infer mutation spectrum history $\boldsymbol\mu(t)$

### TCC>TTC pulse in Europeans
First we'll focus on sharply timing the TCC>TTC, and assessing sensitivity to demographic assumptions.
#### Regularization paremeters
total variation regularization to minimize the number of change points.

In [None]:
regularization_mu = dict(beta_tv=7e1, beta_ridge=1e-10)

#### Loop over EUR populations, inferring history for each under three alternative demographic conditionings
1. The European demography from Tennessen et al. (assumed by Harris and Pritchard to time the TCC>TTC pulse)
2. The demographies for each EUR population reported by Speidel et al. using the Relate method
3. The inferred demographies from above 

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(17.4, 7.5))

species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfrica_2T12")
ddb = model.get_demography_debugger()
steps = np.concatenate((np.array([0]), change_points))
eta_Tennessen = mushi.eta(change_points,
                        1 / ddb.coalescence_rate_trajectory(steps=steps,
                                                            num_samples=[0, 2],
                                                            double_step_validation=False)[0])

for idx, pop in enumerate(pops['EUR']):
    print(pop)
    eta_mushi = ksfs_dict[pop].eta

    with open(f'../Relate_histories/relate_{pop}.coal') as f:
        f.readline()
        t = np.fromstring(f.readline(), sep=' ')
        y = 1 / np.fromstring(f.readline(), sep=' ')[2:]
    t2 = eta_mushi.arrays()[0]
    eta_relate = mushi.eta(t2[1:-1], scipy.interpolate.interp1d(t, y, kind='nearest')(t2[:-1]))
    for row_idx, eta in enumerate((eta_Tennessen, eta_relate, eta_mushi)):
        ksfs_dict[pop].clear_mu()
        ksfs_dict[pop].infer_history(change_points, mu0, eta=eta, infer_eta=False,
                                     loss='prf', **regularization_mu,
                                     **convergence, mask=freq_mask[pop])
        plt.sca(axes[row_idx, 0])
        ksfs_dict[pop].plot_total(kwargs=dict(ls='', alpha=0.5, marker='o', ms=5, mfc='none', c=f'C{idx}', label=pop, rasterized=True),
                                  line_kwargs=dict(c=f'C{idx}', ls=':', marker='.', ms=3, alpha=0.5, lw=1, rasterized=True),
                                  fill_kwargs=dict(color=f'C{idx}', alpha=0))
        plt.xscale('log')
        plt.yscale('log')
        plt.legend(fontsize=6, loc='upper right')
        if idx_super < len(pops) - 1:
            plt.xlabel(None)

        plt.sca(axes[row_idx, 1])
        ksfs_dict[pop].eta.plot(t_gen=t_gen, lw=2, label=pop, alpha=0.75, c=f'C{idx}')
        plt.xlim([1e3, 1e6])
        plt.legend(fontsize=6, loc='upper right')
        if idx_super < len(pops) - 1:
            plt.xlabel(None)
            
        plt.sca(axes[row_idx, 2])
        ksfs_dict[pop].plot(('TCC>TTC',), clr=True,
                            kwargs=dict(ls='', c=f'C{idx}', marker='o', ms=5, mfc='none', alpha=0.5, label=pop, rasterized=True),
                            line_kwargs=dict(c=f'C{idx}', ls=':', marker='.', ms=3, alpha=0.5, lw=1, rasterized=True))
        plt.ylabel('TCC$\\to$TTC component of\nvariant count composition')
        plt.legend(fontsize=6)

        plt.sca(axes[row_idx, 3])
        plt.gca().set_prop_cycle(None)
        ksfs_dict[pop].mu.plot(('TCC>TTC',), t_gen=t_gen, clr=False, c=f'C{idx}', alpha=0.75, lw=2, label=pop)
        plt.ylabel(r'TCC$\to$TTC mutation intensity')
        plt.xlim([1e3, 1e6])
        plt.legend(fontsize=6)

plt.tight_layout();

### Mutation spectrum histories for all populations
Now we'll aim for smooth histories that capture all components well
#### Regularization parameters

In [None]:
regularization_mu = dict(hard=True, beta_rank=2e2, beta_tv=0, beta_spline=1e5)

#### Loop over populations, inferring history for each

In [None]:
mu_ref = None
for idx_super, super_pop in enumerate(sorted(pops)):    
    for idx, pop in enumerate(reversed(pops[super_pop])):
        print(super_pop, pop)
        ksfs_dict[pop].clear_mu()
        if idx_super == idx == 0:
            beta_ridge = 1e-10
        else:
            beta_ridge = 1e4
        ksfs_dict[pop].infer_history(change_points, mu0, infer_eta=False, mu_ref=mu_ref, beta_ridge=beta_ridge,
                                     loss='prf', **regularization_mu,
                                     **convergence, mask=freq_mask[pop])
        if idx_super == idx == 0:
            mu_ref = ksfs_dict[pop].mu

#### Plot histories
plot 3-SFS and inferred mush for each population (this produces many plots)

In [None]:
with mpl.rc_context(rc={'text.usetex': False}):
    for idx_super, super_pop in enumerate(sorted(pops)):   
        for idx, pop in enumerate(reversed(pops[super_pop])):
            print(pop)
            
            plt.figure(figsize=(6, 10))
            plt.subplot(121)            
            ksfs_dict[pop].plot(clr=True, kwargs=dict(alpha=0.25, ls='', marker='o',
                                                      ms=3, mfc='none', rasterized=True),
                                line_kwargs=dict(ls=':', marker='.', ms=2, alpha=0.25,
                                                 lw=1, rasterized=True))
            plt.ylabel('variant count composition')
            plt.subplot(122)
            ksfs_dict[pop].μ.plot(t_gen=t_gen, clr=True, alpha=0.5, lw=2)
            plt.xscale('log')
            plt.ylabel('mutation intensity composition')
            plt.xlim([1e3, 1e6])
            plt.show()                       
            
            singlets = [f'{anc}>{der}' for anc, der in zip(ksfs_dict[pop].mutation_types.str[1],
                                                           ksfs_dict[pop].mutation_types.str[5])]
            
            ksfs_dict[pop].mutation_types.str[1].str.cat(ksfs_dict[pop].mutation_types.str[5], sep='>')

            pal = sns.color_palette('husl', n_colors=6)
            col_map = {'A>C': pal[0], 'A>G': pal[1], 'A>T': pal[2], 'C>A': pal[3], 'C>G': pal[4], 'C>T': pal[5]}
            col_colors = [col_map[singlet] for singlet in singlets]
            
            ksfs_dict[pop].clustermap(figsize=(20, 7), col_cluster=False,
                                      xticklabels=True, rasterized=True,
                                      robust=True, cmap='RdBu_r',
                                      col_colors=col_colors)
            plt.show()

            ksfs_dict[pop].μ.clustermap(t_gen=t_gen,
                                        figsize=(20, 7), col_cluster=False, xticklabels=True, rasterized=True,
                                        robust=True, cmap='RdBu_r',
                                        col_colors=col_colors)
            plt.show()

### Tensor decomposition and mutation signature dynamics
#### Non-negative CP decomposition
We stack the MuSHs for each population to form a tensor of order 3, then use a rank 10 CP decomposition to extract factors for each dimension: sample, time, and mutation type

In [None]:
targets = np.array([masked_genome_size.loc[triplet.split('>')[0], 'count'] for triplet in sorted_triplets])

Z = np.concatenate([(ksfs_dict[pop].mu.Z / targets)[np.newaxis, :, :]
                    for super_pop in sorted(pops)
                    for pop in reversed(pops[super_pop])], 0)

# non-negative parafrac decomposition
core, (factors_sample,
       factors_history,
       factors_signature) = td.parafac(Z, rank=10, non_negative=True,
                                       orthogonalise=False, normalize_factors=False,
                                       tol=1e-10, n_iter_max=10000,
                                       random_state=0)

#### Project the sample dimension factors to two principal components

In [None]:
xy = sklearn.decomposition.PCA(n_components=2, whiten=True).fit_transform(factors_sample)
plt.figure(figsize=(4, 4))
sns.scatterplot(*xy.T, hue=[super_pop for super_pop in sorted(pops)
                                      for pop in reversed(pops[super_pop])])
i = 0
for super_pop in sorted(pops):
    for pop in reversed(pops[super_pop]):
        plt.annotate(pop, xy[i], size=6)   
        i += 1
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.tight_layout()
plt.show()

#### Project the time dimension factors to two principal components

In [None]:
xy = sklearn.decomposition.PCA(n_components=2, whiten=True).fit_transform(factors_history)
idxs = (1e3 <= t_gen * ksfs_dict[pop].mu.arrays()[0][:-1]) & (t_gen * ksfs_dict[pop].mu.arrays()[0][:-1] <= 1e6)
plt.figure(figsize=(4, 4))
x = np.dstack((xy[idxs][:-1, 0], xy[idxs][1:, 0])).flatten()
y = np.dstack((xy[idxs][:-1, 1], xy[idxs][1:, 1])).flatten()
z = t_gen * np.dstack((ksfs_dict[pop].mu.arrays()[0][:-1][idxs][:-1], ksfs_dict[pop].mu.arrays()[0][:-1][idxs][:-1])).flatten()
norm = mpl.colors.LogNorm(vmin=1e3, vmax=1e6)
points = plt.scatter(x, y, c=z, s=0, norm=norm, cmap='viridis')
sns.lineplot(x, y, z, hue_norm=norm, lw=3, legend=False, palette='viridis')
cbar= plt.colorbar(points)
cbar.set_label('$t$ (years ago)', rotation=90)
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.tight_layout()
plt.show()

#### Project the mutation type dimension factors to two principal components

In [None]:
with mpl.rc_context(rc={'text.usetex': False}):
    singlets = [f'{anc}>{der}' for anc, der in zip(ksfs_dict[pop].mutation_types.str[1],
                                                   ksfs_dict[pop].mutation_types.str[5])]
    xy = sklearn.decomposition.PCA(n_components=2, whiten=True).fit_transform(factors_signature * targets[:, np.newaxis])
    plt.figure(figsize=(4, 4))
    sns.scatterplot(*xy.T, hue=singlets, palette='husl')
    for i, mutation_type in enumerate(ksfs_dict[pop].mu.mutation_types):
        if np.abs(xy[i]).max() > 1:
            plt.annotate(mutation_type, xy[i], size=6)   
    plt.xlabel('PC 1')
    plt.ylabel('PC 2')
    plt.tight_layout()
    plt.show()

#### Display mutation type loadings as mutation signatures

In [None]:
signature_df = pd.DataFrame(factors_signature,
                            columns=[i for i in range(1, factors_signature.shape[1] + 1)])
signature_df['mutation type'] = ksfs_dict[pop].mu.mutation_types
signature_df['singlet type'] = singlets

pal = sns.color_palette('husl', n_colors=6)
colors = {'A>C': pal[0], 'A>G': pal[1], 'A>T': pal[2],
          'C>A': pal[3], 'C>G': pal[4], 'C>T': pal[5]}
palette = [colors[singlet] for singlet in signature_df['singlet type']]
signature_df = signature_df.melt(id_vars=['mutation type', 'singlet type'], var_name='signature')

with mpl.rc_context(rc={'text.usetex': False}):
    g = sns.FacetGrid(signature_df, row='signature', aspect=8, height=1.5,
                      margin_titles=True, sharey=True)
    g = g.map(sns.barplot, 'mutation type', 'value',
              order=ksfs_dict[pop].mu.mutation_types,
              palette=palette)
    g.set_xticklabels(rotation=90, fontsize=8, family='monospace')
    for xtick, color in zip(g.axes.flat[-1].get_xticklabels(), palette):
        xtick.set_color(color)    
    plt.tight_layout()
    plt.show()

#### Mutation signature dynamics in each population
- rows correspond to 10 mutation signatures
- columns correspond to super populations

In [None]:
factors_signature_normed = factors_signature * targets[:, np.newaxis]
factors_signature_normed /= factors_signature_normed.sum(0, keepdims=True)

Z_transform = ta.mode_dot(Z * targets[np.newaxis, np.newaxis, :], factors_signature_normed.T, mode=2)
Z_transform_normed = Z_transform / Z_transform.sum(2, keepdims=True)

fig, axes = plt.subplots(Z_transform.shape[2], len(pops), sharex=True, sharey='row', figsize=(2.1 * len(pops), 1.5 * Z_transform.shape[2]))
for k in range(Z_transform.shape[2]):    
    i = 0
    for idx, super_pop in enumerate(sorted(pops)):
        plt.sca(axes[k, idx])
        for idx2, pop in enumerate(reversed(pops[super_pop])):
            plt.plot(t_gen * ksfs_dict[pop].mu.arrays()[0][:-1], Z_transform_normed[i, :, k],
                     label=pop)
            plt.xscale('log')
            plt.xlim([1e3, 1e6])
            i += 1
        if k == 0:
            plt.title(super_pop)
            plt.legend(fontsize=6, loc='upper right', ncol=2)
plt.xlabel('$t$ (years ago)')
plt.tight_layout()
plt.show()

The same histories overlayed for all populations

In [None]:
fig, axes = plt.subplots(Z_transform.shape[2], 1, sharex=True, sharey='row', figsize=(3, 1.5 * Z_transform.shape[2]))
for k in range(Z_transform.shape[2]):    
    i = 0
    plt.sca(axes[k])
    for idx, super_pop in enumerate(sorted(pops)):        
        for idx2, pop in enumerate(reversed(pops[super_pop])):
            plt.plot(t_gen * ksfs_dict[pop].mu.arrays()[0][:-1], Z_transform_normed[i, :, k],
                     label=(super_pop if idx2 == 0 else None), c=f'C{idx}')
            plt.xscale('log')
            plt.xlim([1e3, 1e6])
            i += 1
    if k == 0:
        plt.legend(fontsize=6, loc='upper right')
plt.xlabel('$t$ (years ago)')
plt.tight_layout()
plt.show()

### Global divergence in mutation spectra
#### UMAP embedding of mutation signatures through time for all populations

In [None]:
# truncate to the time range we want to plot
start = 58
end = -30

time = np.concatenate(([0], ksfs_dict[pop].μ.change_points))[start:end]

np.random.seed(1)

embedding_init = sklearn.decomposition.PCA(n_components=2)
embedding_init.fit(factors_history)

colors = {super_pop: f'C{idx}' for idx, super_pop in enumerate(pops.keys())}

Z_dict = {}
i = 0
for super_pop in sorted(pops):        
    for pop in reversed(pops[super_pop]):
        Z_dict[pop] = Z_transform[i, :, :]
        i += 1

Z_stack = np.concatenate([Z_dict[pop] for super_pop in pops for pop in pops[super_pop]])

embedding = umap.UMAP(n_components=2,
                      init=embedding_init.transform(Z_stack), random_state=1)
embedding.fit(Z_stack)    

Z_transform_dict = {pop:embedding.transform(Z_dict[pop])[start:end]
                    for super_pop in pops
                    for pop in pops[super_pop]}

#### Plot embedding

In [None]:
with mpl.rc_context(rc={'text.usetex': False}):

    plt.figure(figsize=(6, 6))
    # plt.subplot(311)
    for super_pop in pops:
        for idx2, pop in enumerate(pops[super_pop]):
            plt.plot(*Z_transform_dict[pop].T,
                     '-', lw=3, alpha=.5,
                     c=colors[super_pop],
                     label=super_pop if idx2 == 0 else None)
            plt.annotate(pop, Z_transform_dict[pop][0, :],
                         ha='center', va='center', c='w',
                         family='monospace',
                         bbox=dict(boxstyle='circle', fc=colors[super_pop], ec=colors[super_pop], lw=2),
                         size=6)
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.legend()
    plt.tight_layout()
    plt.show()

    fig = plt.figure(figsize=(6, 6))
    ax = fig.gca(projection='3d')
    ax.view_init(20, 70)
    # plt.subplot(311)
    for super_pop in pops:
        for pop in pops[super_pop]:
            ax.plot(*Z_transform_dict[pop].T, np.log10(t_gen * time),
                    '-', lw=3, alpha=1,
                    c=colors[super_pop],
                    label=super_pop if idx2 == 0 else None)
#             ax.scatter(*Z_transform_dict[pop][None, 0, :].T, np.log10(t_gen * time[1]), s=50,
#                        c=f'C{idx}', alpha=1)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_zlabel('$\log_{10}(t)$')
    ax.set_zticks([3, 4, 5, 6])
    ax.set_zlim([3, 6])
    # ax.zaxis._set_scale('log')
    # ax.legend()
    plt.tight_layout()
    plt.show()