# Collate model predictions from purity-level partitions

In [None]:
import numpy as np
import pandas as pd
import anndata as adata

from tqdm import tqdm
from pathlib import Path
from typing import Tuple, List, Dict

import plotly.graph_objects as go
import plotly as plotly
import plotly.express as px
import plotly.figure_factory as ff


%load_ext blackcellmagic

In [None]:
# Prefix to visualizations folder
viz_prefix = "???/deconvolution_benchmarking/visualizations"

# Prefix to the experiment we're plotting
prefix = "???/deconvolution_benchmarking/01_purity_levels_experiment/include_normal_epithelial"

# Tumour purity levels
purity_levels = np.arange(0.05, 1, 0.05).round(3).tolist()

# Major cell types
c_types = [
    "Cancer Epithelial",
    "Normal Epithelial",
    "T-cells",
    "B-cells",
    "Myeloid",
    "CAFs",
    "Endothelial",
    "PVL",
    "Plasmablasts",
]

In [None]:
# Extract colour pallete
ctype_colour_pallete_df = pd.read_csv(
    Path(prefix).joinpath("data/Whole_miniatlas_colour_pallete.csv"), sep="\t"
)

# Convert to dictionary
ctype_colour_pallete_d = {
    row["all_celltype"]: {"fill": row["fill"], "line": row["line"]}
    for i, row in ctype_colour_pallete_df.iterrows()
}

In [None]:
# Retrieve subtype informaiton
meta_df = pd.read_csv(
    Path(prefix).joinpath("data/Whole_miniatlas_meta_9_10.csv"), sep="\t", index_col=0
)


# Produce a dataframe containing patient subtypes
subtype_df = (
    meta_df[["Patient", "subtype"]]
    .drop_duplicates()
    .reset_index()
    .drop(["NAME"], axis=1)
)

### Sparse training mixtures simulation method 

We used the Sparse method to simulate training bulk mixtures, which follows 2 steps: 
- Fix percentage of one cell type between [0, 1], i.e. P1
- Percentage of other cell types will sum up to 1 - P1
<br>
This ensures that we have a lot of mixtures where one cell types have high percentages <br>
For the very first experiment (major cell types across purity levels), we generated mixtures using this method and the regular method (used in Menden et al and Cobos et al) <br>
This supplementary figure aims to illustrate the Sparse method produces mixtures with higher variability of cell percentages than the regular method

In [None]:
# Cell-type proportions of mixtures generated using sparse method
sparse_meta_df = pd.read_csv(
    Path(prefix).joinpath("data/train/training_mixt_pror_sparse.tsv"),
    index_col=0,
    sep="\t",
)
# Fill emptycells with 0
sparse_meta_df.fillna(0, inplace=True)
# Re-arrange columns
sparse_meta_df = sparse_meta_df[c_types]
sparse_meta_df = sparse_meta_df.reset_index().rename(columns={"index": "mixtures"})


# Cell-type proportions of mixtures generated using regular method
reg_meta_df = pd.read_csv(
    Path(prefix).joinpath("data/train/training_mixt_pror_regular.tsv"),
    index_col=0,
    sep="\t",
)
# Fill emptycells with 0
reg_meta_df.fillna(0, inplace=True)
# Re-arrange columns
reg_meta_df = reg_meta_df[c_types]
reg_meta_df = reg_meta_df.reset_index().rename(columns={"index": "mixtures"})

#### Plot sparse method

In [None]:
pivot_sparse_meta_df = sparse_meta_df.melt(id_vars=["mixtures"], value_vars=c_types)

pivot_sparse_meta_df.rename(
    columns={"variable": "c_types", "value": "fractions"}, inplace=True
)

pivot_sparse_meta_df["counts"] = (pivot_sparse_meta_df["fractions"] * 500).astype(int)
pivot_sparse_meta_df["Patient"] = [
    i.split("_")[0] for i in pivot_sparse_meta_df["mixtures"]
]

# Merge with subtype and sort patients by subtype
pivot_sparse_meta_df = pivot_sparse_meta_df.merge(
    subtype_df, on=["Patient"], how="inner"
)
pivot_sparse_meta_df["subtype"] = pd.Categorical(
    pivot_sparse_meta_df["subtype"], categories=["HER2+", "TNBC", "ER+"]
)
pivot_sparse_meta_df.sort_values(["subtype"], inplace=True)

In [None]:
# Save source data
pivot_sparse_meta_df.rename(
    columns={
        "mixtures": "Mixture ID",
        "c_types": "Cell type",
        "fractions": "Proportion",
        "counts": "Count",
        "subtype": "PAM50 subtype",
    }
).to_csv(
    Path(viz_prefix).joinpath("source_data/supp_figure_3b.tsv"), sep="\t", index=False
)

