In [4]:
import gc
import gzip
import logging
import math
import re
import sys

import matplotlib.backends.backend_pdf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyranges as pr
from PIL import Image
from scipy import sparse

import collections as cl
import gc
import logging
import sys
from typing import Dict, List, Optional, Tuple, Union

import matplotlib.backends.backend_pdf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyranges as pr
import ray
import seaborn as sns
from scipy.stats import gaussian_kde, norm

In [5]:
def read_fragments_from_file(
    fragments_bed_filename, use_polars: bool = True
) -> pr.PyRanges:
    """
    Read fragments BED file to PyRanges object.
    Parameters
    ----------
    fragments_bed_filename: Fragments BED filename.
    use_polars: Use polars instead of pandas for reading the fragments BED file.
    Returns
    -------
    PyRanges object of fragments.
    """

    bed_column_names = (
        "Chromosome",
        "Start",
        "End",
        "Name",
        "Score",
        "Strand",
        "ThickStart",
        "ThickEnd",
        "ItemRGB",
        "BlockCount",
        "BlockSizes",
        "BlockStarts",
    )

    # Set the correct open function depending if the fragments BED file is gzip compressed or not.
    open_fn = gzip.open if fragments_bed_filename.endswith(".gz") else open

    skip_rows = 0
    nbr_columns = 0
    with open_fn(fragments_bed_filename, "rt") as fragments_bed_fh:
        for line in fragments_bed_fh:
            # Remove newlines and spaces.
            line = line.strip()

            if not line or line.startswith("#"):
                # Count number of empty lines and lines which start with a comment before the actual data.
                skip_rows += 1
            else:
                # Get number of columns from the first real BED entry.
                nbr_columns = len(line.split("\t"))

                # Stop reading the BED file.
                break

    if nbr_columns < 4:
        raise ValueError(
            f'Fragments BED file needs to have at least 4 columns. "{fragments_bed_filename}" contains only '
            f"{nbr_columns} columns."
        )

    if use_polars:
        import polars as pl

        # Read fragments BED file with polars.
        df = (
            pl.read_csv(
                fragments_bed_filename,
                has_headers=False,
                skip_rows=skip_rows,
                sep="\t",
                use_pyarrow=True,
                new_columns=bed_column_names[:nbr_columns],
            )
            .with_columns(
                [
                    pl.col("Chromosome").cast(pl.Utf8),
                    pl.col("Start").cast(pl.Int32),
                    pl.col("End").cast(pl.Int32),
                    pl.col("Name").cast(pl.Utf8),
                ]
            )
            .to_pandas()
        )

        # Convert "Name" column to pd.Categorical as groupby operations will be done on it later.
        df["Name"] = df["Name"].astype("category")
    else:
        # Read fragments BED file with pandas.
        df = pd.read_table(
            fragments_bed_filename,
            sep="\t",
            skiprows=skip_rows,
            header=None,
            names=bed_column_names[:nbr_columns],
            doublequote=False,
            engine="c",
            dtype={
                "Chromosome": str,
                "Start'": np.int32,
                "End": np.int32,
                "Name": "category",
                "Strand": str,
            },
        )

    # Convert pandas dataframe to PyRanges dataframe.
    # This will convert "Chromosome" and "Strand" columns to pd.Categorical.
    return pr.PyRanges(df)

