In [None]:
import numpy as np
import pandas as pd

import altair as alt
import seaborn as sns
import matplotlib.pyplot as plt

import dms_variants.codonvarianttable

# Definitions: Must be contained in the lookup table (lut)
barcode_column = 'barcode'
aa_column = 'aa_substitutions'
n_aa_column = 'n_aa_substitutions'

#Input file
lut_filename = 'rib_lut.csv' # The "Read Illumina Barcode" lut is sufficient

#Filters
mut_freq_filter = 'aa_substitutions_occurence > 1' # Filter for minimum count of mutations (independent from barcode) in lut

#Heatmap
heatmap = True
interactive = True

consensus_lut_filename = 'con_lut.csv'
gen_sequence = 'ATGCATTCTCAAAAGAGAGTTGTTGTTTTAGGTTCCGGTGTTATCGGTTTATCCTCTGCTTTGATTTTGGCTAGAAAGGGTTACTCCGTTCATATTTTGGCAAGAGATTTGCCAGAAGATGTCTCTTCTCAAACTTTTGCTTCTCCATGGGCTGGTGCTAATTGGACTCCTTTTATGACTTTGACTGATGGTCCAAGACAAGCTAAATGGGAAGAATCTACTTTCAAGAAGTGGGTTGAATTGGTTCCAACTGGTCATGCTATGTGGTTGAAAGGTACTAGAAGATTCGCTCAAAACGAGGATGGTTTGTTAGGTCATTGGTACAAGGATATTACCCCAAACTATAGACCATTGCCATCTTCAGAATGTCCACCAGGTGCTATTGGTGTTACTTATGATACTTTGTCTGTTCACGCTCCAAAGTACTGTCAATACTTGGCTAGAGAATTGCAAAAGTTGGGTGCTACCTTTGAAAGAAGAACTGTTACATCTTTGGAACAAGCCTTTGATGGTGCTGATTTGGTTGTTAATGCTACTGGTTTAGGTGCTAAGTCCATTGCTGGTATTGATGATCAAGCTGCTGAACCTATTAGAGGTCAAACTGTTTTGGTTAAGTCTCCATGTAAGAGGTGTACTATGGATTCTTCTGATCCAGCTTCTCCAGCTTACATTATTCCAAGACCAGGTGGTGAAGTTATTTGTGGTGGTACTTACGGTGTTGGTGATTGGGATTTGTCAGTTAATCCAGAAACCGTCCAGAGAATTTTGAAGCACTGTTTGAGATTGGACCCAACCATTTCTTCAGATGGTACTATTGAAGGTATCGAAGTCTTGAGACACAATGTCGGTTTAAGACCAGCTAGAAGAGGTGGTCCTAGAGTTGAAGCTGAAAGAATAGTTTTGCCATTGGATAGGACCAAGTCACCATTGTCTTTAGGTAGAGGTTCTGCTAGAGCTGCCAAAGAAAAAGAAGTTACTTTGGTTCACGCTTACGGTTTTTCATCTGCTGGTTATCAACAATCTTGGGGTGCTGCTGAAGATGTTGCTCAATTGGTTGATGAAGCCTTTCAAAGATATCATGGTGCTGCTAGAGAA' # From snapgene

In [None]:
lut = pd.read_csv(lut_filename,
                  header=None,
                  names=[barcode_column,aa_column,n_aa_column]).fillna('wt')

lut_freq = lut.assign(aa_substitutions_occurence=lambda x: x[aa_column]
                             .map(lut[aa_column]
                                  .value_counts())).copy()
lut_size = lut_freq.query(mut_freq_filter).copy()

lut_stats = pd.DataFrame(data={
    'n_muts' : ['total'],
    'n_muts_size' : [lut[aa_column].shape[0]],
    'n_unique_muts_size' : [lut[aa_column].unique().shape[0]], 
    'n_frequent_muts_size' : [lut_size[aa_column].shape[0]],
    'n_frequent_unique_muts_size' : [lut_size[aa_column].unique().shape[0]]
})

muts_sorted = np.sort ( lut[n_aa_column].unique() )

for i in muts_sorted:
    lut_stats = lut_stats.append(
        pd.DataFrame(data={'n_muts' : [i],
                           'n_muts_size' : [lut.query(f"{n_aa_column} == {i}")[aa_column]
                                            .shape[0]],
                            'n_unique_muts_size' :
                           [lut.query(f"{n_aa_column} == {i}")[aa_column].unique()
                            .shape[0]],
                            'n_frequent_muts_size' :
                           [lut_size.query(f"{n_aa_column} == {i}")[aa_column]
                            .shape[0]],
                            'n_frequent_unique_muts_size' : 
                           [lut_size.query(f"{n_aa_column} == {i}")[aa_column].unique()
                            .shape[0]]
                          }))

display(lut_stats)

for_plot_full = lut_freq.drop(
    columns=[barcode_column]).drop_duplicates().query(
    f"{n_aa_column} > 0").sort_values(
    by='aa_substitutions_occurence', ascending=False).reset_index(
    drop=True).copy()

sns.scatterplot(data=for_plot_full,
                x=for_plot_full.index,
                y="aa_substitutions_occurence").set_title('Counts for all mutations')
plt.show()

for_plot = for_plot_full.query(f"{n_aa_column} == 1").reset_index(drop=True)

sns.scatterplot(data=for_plot,
                x=for_plot.index,
                y="aa_substitutions_occurence").set_title('Counts for the single mutations')
plt.show()

if heatmap:
    if interactive:
        
        single_muts = lut.groupby([aa_column,n_aa_column], as_index=False).size().query(f"{n_aa_column} == 1").copy()

        single_muts['wt_aa'] = single_muts[aa_column].str.extract(r'(^[A-Z*])')
        single_muts['position'] = single_muts[aa_column].str.extract(r'([0-9]+)').astype(int)
        single_muts['mutated_aa'] = single_muts[aa_column].str.extract(r'([A-Z*]$)')

        smp = single_muts[['position','wt_aa','mutated_aa','size',aa_column]].reset_index(drop=True)

        remove_stop_codon = False

        if remove_stop_codon:
            smp = smp.query('mutated_aa != "*"')

        alt.data_transformers.disable_max_rows()

        brush = alt.selection_interval(encodings=['x'])

        bar = alt.Chart(smp).mark_bar(size=2, color='grey').encode(
        alt.X('position:O')
        ).properties(width=alt.Step(2)).add_selection(brush)
        
        muts = alt.Chart(smp).mark_rect().encode(
        alt.X('position:O'),
        y='mutated_aa:O',
        color='size:Q',
        tooltip = ['size',aa_column]
        ).transform_filter(brush).add_selection(alt.selection_single())
        
        wt = alt.Chart(smp).mark_text().encode(
        x='position:O', 
        y='wt_aa:O',
        text='wt_aa:O',
        tooltip = ['position']
        ).transform_filter(brush).add_selection(alt.selection_single())
        
        hm = alt.vconcat(
            
            bar,
            muts + wt,
        
        )
        display(hm)
        
    else:
        variants = dms_variants.codonvarianttable.CodonVariantTable(
            barcode_variant_file = consensus_lut_filename,
            geneseq = gen_sequence,
        )

        display(variants.plotMutHeatmap('all','aa',samples=None))