Notebook corresponding to section 2.3.2 on the UKE dataset.

Imports

In [None]:
# Local dependencies
from NET_CUP.data_loading import data_tree, xyp, feature_type
import NET_CUP.datasources_config as datasources_config

# Other dependencies
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import plotly.graph_objects as go
import plotly.subplots as sp

Functions

In [33]:
def get_thresholds(df: pd.DataFrame):
    """
    Given df returns lists with optimal pancreas/intestine thresholds for all possible others_unsure values
    :param df: Columns y_true, Sum SVM distances
    :return:
    """
    inf = float('inf')
    others_distances = df.loc[df['y_true'] == 2, 'sum_patch_distances'].to_list()
    others_distances.sort()
    pi_distances = df.loc[df['y_true'] != 2, 'sum_patch_distances'].to_numpy()

    thresholds = {}
    thresholds[0.0] = (0.0, 0.0, 0.0)
    n_others = len(others_distances)
    n_pi = len(pi_distances)
    for i in range(n_others):
        for j in range(i, n_others):
            left_threshold = others_distances[i] - 0.0000001
            right_threshold = others_distances[j] + 0.0000001

            others_unsure = (j - i + 1) / n_others
            pi_unsure = np.count_nonzero((pi_distances > left_threshold) & (pi_distances < right_threshold)) / n_pi

            if thresholds.get(others_unsure, (inf, inf, inf))[0] == pi_unsure:
                current_threshold_sum = abs(
                    thresholds.get(others_unsure, (inf, inf, inf))[1] + thresholds.get(others_unsure, (inf, inf, inf))[
                        2])
                new_threshold_sum = abs(left_threshold + right_threshold)
                if new_threshold_sum < current_threshold_sum:
                    thresholds[others_unsure] = (pi_unsure, left_threshold, right_threshold)
            elif thresholds.get(others_unsure, (inf, inf, inf))[0] > pi_unsure:
                thresholds[others_unsure] = (pi_unsure, left_threshold, right_threshold)
    return thresholds


def get_predictions(y_pred, svm_distance_sum, thresholds):
    predictions = []
    for value in thresholds.values():
        if value[1] < svm_distance_sum < value[2]:
            predictions.append(2)
        else:
            predictions.append(y_pred)
    return np.array(predictions)

Settings

In [34]:
# Adjust these settings
feature_type = feature_type.FeatureType.RETCCL
classifier = SVC()

# Keep these settings
pca = PCA(0.95)
patches_per_patient = 100
patch_size = 4096
border_patches = True

Load data

In [35]:
data = data_tree.create_tree(datasources_config.PATIENTS_PATH,
                             datasources_config.ENUMBER_PATH)
data_tree.drop_slides_without_extracted_features(data, feature_type, datasources_config.UKE_DATASET_DIR)

pancreas_patients = data_tree.get_patients(data, 'p')
intestine_patients = data_tree.get_patients(data, 'i')
others_patients = data_tree.get_patients(data, 'o')

all_patients = pancreas_patients + intestine_patients + others_patients

Training and testing

In [36]:
y_true_patient_level_complete = np.empty((len(all_patients)), dtype=int)
# Each row corresponds to one patient and contains the predictions for the different confidence thresholds
y_pred_patient_level_complete = np.empty((len(all_patients), len(others_patients) + 1), dtype=int) 

