In [1]:
import logging
import math
import os
import sys
import time
from typing import Optional, Tuple, Dict, List
from tqdm import tqdm

# import argparse
import click
import gtfparse
import pandas as pd

# import bisect

# Global Variables
version = 1.5

NEWLINE = "\n"
TAB = "\t"

In [2]:
version = 1.5
CLASS_COLUMN_USED = [0, 1, 2, 3, 5, 6, 7, 28, 30, 31]
CLASS_COLUMN_NAME = [
    "isoform",
    "chrom",
    "strand",
    "length",
    "structural_category",
    "associated_gene",
    "associated_transcript",
    "ORF_length",
    "CDS_start",
    "CDS_end",
]

In [3]:
annotation_file = "/home/milo/workspace/Homo_sapiens_GRCh38_Ensembl_86.gff3"

In [4]:
def createGTFFromSqanti(
    exons_gtf: str, transcript_classification: str, junctions: str
) -> pd.DataFrame:
    """ Create a real GTF from the output of sqanti3_qc.

    Params:
    -------
    exons_gtf : `str`
    transcript_classification : `str`
    junctions : `str`
    
    Returns:
    --------
    :class:`~pandas.DataFrame`

    """
    logger = logging.getLogger("IsoAnnotLite_SQ1")
    source = "tappAS"
    aux = "."

    logger.debug(f"reading classification file {transcript_classification}")
    classification_df = pd.read_csv(transcript_classification, delimiter="\t")

    CLASS_COLUMN_NAMES = [
        "isoform",
        "chrom",
        "strand",
        "length",
        "structural_category",
        "associated_gene",
        "associated_transcript",
        "ORF_length",
        "CDS_start",
        "CDS_end",
    ]

    missing_names = [
        _ for _ in CLASS_COLUMN_NAMES if _ not in classification_df.columns
    ]

    if missing_names:
        logger.info(
            f"File classification does not have the necessary fields. "
            f"The columns {','.join(missing_names)} were not found in the "
            f"in the classification file."
        )
        sys.exit()

    # so, weird trick - it is *really* slow to append to a list or dataframe
    # however, you can add on to a dictionary really quickly.
    # also, you can easily convert a dictionary to a dataframe.
    # so,
    res = dict()
    i = 0

    # TODO: vectorize this
    # add transcript, gene and CDS
    for row in tqdm(classification_df.itertuples(), total=len(classification_df)):
        # trans
        transcript = row.isoform  # fields[0]
        # source
        feature = "transcript"
        start = "1"
        end = row.length  # fields[3]
        # aux
        strand = row.strand  # fields[2]

        desc = f"ID={row.associated_transcript}; primary_class={row.structural_category}{NEWLINE}"  # desc = "ID="+fields[7]+"; primary_class="+fields[5]+"\n"
        res[i] = {
            "seqname": transcript,
            "source": source,
            "feature": feature,
            "start": str(int(start)),
            "end": str(int(end)),
            "score": aux,
            "strand": strand,
            "frame": aux,
            "attribute": desc,
        }

        # gene
        transcript = row.isoform
        # source
        feature = "gene"
        start = "1"
        end = row.length
        # aux
        strand = row.strand
        desc = f"ID={row.associated_gene};Name={row.associated_gene};Desc={row.associated_gene}{NEWLINE}"

        i += 1
        res[i] = {
            "seqname": transcript,
            "source": source,
            "feature": feature,
            "start": str(int(start)),
            "end": str(int(end)),
            "score": aux,
            "strand": strand,
            "frame": aux,
            "attribute": desc,
        }

        # CDS
        transcript = row.isoform
        # source
        feature = "CDS"
        start = row.CDS_start  # 30
        end = row.CDS_end  # 31
        # aux
        strand = row.strand
        desc = f"ID=Protein_{transcript};Name=Protein_{transcript};Desc=Protein_{transcript}{NEWLINE}"
        if start != "NA" and not pd.isnull(start):
            prot_length = int(math.ceil((int(end) - int(start) - 1) / 3))
            i += 1
            res[i] = {
                "seqname": transcript,
                "source": source,
                "feature": feature,
                "start": str(int(start)),
                "end": str(int(end)),
                "score": aux,
                "strand": strand,
                "frame": aux,
                "attribute": desc,
            }
            i += 1
            res[i] = {
                "seqname": transcript,
                "source": source,
                "feature": "protein",
                "start": "1",
                "end": str(prot_length),
                "score": aux,
                "strand": strand,
                "frame": aux,
                "attribute": desc,
            }

            # res.write("\t".join([transcript,source, feature, str(int(start)), str(int(end)), aux, strand, aux, desc]))
            # res.write("\t".join([transcript,source,"protein","1",str(prot_length),aux,strand,aux,desc]))
        # else:
        # res.write("\t".join([transcript, source, feature, ".", ".", aux, strand, aux, desc]))

        # genomic
        desc = f"Chr={row.chrom}{NEWLINE}"

        # Coding Dictionary
        CDSstart = row.CDS_start  # 30
        CDSend = row.CDS_end  # 31
        orf = row.ORF_length  # 28

        i += 1
        res[i] = {
            "seqname": transcript,
            "source": source,
            "feature": "genomic",
            "start": "1",
            "end": "1",
            "score": aux,
            "strand": strand,
            "frame": aux,
            "attribute": desc,
        }

        # Write TranscriptAttributes
        sourceAux = "TranscriptAttributes"
        lengthTranscript = row.length
        if not CDSstart == "NA" and not pd.isnull(row.CDS_start):
            # 3'UTR
            feature = "3UTR_Length"
            start = int(CDSend) + 1
            end = lengthTranscript
            desc = "ID=3UTR_Length;Name=3UTR_Length;Desc=3UTR_Length\n"
            i += 1
            res[i] = {
                "seqname": transcript,
                "source": sourceAux,
                "feature": feature,
                "start": str(int(start)),
                "end": str(int(end)),
                "score": aux,
                "strand": strand,
                "frame": aux,
                "attribute": desc,
            }

            # 5'UTR
            feature = "5UTR_Length"
            start = 1
            end = int(row.CDS_start) - 1 + 1  # 30
            desc = "ID=5UTR_Length;Name=5UTR_Length;Desc=5UTR_Length\n"
            i += 1
            res[i] = {
                "seqname": transcript,
                "source": sourceAux,
                "feature": feature,
                "start": str(int(start)),
                "end": str(int(end)),
                "score": aux,
                "strand": strand,
                "frame": aux,
                "attribute": desc,
            }

            # CDS
            feature = "CDS"
            start = CDSstart
            end = CDSend
            desc = "ID=CDS;Name=CDS;Desc=CDS\n"
            i += 1
            res[i] = {
                "seqname": transcript,
                "source": sourceAux,
                "feature": feature,
                "start": str(int(start)),
                "end": str(int(end)),
                "score": aux,
                "strand": strand,
                "frame": aux,
                "attribute": desc,
            }

            # polyA
            feature = "polyA_Site"
            start = lengthTranscript
            end = lengthTranscript
            desc = "ID=polyA_Site;Name=polyA_Site;Desc=polyA_Site\n"
            i += 1
            res[i] = {
                "seqname": transcript,
                "source": sourceAux,
                "feature": feature,
                "start": start,
                "end": end,
                "score": aux,
                "strand": strand,
                "frame": aux,
                "attribute": desc,
            }

    # add exons
    logger.debug(f"reading exon file {exons_gtf}")
    exons_df = gtfparse.parse_gtf(exons_gtf)

    for row in tqdm(exons_df.itertuples(), total=len(exons_df)):
        transcript = row.transcript_id
        # source
        feature = row.feature
        if feature == "transcript":  # just want exons
            continue

        start = row.start
        end = row.end
        # aux
        strand = row.strand
        # desc = fields[8]
        desc = f"Chr={str(row.seqname)}{NEWLINE}"

        # Exons Dictionary
        i += 1
        res[i] = {
            "seqname": transcript,
            "source": source,
            "feature": feature,
            "start": str(int(start)),
            "end": str(int(end)),
            "score": aux,
            "strand": strand,
            "frame": aux,
            "attribute": desc,
        }

    # add junctions
    logger.debug(f"reading junctions file {junctions}")
    junct_df = pd.read_csv(junctions, delimiter="\t")
    # header
    for row in tqdm(junct_df.itertuples(), total=len(junct_df)):
        transcript = row.isoform
        # source
        feature = "splice_junction"
        start = row.genomic_start_coord
        end = row.genomic_end_coord
        # aux
        strand = row.strand
        desc = f"ID={row.junction_number}_{row.canonical};Chr={row.chrom}{NEWLINE}"
        i += 1
        res[i] = {
            "seqname": transcript,
            "source": source,
            "feature": feature,
            "start": str(int(start)),
            "end": str(int(end)),
            "score": aux,
            "strand": strand,
            "frame": aux,
            "attribute": desc,
        }

    logger.debug(f"length of dictionary: {len(res)}")
    logger.debug("converting dictionary to dataframe")
    results_df = pd.DataFrame.from_dict(
        res,
        orient="index",
        columns=[
            "seqname",
            "source",
            "feature",
            "start",
            "end",
            "score",
            "strand",
            "frame",
            "attribute",
        ],
    )
    results_df["attribute"] = results_df["attribute"].apply(lambda x: x.rstrip("\n"))
    logger.debug(f"results_df shape: {results_df.shape}")
    return results_df


