1KG
==
Run mushi on 3-SFS computed from 1000 Genome Project data

In [None]:
%matplotlib inline 
# %matplotlib notebook
from mushi import kSFS
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA, NMF
from umap import UMAP
import pickle
from skbio.stats.composition import ilr, clr, closure, inner

In [None]:
# plt.style.use('dark_background')

In [None]:
pops = {}
with open('data/phase3_1000genomes/integrated_call_samples_v3.20130502.ALL.panel') as f:
    f.readline()
    for line in f:
        pop, super_pop = line.split('\t')[1:3]
        if super_pop not in pops:
            pops[super_pop] = []        
        if pop not in pops[super_pop]:
            pops[super_pop].append(pop)
# pops['AFR'].append('Batwa')

In [None]:
pops

### Load 1KG AFR and Batwa 3-SFSs

In [None]:
sorted_triplets = [f'{a5}{a}{a3}>{a5}{d}{a3}' for a in 'AC' for d in 'ACGT' if d != a for a5 in 'ACGT' for a3 in 'ACGT']

ksfs_dict = {}
for super_pop in pops:
    for pop in pops[super_pop]:
        ksfs_dict[pop] = pickle.load(open(f'1KG/scons_output/{pop}/mushi.pkl', 'rb'))
    
for pop in ksfs_dict:   
    foo, bar = ksfs_dict[pop].mutation_types.reindex(sorted_triplets)
    
    ksfs_dict[pop].mutation_types = foo
    ksfs_dict[pop].X = ksfs_dict[pop].X[:, bar]
    ksfs_dict[pop].μ.mutation_types = foo
    ksfs_dict[pop].μ.Z = ksfs_dict[pop].μ.Z[:, bar]    

### Mutation type enrichment as a heatmap

In [None]:
for pop in ksfs_dict:
    print(pop)
    
    singlets = ksfs_dict[pop].mutation_types.str[1].str.cat(ksfs_dict[pop].mutation_types.str[5], sep='>')
#     a5 = ksfs_dict[pop].mutation_types.str[0]
#     a3 = ksfs_dict[pop].mutation_types.str[2]

    col_map = {'A>C':'C0', 'A>G':'C1', 'A>T':'C2', 'C>A':'C3', 'C>G':'C4', 'C>T':'C5'}
    col_colors = [col_map[singlet] for singlet in singlets]
    
    g = ksfs_dict[pop].clustermap(figsize=(20, 7), col_cluster=False, xticklabels=True, rasterized=True,
                                  vmin=0.75, vmax=1.25, center=1, col_colors=col_colors, cmap='RdBu_r')
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 9, family='monospace')
#     g.ax_col_dendrogram.set_visible(False)
    g.savefig(f'/Users/williamdewitt/Downloads/heatmap.{pop}.pdf', transparent=True)
    plt.show()
    
    g = ksfs_dict[pop].μ.clustermap(figsize=(20, 7), col_cluster=False, xticklabels=True, rasterized=True,
                                    vmin=0.5, vmax=1.5,
                                    col_colors=col_colors, cmap='RdBu_r')
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 9, family='monospace',)
#     g.ax_col_dendrogram.set_visible(False)
    g.savefig(f'/Users/williamdewitt/Downloads/heatmap.{pop}.mu.pdf', transparent=True)
    plt.show()    

In [None]:
g = ksfs_dict['Batwa'].μ.clustermap(figsize=(20, 7), col_cluster=True, xticklabels=True, rasterized=True,
                                    vmin=0.5, vmax=1.5,
                                    col_colors=col_colors, cmap='RdBu_r')
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 9, family='monospace',)
#     g.ax_col_dendrogram.set_visible(False)
g.savefig(f'/Users/williamdewitt/Downloads/heatmap.Batwa.mu.clustered.pdf', transparent=True)
plt.show()    

### TMRCA CDF

In [None]:
for super_pop in pops:
    print(super_pop)
    plt.figure(figsize=(3, 3))
    for pop in pops[super_pop]:
        plt.plot(ksfs_dict[pop].η.change_points, ksfs_dict[pop].tmrca_cdf(), label=pop)
        plt.xlabel('$t$')
        plt.ylabel('TMRCA CDF')
        plt.ylim([0, 1])
        plt.xscale('symlog')
        plt.legend()
        plt.tight_layout()
    plt.show()

In [None]:
for super_pop in pops:
    print(super_pop)
    plt.figure(figsize=(3, 3))
    for pop in pops[super_pop]:
        ksfs_dict[pop].η.plot(label=pop, lw=3, alpha=0.8, zorder=-1)
    # plt.legend(loc=(1.04,0))
    plt.tight_layout()
    plt.savefig(f'/Users/williamdewitt/Downloads/{super_pop}.N.pdf', transparent=True)
    plt.show()

Expectation of emprical spectra from counting variants, doing compositional data transformation, then pca

cite: https://link.springer.com/article/10.1023/A:1023818214614 and Aitchison