In [None]:
fig = px.box(
    pivot_sparse_meta_df,
    x="Patient",
    y="counts",
    color="c_types",
    color_discrete_map={i: v["line"] for i, v in ctype_colour_pallete_d.items()},
    notched=True,
    category_orders={"c_types": c_types},
)
fig.update_traces(quartilemethod="exclusive")

# Update color of all traces
for c_type in c_types:
    fig.update_traces(
        fillcolor={i: v["fill"] for i, v in ctype_colour_pallete_d.items()}[c_type],
        line=dict(width=0.25),
        marker=dict(size=1),
        opacity=1,
        selector=dict(name=c_type),
    )

# Update axes
fig.update_xaxes(
    type="category",
    title="Patient",
    title_standoff=6,
    title_font_size=10,
    ticks="outside",
    ticklen=2,
    showticklabels=True,
    tickmode="array",
    tickwidth=0.5,
    tickangle=-270,
    tickfont_size=9,
    linecolor="black",
    linewidth=0.5,
)
fig.update_yaxes(
    title="Counts",
    title_standoff=6,
    title_font_size=10,
    linecolor="black",
    linewidth=0.5,
    range=[-1, 510],
    dtick=50,
    ticks="outside",
    ticklen=2,
    tickwidth=0.5,
    tickfont_size=9,
    showgrid=True,
    gridwidth=0.5,
    gridcolor="lightgray",
)

# Update layout
fig["layout"].update(
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(size=10, color="black"),
    legend=dict(title="Cell type"),
    showlegend=False,
    margin=dict(t=0, l=0, r=0, b=0),
)

# Save into images
fig.write_image(
    Path(viz_prefix)
    .joinpath("figures/supp_figures/supp_fig_3b.svg")
    .with_suffix(".svg"),
    width=560,
    height=400,
    scale=1,
)

#### Plot regular method

In [None]:
pivot_reg_meta_df = reg_meta_df.melt(id_vars=["mixtures"], value_vars=c_types)
pivot_reg_meta_df.rename(
    columns={"variable": "c_types", "value": "fractions"}, inplace=True
)
pivot_reg_meta_df["counts"] = (pivot_reg_meta_df["fractions"] * 500).astype(int)
pivot_reg_meta_df["Patient"] = [i.split("_")[0] for i in pivot_reg_meta_df["mixtures"]]

# Merge with subtype and sort patients by subtype
pivot_reg_meta_df = pivot_reg_meta_df.merge(subtype_df, on=["Patient"], how="inner")
pivot_reg_meta_df["subtype"] = pd.Categorical(
    pivot_reg_meta_df["subtype"], categories=["HER2+", "TNBC", "ER+"]
)
pivot_reg_meta_df.sort_values(["subtype"], inplace=True)

In [None]:
# Save source data
pivot_reg_meta_df.rename(
    columns={
        "mixtures": "Mixture ID",
        "c_types": "Cell type",
        "fractions": "Proportion",
        "counts": "Count",
        "subtype": "PAM50 subtype",
    }
).to_csv(
    Path(viz_prefix).joinpath("source_data/supp_figure_3a.tsv"), sep="\t", index=False
)

In [None]:
fig = px.box(
    pivot_reg_meta_df,
    x="Patient",
    y="counts",
    color="c_types",
    color_discrete_map={i: v["line"] for i, v in ctype_colour_pallete_d.items()},
    notched=True,
    category_orders={"c_types": c_types},
)
fig.update_traces(quartilemethod="exclusive")

# Update color of all traces
for c_type in c_types:
    fig.update_traces(
        fillcolor={i: v["fill"] for i, v in ctype_colour_pallete_d.items()}[c_type],
        line=dict(width=0.25),
        marker=dict(size=1),
        opacity=1,
        selector=dict(name=c_type),
    )

# Update axes
fig.update_xaxes(
    type="category",
    title="Patient",
    title_standoff=6,
    title_font_size=10,
    ticks="outside",
    ticklen=2,
    showticklabels=True,
    tickmode="array",
    tickwidth=0.5,
    tickangle=-270,
    tickfont_size=9,
    linecolor="black",
    linewidth=0.5,
)
fig.update_yaxes(
    title="Counts",
    title_standoff=6,
    title_font_size=10,
    linecolor="black",
    linewidth=0.5,
    range=[-1, 510],
    dtick=50,
    ticks="outside",
    ticklen=2,
    tickwidth=0.5,
    tickfont_size=9,
    showgrid=True,
    gridwidth=0.5,
    gridcolor="lightgray",
)

# Update layout
fig["layout"].update(
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(size=10, color="black"),
    legend=dict(title="Cell type"),
    showlegend=False,
    margin=dict(t=0, l=0, r=0, b=0),
)

# Save into images
fig.write_image(
    Path(viz_prefix)
    .joinpath("figures/supp_figures/supp_fig_3a.svg")
    .with_suffix(".svg"),
    width=560,
    height=400,
    scale=5,
)