In [5]:
def createGTFFromSqanti(file_exons, file_trans, file_junct, filename):
    res = open(filename, "w+")
    source = "tappAS"
    feature = ""
    start = ""
    end = ""
    aux = "."
    strand = ""
    desc = ""

    dc_coding = {}
    dc_gene = {}
    dc_SQstrand = {}
    f = open(file_trans)

    # check header
    global CLASS_COLUMN_USED
    global CLASS_COLUMN_NAME

    header = next(f)
    fields = header.split("\t")
    index = 0
    for column in CLASS_COLUMN_NAME:  # check all the columns we used
        if (
            column not in fields[CLASS_COLUMN_USED[index]]
        ):  # if now in the correct possition...
            logging.info(
                f"File classification does not have the correct structure. "
                f" The column '{column}' is not in the possition "
                f"{str(CLASS_COLUMN_USED[index])}"
                f" in the classification file. We have found the column '"
                f"{str(fields[CLASS_COLUMN_USED[index]])}'."
            )
            sys.exit()
        else:
            index = index + 1

    # add transcript, gene and CDS
    for line in f:
        fields = line.split("\t")

        # trans
        transcript = fields[0]
        # source
        feature = "transcript"
        start = "1"
        end = fields[3]
        # aux
        strand = fields[2]

        dc_SQstrand.update({str(transcript): strand})  # saving strand

        desc = "ID=" + fields[7] + "; primary_class=" + fields[5] + "\n"
        res.write(
            "\t".join([transcript, source, feature, start, end, aux, strand, aux, desc])
        )
        # gene
        transcript = fields[0]
        # source
        feature = "gene"
        start = "1"
        end = fields[3]
        # aux
        strand = fields[2]
        desc = "ID=" + fields[6] + "; Name=" + fields[6] + "; Desc=" + fields[6] + "\n"
        res.write(
            "\t".join([transcript, source, feature, start, end, aux, strand, aux, desc])
        )
        # CDS
        transcript = fields[0]
        # source
        feature = "CDS"
        start = fields[30]  # 30
        end = fields[31]  # 31
        # aux
        strand = fields[2]
        desc = (
            "ID=Protein_"
            + transcript
            + "; Name=Protein_"
            + transcript
            + "; Desc=Protein_"
            + transcript
            + "\n"
        )
        if start != "NA":
            res.write(
                "\t".join(
                    [transcript, source, feature, start, end, aux, strand, aux, desc]
                )
            )
            res.write(
                "\t".join(
                    [
                        transcript,
                        source,
                        "protein",
                        "1",
                        str(int(math.ceil((int(end) - int(start) - 1) / 3))),
                        aux,
                        strand,
                        aux,
                        desc,
                    ]
                )
            )
        else:
            res.write(
                "\t".join(
                    [transcript, source, feature, ".", ".", aux, strand, aux, desc]
                )
            )
        # genomic
        desc = "Chr=" + fields[1] + "\n"

        # Gene
        gene = fields[6]
        category = fields[5]
        transAssociated = fields[7]

        if transAssociated.startswith("ENS"):
            transAssociated = transAssociated.split(
                "."
            )  # ENSMUS213123.1 -> #ENSMUS213123
            transAssociated = transAssociated[0]

        if not dc_gene.get(transcript):
            dc_gene.update({str(transcript): [gene, category, transAssociated]})
        else:
            dc_gene.update(
                {
                    str(transcript): dc_gene.get(transcript)
                    + [gene, category, transAssociated]
                }
            )

        # Coding Dictionary
        CDSstart = fields[30]  # 30
        CDSend = fields[31]  # 31
        orf = fields[28]  # 28

        if not dc_coding.get(transcript):
            dc_coding.update({str(transcript): [CDSstart, CDSend, orf]})
        else:
            dc_coding.update(
                {str(transcript): dc_coding.get(transcript) + [CDSstart, CDSend, orf]}
            )

        res.write(
            "\t".join([transcript, source, "genomic", "1", "1", aux, strand, aux, desc])
        )

        # Write TranscriptAttributes
        sourceAux = "TranscriptAttributes"
        lengthTranscript = fields[3]
        if not CDSstart == "NA":
            # 3'UTR
            feature = "3UTR_Length"
            start = int(CDSend) + 1
            end = lengthTranscript
            desc = "ID=3UTR_Length; Name=3UTR_Length; Desc=3UTR_Length\n"
            res.write(
                "\t".join(
                    [
                        transcript,
                        sourceAux,
                        feature,
                        str(start),
                        str(end),
                        aux,
                        strand,
                        aux,
                        desc,
                    ]
                )
            )
            # 5'UTR
            feature = "5UTR_Length"
            start = 1
            end = int(fields[30]) - 1 + 1  # 30
            desc = "ID=5UTR_Length; Name=5UTR_Length; Desc=5UTR_Length\n"
            res.write(
                "\t".join(
                    [
                        transcript,
                        sourceAux,
                        feature,
                        str(start),
                        str(end),
                        aux,
                        strand,
                        aux,
                        desc,
                    ]
                )
            )
            # CDS
            feature = "CDS"
            start = CDSstart
            end = CDSend
            desc = "ID=CDS; Name=CDS; Desc=CDS\n"
            res.write(
                "\t".join(
                    [
                        transcript,
                        sourceAux,
                        feature,
                        str(start),
                        str(end),
                        aux,
                        strand,
                        aux,
                        desc,
                    ]
                )
            )
            # polyA
            feature = "polyA_Site"
            start = lengthTranscript
            end = lengthTranscript
            desc = "ID=polyA_Site; Name=polyA_Site; Desc=polyA_Site\n"
            res.write(
                "\t".join(
                    [
                        transcript,
                        sourceAux,
                        feature,
                        str(start),
                        str(end),
                        aux,
                        strand,
                        aux,
                        desc,
                    ]
                )
            )

    f.close()

    f = open(file_exons)
    dc_exons = {}
    # add exons
    for line in f:
        fields = line.split("\t")
        if len(fields) == 9:
            transcript = fields[8].split('"')[1].strip()
            # source
            feature = fields[2]
            if feature == "transcript":  # just want exons
                continue

            start = int(fields[3])
            end = int(fields[4])
            # aux
            strand = fields[6]
            # desc = fields[8]
            desc = "Chr=" + str(fields[0]) + "\n"

            # Exons Dictionary
            if not dc_exons.get(transcript):
                dc_exons.update({str(transcript): [[start, end]]})
            else:
                dc_exons.update(
                    {str(transcript): dc_exons.get(transcript) + [[start, end]]}
                )

            res.write(
                "\t".join(
                    [
                        transcript,
                        source,
                        feature,
                        str(start),
                        str(end),
                        aux,
                        strand,
                        aux,
                        desc,
                    ]
                )
            )
        else:
            logging.error(
                "File corrected doesn't have the correct number of columns (9)."
            )
    f.close()

    # add junctions
    f = open(file_junct)
    # header
    header = next(f)
    for line in f:
        fields = line.split("\t")
        # Junctions file can have a dvierse number of columns, not only 19 but 0-14 are allways the same
        transcript = fields[0]
        # source
        feature = "splice_junction"
        start = fields[4]
        end = fields[5]
        # aux
        strand = fields[2]
        desc = "ID=" + fields[3] + "_" + fields[14] + "; Chr=" + fields[1] + "\n"

        res.write(
            "\t".join([transcript, source, feature, start, end, aux, strand, aux, desc])
        )
    f.close()
    res.close()

    return dc_exons, dc_coding, dc_gene, dc_SQstrand