def profile_tss(
    fragments: Union[str, pd.DataFrame],
    annotation: Union[pd.DataFrame, pr.PyRanges],
    valid_bc: Optional[List[str]] = None,
    plot: Optional[bool] = True,
    plot_data: Optional[pd.DataFrame] = None,
    n_cpu: Optional[int] = 1,
    partition: Optional[int] = 5,
    flank_window: Optional[int] = 1000,
    tss_window: Optional[int] = 50,
    minimum_signal_window: Optional[int] = 100,
    rolling_window: Optional[int] = 10,
    min_norm: Optional[int] = 0.2,
    color: Optional[str] = None,
    save: Optional[str] = None,
    return_TSS_enrichment_per_barcode: Optional[bool] = False,
    return_TSS_coverage_matrix_per_barcode: Optional[bool] = False,
    return_plot_data: Optional[bool] = False,
    use_polars: Optional[bool] = True
):
    """
    Plot the Transcription Start Site (TSS) profile. It is computed as the summed accessibility signal (sample-level), or the number of cut sites per base (barcode-level), in a space around the full set of annotated TSSs and is normalized by the minimum signal in the window. This profile is helpful to assess the signal-to-noise ratio of the library, as it is well known that TSSs and the promoter regions around them have, on average, a high degree of chromatin accessibility compared to the intergenic and intronic regions of the genome.
    Parameters
    ---------
    fragments: str or pd.DataFrame
            The path to the fragments file containing chromosome, start, end and assigned barcode for each read (e.g. from CellRanger ATAC (/outs/fragments.tsv.gz)) or a data frame
            containing 'Chromosome', 'Start', 'End', 'Name', and 'Score', which indicates the number of times that a fragments is found assigned to that barcode. The fragments data
            frame can be obtained using PyRanges:
                    import pyranges as pr
                    fragments = pr.read_bed(fragments_file, as_df=True))
    annotation: pd.DataFrame or pyRanges
            A data frame or pyRanges containing transcription start sites for each gene, with 'Chromosome', 'Start' and 'Strand' as columns (additional columns will be ignored). This data frame can be easily obtained via pybiomart:
                    # Get TSS annotations
                    import pybiomart as pbm
                    # For mouse
                    dataset = pbm.Dataset(name='mmusculus_gene_ensembl',  host='http://www.ensembl.org')
                    # For human
                    dataset = pbm.Dataset(name='hsapiens_gene_ensembl',  host='http://www.ensembl.org')
                    # For fly
                    dataset = pbm.Dataset(name='dmelanogaster_gene_ensembl',  host='http://www.ensembl.org')
                    # Query TSS list and format
                    annot = dataset.query(attributes=['chromosome_name', 'transcription_start_site', 'strand', 'external_gene_name', 'transcript_biotype'])
                    filter = annot['Chromosome/scaffold name'].str.contains('CHR|GL|JH|MT')
                    annot = annot[~filter]
                    annot['Chromosome/scaffold name'] = annot['Chromosome/scaffold name'].str.replace(r'(\b\\S)', r'chr\1')
                    annot.columns=['Chromosome', 'Start', 'Strand', 'Gene', 'Transcript_type']
                    # Select TSSs of protein coding genes
                    annot = annot[annot.Transcript_type == 'protein_coding']
    valid_bc: list, optional
            A list containing selected barcodes. Default: None.
    plot: bool, optional
            Whether to return the plot to the console. Default: True.
    plot_data: pd.DataFrame, optional
            Data frame containing precomputed plot data. Default: None.
    flank_window: int, optional
            Flanking window around the TSS. Default: 1000 (+/- 1000 bp).
    tss_window: int, optional
            Window around the TSS used to count fragments in the TSS when calculating the TSS enrichment per barcode. Default: 50 (+/- 50 bp).
    minimum_signal_window: int, optional
            Tail window use to normalize the TSS enrichment. Default: 100 (average signal in the 100bp in the extremes of the TSS window).
    rolling_window: int, optional
            Rolling window used to smooth signal. Default: 10.
    min_norm: int, optional
            Minimum normalization score. If the average minimum signal value is below this value, this number is used to normalize the TSS signal. This approach penalizes cells with fewer reads.
    color: str, optional
            Line color for the plot. Default: None.
    save: str, optional
            Output file to save plot. Default: None.
    remove_duplicates: bool, optional
            Whether to remove duplicates. Default: True.
    return_TSS_enrichment_per_barcode: bool, optional
            Whether to return a data frame containing the normalized enrichment score on the TSS for each barcode. Default: False.
    return_TSS_coverage_matrix_per_barcode: bool, optional
            Whether to return a matrix containing the normalized enrichment in each position in the window for each barcode, with positions as columns and barcodes as rows. Default: False.
    return_plot_data: bool, optional
            Whether to return the TSS profile plot data. Default: False.
    Return
    ------
    dict
            A dictionary containing a :class:`pd.DataFrame` with the normalized enrichment score on the TSS for each barcode, a :class:`pd.DataFrame` with the normalized enrichment scores in each position for each barcode and/or a :class:`pd.DataFrame` with the TSS profile plot data.
    """
    # Create logger
    level = logging.INFO
    log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
    handlers = [logging.StreamHandler(stream=sys.stdout)]
    logging.basicConfig(level=level, format=log_format, handlers=handlers)
    log = logging.getLogger("cisTopic")

    if isinstance(plot_data, pd.DataFrame):
        log.info("Using plot_data. TSS enrichment per barcode will not be computed")
        fig, ax = plt.subplots()
        ax.plot(plot_data.Position, plot_data.TSSEnrichment)
        plt.xlim(-space_TSS, space_TSS)
        plt.xlabel("Position from TSS", fontsize=10)
        plt.ylabel("Normalized enrichment", fontsize=10)
    else:
        if isinstance(fragments, str):
            log.info("Reading fragments file")
            fragments = read_fragments_from_file(fragments, use_polars=use_polars)
        else:
            if isinstance(fragments, pd.DataFrame):
                fragments = pr.PyRanges(fragments)

        if valid_bc is not None:
            log.info("Using provided valid barcodes")
            fragments = fragments[fragments.Name.isin(set(valid_bc))]
        else:
            valid_bc = list(set(fragments.Name.tolist()))

        log.info("Formatting annnotation")
        if isinstance(annotation, pr.PyRanges):
            annotation = annotation.df
        tss_space_annotation = annotation[["Chromosome", "Start", "Strand"]]
        tss_space_annotation["End"] = tss_space_annotation["Start"] + flank_window
        tss_space_annotation["Start"] = tss_space_annotation["Start"] - flank_window
        tss_space_annotation = tss_space_annotation[
            ["Chromosome", "Start", "End", "Strand"]
        ]
        tss_space_annotation = pr.PyRanges(tss_space_annotation)

        log.info("Creating coverage matrix")
        if partition > 1:
            barcode_list = np.array_split(valid_bc, partition)
            TSS_matrix = pd.concat(
                [
                    get_tss_matrix(
                        fragments[fragments.Name.isin(set(barcode_list[x]))],
                        flank_window,
                        tss_space_annotation,
                    ).fillna(0)
                    for x in range(partition)
                ]
            )
        else:
            TSS_matrix = get_tss_matrix(fragments, flank_window, tss_space_annotation)
        log.info("Coverage matrix done")
        if not TSS_matrix.columns.tolist() == list(range(2 * flank_window + 1)):
            missing_values = list(
                set(TSS_matrix.columns.tolist()).symmetric_difference(
                    list(range(2 * flank_window + 1))
                )
            )
            for x in missing_values:
                TSS_matrix[x] = 0

            TSS_matrix = TSS_matrix.reindex(sorted(TSS_matrix.columns), axis=1)

        if rolling_window is not None:
            TSS_matrix = TSS_matrix.rolling(
                window=rolling_window, min_periods=0, axis=1
            ).mean()

        TSS_counts = TSS_matrix.values.sum(axis=0)
        div = max(
            (
                np.mean(TSS_counts[-minimum_signal_window:])
                + np.mean(TSS_counts[0:minimum_signal_window])
            )
            / 2,
            min_norm,
        )
        if plot is True or save is not None:
            fig, ax = plt.subplots()
            ax.plot(
                range(-flank_window - 1, flank_window), TSS_counts / div, color=color
            )
            plt.xlim(-flank_window, flank_window)
            plt.xlabel("Position from TSS", fontsize=10)
            plt.ylabel("Normalized enrichment", fontsize=10)
            if save is not None:
                fig.savefig(save)
            if plot:
                log.info("Plotting normalized sample TSS enrichment")
                plt.show()
            else:
                plt.close(fig)

    output = {}
    flag = False
    if return_TSS_enrichment_per_barcode:
        TSS_enrich = TSS_matrix.apply(
            lambda x: x
            / max(
                [
                    (
                        (
                            np.mean(x[-minimum_signal_window:])
                            + np.mean(x[0:minimum_signal_window])
                        )
                        / 2
                    ),
                    min_norm,
                ]
            ),
            axis=1,
        )
        TSS_enrich = pd.DataFrame(
            TSS_enrich.iloc[
                :, range(flank_window - tss_window, flank_window + tss_window)
            ].mean(axis=1)
        )
        TSS_enrich.columns = ["TSS_enrichment"]
        output.update({"TSS_enrichment": TSS_enrich})
        flag = True
    if return_TSS_coverage_matrix_per_barcode:
        log.info("Returning normalized TSS coverage matrix per barcode")
        TSS_mat = TSS_matrix.apply(
            lambda x: x
            / max(
                [
                    (
                        (
                            np.mean(x[-minimum_signal_window:])
                            + np.mean(x[0:minimum_signal_window])
                        )
                        / 2
                    ),
                    min_norm,
                ]
            ),
            axis=1,
        )
        output.update({"TSS_coverage_mat": TSS_mat})
        flag = True
    if return_plot_data:
        log.info("Returning normalized sample TSS enrichment data")
        output.update(
            {
                "TSS_plot_data": pd.DataFrame(
                    {
                        "Position": range(-flank_window - 1, flank_window),
                        "TSS_enrichment": TSS_counts / div,
                    }
                )
            }
        )
        flag = True
    del TSS_matrix
    if flag:
        return output

