Quickstart
==

This page demonstrates how to use the Python module `mushi` to infer mutation spectrum history and demography

We use `mushi` to infer history of the mutation process, which we can think of as the mutation rate function over time for each triplet mutation type.
In `mushi`, we use coalescent theory and optimization techniques to learn about this history from the $k$-SFS: a matrix whose columns are sample frequency spectra (SFS) for each mutation type.


## Imports
We first import the `ksfs` module from the `mushi` package, and a few other standard packages.

In [None]:
import mushi

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

## Load $k$-SFS data

### $3$-SFS for 1000 Genomes FIN population

We load $3$-SFS data for the 1000 Genomes Finnish population, which we previously computed with the `mutyper` [package](https://github.com/harrispopgen/mutyper) and saved in tabular format.

In [None]:
ksfs = mushi.kSFS(file='../example_data/3-SFS.EUR.FIN.tsv')

Plot the population variant spectrum (summing the $k$-SFS over sample frequency)

In [None]:
ksfs.as_df().sum(0).plot.bar(figsize=(17, 3))
plt.xticks(family='monospace')
plt.ylabel('number of variants');

### Target sizes
We will also need the masked genome size for each mutation type, which we've also previously computed with `mutyper targets`. This defines mutational target sizes.

In [None]:
masked_genome_size = pd.read_csv(f'../example_data/masked_size.tsv', sep='\t', header=None, index_col=0)
masked_genome_size.index.name='mutation type'

masked_genome_size.plot.bar(figsize=(6, 3), legend=False)
plt.xticks(family='monospace')
plt.ylabel('mutational target size (sites)');

### Normalized variant spectra

With this we can compute the number of SNPs per target in each mutation type. Notice the enrichment of C>T transitions at CpG sites.

In [None]:
normalized_hit_rates = ksfs.as_df().sum(0).to_frame(name='variant count')
normalized_hit_rates['target size'] = [int(masked_genome_size.loc[context])
                                       for context, _ in normalized_hit_rates['variant count'].index.str.split('>')]

(normalized_hit_rates['variant count'] /
 normalized_hit_rates['target size']).plot.bar(figsize=(17, 3), legend=False)
plt.xticks(family='monospace')
plt.ylabel('variants per target');

## Plot the SFS
The SFS is given by summing the rows of the $k$-SFS over mutation types

In [None]:
ksfs.plot_total()

Plot on log scale instead of linear

In [None]:
ksfs.plot_total()
plt.xscale('log')
plt.yscale('log')

## Plot the $k$-SFS
Plot $k$-SFS composition as a scatter (a color for each mutation type)

In [None]:
ksfs.plot()

Plot again using the centered log ratio transform option (`clr=True`) to represent compositions over mutation types in each frequency class

In [None]:
ksfs.plot(clr=True)

Use a heatmap to show the $k$-SFS as a matrix, colored using a compositional centralization transform.
This method wraps the `clustermap` function in [Seaborn](https://seaborn.pydata.org/generated/seaborn.clustermap.html#seaborn.clustermap), and you can use keyword arguments you would use in that function. A few of note:
- `figsize`: a tuple giving the width and height of the figure
- `col_cluster`: if `False`, don't cluster the columns (giving an ordinary heatmap)
- `xticklabels`: if `True`, force printing of all the mutation type labels
- `robust`: if `True`, automatically set the min and max values for the color scaling
- `cmap`: color map name

In [None]:
ksfs.clustermap(figsize=(17, 7), xticklabels=True, robust=True, cmap='RdBu_r')

## Basic model parameters
### Total mutation rate
To compute the total mutation rate in units of mutations per masked genome per generation, we multiply an estimate of the site-wise rate by the target size

In [None]:
μ0 = 1.25e-8 * masked_genome_size[1].sum()
μ0

### Generation time
To render time in years rather than generations, we use an estimate of the generation time

In [None]:
t_gen = 29

### Discrete time grid
We define a grid of times to represent histories on, measured retrospectively from the present in units of Wright-Fisher generations.

In [None]:
t = np.logspace(np.log10(1), np.log10(200000), 200)

## Jointly infer demography and mutation spectrum history

To infer a time-calibrated mutation spectrum history (MuSH), we need to jointly estimate the demographic history, since this defines the diffusion timescale of the coalescent process.

We now run the optimization, setting a few parameters to control how complicated we let the histories look. We use the `verbose=True` argument to print convergence messages.

In [None]:
ksfs.infer_history(t, μ0, alpha_tv=1e2, alpha_spline=3e3, alpha_ridge=1e-4,
                   beta_rank=1e1, beta_tv=7e1, beta_spline=1e1, beta_ridge=1e-4,
                   tol=1e-11, verbose=True)

Hopefully you agree that was fast 🏎

## Visualizing inferred histories

After inferring histories, we can access them as instance attributes of the `kSFS` object:
- demography: `ksfs.eta` (or `ksfs.η`)
- MuSH: `ksfs.mu` (or `ksfs.μ`)

We can inspect these, and see they're both object with base class `mushi.histories.history`

In [None]:
print(ksfs.eta.__class__)
print(ksfs.eta.__class__.__bases__[0])

print(ksfs.mu.__class__)
print(ksfs.mu.__class__.__bases__[0])

We'll now check that the demography has a few features we expect in the Finnish population: the out-of-Africa bottleneck shared by all Eurasians, a later bottleneck associated with northward migration, and exponential population growth toward the present.
- The plot on the left shows fit to the SFS
- The plot on the right shows the inferred haploid effective population size history.

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(121)
ksfs.plot_total()
plt.xscale('log')
plt.yscale('log')
plt.subplot(122)
ksfs.eta.plot(t_gen=t_gen)
plt.xlim([1e3, 1e6]);

Now let's take a look at the inferred mutation spectrum history (MuSH).
- The plot on the left shows the measured $k$-SFS composition (points) and the fit from `mushi` (lines)
- The plot on the right shows the inferred MuSH

In [None]:
plt.figure(figsize=(16, 5))
plt.subplot(121)            
ksfs.plot(clr=True)            
plt.subplot(122)
ksfs.μ.plot(t_gen=t_gen, clr=True, alpha=0.75)
ksfs.μ.plot(('TCC>TTC',), t_gen=t_gen, clr=True, lw=5)
plt.xscale('log')
plt.xlim([1e3, 1e6]);

We can also plot the MuSH as a heatmap with the y axis representing time (with an interface similar to the `kSFS.clustermap` method above)

In [None]:
ksfs.μ.clustermap(t_gen=t_gen, figsize=(17, 7), xticklabels=True, robust=True, cmap='RdBu_r')

Now that you have a MuSH, you can start answering questions about mutation spectrum history!🤸‍