In [1]:
# main imports
import os
import sys
import csv
import glob
import math
import shutil
import random
import importlib
import subprocess

import numpy as np
import pandas as pd

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 50)

In [2]:
# adjust CDS coordinates to include the stop codon

In [3]:
# Gffread solution:
# --adj-stop flag: gffread -g ~/genomicData/hg38/hg38_p12_ucsc.fa --adj-stop -T -o tmp.adj.gtf tmp.gtf
# for seleno-=proteins it will truncate them to the selenocysteine since it appears as a stop-codon
# what if the stop codon is past the 3' end of the transcript? will gffread extend it nonetheless?
#    Answer: partially. Gffread seems to have a bug when extending the coordinates.
#    The CDS extends to the correct coordinate, but the exon-chain remains intact, leading to the CDS extending past the exon end

In [4]:
# custom solution
# implement a python script to do this
# extract protein-fasta file from the gtf/gff with gffread
# load the transcripts from fasta indicating on whether the last base in stop codon or not
# iterate over the gtf/gff looking up the stop codon status in the proteins dictionary
# if missing stop - extend by 3 bases making sure to respect the exon coordinates
#     can use the cut/intersect/etc methods from definitions
# at the very end - rerun gffread to re-assign phase information

# Additional considerations
# for GTF need to process transcript,exon,cds at once, making sure if exon is extended so is the transcript
# for GFF need to also consider gene records which might be tricky
#    - alternatively we re-generate GFF after applying fixes to gtf
# What if the next 3 bases are not a stop codon? What should we do then?
#    After the initial run over the file applying fixes, we can extract proteins again and check. For anything that still has no stop codon - we should revert to the original version of the transcript instead and flag it as unchanged

In [2]:
%load_ext autoreload
%autoreload 1

sys.path.insert(0, "/ccb/salz4-4/avaraby/orfanage/soft")
%aimport definitions

In [3]:
# data
ref_fasta = "/home/avaraby1/genomes/human/chm13/chm13v2.0.fa"
input_gtf = "/ccb/salz8-1/avaraby/chess_maintenance_scripts/tmp/chess3.0.1.CHM13.fix_dup_gid.gtf"
outbase = "../tmp/chess3.0.1.CHM13.fix_dup_gid.adj_stop"

is_gff = definitions.gtf_or_gff(input_gtf)=="gff"

In [4]:
# run gffread to extract proteins first
cmd = ["gffread","-g",ref_fasta,"-x",outbase+".cds_nt.fa",input_gtf]
subprocess.call(cmd)

0

In [5]:
# load into a dictionary
stop_status = definitions.load_fasta_dict(outbase+".cds_nt.fa")
# iterate and replace sequence with True/False based on whether last character is "*"
for tid,seq in stop_status.items():
    if seq[-3:] in ["TAG","TAA","TGA"]:
        stop_status[tid]=True
    else:
        stop_status[tid]=False

In [6]:
# load exon and cds chains

# load gtf to get the attributes (keep only the transcript lines)
df = definitions.load_gtf(input_gtf)
df = df[df["type"]=="transcript"].reset_index(drop=True)
df["tid"] = df["attributes"].str.split("transcript_id \"",expand=True)[1].str.split("\"",expand=True)[0]
df["gid"] = df["attributes"].str.split("gene_id \"",expand=True)[1].str.split("\"",expand=True)[0]
df["gene_name"] = df["attributes"].str.split("gene_id \"",expand=True)[1].str.split("\"",expand=True)[0]
# load and merge exon chains
exons = definitions.get_chains(input_gtf,feature_type="exon",coords=False,phase=False)
df = df.merge(exons[["tid","chain"]],on="tid",how="outer",indicator="merge")
df.rename({"chain":"exon_chain"},axis=1,inplace=True)
assert len(df[~(df["merge"]=="both")])==0,"invalid merge of exons"
df.drop("merge",axis=1,inplace=True)
# load and merge cds chains
cds = definitions.get_chains(input_gtf,feature_type="CDS",coords=True,phase=False)
df = df.merge(cds[["tid","chain","has_cds"]],on="tid",how="outer",indicator="merge")
df.rename({"chain":"cds_chain"},axis=1,inplace=True)
assert len(df[~(df["merge"]=="both")])==0,"invalid merge of cds"
df.drop("merge",axis=1,inplace=True)