In [27]:
import os
genome = "hg38"

pbm_genome_name_dict = {
    "hg38": "hsapiens_gene_ensembl",
    "hg37": "hsapiens_gene_ensembl",
    "mm10": "mmusculus_gene_ensembl",
    "dm6": "dmelanogaster_gene_ensembl",
}

pbm_host_dict = {
    "hg38": "http://www.ensembl.org",
    "hg37": "http://grch37.ensembl.org/",
    "mm10": "http://nov2020.archive.ensembl.org/",
    "dm6": "http://www.ensembl.org",
}

if os.path.exists(f"annotation.tsv"):
    print(f"Loading cached genome annotation...")
    annotation = pd.read_csv("annotation.tsv", sep="\t", header=0, index_col=0)
else:
    dataset = pbm.Dataset(name=pbm_genome_name_dict[genome], host=pbm_host_dict[genome])

    annotation = dataset.query(
        attributes=[
            "chromosome_name",
            "transcription_start_site",
            "strand",
            "external_gene_name",
            "transcript_biotype",
        ]
    )
    filter = annotation["Chromosome/scaffold name"].str.contains("CHR|GL|JH|MT")
    annotation = annotation[~filter]
    annotation["Chromosome/scaffold name"] = annotation[
        "Chromosome/scaffold name"
    ].str.replace(r"(\b\S)", r"chr\1")
    annotation.columns = ["Chromosome", "Start", "Strand", "Gene", "Transcript_type"]
    annotation = annotation[annotation.Transcript_type == "protein_coding"]
    annotation.to_csv("annotation.tsv", sep="\t")

