In [None]:
usage = """Run with papermill:
     
papermill srsnv_report.ipynb output_srsnv_report.ipynb \
    -p report_name <> \
    -p model_file <> \
    -p params_file <> \
    -p qc_h5_file <> \
    -p output_roc_plot <> \
    -p output_LoD_plot <> \
    -p qual_vs_ppmseq_tags_table <> \
    -p output_LoD_qual_plot <> \
    -p output_cm_plot <> \
    -p output_obsereved_qual_plot <> \
    -p output_ML_qual_hist <> \
    -p output_qual_per_feature <> \
    -p output_bepcr_hists <> \
    -p output_bepcr_fpr <> \
    -p output_bepcr_recalls <>
Then convert to html

jupyter nbconvert --to html output_srsnv_report.ipynb --no-input --output srsnv_report.html"""

In [1]:
import pandas as pd
import os
import base64
from IPython.display import Image, HTML, display
import joblib
import json
import math

pd.options.display.max_rows = 200

import matplotlib.pyplot as plt
import matplotlib.image as mpimg



In [None]:
# papermill parameters
report_name = None
model_file = None
params_file = None
srsnv_qc_h5_file = None
output_roc_plot = None
output_LoD_plot = None
qual_vs_ppmseq_tags_table = None
training_progerss_plot = None
SHAP_importance_plot = None
SHAP_beeswarm_plot = None
trinuc_stats_plot = None
output_LoD_qual_plot = None
output_cm_plot = None
output_obsereved_qual_plot = None
output_ML_qual_hist = None
output_qual_per_feature = None
output_bepcr_hists = None
output_bepcr_fpr = None
output_bepcr_recalls = None

In [None]:
# check that we got all the inputs
missing = list()
for varname in [
    "report_name",
    "model_file",
    "params_file",
    "srsnv_qc_h5_file", 
    "output_roc_plot",
    "output_LoD_plot",
    "qual_vs_ppmseq_tags_table",
    "training_progerss_plot",
    "SHAP_importance_plot", 
    "SHAP_beeswarm_plot",
    "trinuc_stats_plot", 
    "output_LoD_qual_plot",
    "output_cm_plot",
    "output_obsereved_qual_plot",
    "output_ML_qual_hist",
    "output_qual_per_feature",
    "output_bepcr_hists",
    "output_bepcr_fpr",
    "output_bepcr_recalls",
]:
    if locals()[varname] is None:
        missing.append(varname)

if len(missing) > 0:
    raise ValueError(f"Following inputs missing:\n{(os.linesep).join(missing)}")

In [None]:
# load files
model = joblib.load(model_file)
if isinstance(model, dict): # joblib after BIOIN-1558
    model = model['models']
if isinstance(model, list): # For models saved from CV
    model = model[0]
with open(params_file, 'r', encoding="utf-8") as f:
    params = json.load(f)

In [None]:
def display_test_train(image_path,titlestr, report_name='test'):
    # other_dataset = 'train' if report_name == 'test' else 'test'
    other_dataset = 'test'
    image_path1 = image_path+'.png'
    image_path2 = image_path.replace(f".{report_name}.",f".{other_dataset}.")+'.png'

    img1 = mpimg.imread(image_path1)
    img2 = mpimg.imread(image_path2)

    fig, ax = plt.subplots(1, 2, figsize=(20, 10),constrained_layout=True)
    ax[0].imshow(img1)
    ax[0].axis('off')
    ax[0].set_title(report_name,fontsize=20)
    ax[1].imshow(img2)
    ax[1].axis('off')
    ax[1].set_title(other_dataset,fontsize=20)
    
    fig.suptitle(titlestr,fontsize=24,y=0.95)
    plt.show()

dataname = params_file.split('/')[-1].split('.')[0]


