### Script preparation

This part is for the explainability module. Using this code, you can use local interpretable model-agnostic explanations (LIME) to check for/derive biological meaning from predictions for patients.

Make sure you have installed all other packages in the main pipeline. Install the packages required for this code:
1. "lime" v0.2.0.1 to implement LIME algorithm.

This phrase indicates an input requirement that must be fulfilled by you.
<font color='orange'>"**INPUT:**"<font>

Other cells do not require any alterations and should be run without any change in the code.

In [None]:
!pip install lime

Import the required packages for the entire code.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV, StratifiedGroupKFold
import seaborn as sns
import flowio
import flowutils
import FlowCal
import lime
import lime.lime_tabular
import warnings
warnings.filterwarnings('ignore')

### Prerequisites

A few notes:
1. Have the original training dataset at hand as it is needed by LIME..
2. The runtime for each patient will be 10-20 minutes depending on your system's specs.

<font color='orange'>**INPUT:** Load the training dataset into a pandas DataFrame. Insert the training dataset name like below:

*dataset_name = 'example.csv'*

In [2]:
dataset_name = 'Sample_training_dataset.csv'    
data = pd.read_csv(dataset_name) 

Define the training data, labels, and group labels. The "Population" is the column with annotations and the "Batch" is the column with patient or batch ID.

In [3]:
X = data.drop(['Population', 'FSC-H', 'Batch'], axis=1)
y = data['Population']
groups = data['Batch']

### Testing with explainability

A function for analyzing test FCS files, saving a table of percentages populations present, saving desired figures, as well as a explainability matrix by LIME. 

<font color='orange'>Please read the comment at line 29.<font>

In [4]:
def analyze(address, ml_model, dotplot_num, dotplot_params):
    file_address = address
    global file_name
    file_name = file_address.replace('.fcs', '')
    fcs_file = flowio.FlowData(file_address)
    try:
        spill, markers = flowutils.compensate.get_spill(fcs_file.text['spill'])
    except KeyError:
        spill, markers = flowutils.compensate.get_spill(fcs_file.text['spillover'])
    raw_data = np.reshape(fcs_file.events, (-1, fcs_file.channel_count))
    fluoro_indices = []
    for channel in fcs_file.channels:
        if fcs_file.channels[channel]['PnN'] in markers:
            fluoro_indices.append(int(channel) - 1)
    fluoro_indices.sort()
    comp_data = flowutils.compensate.compensate(raw_data, spill, fluoro_indices)
    channel_list = []
    for i in range(1, fcs_file.channel_count + 1):
        channel_list.append(fcs_file.text['p{}n'.format(i)])
    flow_data = FlowCal.transform.to_rfi(comp_data, amplification_type=(tuple([(0, 0)] * len(channel_list))))
    events = pd.DataFrame(flow_data, columns=channel_list)
    total_events = events.loc[events['FSC-A'] / events['FSC-H'] < 2]
    for ch in ['FSC-H', 'SSC-H', 'Time']:
        try:
            total_events.drop(columns=[ch], inplace=True)
        except KeyError:
            pass
    model = ml_model
    
    # READ ME:
    # If you want to test the sample files (csv and fcs), to match the columns between the training dataset and sample fcs,
    # run the the next line of code as well (remove # from the beginning).
    #total_events = total_events[['FSC-A', 'SSC-A', 'FITC-A', 'PE-A', 'PerCP-Cy5.5-A', 'PE-Cy7-A', 'APC-A', 'APC-R700-A', 'APC-Cy7-A', 'V450-A', 'V500-C-A']]
    total_events['Predicted'] = model.predict(total_events)
    global wbc_events
    wbc_events = total_events.loc[(total_events['Predicted'] != 'Erythroid Cells') & (total_events['Predicted'] != 'Erythroid Precursors')]
    results = wbc_events['Predicted'].value_counts(normalize=True) * 100
    results.to_csv('{} results.csv'.format(file_name))
    global mrd_events
    mrd_events = wbc_events.loc[wbc_events['Predicted'] == 'Residual Leukemic Cells']
    for n in range(dotplot_num):
        dot_plot(dotplot_params[n][0], dotplot_params[n][1])
    i_list = list(total_events.columns)
    del i_list[-1]
    explanation = pd.DataFrame(index=range(len(i_list)))
    for p in sorted(list(total_events['Predicted'].value_counts().index)):
        pop_events = total_events.loc[total_events['Predicted'] == p].copy()
        try:
            sub_events = pop_events.sample(1000, random_state=13)
        except ValueError:
            sub_events = pop_events.copy()
        interp = pd.DataFrame(index=range(len(i_list)))
        for e in range(len(sub_events.drop(columns='Predicted'))):
            exp = explainer.explain_instance(np.array(sub_events.drop(columns='Predicted'))[e],
                                             model.predict_proba,
                                             top_labels=1)
            df = pd.DataFrame(exp.as_map()[exp.available_labels()[0]], columns=['Label', 'Value'])
            df = df.sort_values('Label').set_index('Label')
            interp['Event ' + str(e)] = df['Value']
        interp.fillna(0, inplace=True)
        interp = interp.T
        interp.loc['Value'] = interp.mean()
        interp = interp.T
        explanation[p] = list(interp['Value'].values)
    explanation.index = i_list
    explanation = (explanation - explanation.min()) / (explanation.max() - explanation.min())
    ax = sns.heatmap(explanation, cmap='Purples', annot=True, fmt='.2f', annot_kws={'size': 6})
    plt.xticks(rotation=45, ha='right')
    plt.title('Explainability Matrix: Patient ' + file_name)
    plt.savefig(f'{file_name} Explainability matrix heatmap.png', bbox_inches='tight', dpi=200)
    plt.close()

A function for creating and saving figures showing the MRD population based on your parameters of choice.

In [5]:
def dot_plot(x, y):
    sns.scatterplot(x=wbc_events[x], y=wbc_events[y], c='lightgrey', s=1)
    sns.scatterplot(x=mrd_events[x], y=mrd_events[y], c='maroon', s=3)
    if x in ['FSC-A', 'SSC-A']:
        plt.xscale('linear')
    else:
        plt.xscale('symlog', linthresh=1000)
        plt.xlim(left=-1000)
    if y in ['FSC-A', 'SSC-A']:
        plt.yscale('linear')
    else:
        plt.yscale('symlog', linthresh=1000)
        plt.ylim(bottom=-1000)
    plt.savefig('{}-{}-{}.png'.format(file_name, x, y))
    plt.close()

<font color='orange'>**INPUT:** Load the model from a saved pickle object. Insert the model name like below:

*model_name = 'example.pkl'*

In [6]:
model_name = 'Sample_model.pkl' 
model = pickle.load(open(model_name, 'rb'))

Define the LIME explainer.

In [7]:
explainer = lime.lime_tabular.LimeTabularExplainer(np.array(X),
                                                   feature_names=list(model.feature_names_in_),
                                                   class_names=list(model.classes_),
                                                   discretize_continuous=True,
                                                   random_state=13,
                                                   verbose=False)

<font color='orange'>**INPUT:** Analyze a single fcs file. Add the name of the file, indicate the number of figures you want and the parameters in figures like below:

*file_name = 'example.fcs'*

*fig_num = 2*

*fig_params = [('FSC-A', 'SSC-A'), ('V500-C-A', 'SSC-A')]*

In [8]:
file_name = 'Sample_patient.fcs'
fig_num = 2
fig_params = [('FSC-A', 'SSC-A'), ('V500-C-A', 'SSC-A')]
analyze(file_name, model, fig_num, fig_params)