# Visualize models performance 

In [None]:
import numpy as np
import pandas as pd
import anndata as adata

from tqdm import tqdm
from pathlib import Path
from typing import Tuple, List, Dict

import plotly.graph_objects as go
import plotly as plotly
import plotly.express as px
import plotly.figure_factory as ff
from plotly.subplots import make_subplots

from sklearn.metrics import mean_squared_error
from sklearn.metrics.pairwise import cosine_similarity as skl_cosine

from scipy.stats import pearsonr
from scipy.spatial.distance import cosine as scipy_cosine
from math import sqrt

%load_ext blackcellmagic

In [None]:
# Prefix to visualizations folder
viz_prefix = "???/deconvolution_benchmarking/visualizations"

# Prefix to the experiment we're plotting
prefix = "???/deconvolution_benchmarking/02_normal_epithelial_lineages_experiment"

# Tumour purity levels
purity_levels = np.arange(0.45, 0.65, 0.05).round(3).tolist()

# Cell types
c_types = [
    "Endothelial",
    "CAFs",
    "PVL",
    "B-cells",
    "T-cells",
    "Myeloid",
    "Plasmablasts",
    "Cancer Epithelial",
    "Luminal Progenitors",
    "Mature Luminal",
    "Myoepithelial",
]

mol_subtypes_color_map = {
    "HER2+": "#E07F80",  # red
    "ER+": "#95AAD3",  # blue
    "TNBC": "#94C47D",  # green
}

mol_subtypes = ["HER2+", "TNBC", "ER+"]

# Methods order is universal
methods_order = ["BayesPrism", "Scaden", "MuSiC", "hspe", "DWLS", "CBX", "EPIC"]

## Load groundtruth

In [None]:
# Load truth.csv
truth_df = pd.read_csv(
    Path(prefix).joinpath("data/results/truth.csv"), sep="\t", index_col=0
)
truth_df = truth_df[c_types]
truth_df.head(2)

## Extract colour pallete

In [None]:
# Load lineage mapping of all cell types in the data
ctype_colour_pallete_df = pd.read_csv(
    Path(prefix).joinpath("data/Whole_miniatlas_colour_pallete.csv"), sep="\t"
)

# Convert to dictionary
ctype_colour_pallete = {
    row["all_celltype"]: {"fill": row["fill"], "line": row["line"]}
    for i, row in ctype_colour_pallete_df.iterrows()
}

## Extract clinical subtype metadata

In [None]:
# Collect information about patient subtypes
meta_df = pd.read_csv(
    Path(prefix).joinpath("data/Whole_miniatlas_meta_normal_lineages.csv"),
    index_col=0,
    sep="\t",
)
patient_subtype_df = (
    meta_df[["Patient", "subtype"]]
    .value_counts()
    .to_frame()
    .reset_index()
    .drop([0], axis=1)
)

## [Fig 3b]. Boxplots of differences across top 6 methods for each molecular subtype

In [None]:
# Fix purity levels to 50%
pur_lvl = 0.5

methods = ["bprism_v2", "scaden", "cbx", "music", "dwls", "hspe", "epic"]

# We only want to plot cancer and normal linages
plot_c_types = [
    "Cancer Epithelial",
    "Luminal Progenitors",
    "Mature Luminal",
    "Myoepithelial",
]

In [None]:
pivot_diff_l = []

for method in tqdm(methods):
    res_df = pd.read_csv(
        Path(prefix).joinpath(f"data/results/{method}.csv"), sep="\t", index_col=0
    )

    subset_truth_df = truth_df[truth_df["Cancer Epithelial"] == pur_lvl].sort_index()
    subset_res_df = res_df[res_df.index.isin(subset_truth_df.index)].sort_index()

    diff_df = (
        subset_res_df[subset_truth_df.columns].sort_index()
        - subset_truth_df.sort_index()
    )
    diff_df["Patient"] = [i.split("_")[0] for i in diff_df.index]

    pivot_diff_df = diff_df.melt(id_vars=["Patient"], value_vars=c_types).rename(
        columns={"variable": "Cell Type", "value": "Error"}
    )

    pivot_diff_df["Method"] = method
    pivot_diff_l.append(pivot_diff_df)

boxplot_diff_df = pd.concat(pivot_diff_l, axis=0)
boxplot_diff_df["Error Pct"] = (boxplot_diff_df["Error"].astype(float) * 100).round(4)

# Merge with metadata to get subtypes
boxplot_diff_df = boxplot_diff_df.merge(patient_subtype_df, on="Patient", how="inner")

# Replace method names
boxplot_diff_df.replace(
    {
        "scaden": "Scaden",
        "dwls": "DWLS",
        "cbx": "CBX",
        "epic": "EPIC",
        "music": "MuSiC",
        "bisque": "Bisque",
        "bprism_v2": "BayesPrism",
    },
    inplace=True,
)

