# Analyze repertoire stats model performance on validation set

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Union

%matplotlib inline
import seaborn as sns
import genetools
from IPython.display import display, Markdown

In [2]:
from malid import config, helpers, logger
import crosseval
from malid.trained_model_wrappers import RepertoireClassifier

# Analyze

In [3]:
for gene_locus in config.gene_loci_used:
    for target_obs_col in config.classification_targets:
        try:
            target_obs_col.confirm_compatibility_with_gene_locus(gene_locus)
            target_obs_col.confirm_compatibility_with_cross_validation_split_strategy(
                config.cross_validation_split_strategy
            )
        except Exception as err:
            # Skip invalid combinations
            logger.warning(f"{err}. Skipping.")
            continue

        models_base_dir = RepertoireClassifier._get_model_base_dir(
            gene_locus=gene_locus,
            target_obs_column=target_obs_col,
            sample_weight_strategy=config.sample_weight_strategy,
        )  # should already exist

        output_base_dir = RepertoireClassifier._get_output_base_dir(
            gene_locus=gene_locus,
            target_obs_column=target_obs_col,
            sample_weight_strategy=config.sample_weight_strategy,
        )  # might not yet exist
        output_base_dir.mkdir(parents=True, exist_ok=True)  # create if needed

        model_output_prefix = models_base_dir / "train_smaller_model"
        results_output_prefix = output_base_dir / "train_smaller_model"

        try:
            logger.info(
                f"{gene_locus}, {target_obs_col} from {model_output_prefix} to {results_output_prefix}"
            )

            ## Load and summarize
            experiment_set = crosseval.ExperimentSet.load_from_disk(
                output_prefix=model_output_prefix
            )

            # Remove global fold (we trained global fold model, but now get evaluation scores on cross-validation folds only)
            # TODO: make kdict support: del self.model_outputs[:, fold_id]
            for key in experiment_set.model_outputs[:, -1].keys():
                logger.debug(f"Removing {key} (global fold)")
                del experiment_set.model_outputs[key]

            experiment_set_global_performance = experiment_set.summarize()
            experiment_set_global_performance.export_all_models(
                func_generate_classification_report_fname=lambda model_name: f"{results_output_prefix}.classification_report.{model_name}.txt",
                func_generate_confusion_matrix_fname=lambda model_name: f"{results_output_prefix}.confusion_matrix.{model_name}.png",
                dpi=72,
            )
            combined_stats = (
                experiment_set_global_performance.get_model_comparison_stats(sort=True)
            )
            combined_stats.to_csv(
                f"{results_output_prefix}.compare_model_scores.tsv",
                sep="\t",
            )
            display(
                Markdown(
                    f"## {gene_locus}, {target_obs_col} from {model_output_prefix} to {results_output_prefix}"
                )
            )
            display(combined_stats)

            ## Review binary misclassifications: Binary prediction vs ground truth
            # For binary case, make new confusion matrix of actual disease label (y) vs predicted y_binary
            # (But this changes global score metrics)
            if (
                target_obs_col.value.is_target_binary_for_repertoire_composition_classifier
            ):
                # this is a binary healthy/sick classifier
                # re-summarize with different ground truth label
                experiment_set.summarize(
                    global_evaluation_column_name=target_obs_col.value.confusion_matrix_expanded_column_name
                ).export_all_models(
                    func_generate_classification_report_fname=lambda model_name: f"{results_output_prefix}.classification_report.{model_name}.binary_vs_ground_truth.txt",
                    func_generate_confusion_matrix_fname=lambda model_name: f"{results_output_prefix}.confusion_matrix.{model_name}.binary_vs_ground_truth.png",
                    confusion_matrix_pred_label="Predicted binary label",
                    dpi=72,
                )

            ## also create the “coefficient variability” plot, over all the CV folds
            for (
                model_name,
                model_global_performance,
            ) in experiment_set_global_performance.model_global_performances.items():
                # get feature importances for each fold
                feature_importances: Union[
                    pd.DataFrame, None
                ] = model_global_performance.feature_importances

                if feature_importances is not None:
                    # feature importances are available for this model
                    fig = plt.figure(figsize=(9, 9))
                    sns.boxplot(data=feature_importances.abs(), orient="h")
                    plt.title(
                        f"Feature importance (absolute value) variability: {model_name}"
                    )
                    plt.tight_layout()
                    genetools.plots.savefig(
                        fig,
                        f"{results_output_prefix}.feature_importances.{model_name}.png",
                        dpi=72,
                    )
                    plt.close(fig)

        except Exception as err:
            logger.exception(f"{gene_locus}, {target_obs_col} failed with error: {err}")

