## Plot sequencing effiency

- What fraction of the *total* data is on target?

In [272]:
import os
import sys
import pandas as pd
import numpy as np


from nomadic.pipeline.qcbams.plot import MappingStatesAndColors


import plotly.graph_objects as go
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 300

In [378]:
def get_experiment_dir(expt, barcoding="single_end"):
    d = f"../experiments/{expt}/nomadic/guppy/hac/{barcoding}"
    return d

def load_mapping_df(expt, barcoding="single_end", state="primary"):
    expt_dir = get_experiment_dir(expt, barcoding)
    df = pd.read_csv(f"{expt_dir}/qc-bams/table.mapping.{state}_state.csv")
    return df

def load_balance_df(expt, barcoding="single_end"):
    expt_dir = get_experiment_dir(expt, barcoding)
    df = pd.read_csv(f"{expt_dir}/target-extraction/table.target_coverage.overview.csv")
    return df

In [379]:
def load_seq_effiency(expt, barcoding="single_end"):
    expt_dir = get_experiment_dir(expt, barcoding)
    df = pd.read_csv(
        f"../../nomadic/experiments/{expt_dir}"
        "/qc-efficiency/table.overall.effiency.csv")
    return df

In [380]:
msc = MappingStatesAndColors()

## Load data

*Experiment*

In [381]:
#expt = "2021-11-14_strain-validation-flongle-lfb"
expt = "2022-02-16_zmb-discards-8plex"

*NOMADIC*

In [382]:
eff_df = load_seq_effiency(expt)
eff_df.columns = ["subset", "n_reads", "n_bases"]

In [383]:
eff_df

Unnamed: 0,subset,n_reads,n_bases
0,n_total,1635169.0,4848295000.0
1,n_passed_qc,1513165.0,4563743000.0
2,n_barcoded,1454852.0,4396009000.0
3,n_mapped,1095022.0,3416287000.0
4,n_ontarget,1016555.0,3334339000.0


*NOMADIC2*

In [384]:

mapping_df = load_mapping_df(expt)
balance_df = load_balance_df(expt)

## Visualise with Sankey

In [106]:
eff_df

Unnamed: 0,subset,n_reads,n_bases
0,n_total,1635169.0,4848295000.0
1,n_passed_qc,1513165.0,4563743000.0
2,n_barcoded,1454852.0,4396009000.0
3,n_mapped,1095022.0,3416287000.0
4,n_ontarget,1016555.0,3334339000.0


In [254]:
link = {
    "source": [0, 0, 2, 2, 4, 4, 6, 6],
    "target": [1, 2, 3, 4, 5, 6, 7, 8],
    "value": [1635169-1513165, 
              1513165, 
              1513165-1454852, 
              1454852,
              1454852-1095022,
              1095022,
              1095022-1016555,
              1016555
             ],
}
node = {
    "pad": 15,
    "thickness": 15,
    "line": dict(color = "black", width = 0.5),
    "color": sns.color_palette("viridis", 9).as_hex(), #[rgb(c) for c in sns.color_palette("Set1", 3)],
    "label": ["Total Reads", 
              "Failed QC", "Passed QC", 
              "Unclassified", "Barcoded",
              "Unmapped", "Mapped",
              "Off-target", "On-target"
             ],
    "x": [0, 0.2, 0.2, 0.4, 0.4, 0.6, 0.6],
    "y": [0, 1, 0.1, 1, 0.1, 1, 0.1]
}

In [256]:
plot_data = go.Sankey(
    domain={'x': [0, 1], 'y': [0, 1]},
    link=link,
    node=node,
    arrangement="snap"
)

In [257]:
fig = go.Figure(plot_data)
fig.show()

- This would be version one, then...

## Adding  genes

In [261]:
gene_sum_df = (balance_df
               .query("overlap == 'any'")
               .groupby("gene_name")
               .sum()
              )

In [263]:
gene_sum_df