In [None]:
# Save source data
boxplot_diff_df[boxplot_diff_df["Cell Type"].isin(plot_c_types)].to_csv(
    Path(viz_prefix).joinpath("source_data/figure_3b.tsv"), sep="\t"
)

In [None]:
fig = px.box(
    boxplot_diff_df[boxplot_diff_df["Cell Type"].isin(plot_c_types)],
    x="Method",
    y="Error Pct",
    facet_col="subtype",
    facet_col_spacing=0.02,
    color="Cell Type",
    color_discrete_map={i: v["line"] for i, v in ctype_colour_pallete.items()},
    notched=True,
    category_orders={"Method": methods_order},
)

# Update all traces
for c_type in c_types:
    fig.update_traces(
        fillcolor={i: v["fill"] for i, v in ctype_colour_pallete.items()}[c_type],
        line=dict(width=1),
        marker=dict(size=2),
        opacity=1,
        selector=dict(name=c_type),
    )

# Update axes
fig.update_xaxes(
    title="Method",
    title_font_size=13,
    title_standoff=3,
    ticks="outside",
    showticklabels=True,
    tickangle=90,
    tickmode="array",
    tickfont_size=12,
    ticklen=4,
    tickwidth=1,
    linecolor="black",
    linewidth=1,
    side="bottom",
)
fig.update_yaxes(
    linecolor="black",
    linewidth=1,
    range=[-55, 101],
    tickfont_size=13,
    ticklen=4,
    tickwidth=1,
    dtick=25,
    ticks="outside",
    showgrid=True,
    gridwidth=0.5,
    gridcolor="lightgray",
)

# Update title of y-axis of the first subplot
fig.update_yaxes(
    title="Raw prediction errors (%)",
    col=1,
    title_font_size=12,
    title_standoff=2,
)

fig["layout"].update(
    boxmode="group",
    boxgroupgap=0.05,
    showlegend=True,
    plot_bgcolor="rgba(0,0,0,0)",
    legend=dict(
        title="Cell type",
        title_font_size=13,
        font_size=13,
        orientation="h",
        yanchor="bottom",
        y=-0.35,
        xanchor="center",
        x=0.5,
    ),
    margin=dict(t=20, l=0, r=0, b=0),
    font_color="black",
)

# Zero line
fig.add_hline(y=0, line_dash="dot", line_width=0.3)

# Format annotations
fig.for_each_annotation(
    lambda a: a.update(text=a.text.split("=")[-1], y=1, font_size=13)
)

# Save into png and svg
fig.write_image(
    Path(viz_prefix)
    .joinpath("figures/main_figures/main_fig_3b.svg")
    .with_suffix(".svg"),
    width=950,
    height=525,
    scale=5,
)

## [Fig 3a]. RMSE over tumour purity level across cell type

In [None]:
# Fix purity levels to 50%
pur_lvl = 0.5

methods = ["bprism_v2", "scaden", "music", "cbx", "dwls", "hspe", "epic"]

normal_lineages = ["Luminal Progenitors", "Mature Luminal", "Myoepithelial"]

# Set cell types order
c_types_order = [
    "Plasmablasts",
    "PVL",
    "CAFs",
    "Endothelial",
    "Myeloid",
    "B-cells",
    "T-cells",
    "Myoepithelial",
    "Mature Luminal",
    "Luminal Progenitors",
    "Cancer Epithelial",
]

