In [10]:
from pycisTopic.clust_vis import (
    find_clusters,
    run_umap,
    run_tsne,
    plot_metadata,
    plot_topic,
)
import os
import glob
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import Image, display
import math

%load_ext lab_black

The lab_black extension is already loaded. To reload it, use:
  %reload_ext lab_black


In [11]:
samples = [
    x.split("/")[-1].split("__")[0] + "." + x.split("/")[-1].split(".")[-3]
    for x in sorted(glob.glob("cistopic_objects_subsampled/*singlets*model*topics.pkl"))
]
if len(samples) > len(set(samples)):
    print("samples are not unique!")
else:
    print("samples are unique.")

samples are unique.


In [12]:
prediction_path_dict = {
    x.split("/")[-1].split(f"__")[0]: x
    for x in sorted(glob.glob("cell_type_classification/*__cell_type_seurat.txt"))
}
prediction_path_dict

{'BIO_ddseq_1.FIXEDCELLS.05k': 'cell_type_classification/BIO_ddseq_1.FIXEDCELLS.05k__cell_type_seurat.txt',
 'BIO_ddseq_1.FIXEDCELLS.15k': 'cell_type_classification/BIO_ddseq_1.FIXEDCELLS.15k__cell_type_seurat.txt',
 'BIO_ddseq_1.FIXEDCELLS.1k': 'cell_type_classification/BIO_ddseq_1.FIXEDCELLS.1k__cell_type_seurat.txt',
 'BIO_ddseq_1.FIXEDCELLS.25k': 'cell_type_classification/BIO_ddseq_1.FIXEDCELLS.25k__cell_type_seurat.txt',
 'BIO_ddseq_1.FIXEDCELLS.2k': 'cell_type_classification/BIO_ddseq_1.FIXEDCELLS.2k__cell_type_seurat.txt',
 'BIO_ddseq_1.FIXEDCELLS.3k': 'cell_type_classification/BIO_ddseq_1.FIXEDCELLS.3k__cell_type_seurat.txt',
 'BIO_ddseq_2.FIXEDCELLS.05k': 'cell_type_classification/BIO_ddseq_2.FIXEDCELLS.05k__cell_type_seurat.txt',
 'BIO_ddseq_2.FIXEDCELLS.15k': 'cell_type_classification/BIO_ddseq_2.FIXEDCELLS.15k__cell_type_seurat.txt',
 'BIO_ddseq_2.FIXEDCELLS.1k': 'cell_type_classification/BIO_ddseq_2.FIXEDCELLS.1k__cell_type_seurat.txt',
 'BIO_ddseq_2.FIXEDCELLS.25k': 'cell

In [13]:
cto_model_path_dict = {
    x.split("/")[-1].split("__")[0] + "." + x.split("/")[-1].split(".")[-3]: x
    for x in sorted(glob.glob("cistopic_objects_subsampled/*singlets*model*topics.pkl"))
}
cto_model_path_dict

{'BIO_ddseq_1.FIXEDCELLS.05k': 'cistopic_objects_subsampled/BIO_ddseq_1.FIXEDCELLS__cto.scrublet0-4.fmx.singlets.05k.model_10topics.pkl',
 'BIO_ddseq_1.FIXEDCELLS.15k': 'cistopic_objects_subsampled/BIO_ddseq_1.FIXEDCELLS__cto.scrublet0-4.fmx.singlets.15k.model_11topics.pkl',
 'BIO_ddseq_1.FIXEDCELLS.1k': 'cistopic_objects_subsampled/BIO_ddseq_1.FIXEDCELLS__cto.scrublet0-4.fmx.singlets.1k.model_10topics.pkl',
 'BIO_ddseq_1.FIXEDCELLS.25k': 'cistopic_objects_subsampled/BIO_ddseq_1.FIXEDCELLS__cto.scrublet0-4.fmx.singlets.25k.model_9topics.pkl',
 'BIO_ddseq_1.FIXEDCELLS.2k': 'cistopic_objects_subsampled/BIO_ddseq_1.FIXEDCELLS__cto.scrublet0-4.fmx.singlets.2k.model_12topics.pkl',
 'BIO_ddseq_1.FIXEDCELLS.3k': 'cistopic_objects_subsampled/BIO_ddseq_1.FIXEDCELLS__cto.scrublet0-4.fmx.singlets.3k.model_11topics.pkl',
 'BIO_ddseq_2.FIXEDCELLS.05k': 'cistopic_objects_subsampled/BIO_ddseq_2.FIXEDCELLS__cto.scrublet0-4.fmx.singlets.05k.model_11topics.pkl',
 'BIO_ddseq_2.FIXEDCELLS.15k': 'cistopic_

In [None]:
leiden_res = [3.0]
to_plot_leiden = ["pycisTopic_leiden_10_" + str(x) for x in leiden_res]

to_plot_vars = [
    "Unique_nr_frag",
    "TSS_enrichment",
    "Dupl_rate",
    "FRIP",
    "Doublet_scores_fragments",
    "fmx_sample",
    "seurat_cell_type",
]

for sample in cto_model_path_dict.keys():
    # for sample in ['VIB_hydrop_1.FIXEDCELLS']:
    to_plot = to_plot_vars + to_plot_leiden

    print(sample)
    cto_path = cto_model_path_dict[sample]
    cto_path_new = cto_path.replace(".pkl", ".dimreduc.pkl")

    if not sample in prediction_path_dict.keys():
        print(f"\tPredictions do not exist for {sample}! Skipping")

    else:
        predictions_path = prediction_path_dict[sample]

        if not os.path.exists(cto_path_new):
            with open(cto_path, "rb") as f:
                cto = pickle.load(f)

            ct_pred = pd.read_csv(predictions_path, sep="\t")

            # format to add to cistopic object:
            ct_annot = (
                ct_pred[["composite_sample_id", "cell_type", "cell_type_pred_score"]]
                .copy()
                .set_index("composite_sample_id")
            )
            ct_annot.columns = ["seurat_cell_type", "seurat_cell_type_pred_score"]
            ct_annot.index = [x.replace("-", "___") for x in ct_annot.index]
            ct_annot.index = [".".join(x.split(".")[:-1]) for x in ct_annot.index]

            index_old = cto.cell_data.index

            if sample.split(".")[0] == "VIB_hydrop_1":
                #     ct_annot.index = [x.split('___')[0] + '___' + x.split('___')[2].split('.')[0] + x.split('___')[1] + '.' + x.split('___')[2].split('.')[-1] for x in ct_annot.index]
                #    cto.cell_data.index = [x.replace('-1','').replace('-2','') for x in cto.cell_data.index]
                ct_annot.index = [
                    x.replace(
                        "___2___VIB_hydrop_1.FIXEDCELLS", "___VIB_hydrop_12.FIXEDCELLS"
                    ).replace(
                        "___1___VIB_hydrop_1.FIXEDCELLS", "___VIB_hydrop_11.FIXEDCELLS"
                    )
                    for x in ct_annot.index
                ]
            elif sample.split(".")[0] == "VIB_hydrop_2":
                ct_annot.index = [
                    x.replace(
                        "___2___VIB_hydrop_2.FIXEDCELLS", "___VIB_hydrop_22.FIXEDCELLS"
                    ).replace(
                        "___1___VIB_hydrop_2.FIXEDCELLS", "___VIB_hydrop_21.FIXEDCELLS"
                    )
                    for x in ct_annot.index
                ]

            # cto.add_cell_data(ct_annot)
            # cto.cell_data.index = index_old

            cto.cell_data["seurat_cell_type"] = ct_annot["seurat_cell_type"]
            cto.cell_data["seurat_cell_type_pred_score"] = ct_annot[
                "seurat_cell_type_pred_score"
            ]
            cto.cell_names = cto.cell_data.index

            cto.projections["cell"] = {}
            find_clusters(
                cto, target="cell", k=10, res=leiden_res, prefix="pycisTopic_"
            )
            run_umap(cto, target="cell")
            run_tsne(cto, target="cell")

            if not "fmx_sample" in cto.cell_data.columns:
                to_plot.remove("fmx_sample")

            n_to_plot = len(to_plot)
            n_cols = 4
            n_rows = math.ceil(n_to_plot / n_cols)

            plot_metadata(
                cto,
                reduction_name="UMAP",
                variables=to_plot,
                target="cell",
                num_columns=4,
                text_size=16,
                dot_size=15,
                figsize=(n_cols * 4, n_rows * 4),
                save=f"plots_qc/{sample}__umap_summary.png",
            )

            plot_metadata(
                cto,
                reduction_name="tSNE",
                variables=to_plot,
                target="cell",
                num_columns=4,
                text_size=16,
                dot_size=15,
                figsize=(n_cols * 4, n_rows * 4),
                save=f"plots_qc/{sample}__tsne_summary.png",
            )

            with open(cto_path_new, "wb") as f:
                pickle.dump(cto, f, protocol=4)

        else:
            print(f"\t{cto_path_new} exists, skipping")

            umap_path = f"plots_qc/{sample}__umap_summary.png"
            tsne_path = f"plots_qc/{sample}__tsne_summary.png"

            if os.path.exists(umap_path):
                display(Image(umap_path))

            else:
                with open(cto_path_new, "rb") as f:
                    cto = pickle.load(f)

                if not "fmx_sample" in cto.cell_data.columns:
                    to_plot.remove("fmx_sample")

                n_to_plot = len(to_plot)
                n_cols = 4
                n_rows = math.ceil(n_to_plot / n_cols)

                plot_metadata(
                    cto,
                    reduction_name="UMAP",
                    variables=to_plot,
                    target="cell",
                    num_columns=4,
                    text_size=16,
                    dot_size=15,
                    figsize=(n_cols * 4, n_rows * 4),
                    save=f"plots_qc/{sample}__umap_summary.png",
                )

            if os.path.exists(tsne_path):
                display(Image(tsne_path))

            else:
                with open(cto_path_new, "rb") as f:
                    cto = pickle.load(f)

                n_to_plot = len(to_plot)
                n_cols = 4
                n_rows = math.ceil(n_to_plot / n_cols)

                plot_metadata(
                    cto,
                    reduction_name="tSNE",
                    variables=to_plot,
                    target="cell",
                    num_columns=4,
                    text_size=16,
                    dot_size=15,
                    figsize=(n_cols * 4, n_rows * 4),
                    save=f"plots_qc/{sample}__tsne_summary.png",
                )