In [None]:
def display_with_vertical_lines(df, sep_columns=None, unique_id='dataframe'):
    """
    Displays a DataFrame in a Jupyter notebook with vertical lines between selected columns.
    Applies styling only to the specified DataFrame using a unique ID.

    Parameters:
    df (pd.DataFrame): The DataFrame to display.
    sep_columns (list of int): The list of column indices where vertical lines should be drawn. 
                               The line is drawn before these column indices.
    unique_id (str): A unique ID for the DataFrame to scope the CSS styling.
    """
    # Convert DataFrame to HTML with a unique ID and include index
    html = df.to_html(border=0, classes='dataframe', justify='left', index=True)
    html = html.replace('<table', f'<table id="{unique_id}"')

    # Add styles for vertical lines with unique ID
    style = f'''
    <style>
    #{unique_id} th {{
        text-align: left;  /* Left-align column headers */
        padding: 6px;  /* Adjust padding if necessary */
    }}
    #{unique_id} td {{
        text-align: right;  /* Right-align table body cells */
        padding: 6px;  /* Adjust padding if necessary */
    }}
    #{unique_id} td, #{unique_id} th {{
        border-right: 1px solid #000;
    }}
    '''
    
    # If specific columns are selected for separation lines
    if sep_columns is not None:
        sep_style = ""
        for col in sep_columns:
            sep_style += f"#{unique_id} td:nth-child({col + 1}), #{unique_id} th:nth-child({col + 1}) {{ border-right: 2px solid black !important; }}\n"
        style += sep_style
    
    # Close style tags
    style += "</style>"
    
    # Display the styled HTML
    display(HTML(style + html))

In [None]:
def display_image_with_max_width(image_path, w=1.0):
    """
    Displays an image in a Jupyter notebook with a width that is a fraction of the notebook frame.

    Parameters:
    - image_path: Path to the image file.
    - w: Fraction of the notebook frame width (between 0 and 1).
    """
    # Convert w to percentage for use in HTML/CSS
    width_percent = int(w * 100)

    # Read the image file and encode it in base64
    with open(image_path, "rb") as image_file:
        encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
    
    # Create an HTML string to display the image with the desired max-width
    image_html = f'<img src="data:image/png;base64,{encoded_image}" style="max-width: {width_percent}%; height: auto;">'
    
    # Display the image using HTML
    display(HTML(image_html))

def display_images_grid(image_paths, w=1/3):
    """
    Displays multiple images in a grid layout in a Jupyter notebook.
    Images are displayed in rows, with each image's width being a fraction of the notebook frame.

    Parameters:
    - image_paths: List of paths to the image files.
    - w: Fraction of the notebook frame width for each image (between 0 and 1).
    """
    # Convert w to percentage for use in HTML/CSS
    width_percent = int(w * 100)
    
    # Calculate the number of images per row
    images_per_row = math.floor(1 / w)
    
    # Start the HTML string for the grid
    images_html = ''
    
    # Loop through the images and create rows
    for i in range(0, len(image_paths), images_per_row):
        # Start a new row
        images_html += '<div style="display: flex; justify-content: space-between;">'
        
        # Add each image in the row
        for image_path in image_paths[i:i + images_per_row]:
            # Read the image file and encode it in base64
            with open(image_path, "rb") as image_file:
                encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
            
            # Add the image HTML to the row, with max-width and height auto
            images_html += f'<img src="data:image/png;base64,{encoded_image}" style="max-width: {width_percent}%; height: auto;">'
        
        # Close the div for the row
        images_html += '</div>'
    
    # Display the images in a grid using HTML
    display(HTML(images_html))

In [None]:
display(HTML(f'<font size="6">SRSNV pipeline report </font>'))

* This report contains an analysis of the SRSNV model training.
* We train as binary classifier per read. 
* The probabilities are translated to quality: quality = -10*log10(probability). 
* The quality is used as a threshold for discriminating true and false variants.

<!--TOC_PLACEHOLDER-->

# Run details

In [None]:
run_info_table = pd.read_hdf(srsnv_qc_h5_file, key='run_info_table')
run_info_table.to_frame()

# General statistics

## Summary

## Quality

In [None]:
run_quality_table = pd.read_hdf(srsnv_qc_h5_file, key='run_quality_table_display')
display_with_vertical_lines(run_quality_table, unique_id='run_quality_table')

## SNVQ vs Recall

We calculate the residual SNV rate as following: 
```
error rate in test data = # errors / # bases sequenced
```
where:
```
# errors = # of single substitution snps > filter thresh
# bases sequenced = # of bases aligned * % mapq60 * ratio_of_bases_in_coverage_range *
                    read_filter_correction_factor * recall[threshold]
```
and: 
```
# of bases aligned = mean_coverage * bases in region * downsampling factor
downsampling factor = % of the featuremap reads sampled for test set
```

In [None]:
image_path1 = output_LoD_plot+'.png'
display(HTML(f'<font size="6">Test LoD simulation </font>'))
display_image_with_max_width(image_path1, 0.7)
# display(Image(filename=image_path1), width=800, height=800)
# other_dataset = 'train' if report_name=='test' else 'test'
# other_dataset = 'test'
# image_path2 = output_LoD_plot.replace(f".{report_name}.",f".{other_dataset}.")+'.png'
# display(HTML(f'<font size="6">Train LoD simulation </font>'))
# display(Image(filename=image_path2, width=800, height=800))