# add the stop_status info
df['status'] = df['tid'].map(stop_status)

df.head()

Unnamed: 0,seqid,source,type,start,end,score,strand,phase,attributes,tid,gid,gene_name,exon_chain,cds_chain,has_cds,status
0,chr19,Liftoff,transcript,4460,6999,.,+,.,"transcript_id ""CHS.1.1""; gene_id ""CHS.1""; gene...",CHS.1.1,CHS.1,CHS.1,"((4460, 4813), (5199, 5307), (5807, 6999))",(),0,
1,chr1,Liftoff,transcript,864517,868933,.,-,.,"transcript_id ""CHS.100.1""; gene_id ""CHS.100""; ...",CHS.100.1,CHS.100,CHS.100,"((864517, 865022), (868548, 868933))",(),0,
2,chr1,Liftoff,transcript,864517,868933,.,-,.,"transcript_id ""CHS.100.2""; gene_id ""CHS.100""; ...",CHS.100.2,CHS.100,CHS.100,"((864517, 865022), (868538, 868933))",(),0,
3,chr1,Liftoff,transcript,864517,868933,.,-,.,"transcript_id ""CHS.100.3""; gene_id ""CHS.100""; ...",CHS.100.3,CHS.100,CHS.100,"((864517, 865314), (868538, 868933))",(),0,
4,chr1,CHESS,transcript,864517,869839,.,-,.,"transcript_id ""CHS.100.8""; gene_id ""CHS.100""; ...",CHS.100.8,CHS.100,CHS.100,"((864517, 865314), (868538, 868944), (869739, ...",(),0,


In [7]:
def get_n(chain, n, forward=True):
    res = []
    count = 0
    
    if forward:
        for start, end in chain:
            ilen = end - start + 1
            if count + ilen < n:
                res.append((start, end))
                count += ilen
            else:
                res.append((start, start + n - count - 1))
                break
    else:
        for start, end in reversed(chain):
            ilen = end - start + 1
            if count + ilen < n:
                res.insert(0, (start, end))
                count += ilen
            else:
                res.insert(0, (end - (n - count - 1), end))
                break

    return res

def extend_intervals(c1, c2, num_coords=3, forward=True):
    if forward:
        # first cut the first list to keep only the coordinates that are in the first list on the appropriate end
        c1_cut = definitions.cut(c1,c2[-1][1],c1[-1][1])
        # skip the first position because it includes the last position in first list
        if c1_cut[0][0]==c1_cut[0][1]:
            c1_cut.pop(0)
        else:
            tmp = list(c1_cut[0])
            tmp[0]+=1
            c1_cut[0] = tuple(tmp)

        # extract first N positions from the cut chain
        extracted_positions = get_n(c1_cut,num_coords)

        # append these to the second chain
        c2.extend(extracted_positions)
        c2 = definitions.merge(c2,True)
        c2 = [tuple(x) for x in c2]
        
        if definitions.clen(extracted_positions)==num_coords:
            pass
        elif definitions.clen(extracted_positions)<num_coords: # not long enough? Add missing pieces
            for i in range(num_coords-definitions.clen(extracted_positions)):
                tmp = list(c2[-1])
                tmp[1]+=1
                c2[-1] = tuple(tmp)

                tmp = list(c1[-1])
                tmp[1]+=1
                c1[-1] = tuple(tmp)
        else:
            raise Exception(c1,c2,num_coords)

        return c1,c2

    else:
        # first cut the first list to keep only the coordinates that are in the first list on the appropriate end
        c1_cut = definitions.cut(c1,c1[0][0],c2[0][0])
        # skip the last position because it includes the last position in first list
        if c1_cut[-1][0]==c1_cut[-1][1]:
            c1_cut.pop()
        else:
            tmp = list(c1_cut[-1])
            tmp[1]-=1
            c1_cut[-1] = tuple(tmp)

        # extract last N positions from the cut chain
        extracted_positions = get_n(c1_cut,num_coords,False)

        # append these to the second chain
        c2 = extracted_positions+c2
        c2 = definitions.merge(c2,True)
        c2 = [tuple(x) for x in c2]
        
        if definitions.clen(extracted_positions)==num_coords:
            pass
        elif definitions.clen(extracted_positions)<num_coords: # not long enough? Add missing pieces
            for i in range(num_coords-definitions.clen(extracted_positions)):
                tmp = list(c2[0])
                tmp[0]-=1
                c2[0] = tuple(tmp)

                tmp = list(c1[0])
                tmp[0]-=1
                c1[0] = tuple(tmp)
        else:
            raise Exception(c1,c2,num_coords)

        return c1,c2

