In [None]:
usage = """Run with papermill:
     
papermill srsnv_report.ipynb output_srsnv_report.ipynb \
    -p model_file <> \
    -p params_file <> \
    -p srsnv_qc_h5_file <> \
    -p output_LoD_plot <> \
    -p qual_vs_ppmseq_tags_table <> \
    -p training_progerss_plot <> \
    -p SHAP_importance_plot <> \
    -p SHAP_beeswarm_plot <> \
    -p trinuc_stats_plot <> \
    -p output_qual_per_feature <> \
    -p qual_histogram <> \
    -p logit_histogram <> \
    -p calibration_fn_with_hist <> \
Then convert to html

jupyter nbconvert --to html output_srsnv_report.ipynb --no-input --output srsnv_report.html"""

In [1]:
import functools
import pandas as pd
import os
import base64
from IPython.display import Image, HTML, display
import joblib
import json
import math

pd.options.display.max_rows = 200

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from ugvc.mrd.srsnv_plotting_utils import signif

In [None]:
# papermill parameters
model_file = None
params_file = None
srsnv_qc_h5_file = None
output_LoD_plot = None
qual_vs_ppmseq_tags_table = None
training_progerss_plot = None
SHAP_importance_plot = None
SHAP_beeswarm_plot = None
trinuc_stats_plot = None
output_qual_per_feature = None
qual_histogram = None
logit_histogram = None
calibration_fn_with_hist = None


In [None]:
# check that we got all the inputs
missing = list()
for varname in [
    "model_file",
    "params_file",
    "srsnv_qc_h5_file", 
    "output_LoD_plot",
    "qual_vs_ppmseq_tags_table",
    "training_progerss_plot",
    "SHAP_importance_plot", 
    "SHAP_beeswarm_plot",
    "trinuc_stats_plot", 
    "output_qual_per_feature",
    "qual_histogram",
    "logit_histogram",
    "calibration_fn_with_hist",
]:
    if locals()[varname] is None:
        missing.append(varname)

if len(missing) > 0:
    raise ValueError(f"Following inputs missing:\n{(os.linesep).join(missing)}")

In [None]:
def safe_run(method):
    @functools.wraps(method)
    def wrapper(*args, **kwargs):
        try:
            return method(*args, **kwargs)
        except Exception as e:
            print(f"Error in {method.__name__}: {e}")
            return None

    return wrapper

In [None]:
# load files
model = joblib.load(model_file)
if isinstance(model, dict): # joblib after BIOIN-1558
    model = model['models']
if isinstance(model, list): # For models saved from CV
    model = model[0]
with open(params_file, 'r', encoding="utf-8") as f:
    params = json.load(f)

In [None]:
@safe_run
def display_test_train(image_path,titlestr, report_name='test'):
    # other_dataset = 'train' if report_name == 'test' else 'test'
    other_dataset = 'test'
    image_path1 = image_path+'.png'
    image_path2 = image_path.replace(f".{report_name}.",f".{other_dataset}.")+'.png'

    img1 = mpimg.imread(image_path1)
    img2 = mpimg.imread(image_path2)

    fig, ax = plt.subplots(1, 2, figsize=(20, 10),constrained_layout=True)
    ax[0].imshow(img1)
    ax[0].axis('off')
    ax[0].set_title(report_name,fontsize=20)
    ax[1].imshow(img2)
    ax[1].axis('off')
    ax[1].set_title(other_dataset,fontsize=20)
    
    fig.suptitle(titlestr,fontsize=24,y=0.95)
    plt.show()

dataname = params_file.split('/')[-1].split('.')[0]


In [None]:
@safe_run
def display_with_vertical_lines(df, sep_columns=None, unique_id='dataframe'):
    """
    Displays a DataFrame in a Jupyter notebook with vertical lines between selected columns.
    Applies styling only to the specified DataFrame using a unique ID.

    Parameters:
    df (pd.DataFrame): The DataFrame to display.
    sep_columns (list of int): The list of column indices where vertical lines should be drawn. 
                               The line is drawn before these column indices.
    unique_id (str): A unique ID for the DataFrame to scope the CSS styling.
    """
    # Convert DataFrame to HTML with a unique ID and include index
    html = df.to_html(border=0, classes='dataframe', justify='left', index=True)
    html = html.replace('<table', f'<table id="{unique_id}"')

    # Add styles for vertical lines with unique ID
    style = f'''
    <style>
    #{unique_id} th {{
        text-align: left;  /* Left-align column headers */
        padding: 6px;  /* Adjust padding if necessary */
    }}
    #{unique_id} td {{
        text-align: right;  /* Right-align table body cells */
        padding: 6px;  /* Adjust padding if necessary */
    }}
    #{unique_id} td, #{unique_id} th {{
        border-right: 1px solid #000;
    }}
    '''
    
    # If specific columns are selected for separation lines
    if sep_columns is not None:
        sep_style = ""
        for col in sep_columns:
            sep_style += f"#{unique_id} td:nth-child({col + 1}), #{unique_id} th:nth-child({col + 1}) {{ border-right: 2px solid black !important; }}\n"
        style += sep_style
    
    # Close style tags
    style += "</style>"
    
    # Display the styled HTML
    display(HTML(style + html))