# SNVQ and statistics of features

## Categorical features

### SNVQ vs start/end ppmSeq tag

The table below present median SNVQ values on the TP (homozygous substitutions) training dataset. Numbers in square brackets are the proportion of datapoints with each start/end tag combination. 

In [None]:
display_image_with_max_width(qual_vs_ppmseq_tags_table+'.png', 0.4)

### Quality as function of trinuc context and alt

In [None]:
display_image_with_max_width(trinuc_stats_plot+'.png', 1.0)

## Numerical features

In [None]:
image_paths = [output_qual_per_feature + f + '.png' for f in params["numerical_features"]]
display_images_grid(image_paths, w=1/3)

# Training

## Training progress

In [None]:
display_image_with_max_width(training_progerss_plot+'.png', 0.6)

# Feature importance: SHAP
SHAP values are an estimation of how much each feature value has contributed to the model prediction, in our binary classification case, to the model's logit value. The output logit equals an overall bias term plus the sum of all features SHAP values for a given input. Large positive SHAP values "push" the prediction towards True, and large negative values towards False. 

For example, for a linear classifier (logistic regression), the logit value is $$y = \sum_i w_i x_i,$$ where the $x_i$'s are the feature values. The SHAP value of feature $i$ for this prediction is $w_i x_i$. 

## Shap bar plot
The following plot measure the importance of features by mean absolute shap values per feature.

In [None]:
display_image_with_max_width(SHAP_importance_plot+'.png', 0.4)

## SHAP beeswarm plot
More insight into the model's prediction is available from the beeswarm plot. This is a scatter plot of SHAP values (each point is the SHAP value for one datapoint), providing insight into how SHAP values are distributed by feature. Moreover, the colors represent the feature value at the given datapoint, revealing whether particular values tend to affect the predictions in a certain direction. 

In [None]:
display_image_with_max_width(SHAP_beeswarm_plot+'.png', 1.0)

The number in brackets next to categorical features refers to the appropriate colormap on the right that corresponds to the possible values of the feature. 

---
# Remnants from old report

In [None]:
#display_test_train(output_LoD_qual_plot,"LoD vs. ML qual \n"+dataname, report_name=report_name)

In [None]:
#display_test_train(output_roc_plot,"ROC curve \n"+dataname, report_name=report_name)

## Residual SNV rate vs Retention and LoD simulation

# Training metrics

In [None]:
# title = 'Confusion matrix'
# display(HTML(f'<font size="4">{title}</font>'))
# display_test_train(output_cm_plot,dataname, report_name=report_name)

In [None]:
# title = 'ML qual hists by class'
# display(HTML(f'<font size="4">{title}</font>'))
# display_test_train(output_ML_qual_hist,dataname, report_name=report_name)

# display(HTML(f'<font size="4">Stratified by category </font>'))
# subset_data_list = [
#     'mixed_cycle_skip',
#     'mixed_non_cycle_skip',
#     'non_mixed_cycle_skip',
#     'non_mixed_non_cycle_skip',
#     'cycle_skip',
#     'non_cycle_skip',
# ]

# for suffix in subset_data_list:
#     image_path = output_bepcr_hists + suffix    
#     if os.path.isfile(image_path+'.png'):
#         display_test_train(image_path,dataname, report_name=report_name)

display(HTML(f'<font size="4">ML qual calibration by category </font>'))
display_test_train(output_bepcr_fpr,dataname, report_name=report_name)

# display(HTML(f'<font size="4">Recall rate by category </font>'))
# display_test_train(output_bepcr_recalls,dataname, report_name=report_name)

In [None]:
# display(HTML(f'<font size="4">Feature distribution per label</font>'))
# for f in model.feature_names_in_:
#     image_path = output_qual_per_feature + f
#     if os.path.isfile(image_path + '.png'):        
#         display_test_train(image_path,dataname, report_name=report_name)


In [None]:

display(HTML(f'<font size="4">Input parameters: </font>'))

for item in params['model_parameters']:
    print(f"    * {item}: {params['model_parameters'][item]}")

params_for_print = [
    'numerical_features',
    'categorical_features_dict',
    'train_set_size',   
    'test_set_size',    
]
for p in params_for_print:    
    if (type(params[p]) == list):
        print(f"    * {p}:")
        for pp in params[p]:
            print(f"        - {pp}")
    elif (type(params[p]) == dict):
        print(f"    * {p}:")
        for k, v in params[p].items():
            print(f"        - {k}: {v}")
    else:
        print(f"    * {p}: {params[p]}")