Unnamed: 0_level_0,reads_total,reads_mapped,bases_total,bases_mapped,mean_read_length,mean_read_qual,mismatches,error_rate
gene_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
CRT1,22339.0,22339.0,75566717.0,70329993.0,152324.0,1270.3,5627319.0,3.391961
DHFR,255256.0,255256.0,889405773.0,820265221.0,158056.0,1207.2,43882354.0,2.271811
DHPS,119563.0,119563.0,408919847.0,387680440.0,150972.0,1238.3,21275815.0,2.322249
K13,172970.0,172970.0,640288633.0,596591684.0,164642.0,1131.5,33338717.0,2.364168
MDR1,226662.0,226662.0,802392488.0,782838399.0,160486.0,1100.1,29951190.0,1.637052
MSP2,48396.0,48396.0,175051434.0,162679831.0,159902.0,1204.4,17494919.0,4.533634
PMI,100450.0,100450.0,327386326.0,289416485.0,146909.0,1115.2,15839647.0,2.297388
PMIII,70918.0,70918.0,236173342.0,224549826.0,145672.0,1296.7,13752535.0,2.613582


In [347]:
link = {
    "source": [0, 0, 2, 2, 4, 4, 6, 6] + [8]*8,
    "target": [1, 2, 3, 4, 5, 6, 7, 8] + list(range(9, 9+8)),
    "value": [1635169-1513165, 
              1513165, 
              1513165-1454852, 
              1454852,
              1454852-1095022,
              1095022,
              1095022-1016555,
              1016555
             ] + gene_sum_df["reads_total"].tolist(),
}
node = {
    "pad": 15,
    "thickness": 15,
    "line": dict(color = "black", width = 0.5),
    "color": sns.color_palette("viridis", 9).as_hex()
    + sns.color_palette("Spectral", 8).as_hex(), #[rgb(c) for c in sns.color_palette("Set1", 3)],
    "label": ["Total Reads", 
              "Failed QC", "Passed QC", 
              "Unclassified", "Barcoded",
              "Unmapped", "Mapped",
              "Off-target", "On-target"
             ] + gene_sum_df.index.tolist(),
    "x": [0, 0.15, 0.15, 0.3, 0.3, 0.45, 0.45, 0.6, 0.6] + [1]*8,
    "y": [0, 0.1, 0.1, 0.5, 0.2, 0.5, 0.3, 0.5, 0.4] + list(np.linspace(-0.2, 0.2, 9))
}

In [348]:
plot_data = go.Sankey(
    domain={'x': [0, 1], 'y': [0, 1]},
    link=link,
    node=node,
    arrangement="snap"
)

In [349]:
fig = go.Figure(plot_data)
fig.show()

TODO:
- Automate or partially automate
- Add percentages and additional information
- For clinical discards, how to deal with the fact that a whole row of samples were dropped? That needs to come outof the whole analysis now...

## Automate from input data (why?)

*Before amplicons*

In [365]:
sources = np.repeat(np.arange(0, 8, 2), 2).tolist()
targets = np.arange(0, 9).tolist()

In [371]:
# Values
STAT = "n_reads"
values = []
for i, row in eff_df.iterrows():
    if i == 0:
        last_value = row[STAT]
        continue
    current_value = row[STAT]
    lost = last_value - current_value
    values.append(lost)
    values.append(current_value)
    last_value = current_value
    
# Labels
labels_text = ["Total Reads", 
               "Failed QC", "Passed QC", 
               "Unclassified", "Barcoded",
               "Unmapped", "Mapped",
               "Off-target", "On-target"
              ]
norm_val = eff_df.query("subset == 'n_total'").squeeze()["n_reads"]
labels_percent = [
    f"{l} ({100*v/norm_val:.01f}%)"
    for l, v in zip(labels_text, [norm_val] + values)
]

In [372]:
labels_percent

['Total Reads (100.0%)',
 'Failed QC (7.5%)',
 'Passed QC (92.5%)',
 'Unclassified (3.6%)',
 'Barcoded (89.0%)',
 'Unmapped (22.0%)',
 'Mapped (67.0%)',
 'Off-target (4.8%)',
 'On-target (62.2%)']

In [394]:
mapping_df[msc.primary_levels].sum()

unmapped       11836
hs_mapped     347699
pf_mapped    1094955
dtype: int64

In [402]:
eff_df.index = eff_df["subset"]