{"message": "GeneLocus.BCR, TargetObsColumnEnum.disease from /home/maxim/code/immune-repertoire-classification/data/data_v_20231027/in_house_peak_disease_timepoints/repertoire_stats/BCR/disease/train_smaller_model to /home/maxim/code/immune-repertoire-classification/out/in_house_peak_disease_timepoints/repertoire_stats/BCR/disease/train_smaller_model", "time": "2023-11-04T01:56:10.061205"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.672589"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.717026"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.741446"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.753578"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.766884"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.790708"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.814702"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.826922"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.840173"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.864502"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.888318"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:19.900710"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.214594"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.239465"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.263599"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.275928"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.289634"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.313601"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.337222"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.349448"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.362823"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.386788"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.410348"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:20.422432"}


## GeneLocus.BCR, TargetObsColumnEnum.disease from /home/maxim/code/immune-repertoire-classification/data/data_v_20231027/in_house_peak_disease_timepoints/repertoire_stats/BCR/disease/train_smaller_model to /home/maxim/code/immune-repertoire-classification/out/in_house_peak_disease_timepoints/repertoire_stats/BCR/disease/train_smaller_model

Unnamed: 0,ROC-AUC (weighted OvO) per fold,ROC-AUC (macro OvO) per fold,au-PRC (weighted OvO) per fold,au-PRC (macro OvO) per fold,Accuracy per fold,MCC per fold,Accuracy global,MCC global,sample_size,n_abstentions,sample_size including abstentions,abstention_rate,missing_classes
elasticnet_cv0.25,0.929 +/- 0.003 (in 3 folds),0.935 +/- 0.004 (in 3 folds),0.930 +/- 0.006 (in 3 folds),0.938 +/- 0.009 (in 3 folds),0.747 +/- 0.051 (in 3 folds),0.678 +/- 0.062 (in 3 folds),0.746,0.677,410,0,410,0.0,False
elasticnet_cv,0.929 +/- 0.003 (in 3 folds),0.935 +/- 0.004 (in 3 folds),0.930 +/- 0.006 (in 3 folds),0.938 +/- 0.007 (in 3 folds),0.739 +/- 0.046 (in 3 folds),0.669 +/- 0.056 (in 3 folds),0.739,0.668,410,0,410,0.0,False
elasticnet_sklearn_with_lambdamax,0.929 +/- 0.003 (in 3 folds),0.935 +/- 0.004 (in 3 folds),0.930 +/- 0.005 (in 3 folds),0.938 +/- 0.007 (in 3 folds),0.739 +/- 0.046 (in 3 folds),0.669 +/- 0.056 (in 3 folds),0.739,0.668,410,0,410,0.0,False
elasticnet0.25_sklearn_with_lambdamax,0.929 +/- 0.003 (in 3 folds),0.935 +/- 0.004 (in 3 folds),0.930 +/- 0.007 (in 3 folds),0.938 +/- 0.009 (in 3 folds),0.747 +/- 0.051 (in 3 folds),0.678 +/- 0.062 (in 3 folds),0.746,0.677,410,0,410,0.0,False
elasticnet0.75_sklearn_with_lambdamax,0.929 +/- 0.001 (in 3 folds),0.936 +/- 0.002 (in 3 folds),0.931 +/- 0.004 (in 3 folds),0.939 +/- 0.006 (in 3 folds),0.739 +/- 0.049 (in 3 folds),0.670 +/- 0.059 (in 3 folds),0.739,0.669,410,0,410,0.0,False
elasticnet_cv0.75,0.928 +/- 0.001 (in 3 folds),0.936 +/- 0.002 (in 3 folds),0.930 +/- 0.004 (in 3 folds),0.938 +/- 0.006 (in 3 folds),0.739 +/- 0.049 (in 3 folds),0.670 +/- 0.059 (in 3 folds),0.739,0.669,410,0,410,0.0,False
lasso_cv,0.927 +/- 0.002 (in 3 folds),0.934 +/- 0.001 (in 3 folds),0.927 +/- 0.004 (in 3 folds),0.935 +/- 0.005 (in 3 folds),0.720 +/- 0.035 (in 3 folds),0.645 +/- 0.042 (in 3 folds),0.72,0.645,410,0,410,0.0,False
lasso_sklearn_with_lambdamax,0.927 +/- 0.002 (in 3 folds),0.934 +/- 0.002 (in 3 folds),0.928 +/- 0.004 (in 3 folds),0.936 +/- 0.005 (in 3 folds),0.722 +/- 0.031 (in 3 folds),0.648 +/- 0.037 (in 3 folds),0.722,0.648,410,0,410,0.0,False
ridge_sklearn_with_lambdamax,0.926 +/- 0.002 (in 3 folds),0.934 +/- 0.004 (in 3 folds),0.929 +/- 0.006 (in 3 folds),0.936 +/- 0.010 (in 3 folds),0.744 +/- 0.036 (in 3 folds),0.676 +/- 0.045 (in 3 folds),0.744,0.676,410,0,410,0.0,False
ridge_cv,0.926 +/- 0.002 (in 3 folds),0.934 +/- 0.004 (in 3 folds),0.929 +/- 0.006 (in 3 folds),0.937 +/- 0.010 (in 3 folds),0.744 +/- 0.036 (in 3 folds),0.676 +/- 0.045 (in 3 folds),0.744,0.676,410,0,410,0.0,False