In [None]:
def plot_rmse_heatmap(
    avg_diff_df: pd.DataFrame,
    outfile_name: str,
    metric: str,
    metric_suffix: str,
    colorscale: str,
    c_types: List,
    plot_w: int,
    plot_h: int,
    z_range: List = [0, 50],
    dticks: int = 10,
    auto_open: bool = True,
) -> None:
    """Plot heatmap of Mean Absolute Error across tumour purity levels

    Args:
        - avg_diff_df:        DataFrame holding MAE over tumour purity levels
        - outfile_name:       name of output html and png files
        - c_types:            cell types in a specific order we'd like to appear on the y-axis
        - z_range:            Maximum error (between Scaden, CBX and EPIC) is ~56%
                              so we only need to set maximum zaxis to 50%.
                              This ensure extreme errors are very red on the scale
        - auto_open:          Whether to open html after creation or not

    """
    # Create annotated heatmap object with plotly
    fig = ff.create_annotated_heatmap(
        z=avg_diff_df.values,
        # Annotate each cell in the heatmap with the corresponding labels
        annotation_text=avg_diff_df.values.round(2).astype(str),
        zmin=z_range[0],
        zmax=z_range[1],
        x=avg_diff_df.columns.tolist(),  # Rows are methods
        y=avg_diff_df.index.tolist(),  # Columns are cell types
        colorscale=colorscale,
        showscale=False,
        hoverinfo="text",
        text=avg_diff_df.values.round(4),
        colorbar=dict(
            title=metric,
            ticks="outside",
            ticksuffix=metric_suffix,
            dtick=dticks,
            orientation="h",
            ticklen=2,
        ),
    )

    # Update axes
    fig.update_xaxes(
        title="Methods",
        title_font_size=8,
        title_standoff=1,
        ticks="outside",
        tickfont_size=7,
        showticklabels=True,
        ticklen=2,
        tickwidth=0.5,
        tickangle=90,
        linecolor="black",
        linewidth=0.5,
        side="bottom",
    )
    fig.update_yaxes(
        title="Cell Types",
        title_font_size=8,
        title_standoff=1,
        linecolor="black",
        linewidth=0.5,
        categoryorder="array",
        categoryarray=c_types,  # Order cell types by linages
        ticks="outside",
        showticklabels=True,
        ticklen=2,
        tickwidth=0.5,
        tickfont_size=7,
    )

    # Update layout
    fig["layout"].update(
        margin=dict(t=0, l=0, r=0, b=0),
        font_size=7,
        plot_bgcolor="rgba(0,0,0,0)",
        font_color="black",
    )

    # Save offline mode
    fig.write_image(
        outfile_name,
        width=plot_w,
        height=plot_h,
        scale=5,
    )

In [None]:
all_rmse_l = []

for subtype in ["HER2+", "ER+", "TNBC"]:

    rmse_l = []

    for method in tqdm(methods):
        res_df = pd.read_csv(
            Path(prefix).joinpath(f"data/results/{method}.csv"), sep="\t", index_col=0
        )

        # Clip to 0 and sort columns as truth_df
        res_df.clip(lower=0, inplace=True)
        res_df = res_df[truth_df.columns]

        # Only keep 50% tumour purity samples
        subset_truth_df = truth_df[
            truth_df["Cancer Epithelial"] == pur_lvl
        ].sort_index()
        subset_res_df = res_df[res_df.index.isin(subset_truth_df.index)].sort_index()

        # Get only the molcular subtype we're iterating over
        subset_truth_df["Patient"] = [i.split("_")[0] for i in subset_truth_df.index]
        subset_truth_df = (
            subset_truth_df.reset_index()
            .merge(patient_subtype_df, on="Patient", how="inner")
            .set_index(["index"])
        )
        subset_truth_df = subset_truth_df[
            subset_truth_df["subtype"] == subtype
        ].sort_index()
        subset_res_df = subset_res_df[
            subset_res_df.index.isin(subset_truth_df.index)
        ].sort_index()

        # Drop Patient and subtype columns after we've used them
        subset_truth_df.drop(["Patient", "subtype"], axis=1, inplace=True)

        # Sanity check
        assert subset_res_df.index.equals(subset_truth_df.index)
        assert subset_res_df.columns.equals(subset_truth_df.columns)

        # Calculate RMSE
        method_rmse_df = (
            np.sqrt(np.square(subset_res_df * 100 - subset_truth_df * 100).mean(axis=0))
            .to_frame()
            .rename(columns={0: method})
        )

        # Append to rmse_l
        rmse_l.append(method_rmse_df)

    # Concatenate
    rmse_df = pd.concat(rmse_l, axis=1)

    # Replace method names
    rmse_df.rename(
        columns={
            "scaden": "Scaden",
            "dwls": "DWLS",
            "cbx": "CBX",
            "epic": "EPIC",
            "music": "MuSiC",
            "cpm": "CPM",
            "bprism_v2": "BayesPrism",
        },
        inplace=True,
    )

    # Plot RMSE heatmap
    plot_rmse_heatmap(
        avg_diff_df=rmse_df[methods_order],  # Reorder methods to match with Fig 2a
        outfile_name=f"figures/main_figures/main_fig_3a_{subtype}.svg",
        metric="RMSE",
        metric_suffix="",
        c_types=c_types_order,
        z_range=[0, 50],
        dticks=10,
        colorscale="reds",
        plot_w=250,
        plot_h=175,
    )

    # Collect RMSE and method
    rmse_df = rmse_df.reset_index().rename(columns={"index": "Cell type"})
    rmse_df["pam50_subtype"] = subtype
    all_rmse_l.append(rmse_df)

In [None]:
# Save source data
all_rmse_df = pd.concat(all_rmse_l, axis=0)
all_rmse_df.to_csv(Path(viz_prefix).joinpath("source_data/figure_3a.tsv"), sep="\t")