In [None]:
@safe_run
def display_image_with_max_width(image_path, w=1.0):
    """
    Displays an image in a Jupyter notebook with a width that is a fraction of the notebook frame.

    Parameters:
    - image_path: Path to the image file.
    - w: Fraction of the notebook frame width (between 0 and 1).
    """
    # Convert w to percentage for use in HTML/CSS
    width_percent = int(w * 100)

    # Read the image file and encode it in base64
    with open(image_path, "rb") as image_file:
        encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
    
    # Create an HTML string to display the image with the desired max-width
    image_html = f'<img src="data:image/png;base64,{encoded_image}" style="max-width: {width_percent}%; height: auto;">'
    
    # Display the image using HTML
    display(HTML(image_html))

@safe_run
def display_images_grid(image_paths, w=1/3):
    """
    Displays multiple images in a grid layout in a Jupyter notebook.
    Images are displayed in rows, with each image's width being a fraction of the notebook frame.
    Missing images are noted, and their paths are displayed as an error report.

    Parameters:
    - image_paths: List of paths to the image files.
    - w: Fraction of the notebook frame width for each image (between 0 and 1).
    """
    # Convert w to percentage for use in HTML/CSS
    width_percent = int(w * 100)
    
    # Calculate the number of images per row
    images_per_row = math.floor(1 / w)
    
    # Start the HTML string for the grid
    images_html = ''
    missing_files = []
    
    # Loop through the images and create rows
    for i in range(0, len(image_paths), images_per_row):
        # Start a new row
        images_html += '<div style="display: flex; justify-content: space-between;">'
        
        # Add each image in the row or leave a space if the file doesn't exist
        for image_path in image_paths[i:i + images_per_row]:
            if os.path.exists(image_path):
                # Read the image file and encode it in base64
                with open(image_path, "rb") as image_file:
                    encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
                
                # Add the image HTML to the row, with max-width and height auto
                images_html += f'<img src="data:image/png;base64,{encoded_image}" style="max-width: {width_percent}%; height: auto;">'
            else:
                # Leave an empty space and record the missing file
                images_html += f'<div style="max-width: {width_percent}%; height: auto;"></div>'
                missing_files.append(image_path)
        
        # Close the div for the row
        images_html += '</div>'
    
    # Display the images in a grid using HTML
    display(HTML(images_html))
    
    # If there are missing files, display them as an error report
    if missing_files:
        error_report = '<p style="font-family: monospace; color: red;">'
        error_report += 'The following files were not found:<br>' + '<br>'.join(missing_files)
        error_report += '</p>'
        display(HTML(error_report))

In [None]:
display(HTML(f'<font size="6">SRSNV pipeline report </font>'))

* This report contains an analysis of the SRSNV model training.
* We train as binary classifier per SNV. 
* The probabilities are translated to quality: quality = -10*log10(probability). 
* The quality is used as a threshold for discriminating true and false variants.

<!--TOC_PLACEHOLDER-->

In [None]:
display(HTML(f'<font size="5">Run Info </font>'))
run_info_table = pd.read_hdf(srsnv_qc_h5_file, key='run_info_table') 
if isinstance(run_info_table.loc[('Docker image', '')], str):
    run_info_table.loc[('Docker image', '')] = '/<br>'.join(run_info_table.loc[('Docker image', '')].split('/'))
display(HTML(run_info_table.to_frame().to_html(escape=False)))

# Summary of quality statistics

## Summary statistics table

In [None]:
run_quality_summary_table = pd.read_hdf(srsnv_qc_h5_file, key='run_quality_summary_table')
run_quality_summary_table.apply(signif, args=(4,)).astype(str).to_frame() # round to 4 decimal places

## SNVQ vs Recall

In [None]:
image_path1 = output_LoD_plot+'.png'
display_image_with_max_width(image_path1, 0.5)

We calculate the residual SNV rate as following: 
```
error rate in test data = # errors / # bases sequenced
```
where:
```
# errors = # of single substitution snps > filter thresh
# bases sequenced = # of bases aligned * % mapq60 * ratio_of_bases_in_coverage_range *
                    read_filter_correction_factor * recall[threshold]
```
and: 
```
# of bases aligned = mean_coverage * bases in region * downsampling factor
downsampling factor = % of the featuremap reads sampled for test set
```

