In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

import plotly
import plotly.express as px
from plotly.subplots import make_subplots
from plotly import graph_objects as go

%load_ext blackcellmagic

In [None]:
# Prefix to visualizations folder
viz_prefix = "???/deconvolution_benchmarking/visualizations"

# Prefix to the experiment we're plotting
prefix = "???/deconvolution_benchmarking/01_purity_levels_experiment/include_normal_epithelial"

In [None]:
# All patient IDs
p_ids = [
    "CID4471",
    "CID44971",
    "CID4066",
    "CID3586",
    "CID3941",
    "CID3963",
    "CID44041",
    "CID4530N",
    "CID3838",
    "CID3946",
    "CID4040",
    "CID4461",
    "CID44991",
    "CID45171",
    "CID4535",
    "CID3948",
    "CID4398",
    "CID4463",
    "CID4495",
    "CID4513",
    "CID4465",
    "CID4067",
    "CID4290A",
    "CID3921",
    "CID4523",
    "CID4515",
]
major_c_types = [
    "Cancer Epithelial",
    "Normal Epithelial",
    "T-cells",
    "B-cells",
    "Myeloid",
    "Endothelial",
    "CAFs",
    "PVL",
    "Plasmablasts",
]

## Extract colour pallete

In [None]:
# Load lineage mapping of all cell types in the data
colour_pallete_df = pd.read_csv(
    Path(prefix).joinpath("data/Whole_miniatlas_colour_pallete.csv"), sep="\t"
)

# Convert to dictionary
colour_pallete_d = {
    row["all_celltype"]: row["line"] for i, row in colour_pallete_df.iterrows()
}
colour_pallete_d["SMOTE"] = "#999B9E"

## Plot UMAP for all patients in one plot

In [None]:
umap_l = []

for p in tqdm(p_ids):
    metadata_df = pd.read_csv(
        Path(prefix).joinpath(f"data/patients/{p}/smoted_metadata.csv"),
        sep="\t",
        index_col=0,
    )
    umap_df = pd.read_csv(
        Path(prefix).joinpath(f"data/patients/{p}/seurat/UMAPC.csv"),
        sep="\t",
        index_col=0,
    )

    # Seurat turns SMOTE_4608 into SMOTE.4608
    # Revert this so we can merge the DataFrames
    umap_df.index = [
        i.replace(".", "_") if i.startswith(f"{p}_SMOTE") else i for i in umap_df.index
    ]

    # Also replace "-" by "_" in metadata index
    metadata_df.index = [i.replace("-", "_") for i in metadata_df.index]

    # Merge 2 DataFrames to get cell types with UMAP coordinates
    umap_metadata_df = pd.merge(umap_df, metadata_df, left_index=True, right_index=True)

    umap_metadata_df["synthesized_str"] = umap_metadata_df["synthesized"].replace(
        {False: "original", True: "synthesized"}
    )

    # Add patient id
    umap_metadata_df["Patient"] = p

    # Switch cell labels of all SMOTE cells to "SMOTE"
    umap_metadata_df.loc[
        umap_metadata_df["synthesized"] == True, "cell_labels"
    ] = "SMOTE"

    # Add to umap_l
    umap_l.append(umap_metadata_df)

# Concatenate across all patient
all_patients_umap_metadata_df = pd.concat(umap_l, axis=0)

In [None]:
# Save source data
all_patients_umap_metadata_df.reset_index().rename(
    columns={
        "index": "Cell ID",
        "UMAPC_1": "UMAP_1",
        "UMAPC_2": "UMAP_2",
        "cell_labels": "Cell type",
        "synthesized_str": "Synthesized/Original",
    },
).drop(["synthesized"], axis=1).to_csv(
    Path(viz_prefix).joinpath("data/supp_figure_2.tsv"), sep="\t", index=False
)

In [None]:
fig = px.scatter(
    all_patients_umap_metadata_df.rename(
        columns={"UMAPC_1": "UMAP_1", "UMAPC_2": "UMAP_2"}
    ),
    x="UMAP_1",
    y="UMAP_2",
    color="cell_labels",
    labels="synthesized",
    color_discrete_map=colour_pallete_d,
    facet_col="Patient",
    facet_col_wrap=5,
    facet_row_spacing=0.025,
    facet_col_spacing=0.025,
    category_orders={
        "cell_labels": [
            "SMOTE",
            "Cancer Epithelial",
            "Normal Epithelial",
            "T-cells",
            "B-cells",
            "Myeloid",
            "Endothelial",
            "CAFs",
            "PVL",
            "Plasmablasts",
        ]
    },
)

# Update each trace separately
for c_type in major_c_types:
    fig.update_traces(selector=dict(name=c_type), marker=dict(size=1), opacity=0.8)
fig.update_traces(
    selector=dict(name="SMOTE"), marker=dict(size=1, color="#E4E4E4"), opacity=1
)

# Update axes
fig.update_xaxes(
    title_font_size=7,
    title_standoff=3,
    linecolor="black",
    linewidth=0.5,
    ticks="outside",
    tickfont_size=6,
    ticklen=2,
    range=[-21, 21],
    dtick=10,
    matches=None,
)
fig.update_yaxes(
    title_font_size=7,
    title_standoff=3,
    linecolor="black",
    linewidth=0.5,
    ticks="outside",
    tickfont_size=6,
    ticklen=2,
    range=[-21, 21],
    dtick=10,
    matches=None,
)

# Update layout
fig["layout"].update(
    font=dict(size=7, color="black"),
    plot_bgcolor="rgba(0,0,0,0)",
    margin=dict(t=10, l=0, r=0, b=0),
    legend=dict(
        title="Cell Types",
        orientation="h",
        yanchor="bottom",
        y=-1,
        xanchor="center",
        x=0.25,
    ),
    showlegend=False,
)

# Nicely formate subplot annotations
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

# Plots 2-5 in the second last row need x-axis title
for col in range(2, 6):
    fig.update_xaxes(
        title="UMAP_1",
        title_font_size=7,
        title_standoff=2,
        showticklabels=True,
        ticks="outside",
        tickfont_size=6,
        row=2,
        col=col,
    )

# Save into vsg
fig.write_image(
    Path(viz_prefix).joinpath("figures/supp_figures/supp_fig_2").with_suffix(".svg"),
    height=650,
    width=500,
    scale=5,
)