In [6]:
def readGFFandGetData(filenameMod):
    # open annotation file and process all data
    dcTrans = {}
    dcExon = {}
    dcTransFeatures = {}
    dcGenomic = {}
    dcSpliceJunctions = {}
    dcProt = {}
    dcProtFeatures = {}
    dcTranscriptAttributes = {}

    # dcTransID = {}

    with open(filenameMod, "r") as f:
        # process all entries - no header line in file
        for line in f:
            if len(line) == 0:
                break
            else:
                if line and line[0] != "#":
                    fields = line.split("\t")
                    if len(fields) == 9:

                        transcript = fields[0]
                        text = fields[8].split(" ")
                        # transcriptID = text[0]
                        # transcriptID = transcriptID[3:-1]

                        if fields[1] == "tappAS":
                            if fields[2] in ["transcript", "gene", "CDS"]:
                                if not dcTrans.get(str(transcript)):
                                    dcTrans.update({str(transcript): [line]})
                                else:
                                    dcTrans.update(
                                        {
                                            str(transcript): dcTrans.get(
                                                str(transcript)
                                            )
                                            + [line]
                                        }
                                    )
                                # extra dcTransID
                                # if not dcTransID.get(str(transcriptID)):
                                #    dcTransID.update({str(transcriptID) : [line]})
                                # else:
                                #    dcTransID.update({str(transcriptID) : dcTransID.get(str(transcriptID)) + [line]})
                            elif fields[2] in ["exon"]:
                                if not dcExon.get(str(transcript)):
                                    dcExon.update({str(transcript): [line]})
                                else:
                                    dcExon.update(
                                        {
                                            str(transcript): dcExon.get(str(transcript))
                                            + [line]
                                        }
                                    )
                            elif fields[2] in ["genomic"]:
                                if not dcGenomic.get(str(transcript)):
                                    dcGenomic.update({str(transcript): [line]})
                                else:
                                    dcGenomic.update(
                                        {
                                            str(transcript): dcGenomic.get(
                                                str(transcript)
                                            )
                                            + [line]
                                        }
                                    )
                            elif fields[2] in ["splice_junction"]:
                                if not dcSpliceJunctions.get(str(transcript)):
                                    dcSpliceJunctions.update({str(transcript): [line]})
                                else:
                                    dcSpliceJunctions.update(
                                        {
                                            str(transcript): dcSpliceJunctions.get(
                                                str(transcript)
                                            )
                                            + [line]
                                        }
                                    )
                            elif fields[2] in ["protein"]:
                                if not dcProt.get(str(transcript)):
                                    dcProt.update({str(transcript): [line]})
                                else:
                                    dcProt.update(
                                        {
                                            str(transcript): dcProt.get(str(transcript))
                                            + [line]
                                        }
                                    )
                        # Transcript Information
                        elif fields[1] == "TranscriptAttributes":
                            if not dcTranscriptAttributes.get(str(transcript)):
                                dcTranscriptAttributes.update({str(transcript): [line]})
                            else:
                                dcTranscriptAttributes.update(
                                    {
                                        str(transcript): dcTranscriptAttributes.get(
                                            str(transcript)
                                        )
                                        + [line]
                                    }
                                )
                        # Feature information
                        else:
                            if text[-1].endswith("T\n"):
                                if not dcTransFeatures.get(str(transcript)):
                                    dcTransFeatures.update({str(transcript): [line]})
                                else:
                                    dcTransFeatures.update(
                                        {
                                            str(transcript): dcTransFeatures.get(
                                                str(transcript)
                                            )
                                            + [line]
                                        }
                                    )
                            elif (
                                text[-1].endswith("P\n")
                                or text[-1].endswith("G\n")
                                or text[-1].endswith("N\n")
                            ):
                                if not dcProtFeatures.get(str(transcript)):
                                    dcProtFeatures.update({str(transcript): [line]})
                                else:
                                    dcProtFeatures.update(
                                        {
                                            str(transcript): dcProtFeatures.get(
                                                str(transcript)
                                            )
                                            + [line]
                                        }
                                    )

    return (
        dcTrans,
        dcExon,
        dcTransFeatures,
        dcGenomic,
        dcSpliceJunctions,
        dcProt,
        dcProtFeatures,
        dcTranscriptAttributes,
    )