{"message": "GeneLocus.TCR, TargetObsColumnEnum.disease from /home/maxim/code/immune-repertoire-classification/data/data_v_20231027/in_house_peak_disease_timepoints/repertoire_stats/TCR/disease/train_smaller_model to /home/maxim/code/immune-repertoire-classification/out/in_house_peak_disease_timepoints/repertoire_stats/TCR/disease/train_smaller_model", "time": "2023-11-04T01:56:24.171401"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.728914"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.814801"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.838733"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.851128"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.864616"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.889207"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.913450"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.925602"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.939011"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.963211"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.986783"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:37.998967"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.239082"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.263429"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.287179"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.299412"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.312493"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.336043"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.359249"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.371084"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.384085"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.407660"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.430668"}


{"message": "Inserting phantom class: Healthy/Background", "time": "2023-11-04T01:56:38.442728"}


## GeneLocus.TCR, TargetObsColumnEnum.disease from /home/maxim/code/immune-repertoire-classification/data/data_v_20231027/in_house_peak_disease_timepoints/repertoire_stats/TCR/disease/train_smaller_model to /home/maxim/code/immune-repertoire-classification/out/in_house_peak_disease_timepoints/repertoire_stats/TCR/disease/train_smaller_model

Unnamed: 0,ROC-AUC (weighted OvO) per fold,ROC-AUC (macro OvO) per fold,au-PRC (weighted OvO) per fold,au-PRC (macro OvO) per fold,Accuracy per fold,MCC per fold,Accuracy global,MCC global,sample_size,n_abstentions,sample_size including abstentions,abstention_rate,missing_classes
lasso_cv,0.941 +/- 0.014 (in 3 folds),0.942 +/- 0.017 (in 3 folds),0.933 +/- 0.014 (in 3 folds),0.935 +/- 0.016 (in 3 folds),0.733 +/- 0.036 (in 3 folds),0.667 +/- 0.053 (in 3 folds),0.733,0.665,367,0,367,0.0,False
lasso_sklearn_with_lambdamax,0.941 +/- 0.014 (in 3 folds),0.942 +/- 0.017 (in 3 folds),0.933 +/- 0.014 (in 3 folds),0.935 +/- 0.016 (in 3 folds),0.733 +/- 0.036 (in 3 folds),0.667 +/- 0.053 (in 3 folds),0.733,0.665,367,0,367,0.0,False
elasticnet_cv0.75,0.941 +/- 0.014 (in 3 folds),0.941 +/- 0.017 (in 3 folds),0.933 +/- 0.014 (in 3 folds),0.935 +/- 0.016 (in 3 folds),0.733 +/- 0.039 (in 3 folds),0.665 +/- 0.059 (in 3 folds),0.733,0.663,367,0,367,0.0,False
elasticnet0.75_sklearn_with_lambdamax,0.941 +/- 0.014 (in 3 folds),0.941 +/- 0.017 (in 3 folds),0.933 +/- 0.014 (in 3 folds),0.935 +/- 0.016 (in 3 folds),0.733 +/- 0.039 (in 3 folds),0.665 +/- 0.059 (in 3 folds),0.733,0.663,367,0,367,0.0,False
elasticnet_cv_lambda1se,0.941 +/- 0.013 (in 3 folds),0.942 +/- 0.016 (in 3 folds),0.932 +/- 0.012 (in 3 folds),0.936 +/- 0.013 (in 3 folds),0.722 +/- 0.019 (in 3 folds),0.654 +/- 0.033 (in 3 folds),0.722,0.652,367,0,367,0.0,False
elasticnet_cv0.25_lambda1se,0.941 +/- 0.012 (in 3 folds),0.942 +/- 0.015 (in 3 folds),0.933 +/- 0.011 (in 3 folds),0.937 +/- 0.012 (in 3 folds),0.717 +/- 0.019 (in 3 folds),0.648 +/- 0.036 (in 3 folds),0.717,0.646,367,0,367,0.0,False
elasticnet_cv0.25,0.940 +/- 0.014 (in 3 folds),0.940 +/- 0.017 (in 3 folds),0.933 +/- 0.014 (in 3 folds),0.935 +/- 0.016 (in 3 folds),0.733 +/- 0.039 (in 3 folds),0.666 +/- 0.059 (in 3 folds),0.733,0.664,367,0,367,0.0,False
elasticnet_cv,0.940 +/- 0.014 (in 3 folds),0.941 +/- 0.017 (in 3 folds),0.932 +/- 0.013 (in 3 folds),0.935 +/- 0.015 (in 3 folds),0.736 +/- 0.039 (in 3 folds),0.669 +/- 0.059 (in 3 folds),0.736,0.667,367,0,367,0.0,False
elasticnet_sklearn_with_lambdamax,0.940 +/- 0.014 (in 3 folds),0.941 +/- 0.017 (in 3 folds),0.932 +/- 0.013 (in 3 folds),0.935 +/- 0.015 (in 3 folds),0.736 +/- 0.039 (in 3 folds),0.669 +/- 0.059 (in 3 folds),0.736,0.667,367,0,367,0.0,False
elasticnet0.25_sklearn_with_lambdamax,0.940 +/- 0.014 (in 3 folds),0.940 +/- 0.017 (in 3 folds),0.933 +/- 0.014 (in 3 folds),0.936 +/- 0.015 (in 3 folds),0.736 +/- 0.039 (in 3 folds),0.669 +/- 0.059 (in 3 folds),0.736,0.667,367,0,367,0.0,False