for i, loo_patient in tqdm(enumerate(all_patients)):
    # Remove the selected LOO patient, the remaining patients are used for training
    if loo_patient.origin.value == 0:
        pancreas_patients.remove(loo_patient)
    elif loo_patient.origin.value == 1:
        intestine_patients.remove(loo_patient)
    elif loo_patient.origin.value == 2:
        others_patients.remove(loo_patient)

    # svm_patients are used to train the SVM
    # dist_patients are used for calculating the patient level sum of patch distances used for determining the thresholds
    svm_patients, dist_pi_patients = train_test_split(pancreas_patients + intestine_patients, 
                                                      train_size=0.6,
                                                      stratify=xyp.get_patient_level_y(pancreas_patients + intestine_patients))
    dist_patients = dist_pi_patients + others_patients

    # Train the patch level SVM
    X_svm_patch_level, y_svm_patch_level, _ = xyp.get_patch_level_Xyp_complete(svm_patients, patches_per_patient, feature_type, patch_size, border_patches, datasources_config.UKE_DATASET_DIR)
    X_svm_patch_level = pca.fit_transform(X_svm_patch_level)
    classifier.fit(X_svm_patch_level, y_svm_patch_level)

    # Calculate the sum of patch distances for every patient in dist_patients
    sum_patch_distances_complete = []
    y_dist_true_patient_level_complete = []
    for dist_patient in dist_patients:
        X_dist_patch_level, _, _ = xyp.get_patch_level_Xyp_complete([dist_patient], patches_per_patient, feature_type, patch_size, border_patches, datasources_config.UKE_DATASET_DIR)

        y_dist_true_patient_level_complete.append(dist_patient.origin.value)

        sum_patch_distances = sum(classifier.decision_function(pca.transform(X_dist_patch_level)))
        sum_patch_distances_complete.append(sum_patch_distances)
    
    # Determine thresholds based on the calculated sum of patch distances for every patient
    df = pd.DataFrame({'y_true': y_dist_true_patient_level_complete, 'sum_patch_distances': sum_patch_distances_complete})
    thresholds = get_thresholds(df)

    # Classify LOO patient depending on the thresholds
    X_test_patch_level, _, _ = xyp.get_patch_level_Xyp_complete([loo_patient], patches_per_patient, feature_type, patch_size, border_patches, datasources_config.UKE_DATASET_DIR)
    y_pred_patient_level = np.bincount(classifier.predict(pca.transform(X_test_patch_level)).astype(int)).argmax()
    distance = sum(classifier.decision_function(pca.transform(X_test_patch_level)))
    preds = list(get_predictions(y_pred_patient_level, distance, thresholds))
    if loo_patient.origin.value == 2:
        preds = preds[:1] + [preds[1]] + preds[1:]

    y_true_patient_level_complete[i] = loo_patient.origin.value
    y_pred_patient_level_complete[i, :] = np.array(preds)

    # Add the selected LOO patient to the correct list again
    if loo_patient.origin.value == 0:
        pancreas_patients.append(loo_patient)
    elif loo_patient.origin.value == 1:
        intestine_patients.append(loo_patient)
    elif loo_patient.origin.value == 2:
        others_patients.append(loo_patient)


99it [19:27, 11.79s/it]


Visualization

In [None]:
results = np.empty((9, 10))
for i in range(y_pred_patient_level_complete.shape[1]):
    cf_matrix = confusion_matrix(y_true_patient_level_complete, y_pred_patient_level_complete[:, i])
    results[i, 0] = i / 8
    results[i, 1] = cf_matrix[2, 0] / sum(cf_matrix[2, :])
    results[i, 2] = cf_matrix[2, 1] / sum(cf_matrix[2, :])
    results[i, 3] = cf_matrix[2, 2] / sum(cf_matrix[2, :])


    results[i, 4] = cf_matrix[0, 0] / sum(cf_matrix[0, :])
    results[i, 5] = cf_matrix[0, 1] / sum(cf_matrix[0, :])
    results[i, 6] = cf_matrix[0, 2] / sum(cf_matrix[0, :])

    results[i, 7] = cf_matrix[1, 0] / sum(cf_matrix[1, :])
    results[i, 8] = cf_matrix[1, 1] / sum(cf_matrix[1, :])
    results[i, 9] = cf_matrix[1, 2] / sum(cf_matrix[1, :])

# Extracting data from the 'results' array
x_values = results[:, 0]

# Create subplots with 3 rows and 1 column
fig = sp.make_subplots(rows=3, cols=1, shared_xaxes=True, subplot_titles=['Pancreas', 'Small Intestine', 'Others'], vertical_spacing=0.18)

# Scatter plot for "pancreas" category
fig.add_trace(go.Scatter(x=x_values, y=results[:, 4], mode='lines', line={'color': '#6C8EBF', 'dash': 'solid'},
                         name='Identified as "pancreas"', legendgroup='pancreas', legendgrouptitle_text='Proportion of "pancreas"'), row=1, col=1)
fig.add_trace(go.Scatter(x=x_values, y=results[:, 5], mode='lines', line={'color': '#6C8EBF', 'dash': 'dash'},
                         name='Misclassified as "small intestine"', legendgroup='pancreas'), row=1, col=1)
fig.add_trace(go.Scatter(x=x_values, y=results[:, 6], mode='lines', line={'color': '#6C8EBF', 'dash': 'dashdot'},
                         name='Labeled as "unsure"', legendgroup='pancreas'), row=1, col=1)

# Scatter plot for "small intestine" category
fig.add_trace(go.Scatter(x=x_values, y=results[:, 8], mode='lines', line={'color': '#D79B01', 'dash': 'solid'},
                         name='Identified as "small intestine"', legendgroup='si', legendgrouptitle_text='Proportion of "small intestine"'), row=2, col=1)
fig.add_trace(go.Scatter(x=x_values, y=results[:, 7], mode='lines', line={'color': '#D79B01', 'dash': 'dash'},
                         name='Misclassified as "pancreas"', legendgroup='si'), row=2, col=1)