In [7]:
annot_tpl = readGFFandGetData(annotation_file)

In [12]:
len(annot_df[0])

188667

In [13]:
len(annot_df[1])

188667

In [8]:
len(annot_df[-1])

0

In [9]:
len(annot_df[2])

124947

In [10]:
len(annot_df[6])

79641

In [12]:
annot_df[6]['ENST00000618881']

['ENST00000618881\tGeneOntology\tP\t.\t.\t.\t.\t.\tID=GO:0006810; Name=transport; PosType=N\n',
 'ENST00000618881\tGeneOntology\tC\t.\t.\t.\t.\t.\tID=GO:0005622; Name=intracellular; PosType=N\n',
 'ENST00000618881\tGeneOntology\tC\t.\t.\t.\t.\t.\tID=GO:0005737; Name=cytoplasm; PosType=N\n',
 'ENST00000618881\tGeneOntology\tP\t.\t.\t.\t.\t.\tID=GO:0006406; Name=mRNA export from nucleus; PosType=N\n',
 'ENST00000618881\tGeneOntology\tC\t.\t.\t.\t.\t.\tID=GO:0005634; Name=nucleus; PosType=N\n',
 'ENST00000618881\tGeneOntology\tF\t.\t.\t.\t.\t.\tID=GO:0003723; Name=RNA binding; PosType=N\n',
 'ENST00000618881\tGeneOntology\tF\t.\t.\t.\t.\t.\tID=GO:0005515; Name=protein binding; PosType=N\n',
 'ENST00000618881\tGeneOntology\tF\t.\t.\t.\t.\t.\tID=GO:0003676; Name=nucleic acid binding; PosType=N\n',
 'ENST00000618881\tReactome\tpathway\t.\t.\t.\t.\t.\tID=R-HSA-159236; Name=Transport of Mature mRNA derived from an Intron-Containing Transcript; PosType=N\n',
 'ENST00000618881\tPFAM\tDOMAIN\t304