# Visualize models performance 

In [None]:
import numpy as np
import pandas as pd
import anndata as adata

from tqdm import tqdm
from pathlib import Path
from typing import Tuple, List, Dict

import plotly.graph_objects as go
import plotly as plotly
import plotly.express as px
import plotly.figure_factory as ff
from plotly.subplots import make_subplots

from sklearn.metrics import mean_squared_error, confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity as skl_cosine

from scipy.stats import pearsonr
from scipy.spatial.distance import cosine as scipy_cosine
from scipy.spatial.distance import braycurtis, cdist
from math import sqrt

%load_ext blackcellmagic

In [None]:
# Prefix to visualizations folder
viz_prefix = "???/deconvolution_benchmarking/visualizations"

# Prefix to the experiment we're plotting
prefix = "???/deconvolution_benchmarking/01_purity_levels_experiment/include_normal_epithelial"

# Tumour purity levels
purity_levels = np.arange(0.05, 1, 0.05).round(3).tolist()

# Major cell types
c_types = [
    "Cancer Epithelial",
    "Normal Epithelial",
    "T-cells",
    "B-cells",
    "Myeloid",
    "CAFs",
    "Endothelial",
    "PVL",
    "Plasmablasts",
]

### Load groundtruth

In [None]:
# Load truth.csv
truth_df = pd.read_csv(
    Path(prefix).joinpath("data/results/truth.csv"), sep="\t", index_col=0
)
truth_df = truth_df[c_types]

# Pivot longer for when we need it
truth_copy_df = truth_df.copy().sample(frac=0.05, random_state=41)
truth_copy_df["purity_level"] = truth_copy_df["Cancer Epithelial"]

pivot_truth_df = (
    truth_copy_df.reset_index()
    .melt(id_vars=["index", "purity_level"], value_vars=c_types)
    .rename(columns={"index": "mixture_id", "variable": "cell_type", "value": "truth"})
    .set_index(["mixture_id", "cell_type"])
)

### Extract colour pallete

In [None]:
# Extract colour pallete
ctype_colour_pallete_df = pd.read_csv(
    Path(prefix).joinpath("data/Whole_miniatlas_colour_pallete.csv"), sep="\t"
)

# Convert to dictionary
ctype_colour_pallete_d = {
    row["all_celltype"]: {"fill": row["fill"], "line": row["line"]}
    for i, row in ctype_colour_pallete_df.iterrows()
}

## [Fig]. Box plot of CPM across purity levels

In [None]:
# Read predictions for CPM
method = "cpm"

res_df = pd.read_csv(
    Path(prefix).joinpath(f"data/results/{method}.csv"), sep="\t", index_col=0
)
res_df = res_df[c_types]
res_df[res_df < 0] = 0

# Clone a version of truth_df and make a new column
truth_copy_df = truth_df.copy()
truth_copy_df["Purity Level"] = truth_copy_df["Cancer Epithelial"]

# Pivot predictions and groundtruth
pivot_truth_df = (
    truth_copy_df.reset_index()
    .melt(id_vars=["index", "Purity Level"])
    .rename(
        columns={
            "index": "Mixture ID",
            "variable": "Cell Type",
            "value": "Actual Proportion",
        }
    )
)
pivot_res_df = (
    res_df.reset_index()
    .melt(id_vars=["index"])
    .rename(
        columns={
            "index": "Mixture ID",
            "variable": "Cell Type",
            "value": "Predicted Proportion",
        }
    )
)

# Merge pivoted predictions and groundutrht
pivot_truth_res_df = pivot_res_df.merge(
    pivot_truth_df, on=["Mixture ID", "Cell Type"], how="outer"
)
pivot_truth_res_df.set_index(["Mixture ID"], inplace=True)

# Filter by reduced tumour purity levels
reduced_purity_levels = np.arange(0.05, 1, 0.15).round(3).tolist()
pivot_truth_res_df = pivot_truth_res_df[
    pivot_truth_res_df["Purity Level"].isin(reduced_purity_levels)
]

# Multiply actual and predicted proportions by 100
pivot_truth_res_df["Predicted Proportion"] = (
    pivot_truth_res_df["Predicted Proportion"] * 100
)
pivot_truth_res_df["Actual Proportion"] = pivot_truth_res_df["Actual Proportion"] * 100

In [None]:
# Save to source data
pivot_truth_res_df.to_csv(
    Path(viz_prefix).joinpath("source_data/supp_figure_5.tsv"), sep="\t"
)

In [None]:
fig = px.box(
    pivot_truth_res_df,
    x="Purity Level",
    y="Predicted Proportion",
    color="Cell Type",
    color_discrete_map={i: v["line"] for i, v in ctype_colour_pallete_d.items()},
)

# Update trace properties for main plots
fig.update_traces(marker=dict(size=2.5), line=dict(width=0.8))

# Update axes of main plot
fig.update_yaxes(
    title="Predicted proportion (%)",
    linecolor="black",
    linewidth=0.8,
    ticks="outside",
    # showticklabels=True, # Hide this so subplot share the same ticklabels
    ticklen=3,
    tickfont_size=9,
    title_font_size=10,
    title_standoff=5,
    showgrid=True,
    gridwidth=0.75,
    gridcolor="lightgray",
    range=[0, 26],
    dtick=5,
)
fig.update_xaxes(
    title="Tumour purity levels (%)",
    linecolor="black",
    linewidth=0.8,
    ticks="outside",
    ticklen=3,
    # showticklabels=True, # Hide this so subplot share the same ticklabels
    tickfont_size=9,
    title_font_size=10,
    title_standoff=5,
)

# Tight margin
fig.update_layout(
    margin=dict(t=0, l=0, r=0, b=0),
)

fig["layout"].update(
    plot_bgcolor="rgba(0,0,0,0)", showlegend=False, font=dict(size=9, color="black")
)

fig.write_image(
    Path(viz_prefix)
    .joinpath("figures/supp_figures/supp_fig_5.svg")
    .with_suffix(".svg"),
    height=300,
    width=600,
    scale=5,
)