Loading cached genome annotation...


In [31]:
from pycisTopic.cistopic_class import *

In [8]:
import glob
fragments_list = sorted(
    glob.glob(f"*k/*_preprocessing_out/data/fragments/*fragments.tsv.gz")
)
fragments_dict = {}
for fragments_file in fragments_list:
    sample = fragments_file.split("/")[-1].split(".fragments.tsv.gz")[0]
    fragments_dict[sample] = fragments_file

{'BIO_ddseq_1.10k': '10k/10k_preprocessing_out/data/fragments/BIO_ddseq_1.10k.fragments.tsv.gz',
 'BIO_ddseq_2.10k': '10k/10k_preprocessing_out/data/fragments/BIO_ddseq_2.10k.fragments.tsv.gz',
 'BIO_ddseq_3.10k': '10k/10k_preprocessing_out/data/fragments/BIO_ddseq_3.10k.fragments.tsv.gz',
 'BIO_ddseq_4.10k': '10k/10k_preprocessing_out/data/fragments/BIO_ddseq_4.10k.fragments.tsv.gz',
 'BRO_mtscatac_1.10k': '10k/10k_preprocessing_out/data/fragments/BRO_mtscatac_1.10k.fragments.tsv.gz',
 'BRO_mtscatac_2.10k': '10k/10k_preprocessing_out/data/fragments/BRO_mtscatac_2.10k.fragments.tsv.gz',
 'CNA_10xmultiome_1.10k': '10k/10k_preprocessing_out/data/fragments/CNA_10xmultiome_1.10k.fragments.tsv.gz',
 'CNA_10xmultiome_2.10k': '10k/10k_preprocessing_out/data/fragments/CNA_10xmultiome_2.10k.fragments.tsv.gz',
 'CNA_10xv11_1.10k': '10k/10k_preprocessing_out/data/fragments/CNA_10xv11_1.10k.fragments.tsv.gz',
 'CNA_10xv11_2.10k': '10k/10k_preprocessing_out/data/fragments/CNA_10xv11_2.10k.fragments

In [20]:
fragments_df = read_fragments_from_file(fragments_dict['BIO_ddseq_1.10k'], use_polars=True).df
fragments_df["Score"] = 1
fragments_df = pr.PyRanges(fragments_df)

# d_fragments_from_file(fragments_dict['BIO_ddseq_1.10k'])

  pl.read_csv(