and see: https://stats.stackexchange.com/questions/305965/can-i-use-the-clr-centered-log-ratio-transformation-to-prepare-data-for-pca

In [None]:
# PCA of integrated signal (expected seg sites)

import matplotlib.animation as animation
# from IPython.display import HTML

time = np.concatenate(([0], ksfs_dict[pop].μ.change_points))

# embedding = UMAP(n_components=2,
# #                  min_dist=0,
#                  n_neighbors=10,
#                 )
embedding = PCA(n_components=2)#, whiten=True)
Z_dict = {}
for pop in ksfs_dict:
    Z_dict[pop] = np.array([ilr(closure(ksfs_dict[pop].L @ ksfs_dict[pop].μ.Z).sum(0, keepdims=True))])
    
Z = np.concatenate(tuple(Z_dict[pop] for pop in ksfs_dict))
    
embedding.fit(Z)

# each
Z_transform_dict = {pop:embedding.transform(Z_dict[pop]) for pop in ksfs_dict}

fig, ax = plt.subplots(figsize=(5, 3))
for idx, super_pop in enumerate(pops):
    for pop in pops[super_pop]:
        plt.plot(*Z_transform_dict[pop][0, :],
                 'o',
                 markersize=5,
                 c=f'C{idx}',
                 label=pop)    
    #     ax.set_xlabel('UMAP 1')
    #     ax.set_ylabel('UMAP 2')
ax.set_xlabel('PC 1')
ax.set_ylabel('PC 2')    
ax.legend(loc=(1.04, -0.15), fancybox=True, framealpha=0, ncol=2)
plt.tight_layout()
plt.savefig(f'/Users/williamdewitt/Downloads/X_pca.pdf', transparent=True)

In [None]:
# sns.barplot(ksfs_dict[pop].mutation_types, embedding.components_[0,:])
# plt.show()

In [None]:
Z_dict = {}
for pop in pops['AFR']:
    Z_dict[pop] = closure(ksfs_dict[pop].X.sum(0, keepdims=True))

df = pd.DataFrame(index=ksfs_dict[pop].μ.mutation_types)

df["5'"] = df.index.str[0]
df['mutation'] = df.index.str[1].str.cat(df.index.str[5], sep='→')
df["3'"] = df.index.str[2]

df['LWK Vs MSL'] = Z_dict['LWK'] / Z_dict['MSL']
df['YRI Vs MSL'] = Z_dict['YRI'] / Z_dict['MSL']
df['GWD Vs MSL'] = Z_dict['GWD'] / Z_dict['MSL']
df['ESN Vs MSL'] = Z_dict['ESN'] / Z_dict['MSL']

df = df.melt(id_vars=["5'", 'mutation', "3'"], var_name='comparison')

g = sns.FacetGrid(df, row='mutation', col='comparison',
                  row_order=('C→A', 'C→G', 'C→T', 'A→G', 'A→C', 'A→T'),
                  margin_titles=True,
                  height=1.5
                  )

def facet_heatmap(data, color, **kwargs):
    data = data.pivot(index="5'", columns="3'", values='value')
    sns.heatmap(data, **kwargs).invert_yaxis()

# colorbar axes
cbar_ax = g.fig.add_axes([1.1, .3, .05, .4])

g = g.map_dataframe(facet_heatmap,
                    cbar_ax=cbar_ax,
                    cmap='RdBu_r',
                    center=1,
                    vmin=0.84, vmax=1.16
                    )
for ax in g.axes.flat:
    plt.setp(ax.texts, text="") 
g.set_titles(row_template = '{row_name}', col_template = '{col_name}')

g.fig.tight_layout()
plt.tight_layout()
# so the colorbar doesn't overlap the plot
# g.fig.subplots_adjust(right=.9)
g.savefig('/Users/williamdewitt/Downloads/AFR_heatmap.pdf')
plt.show()

In [None]:
import matplotlib.animation as animation
# from IPython.display import HTML

time = np.concatenate(([0], ksfs_dict[pop].μ.change_points))

Z_dict = {}
for pop in ksfs_dict:
    Z_dict[pop] = ilr(closure(ksfs_dict[pop].μ.Z))
    
Z = np.concatenate(tuple(Z_dict[pop] for pop in ksfs_dict))

# use one of these, or the one above learned on the expected segregating sites
embedding = UMAP(n_components=2,
                 n_neighbors=100,
                 min_dist=0,
#                  metric='cosine',
#                  local_connectivity=2,
#                  n_epochs=500
                )
# embedding = PCA(n_components=2)#, whiten=True)
embedding.fit(Z)

# each
Z_transform_dict = {pop:embedding.transform(Z_dict[pop]) for pop in ksfs_dict}

plt.figure(figsize=(5, 3))
# plt.subplot(311)
for idx, super_pop in enumerate(pops):
    for pop in pops[super_pop]:
        plt.plot(*Z_transform_dict[pop].T,
                 '-o',
                 c=f'C{idx}',
                 label=pop, markevery=[0])
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
plt.legend(loc=(1.04, -0.15), fancybox=True, framealpha=0, ncol=2)
plt.tight_layout()