In [407]:
eff_df.drop("n_mapped")

Unnamed: 0_level_0,subset,n_reads,n_bases
subset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
n_total,n_total,1635169.0,4848295000.0
n_passed_qc,n_passed_qc,1513165.0,4563743000.0
n_barcoded,n_barcoded,1454852.0,4396009000.0
n_ontarget,n_ontarget,1016555.0,3334339000.0


In [408]:
ORDER = [
    "n_total",
    "n_passed_qc",
    "n_barcoded",
    "pf_mapped",
    "hs_mapped",
    "unmapped",
    "n_ontarget"
]

In [409]:
values = pd.concat(
    [eff_df["n_reads"], 
     mapping_df[msc.primary_levels].sum()]
)[ORDER]

In [410]:
values

n_total        1635169.0
n_passed_qc    1513165.0
n_barcoded     1454852.0
pf_mapped      1094955.0
hs_mapped       347699.0
unmapped         11836.0
n_ontarget     1016555.0
dtype: float64

In [411]:
from dataclasses import dataclass

In [442]:
@dataclass
class Node:
    
    NORM_VALUE = eff_df.query("subset == 'n_total'")["n_reads"].squeeze()
    
    name: str
    value: int
    color: str
    index: int
    label: str = None
        
    def __post_init__(self):
        self.label = f"{self.name} ({100*self.value/self.NORM_VALUE:.01f}%)"

- This will be hard to automate, because of the different branching structures
- One solution would be to *add* the failing nodes here

In [443]:
Node("hi", 2, "blue", 3)

Node(name='hi', value=2, color='blue', index=3, label='hi (0.0%)')

In [444]:
nodes = [
    Node("Total Reads",
         value=1635169,
         color="black",
         index=0),
    Node("No. Failed QC",
         value=1635169-1513165,
         color="lightgrey",
         index=1),
    Node("No. Passed QC",
         value=1513165,
         color="lightgrey",
         index=2),
    Node("Unclassified",
         value=1513165-1454852,
         color="lightgrey"
         index=3,
        )
    Node("No. Barcoded",
         value=1454852,
         color="lightgrey",
         index=4),
    
]

In [445]:
nodes

[Node(name='Total Reads', value=1635169, color='black', index=0, label='Total Reads (100.0%)'),
 Node(name='No. Failed QC', value=122004, color='lightgrey', index=2, label='No. Failed QC (7.5%)'),
 Node(name='No. Passed QC', value=1513165, color='lightgrey', index=1, label='No. Passed QC (92.5%)')]

In [446]:
nodes[0].label

'Total Reads (100.0%)'

### Visualise

In [None]:
link = {
    "source": [0, 0, 2, 2, 4, 4, 6, 6] + [8]*8,
    "target": [1, 2, 3, 4, 5, 6, 7, 8] + list(range(9, 9+8)),
    "value": [1635169-1513165, 
              1513165, 
              1513165-1454852, 
              1454852,
              1454852-1095022,
              1095022,
              1095022-1016555,
              1016555
             ] + gene_sum_df["reads_total"].tolist(),
}
node = {
    "pad": 15,
    "thickness": 15,
    "line": dict(color = "black", width = 0.5),
    "color": sns.color_palette("viridis", 9).as_hex()
    + sns.color_palette("Spectral", 8).as_hex(), #[rgb(c) for c in sns.color_palette("Set1", 3)],
    "label": ["Total Reads", 
              "Failed QC", "Passed QC", 
              "Unclassified", "Barcoded",
              "Unmapped", "Mapped",
              "Off-target", "On-target"
             ] + gene_sum_df.index.tolist(),
    "x": [0, 0.15, 0.15, 0.3, 0.3, 0.45, 0.45, 0.6, 0.6] + [1]*8,
    "y": [0, 0.1, 0.1, 0.5, 0.2, 0.5, 0.3, 0.5, 0.4] + list(np.linspace(-0.2, 0.2, 9))
}

plot_data = go.Sankey(
    domain={'x': [0, 1], 'y': [0, 1]},
    link=link,
    node=node,
    arrangement="snap"
)

fig = go.Figure(plot_data)
fig.show()