fig.add_trace(go.Scatter(x=x_values, y=results[:, 9], mode='lines', line={'color': '#D79B01', 'dash': 'dashdot'},
                         name='Labeled as "unsure"', legendgroup='si'), row=2, col=1)

# Scatter plot for "others" category
fig.add_trace(go.Scatter(x=x_values, y=results[:, 3], mode='lines', line={'color': '#81B366', 'dash': 'solid'},
                         name='Labeled as "unsure"', legendgroup='others', legendgrouptitle_text='Proportion of "others"'), row=3, col=1)
fig.add_trace(go.Scatter(x=x_values, y=results[:, 1], mode='lines', line={'color': '#81B366', 'dash': 'dash'},
                         name='Misclassified as "pancreas"', legendgroup='others'), row=3, col=1)
fig.add_trace(go.Scatter(x=x_values, y=results[:, 2], mode='lines', line={'color': '#81B366', 'dash': 'dashdot'},
                         name='Misclassified as "small intestine"', legendgroup='others'), row=3, col=1)


# Update axis properties and layout
fig.update_xaxes(showline=True, linecolor='black', gridcolor='lightgrey')
fig.update_yaxes(showline=True, linecolor='black', tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1], range=[0, 1.05], gridcolor='lightgrey')
fig.update_layout(showlegend=True, height=702, width=902, plot_bgcolor='white', margin=dict(l=100,r=100,b=70,t=70), font=dict(color='black'),
                  xaxis_showticklabels=True, xaxis2_showticklabels=True,
                  xaxis_title="Confidence", xaxis2_title="Confidence", xaxis3_title="Confidence",
                  yaxis_title="Proportion", yaxis2_title="Proportion", yaxis3_title="Proportion")

fig.add_annotation(xref='x domain',
                   yref='y domain',
                   x=-0.13,
                   y=1.3,
                   text='A',
                   font=dict(size=30),
                   showarrow=False,
                   row=1, col=1)
fig.add_annotation(xref='x domain',
                   yref='y domain',
                   x=-0.13,
                   y=1.3,
                   text='B',
                   font=dict(size=30),
                   showarrow=False,
                   row=2, col=1)
fig.add_annotation(xref='x domain',
                   yref='y domain',
                   x=-0.13,
                   y=1.3,
                   text='C',
                   font=dict(size=30),
                   showarrow=False,
                   row=3, col=1)

fig.show()

In [42]:
for origin in range(2):
    true_pos_rates = []
    false_pos_rates = []
    for i in range(y_pred_patient_level_complete.shape[1]):
        # Initialize confusion matrix components
        true_pos, false_pos, true_neg, false_neg = 0, 0, 0, 0

        for j in range(y_pred_patient_level_complete.shape[0]):
            true_label = y_true_patient_level_complete[j]
            pred_label = y_pred_patient_level_complete[j][i]

            # Update confusion matrix counts
            if true_label == origin:
                if pred_label == origin:
                    true_pos += 1
                else:
                    false_neg += 1
            else:
                if pred_label == origin:
                    false_pos += 1
                else:
                    true_neg += 1

        # Compute rates while avoiding division by zero
        tpr = true_pos / (true_pos + false_neg) 
        fpr = false_pos / (true_neg + false_pos)

        true_pos_rates.append(tpr)
        false_pos_rates.append(fpr)

    
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=false_pos_rates,
        y=true_pos_rates,
        mode='markers',
        line=dict(color='blue', width=3),
        marker=dict(size=10, color='blue', symbol='circle')
    ))

    titles = {0: "Pancreas vs. (Small Intestine & Others)",
              1: "Small intestine vs. (Pancreas & Others)"}
    fig.update_layout(
        title=dict(
            text=titles[origin],
            x=0.5,
            xanchor='center',
            font=dict(size=24, color='black')
        ),
        xaxis=dict(
            title="False Positive Rate",
            range=[-0.01, 1.02],
            gridcolor='lightgray',
            title_font=dict(size=18),
            tickfont=dict(size=14)
        ),
        yaxis=dict(
            title="True Positive Rate",
            range=[-0.01, 1.02],
            gridcolor='lightgray',
            title_font=dict(size=18),
            tickfont=dict(size=14)
        ),
        legend=dict(
            x=0.7, y=0.2,
            bgcolor='rgba(255, 255, 255, 0.7)',
            bordercolor='black',
            borderwidth=1,
            font=dict(size=14)
        ),
        template='plotly_white',
        width=800,
        height=600
    )
    
    fig.update_xaxes(showgrid=True, zeroline=True, zerolinewidth=1, zerolinecolor='lightgray')
    fig.update_yaxes(showgrid=True, zeroline=True, zerolinewidth=1, zerolinecolor='lightgray')

    fig.show(dpi=500)