In [8]:
# process line by line - remember strands
def fix_coordinates(row):
    if not row["has_cds"]:
        return [row["exon_chain"], row["cds_chain"]]
    if not definitions.clen(row["cds_chain"])%3==0: # invalid CDS
        return [row["exon_chain"], row["cds_chain"]]
    if not definitions.cut(row["exon_chain"],row["cds_chain"][0][0],row["cds_chain"][-1][1])==list(row["cds_chain"]): # if the exon cahin can not be cut to the cds chain
        return [row["exon_chain"], row["cds_chain"]]
    if row["status"]==True: # already has stop
        return [row["exon_chain"], row["cds_chain"]]
    if row["exon_chain"][0][0]<4: # too close to start
        return [row["exon_chain"], row["cds_chain"]]
        
    # otherwise process the data
    c1,c2 = extend_intervals(list(row["exon_chain"]),list(row["cds_chain"]),3,row["strand"]=="+")
    return [c1,c2]

In [9]:
# now we need to iterate over the dataframe and write out the transcripts for the initial validation

def to_gtf(row,outFP):
    exons_adjusted = [tuple(x) for x in row["exon_chain"]] != [tuple(x) for x in row["new_exon_chain"]]
    cds_adjusted = [tuple(x) for x in row["cds_chain"]] != [tuple(x) for x in row["new_cds_chain"]]
    tlcs     = [row["seqid"],
                row["source"],
                row["type"],
                str(row["new_exon_chain"][0][0]),
                str(row["new_exon_chain"][-1][1]),
                row["score"],
                row["strand"],
                row["phase"],
                row["attributes"].rstrip().rstrip(";").rstrip()+"; exon_adjustment_status \""+str(int(exons_adjusted))+"\"; cds_adjustment_status \""+str(int(cds_adjusted))+"\";"]
    outFP.write("\t".join(tlcs)+"\n")

    
    exon_attrs = definitions.extract_attributes(row["attributes"])
    allowed_keys = {"transcript_id", "gene_id", "gene_name"}
    exon_attrs = {key: value for key, value in exon_attrs.items() if key in allowed_keys}
    exon_attrs = definitions.to_attribute_string(exon_attrs)

    for exon in row["new_exon_chain"]:
        elcs = [row["seqid"],
                row["source"],
                "exon",
                str(exon[0]),
                str(exon[1]),
                row["score"],
                row["strand"],
                row["phase"],
                exon_attrs]
        outFP.write("\t".join(elcs)+"\n")

    for cds in row["new_cds_chain"]:
        clcs = [row["seqid"],
                row["source"],
                "CDS",
                str(cds[0]),
                str(cds[1]),
                row["score"],
                row["strand"],
                row["phase"],
                exon_attrs]
        outFP.write("\t".join(clcs)+"\n")

In [10]:
df[["new_exon_chain", "new_cds_chain"]] = df.apply(lambda row: pd.Series(fix_coordinates(row)), axis=1)
df.head()