# plt.subplot(312)
# for idx, super_pop in enumerate(pops):
#     for pop in pops[super_pop]:
#         plt.plot(time, Z_transform_dict[pop][:, 0],
#                  '-o',
#                  c=f'C{idx}',
#                  label=pop, markevery=[0])
# plt.xlabel('time (generations)')
# plt.ylabel('UMAP 1')
# plt.xscale('symlog')
# # plt.legend(loc=(1.04, -1), fancybox=True, framealpha=0)
# plt.subplot(313)
# for idx, super_pop in enumerate(pops):
#     for pop in pops[super_pop]:
#         plt.plot(time, Z_transform_dict[pop][:, 1],
#                  '-o',
#                  c=f'C{idx}',
#                  label=pop, markevery=[0])
# plt.xlabel('time (generations)')
# plt.ylabel('UMAP 2')
# plt.xscale('symlog')
# plt.tight_layout()

plt.savefig(f'/Users/williamdewitt/Downloads/Z_embed.pdf', transparent=True)
plt.show()

In [None]:
time = np.concatenate(([0], ksfs_dict[pop].μ.change_points))

embedding = NMF(alpha=1.5e0, n_components=10)#, verbose=False, tol=1e-10, max_iter=10000)
# embedding = PCA(n_components=3)
embedding.fit(np.concatenate(tuple(ksfs_dict[pop].μ.Z.T for pop in ksfs_dict)))

Z_transform_dict = {pop:embedding.transform(ksfs_dict[pop].μ.Z.T) for pop in ksfs_dict}

plt.figure(figsize=(4, 2))
for i in range(embedding.n_components):
    plt.plot(time, embedding.components_[i], label=f'latent history {i + 1}')
plt.xlabel('$t$')
plt.xscale('symlog')
plt.legend(loc='center left', prop={'size': 7.5}, framealpha=.5)
plt.show()

In [None]:
for pop in ksfs_dict:
    print(pop)
    W = embedding.transform(ksfs_dict[pop].μ.Z.T).T

#     # norms = nmf.components_.T.mean(0, keepdims=True)
#     H = nmf.components_.T# / norms
#     # W = norms.T * W

    total_weight = W.sum(1, keepdims=True)
    W = W / total_weight

    df = pd.DataFrame(data=W.T,
                      index=ksfs_dict[pop].μ.mutation_types,
                      columns=range(1, n_components + 1))

    df_min = df.values.min()
    df_max = df.values.max()

    df["5'"] = df.index.str[0]
    df['mutation'] = df.index.str[1].str.cat(df.index.str[5], sep='→')
    df["3'"] = df.index.str[2]

    df = df.melt(id_vars=["5'", 'mutation', "3'"], var_name='latent history')

    g = sns.FacetGrid(df, row='mutation', col='latent history',
                      row_order=('C→A', 'C→G', 'C→T', 'A→G', 'A→C', 'A→T'),
                      margin_titles=True,
                      size=1.5
                      )

    def facet_heatmap(data, color, **kwargs):
        data = data.pivot(index="5'", columns="3'", values='value')
        sns.heatmap(data, **kwargs).invert_yaxis()n

    # colorbar axes
    cbar_ax = g.fig.add_axes([1.1, .3, .05, .4])

    g = g.map_dataframe(facet_heatmap,
                        cbar_ax=cbar_ax,
                        cmap='RdBu_r',
                        center=0,
                        vmin=df_min, vmax=df_max
                        )

    # so the colorbar doesn't overlap the plot
    # g.fig.subplots_adjust(right=.9)
#     plt.savefig('/Users/williamdewitt/Downloads/PC_heatmap.pdf')
    plt.show()

In [None]:
plt.figure()
for pop in ksfs_dict:
    plt.scatter(*Z_transform_dict[pop][:, 1:].T, alpha=0.8)
plt.show()

In [None]:
x = []
y = []
derived = []
context = []
pops = []

for i, triplet in enumerate(sorted_triplets):
    for pop in ksfs_dict:
        xy = Z_transform_dict[pop][i, 1:]
        x.append(xy[0])
        y.append(xy[1])
        pops.append(pop)
        context.append(triplet[:3])
        derived.append(triplet[5])
        
df = pd.DataFrame({'population':pops, 'context':context, 'derived':derived, 'latent history 2':x, 'latent history 3':y})

In [None]:
g = sns.relplot(x='latent history 2', y='latent history 3', row='derived', col='context', data=df, hue='population',
           height=2, aspect=1, alpha=.8).set_titles("{col_name}>{row_name}")
g.savefig('/Users/williamdewitt/Downloads/AFR.signatures.pdf')
plt.show()

In [None]:
sns.catplot(y='latent history 3', x='singlet', row='context', data=df, hue='population')
plt.show()