In [None]:
if params.get('normalization_factors_dict', None) is not None:
    print('Normalization factors:')
    display(pd.Series(params['normalization_factors_dict'], name='').to_frame())

## SNVQ percentiles

In [None]:
run_quality_table = pd.read_hdf(srsnv_qc_h5_file, key='run_quality_table_display')
display_with_vertical_lines(run_quality_table, unique_id='run_quality_table')

## SNVQ histogram (TP reads)

In [None]:
display_image_with_max_width(qual_histogram+'.png', 0.5)

# SNVQ and statistics per feature

## Categorical features

### SNVQ vs start/end ppmSeq tag

The table below present median SNVQ values on the TP (homozygous substitutions) training dataset. Numbers in square brackets are the proportion of datapoints with each start/end tag combination. 

In [None]:
display_image_with_max_width(qual_vs_ppmseq_tags_table+'.png', 0.4)

### Quality as function of trinuc context and alt

In [None]:
display_image_with_max_width(trinuc_stats_plot+'.png', 1.0)

## Numerical features

In [None]:
image_paths = [output_qual_per_feature + f + '.png' for f in params["numerical_features"]]
display_images_grid(image_paths, w=1/3)

# Training

## General training information

In [None]:
training_info_table = pd.read_hdf(srsnv_qc_h5_file, key='training_info_table')
training_info_table.to_frame()

## ROC AUC
Values below are ROC AUC phred scores, i.e., $-10 \log_{10}(1-\text{AUC})$. NaN values indicate a problem calculating the ROC AUC score, e.g. when there are no mixed reads. 

In [None]:
roc_auc_table = pd.read_hdf(srsnv_qc_h5_file, key='roc_auc_table')
roc_auc_table

## Logit histogram
The logit of a prediction is defined as $$\text{logit} = 10 \log_{10}\frac{p}{1-p}$$ where $p$ is the predicted probability to be True. When $p$ is close to 1, logit is close to ML_qual.

The following plot presents histograms of the logits. Histograms for the predictions of each data fold are calculated separately and overlayed in the plot. 

In [None]:
display_image_with_max_width(logit_histogram+'.png', 0.5)

## ML_qual -> SNVQ mapping function
The function that maps the models' ML_qual values to SNVQ values. Histograms of ML_qual and SNVQ are also provided, as well as the derivative $\frac{d\text{ML}\_\text{qual}}{d\text{SNVQ}}$ (deonted 'deriv' in plot).

In [None]:
display_image_with_max_width(calibration_fn_with_hist+'.png', 0.35)

## Training progress

In [None]:
display_image_with_max_width(training_progerss_plot+'.png', 0.6)

# Feature importance: SHAP
SHAP values are an estimation of how much each feature value has contributed to the model prediction, in our binary classification case, to the model's logit value. The output logit equals an overall bias term plus the sum of all features SHAP values for a given input. Large positive SHAP values "push" the prediction towards True, and large negative values towards False. 

For example, for a linear classifier (logistic regression), the logit value is $$y = \sum_i w_i x_i,$$ where the $x_i$'s are the feature values. The SHAP value of feature $i$ for this prediction is $w_i x_i$. 

## Shap bar plot
The following plot measure the importance of features by mean absolute shap values per feature.

In [None]:
display_image_with_max_width(SHAP_importance_plot+'.png', 0.4)

## SHAP beeswarm plot
More insight into the model's prediction is available from the beeswarm plot. This is a scatter plot of SHAP values (each point is the SHAP value for one datapoint), providing insight into how SHAP values are distributed by feature. Moreover, the colors represent the feature value at the given datapoint, revealing whether particular values tend to affect the predictions in a certain direction. 

In [None]:
display_image_with_max_width(SHAP_beeswarm_plot+'.png', 1.0)

The number in brackets next to categorical features refers to the appropriate colormap on the right that corresponds to the possible values of the feature. 

# Data and model parameters

In [None]:

display(HTML(f'<font size="4">Input parameters: </font>'))

for item in params['model_parameters']:
    print(f"    * {item}: {params['model_parameters'][item]}")

params_for_print = [
    'numerical_features',
    'categorical_features_dict',
    'train_set_size',   
    'test_set_size',    
]
for p in params_for_print:    
    if (type(params[p]) == list):
        print(f"    * {p}:")
        for pp in params[p]:
            print(f"        - {pp}")
    elif (type(params[p]) == dict):
        print(f"    * {p}:")
        for k, v in params[p].items():
            print(f"        - {k}: {v}")
    else:
        print(f"    * {p}: {params[p]}")