Unnamed: 0,seqid,source,type,start,end,score,strand,phase,attributes,tid,gid,gene_name,exon_chain,cds_chain,has_cds,status,new_exon_chain,new_cds_chain
0,chr19,Liftoff,transcript,4460,6999,.,+,.,"transcript_id ""CHS.1.1""; gene_id ""CHS.1""; gene...",CHS.1.1,CHS.1,CHS.1,"((4460, 4813), (5199, 5307), (5807, 6999))",(),0,,"((4460, 4813), (5199, 5307), (5807, 6999))",()
1,chr1,Liftoff,transcript,864517,868933,.,-,.,"transcript_id ""CHS.100.1""; gene_id ""CHS.100""; ...",CHS.100.1,CHS.100,CHS.100,"((864517, 865022), (868548, 868933))",(),0,,"((864517, 865022), (868548, 868933))",()
2,chr1,Liftoff,transcript,864517,868933,.,-,.,"transcript_id ""CHS.100.2""; gene_id ""CHS.100""; ...",CHS.100.2,CHS.100,CHS.100,"((864517, 865022), (868538, 868933))",(),0,,"((864517, 865022), (868538, 868933))",()
3,chr1,Liftoff,transcript,864517,868933,.,-,.,"transcript_id ""CHS.100.3""; gene_id ""CHS.100""; ...",CHS.100.3,CHS.100,CHS.100,"((864517, 865314), (868538, 868933))",(),0,,"((864517, 865314), (868538, 868933))",()
4,chr1,CHESS,transcript,864517,869839,.,-,.,"transcript_id ""CHS.100.8""; gene_id ""CHS.100""; ...",CHS.100.8,CHS.100,CHS.100,"((864517, 865314), (868538, 868944), (869739, ...",(),0,,"((864517, 865314), (868538, 868944), (869739, ...",()


In [11]:
outFP = open(outbase+".p1.gtf","w+")
df.apply(lambda row: pd.Series(to_gtf(row,outFP)), axis=1)
outFP.close()

In [12]:
# assign phase with gffread
cmd = ["gffread","-o",outbase+".p2.gtf","-T",outbase+".p1.gtf"]
subprocess.call(cmd)

# extract proteins again
cmd = ["gffread","-g",ref_fasta,"-x",outbase+".p2.cds_nt.fa",outbase+".p2.gtf"]
subprocess.call(cmd)

0

In [13]:
# load into a dictionary
stop_status = definitions.load_fasta_dict(outbase+".p2.cds_nt.fa")
# iterate and replace sequence with True/False based on whether last character is "*"
for tid,seq in stop_status.items():
    if seq[-3:] in ["TAG","TAA","TGA"]:
        stop_status[tid]=True
    else:
        stop_status[tid]=False

In [14]:
# now merge these updated results to the dataframe and repeat the steps this time only correcting those that need to be
df['status_p2'] = df['tid'].map(stop_status)


# process line by line - remember strands
def fix_coordinates_p2(row):
    if not row["has_cds"]:
        return [row["exon_chain"], row["cds_chain"]]
    if not definitions.clen(row["cds_chain"])%3==0: # invalid CDS
        return [row["exon_chain"], row["cds_chain"]]
    if not definitions.cut(row["exon_chain"],row["cds_chain"][0][0],row["cds_chain"][-1][1])==list(row["cds_chain"]): # if the exon cahin can not be cut to the cds chain
        return [row["exon_chain"], row["cds_chain"]]
    if row["status"]==True: # already has stop
        return [row["exon_chain"], row["cds_chain"]]
    if row["status_p2"]==False: # correction didn't work - just skip
        return [row["exon_chain"], row["cds_chain"]]
    if row["exon_chain"][0][0]<4: # too close to start
        return [row["exon_chain"], row["cds_chain"]]
        
    # otherwise process the data
    c1,c2 = extend_intervals(list(row["exon_chain"]),list(row["cds_chain"]),3,row["strand"]=="+")
    return [c1,c2]

df[["new_exon_chain", "new_cds_chain"]] = df.apply(lambda row: pd.Series(fix_coordinates_p2(row)), axis=1)

outFP = open(outbase+".corrected.gtf","w+")
df.apply(lambda row: pd.Series(to_gtf(row,outFP)), axis=1)
outFP.close()

In [15]:
cmd = ["gffread","-o",outbase+".corrected.gffread.gtf","-T","-F","--keep-exon-attrs",outbase+".corrected.gtf"]
subprocess.call(cmd)

0

In [16]:
os.path.abspath(outbase+".corrected.gffread.gtf")

'/ccb/salz8-1/avaraby/chess_maintenance_scripts/tmp/chess3.0.1.CHM13.fix_dup_gid.adj_stop.corrected.gffread.gtf'