# Synonymous mutation spectrum
Get input variables from [papermill](https://papermill.readthedocs.io/) parameterization (note next cell is tagged as `parameters`):

In [1]:
input_csv = "results/mutation_counts/aggregated.csv"

Import Python modules:

In [2]:
import os

import altair as alt

import numpy

import pandas as pd

import sklearn.decomposition

Read the mutation counts and assign mutation types:

In [3]:
mutation_counts = pd.read_csv(input_csv).assign(
    mut_type=lambda x: x["nt_mutation"].map(lambda m: f"{m[0]}to{m[-1]}")
)

mutation_counts

Unnamed: 0,protein,aa_mutation,nt_mutation,codon_change,synonymous,count,nt_site,exclude,exclude_reason,clade,subset,mut_type
0,ORF1ab,ORF1ab,C16466T,CCA>CTA,False,62,16466,False,,19A,all,CtoT
1,M,M,T26767C,ATC>ACC,False,54,26767,False,,19A,all,TtoC
2,ORF1a;ORF1ab,ORF1a;ORF1ab,C3037T,TTC>TTT;TTC>TTT,True,50,3037,False,,19A,all,CtoT
3,ORF1ab,ORF1ab,C19220T,GCT>GTT,False,47,19220,False,,19A,all,CtoT
4,ORF1ab,ORF1ab,A17236G,ATA>GTA,False,45,17236,False,,19A,all,AtoG
...,...,...,...,...,...,...,...,...,...,...,...,...
974785,ORF1a;ORF1ab,ORF1a;ORF1ab,G2900A,GTC>ATC;GTC>ATC,False,1,2900,False,,22C,England,GtoA
974786,ORF1a;ORF1ab,ORF1a;ORF1ab,G2867A,GTA>ATA;GTA>ATA,False,1,2867,False,,22C,England,GtoA
974787,ORF1a;ORF1ab,ORF1a;ORF1ab,G2782T,GTG>GTT;GTG>GTT,True,1,2782,False,,22C,England,GtoT
974788,ORF1a;ORF1ab,ORF1a;ORF1ab,G2525A,GAG>AAG;GAG>AAG,False,1,2525,False,,22C,England,GtoA


For each clade plot the top mutations as a fraction of all mutations in that clade, just using the "all" subset.
You can mouseover points to highlight mutations (which will highlight all mutations at that site on all facets), and click the legend to show/hide excluded or non-excluded mutations.
This plot is useful to look at to identifier apparent outlier sites with aberrantly high mutation counts:

In [27]:
top_n = 100  # plot this many per clade

mutation_freqs = (
    mutation_counts
    .query("subset == 'all'")
    .sort_values(["clade", "count"], ascending=False)
    .groupby("clade")
    .head(n=top_n)
    .assign(
        freq=lambda x: x["count"] / x.groupby("clade")["count"].transform("sum"),
        rank=lambda x: x.groupby("clade")["freq"].rank(ascending=False, method="first"),
        exclude=lambda x: x["exclude"].map({True: "yes", False: "no"}),
    )
)

select_exclude = alt.selection_multi(
    fields=["exclude"], bind="legend", init=[{"exclude": "yes"}, {"exclude": "no"}],
)

select_site = alt.selection_single(
    fields=["nt_site"], on="mouseover", empty="none",
)

mutation_freqs_chart = (
    alt.Chart(mutation_freqs)
    .encode(
        x="rank",
        y="freq",
        strokeWidth=alt.condition(select_mutation, alt.value(2), alt.value(0)),
        color=alt.Color("exclude", scale=alt.Scale(domain=["yes", "no"])),
        shape=alt.Shape("synonymous"),
        size=alt.condition(select_mutation, alt.value(50), alt.value(25)),
        tooltip=["nt_site", "nt_mutation", "count", "freq"],
    )
    .mark_point(filled=True, stroke="black")
    .properties(width=250, height=100)
    .facet("clade", columns=3)
    .add_selection(select_exclude, select_mutation)
    .transform_filter(select_exclude)
)

mutation_freqs_chart

Tally mutation type counts among **only synonymous** mutations for each clade and subset, separating reversions to reference from other mutations:

In [None]:
mut_type_counts = (
    mutation_counts
    .query("synonymous")
    .groupby(["clade", "subset", "mut_type", "reversion_to_ref"], as_index=False)
    .aggregate({"count": "sum"})
)

mut_types = mut_type_counts["mut_type"].unique().tolist()

Plot total mutation counts for each clade and subset on a log scale:

In [None]:
clade_counts = (
    mut_type_counts
    .groupby(["clade", "subset"], as_index=False)
    .aggregate({"count": "sum"})
)

clade_counts_chart = (
    alt.Chart(clade_counts)
    .encode(
        x="clade",
        y=alt.Y("count", title="total mutations"),
        tooltip=["clade", "count"],
        column=alt.Column("subset", title=None),
    )
    .mark_bar()
    .properties(width=alt.Step(12), height=175)
)

clade_counts_chart

Plot fraction of mutation counts from reversions to reference.
Below you can see these fractions are high, probably indicating there is some issue with calling reversions to reference that is still plaguing the data and such reversions should perhaps be ignored:

In [None]:
reversion_fracs = (
    mut_type_counts
    .groupby(["clade", "subset", "reversion_to_ref"], as_index=False)
    .aggregate({"count": "sum"})
    .assign(frac=lambda x: x["count"] / x.groupby(["clade", "subset"])["count"].transform("sum"))
)

reversion_fracs_chart = (
    alt.Chart(reversion_fracs)
    .encode(
        x="clade",
        y=alt.Y("frac", title="fraction of mutations"),
        color="reversion_to_ref",
        tooltip=["clade", "count", "frac"],
        column=alt.Column("subset", title=None),
    )
    .mark_bar()
    .properties(width=alt.Step(12), height=150)
)

reversion_fracs_chart

In [None]:
mutation_counts.groupby(["clade", "reversion_to_ref", "reversion_to_founder"]).aggregate({"count": "sum"})

In [None]:
(
    mutation_counts
    .query("synonymous")
    .query("subset == 'all'")
   # .query("reversion_to_ref")
    .sort_values("count", ascending=False)
    .query("nt_mutation.str.contains('3037')")
    .head(20)
)



In [None]:
(
    mutation_counts
    .query("synonymous")
    .query("subset == 'all'")
    .query("not reversion_to_ref")
    .sort_values("count", ascending=False)
    .query("mut_type == 'CtoT'")
    .head(20)
)

Get PCA of mutation spectrum separately for each clade and stratifying mutations by whether they are reversions to reference:

In [None]:
mut_type_freqs = (
    mut_type_counts
    .assign(
        total_count=lambda x: x.groupby(["clade", "reversion_to_ref"])["count"].transform("sum"),
        freq=lambda x: x["count"] / x["total_count"],
    )
    .pivot_table(
        index=["clade", "reversion_to_ref", "total_count"],
        values="freq",
        columns="mut_type",
        fill_value=0,
    )
)

pca = sklearn.decomposition.PCA(n_components=2)
pca_coords = pca.fit_transform(mut_type_freqs.values)
assert len(pca_coords) == len(mut_type_freqs)

mut_type_freqs_pca = (
    mut_type_freqs
    .reset_index()
    .assign(
        principal_component_1=pca_coords[:, 0],
        principal_component_2=pca_coords[:, 1],
        log10_total_count=lambda x: numpy.log(x["total_count"]) / numpy.log(10),
    )
)

Plot mutation spectrum PCA for all clades including both reversions to reference and non reversions to reference:

In [None]:
mut_type_freqs_chart = (
    alt.Chart(mut_type_freqs_pca)
    .encode(
        x="principal_component_1",
        y="principal_component_2",
        shape="reversion_to_ref",
        color=alt.Color("clade", scale=alt.Scale(scheme="viridis")),
        tooltip=["clade", "reversion_to_ref", "total_count"],
    )
    .mark_point(filled=True, size=50)
)

mut_type_freqs_chart

Now do PCA on just non-reversion mutations:

In [None]:
mut_type_freqs_no_revert = mut_type_freqs.query("reversion_to_ref == False")

pca = sklearn.decomposition.PCA(n_components=2)
pca_coords_no_revert = pca.fit_transform(mut_type_freqs_no_revert.values)
assert len(pca_coords_no_revert) == len(mut_type_freqs_no_revert)

mut_type_freqs_no_revert_pca = (
    mut_type_freqs_no_revert
    .reset_index()
    .assign(
        principal_component_1=pca_coords_no_revert[:, 0],
        principal_component_2=pca_coords_no_revert[:, 1],
        log10_total_count=lambda x: numpy.log(x["total_count"]) / numpy.log(10),
    )
)

Plot the PCA on just non-reversion mutations:

In [None]:
total_count_selection = alt.selection_single(
    fields=["log10_total_count"],
    init={"log10_total_count": 4},
    bind=alt.binding_range(
        name="minimum log10 total counts",
        min=int(mut_type_freqs_no_revert_pca["log10_total_count"].min()),
        max=mut_type_freqs_no_revert_pca["log10_total_count"].max(),
    )
)

mut_type_freqs_no_revert_chart = (
    alt.Chart(mut_type_freqs_no_revert_pca)
    .encode(
        x="principal_component_1",
        y="principal_component_2",
        color=alt.Color("clade", scale=alt.Scale(scheme="viridis")),
        tooltip=["clade", "total_count"],
    )
    .mark_point(filled=True, size=50)
    .add_selection(total_count_selection)
    .transform_filter(
        total_count_selection.log10_total_count <= alt.datum.log10_total_count
    )
)

mut_type_freqs_no_revert_chart