# Quantifying poly(A) sites as junction counts relative to 5' TE splice junction


The throw it out there idea is that if this is reliable (i.e. just considering counts across poly(A) sites (accounting for minimum overhang)

In [47]:
import os
import pandas as pd
import pysam
os.getcwd()

paqr_rel_usages_path = "../../data/relative_usages.filtered.tsv"
rel_usages_header_path = "../../data/relative_usages.header.out"
m323k_wt_1_bam_path = "/home/sam/cluster/TDP43_RNA/TDP_F210I_M323K/M323K/New_adult_brain/processed/M323K_WT_1/M323K_WT_1_unique_rg_fixed.bam"
m323k_hom_1_bam_path = "/home/sam/cluster/TDP43_RNA/TDP_F210I_M323K/M323K/New_adult_brain/processed/M323K_HOM_1/M323K_HOM_1_unique_rg_fixed.bam"

rel_usages = pd.read_csv(paqr_rel_usages_path,sep="\t")

with open(rel_usages_header_path) as inpt:
    
    sample_names = [line.rstrip() for line in inpt]


#print(rel_usages)
#print(sample_names)

#First 10 columns are the same in relative usages df
colnames_rel_usages = ["chr",
 "cluster_start",
 "cluster_end",
 "site_id",
 "score",
 "strand",
 "n_along_exon",
 "total_sites_on_exon", 
 "paqr_exon_id", 
 "gene_id" # this is technically transcript_id but for now I'll leave it...
]

#Rest of columns are samples names in order found in relative usages output df
colnames_rel_usages.extend(sample_names)

print(colnames_rel_usages)


#Because samples are paired according to config - can appear multiple times in df...
sample_names = set(sample_names)
print(sample_names)

rel_usages.columns = colnames_rel_usages
print(rel_usages)

['chr', 'cluster_start', 'cluster_end', 'site_id', 'score', 'strand', 'n_along_exon', 'total_sites_on_exon', 'paqr_exon_id', 'gene_id', 'M323K_WT_1', 'M323K_HOM_1', 'M323K_WT_2', 'M323K_HOM_2', 'M323K_WT_3', 'M323K_HOM_3', 'M323K_WT_4', 'M323K_HOM_4', 'M323K_WT_1', 'M323K_HOM_5']
{'M323K_HOM_4', 'M323K_WT_1', 'M323K_HOM_5', 'M323K_HOM_1', 'M323K_HOM_3', 'M323K_WT_2', 'M323K_WT_4', 'M323K_HOM_2', 'M323K_WT_3'}
        chr  cluster_start  cluster_end               site_id  score strand  \
0     chr16       20544628     20544662   chr16:+:20544651:TE      4      +   
1      chrX      133931787    133931831   chrX:+:133931813:TE      7      +   
2      chrX      133932295    133932296   chrX:+:133932296:TE      1      +   
3     chr17       75390823     75390843   chr17:+:75390843:TE      1      +   
4     chr17       75391754     75391779   chr17:+:75391769:TE      8      +   
...     ...            ...          ...                   ...    ...    ...   
3498   chr1      171388592    1713

In [48]:

#grouped = rel_usages.groupby("gene_id")
#print(grouped)

#for a, b in grouped:
#    print(a)
#    print(b["site_id"].to_list())

    
    
def polyASite_id_to_coordinate_tuple(paqr_df, site_id_colname = "site_id", group_col = "gene_id"):
    '''
    Nested dict of {<group_col_key>: {site_id: (chr, start, end)}}}
    Assume coord is 1 based, output will be 0 based, 1/2 open
    '''
    
    df_grouped = paqr_df.groupby(group_col)
    
    out_dict = {}
    
    for group_name, group in df_grouped:
        
        site_ids = group[site_id_colname].to_list()
        nested_dict = {}
        
        for site in site_ids:
            #'chr11:+:55110898:TE'
            ID = site.split(':')
            seq_tuple = (ID[0], int(ID[2]), int(ID[2]) + 1)
            
            nested_dict[site] = seq_tuple 
            
        out_dict[group_name] = nested_dict
        
        
    return out_dict


polya_jnc_coords = polyASite_id_to_coordinate_tuple(rel_usages)
#print(polya_jnc_coords)


### For testing, let's just look at a few transcripts
small_polya_jnc_coords = {key: polya_jnc_coords[key] for key in list(polya_jnc_coords.keys())[:5]}
#print(small_polya_jnc_coords)

Now I've got my PAQR inferred/ PolyASite poly(A) sites , I want to try and count the number of reads that span these positions - 'junction reads'

As I will want to quantify these junctions relative to the splice junction at the 5'end of the terminal exon, I have to treat the poly(A) junction alignments as if they were a splice junction (where sequence is disjointly aligned to the genome

STAR (and likely other splice-aware aligners) has a --alignSJDBoverhangMin parameter, which requires that a putative junction read 'overhangs' the junction by at least x nt (x = 3 by default) in orfer for it to be assigned to the junction


Reads aligned to genome at poly(A) sites aren't subject to this parameter, so junction counts would be inflated relative to the 5' splice junction


using pysam, for each read crossing the poly(A) site, check whether start or end of read alignment falls within 3nt of the poly(A) site coordinate

abs(polya_site - alignment_start) > 3 and abs(alignment_end - polya_site) > 3 = valid poly(A) site junction read





In [49]:
#WT 1

wt_1 = pysam.AlignmentFile(m323k_wt_1_bam_path, "rb")


jnc_counts_dict = {}

for transcript, polya_dict in small_polya_jnc_coords.items():
    # store junction read counts for every poly(A) site in transcript
    tr_polya_dict = {}
    #print(polya_dict)
    
    for polya_site, coords_tuple in polya_dict.items():
        
        #Reads overlapping poly(A) site
        site_jnc_read_count = 0
        
        for read_entry in wt_1.fetch(coords_tuple[0], coords_tuple[1], coords_tuple[2]):
            
            #0-based leftmost reference coordinate of the aligned sequence
            align_start = read_entry.reference_start
            
            # Aligned reference position of the read on the reference genome.
            # Reference_end points to one past the last aligned residue. Returns None if not available (read is unmapped or no cigar alignment present).

            align_end = read_entry.reference_end
            
            #print(read_entry.query_name)
            #print("align_start: {0}, align_end: {1}".format(str(align_start), str(align_end)))
            
            #polyA site - align_start & align_end - polya_site
            if abs(coords_tuple[1] - align_start) > 3 and abs(align_end - coords_tuple[1]) > 3:
                site_jnc_read_count += 1
                
            else:
                pass
        
        #Now have checked every read overlapping poly(A) site
        tr_polya_dict[polya_site] = site_jnc_read_count
        
    
    #Now have counted junction reads for each poly(A) site in transcript
    jnc_counts_dict[transcript] = tr_polya_dict
    
    

#print(jnc_counts_dict)


def is_polya_junction_read(read_entry,polya_start, sj_overhang):
    '''
    
    '''
    
    #0-based leftmost reference coordinate of the aligned sequence
    align_start = read_entry.reference_start
           
    # Aligned reference position of the read on the reference genome.
    # Reference_end points to one past the last aligned residue. Returns None if not available (read is unmapped or no cigar alignment present).

    align_end = read_entry.reference_end
    
    if abs(polya_start - align_start) > sj_overhang and abs(align_end - polya_start) > sj_overhang:
        return True
    
    else:
        return False
    

    
#Try a more compartmentalised version of count_dict
jnc_counts_dict_2 = {}

for transcript, polya_dict in small_polya_jnc_coords.items():
    
    tr_polya_counts = {polya_site: sum(is_polya_junction_read(read, coords_tuple[1], 3) for read in wt_1.fetch(coords_tuple[0], 
                                                                            coords_tuple[1], 
                                                                            coords_tuple[2]))
                       for polya_site, coords_tuple in polya_dict.items()}

    jnc_counts_dict_2[transcript] = tr_polya_counts

    
#print(jnc_counts_dict_2)

wt_1.close()

def get_polya_junction_counts_dict(bam_path, jnc_coords_dict, sj_overhang):
    '''
    '''

    bam = pysam.AlignmentFile(bam_path, "rb")
    
    #final output dict of {tr: {site_id: jnc_counts}}
    counts_dict = {}
    
    for transcript, polya_coords_dict in jnc_coords_dict.items():
        
        polya_coords_counts = {polya_site: sum(is_polya_junction_read(read,
                                                                      coords_tuple[1],
                                                                      sj_overhang) 
                                               
                                               for read in bam.fetch(coords_tuple[0], 
                                                                     coords_tuple[1], 
                                                                     coords_tuple[2])
                                              )
                               for polya_site, coords_tuple in polya_coords_dict.items()}
        
        counts_dict[transcript] = polya_coords_counts
        
    
    bam.close()
    
    return counts_dict

wt_1_polya_jnc_counts = get_polya_junction_counts_dict(m323k_wt_1_bam_path, small_polya_jnc_coords, 3)
hom_1_polya_jnc_counts = get_polya_junction_counts_dict(m323k_hom_1_bam_path, small_polya_jnc_coords, 3)

#print(len(wt_1_polya_jnc_counts))
#print(len(hom_1_polya_jnc_counts))

print(wt_1_polya_jnc_counts)
print(hom_1_polya_jnc_counts)

{'ENSMUST00000000608': {'chr11:+:55110898:TE': 3, 'chr11:+:55113026:TE': 1}, 'ENSMUST00000000844': {'chr2:+:181515384:TE': 212, 'chr2:+:181515774:TE': 102, 'chr2:+:181516756:TE': 26}, 'ENSMUST00000000896': {'chr2:-:154587124:TE': 16, 'chr2:-:154585759:TE': 0}, 'ENSMUST00000000985': {'chr14:+:54368610:TE': 64, 'chr14:+:54369669:TE': 4}, 'ENSMUST00000001079': {'chr3:-:129983184:TE': 49, 'chr3:-:129982765:TE': 1}}
{'ENSMUST00000000608': {'chr11:+:55110898:TE': 3, 'chr11:+:55113026:TE': 1}, 'ENSMUST00000000844': {'chr2:+:181515384:TE': 212, 'chr2:+:181515774:TE': 102, 'chr2:+:181516756:TE': 26}, 'ENSMUST00000000896': {'chr2:-:154587124:TE': 16, 'chr2:-:154585759:TE': 0}, 'ENSMUST00000000985': {'chr14:+:54368610:TE': 64, 'chr14:+:54369669:TE': 4}, 'ENSMUST00000001079': {'chr3:-:129983184:TE': 49, 'chr3:-:129982765:TE': 1}}
5
5
{'ENSMUST00000000608': {'chr11:+:55110898:TE': 3, 'chr11:+:55113026:TE': 1}, 'ENSMUST00000000844': {'chr2:+:181515384:TE': 212, 'chr2:+:181515774:TE': 102, 'chr2:+:18

**Now I need counts for junction reads at 5' end of terminal exon**

The coordinates are provided in the paqr output table, - suggest pulling from the exon_id string

Coordinates are bed-like i.e. 0-based, 1/2 open - start is included, end is not
I want the splice junction coordinates to follow this
if + strand then start of exon = start coordinate
if - strand then start of exon = end coordinate (-1)



In [69]:
#Want dictionary of {transcript_id: (chr, start, end, )}

#print(small_polya_jnc_coords)
#print(rel_usages)

def paqr_out_to_sj_coord_tuple(paqr_df, 
                               group_col = "gene_id",
                               exon_colname = "paqr_exon_id",
                               chr_colname = "chr", 
                               strand_colname = "strand"):
    '''
    Get dict of {transcript_id: (chr, start, end)} where coords are for 5' splice junction of terminal exon
    '''
    
    out_dict = {}
    
    df_grouped = paqr_df.groupby(group_col)
    
    for group_name, group in df_grouped:
        
        
        # Every id should have same strand, chromosome and terminal exon string, so only need 1 row
        # List selection ensures return a dataframe
        group = group.iloc[[0]]
        
        #exon id like ENSMUST00000007216:12:12:20543310:20544909
        exon_split = group[exon_colname].to_string(index = False).split(':')
        
        if (group[strand_colname] == "+").bool():
            #5 coord = start of te (& start coord = actual start)
                        
            coord_tuple = tuple([group[chr_colname].to_string(index = False).lstrip(' '),
                                int(exon_split[-2]),
                                int(exon_split[-2]) + 1
                                ])
            
        elif (group[strand_colname] == "-").bool():
            #5 of exon = end coord in string (half-open, so actual start coord = end -1)
            
            coord_tuple = tuple([group[chr_colname].to_string(index = False).lstrip(' '),
                                int(exon_split[-1]) -1,
                                int(exon_split[-1])
                                ])
            
        out_dict[group_name] = coord_tuple
    
    return out_dict


splice_jnc_coords = paqr_out_to_sj_coord_tuple(rel_usages)
#print(splice_jnc_coords)


#Now lets make a small_splice_jnc_coords with same sjs as in small_polya_jnc_coords

small_splice_jnc_coords = {key: val for key, val in splice_jnc_coords.items() 
                           if key in small_polya_jnc_coords.keys()}

print(small_splice_jnc_coords)
print(small_polya_jnc_coords)

{'ENSMUST00000000608': ('chr11', 55109377, 55109378), 'ENSMUST00000000844': ('chr2', 181515079, 181515080), 'ENSMUST00000000896': ('chr2', 154588091, 154588092), 'ENSMUST00000000985': ('chr14', 54368317, 54368318), 'ENSMUST00000001079': ('chr3', 129984005, 129984006)}
{'ENSMUST00000000608': {'chr11:+:55110898:TE': ('chr11', 55110898, 55110899), 'chr11:+:55113026:TE': ('chr11', 55113026, 55113027)}, 'ENSMUST00000000844': {'chr2:+:181515384:TE': ('chr2', 181515384, 181515385), 'chr2:+:181515774:TE': ('chr2', 181515774, 181515775), 'chr2:+:181516756:TE': ('chr2', 181516756, 181516757)}, 'ENSMUST00000000896': {'chr2:-:154587124:TE': ('chr2', 154587124, 154587125), 'chr2:-:154585759:TE': ('chr2', 154585759, 154585760)}, 'ENSMUST00000000985': {'chr14:+:54368610:TE': ('chr14', 54368610, 54368611), 'chr14:+:54369669:TE': ('chr14', 54369669, 54369670)}, 'ENSMUST00000001079': {'chr3:-:129983184:TE': ('chr3', 129983184, 129983185), 'chr3:-:129982765:TE': ('chr3', 129982765, 129982766)}}


In [75]:
## Ok now need to search bam for reads overlapping splice jnc coord,
## count number of aligned reads crossing splice junction
## 

def get_splice_junction_counts_dict(bam_path, jnc_coords_dict):
    '''
    output dict of {tr: count}
    '''

    bam = pysam.AlignmentFile(bam_path, "rb")
    
    #final output dict of {tr: {site_id: jnc_counts}}
    counts_dict = {transcript: (sum(1 for read in bam.fetch(coords_tuple[0], 
                                              coords_tuple[1], 
                                              coords_tuple[2])))
                  for transcript, coords_tuple in jnc_coords_dict.items()}
    
    bam.close()
    
    return counts_dict

wt_1_splice_jnc_counts = get_splice_junction_counts_dict(m323k_wt_1_bam_path, small_splice_jnc_coords)
hom_1_splice_jnc_counts = get_splice_junction_counts_dict(m323k_hom_1_bam_path, small_splice_jnc_coords)

print(wt_1_splice_jnc_counts)
print(wt_1_polya_jnc_counts)

print(hom_1_splice_jnc_counts)
print(hom_1_polya_jnc_counts)


{'ENSMUST00000000608': 310, 'ENSMUST00000000844': 450, 'ENSMUST00000000896': 30, 'ENSMUST00000000985': 183, 'ENSMUST00000001079': 192}
{'ENSMUST00000000608': {'chr11:+:55110898:TE': 3, 'chr11:+:55113026:TE': 1}, 'ENSMUST00000000844': {'chr2:+:181515384:TE': 212, 'chr2:+:181515774:TE': 102, 'chr2:+:181516756:TE': 26}, 'ENSMUST00000000896': {'chr2:-:154587124:TE': 16, 'chr2:-:154585759:TE': 0}, 'ENSMUST00000000985': {'chr14:+:54368610:TE': 64, 'chr14:+:54369669:TE': 4}, 'ENSMUST00000001079': {'chr3:-:129983184:TE': 49, 'chr3:-:129982765:TE': 1}}
{'ENSMUST00000000608': 475, 'ENSMUST00000000844': 482, 'ENSMUST00000000896': 48, 'ENSMUST00000000985': 244, 'ENSMUST00000001079': 196}
{'ENSMUST00000000608': {'chr11:+:55110898:TE': 2, 'chr11:+:55113026:TE': 0}, 'ENSMUST00000000844': {'chr2:+:181515384:TE': 268, 'chr2:+:181515774:TE': 90, 'chr2:+:181516756:TE': 27}, 'ENSMUST00000000896': {'chr2:-:154587124:TE': 20, 'chr2:-:154585759:TE': 0}, 'ENSMUST00000000985': {'chr14:+:54368610:TE': 66, 'chr1

These counts are way way off...

I think I'm going to actually have to do some smart filtering to get junction read counts

i.e. for every read in bam.fetch(splice junction), check read is
primary alignment/uniquely aligned

is a split alignment i.e has 'N' in cigar string & this split alignment starts at the splice junction i.e. read skips a bit of reference sequence (intron) then starts at splice junction with contiguous/consective alignment to reference (i.e. exon)

`~~~` - intron

`___ `- exon

`|` - splice junction

`------` - read

I'm interested in reads spanning  the splice junction of 'exon 2' (i.e. the TE)

What i'm really after is the total number of reads supporting inclusion of/crossing this junction
I'm not particularly interested which upstream exon/ splice junction it is spliced to, just that it supports inclusion of the terminal exon junction

If I want to make sure it is a read connecting two exons (but without caring what the other junction is), I can check the cigar string such that the intron (N) is bounded by exact reference matches of at least (3) nt on either side e.g. a string like
**70M84N5M** or for readability **70M 84N 5M** 
would be a valid read because a reference gap of 84N (intron) is separated by a consecutive match to reference (exon) >= 3nt on either side

To be extra sure, I should really check the N & M segments align with reference coordinate/junction I'm interested in...

**to this end, I just want reads that cross the TE splice junction, and contain a 'right hand overhang'. --alignSJDBoverhangMin, of at least x nt (default 3nt) i.e. x nt of the terminal exon** 

Maybe also only want to count spliced alignments, but for now let's not worry about that??


5' exon 1                            exon 2            3'

`____________|~~~~~~~~~~~~~~~~~~|______________`

`xxxxxx-------NNNNNNNNNNNNNNNNNN-------xxxxxxxx` Read supporting splicing-in/inclusion of this TE

`xx-----------NNNNNNNNNNNNNNNNNN--xxxxxxxxxxxxx`

`xxxxxxxxxxxxxxxxxxxxxxxxxxxx----------------xx`


check if right hand overhang is > --alignSJDBoverhangMin (usually 3nt) i.e. aligned portion of read has at least 3nt of terminal exon sequence



In [109]:
test_tr_splice_jnc_tuple = small_splice_jnc_coords.get('ENSMUST00000000985')
#print(test_tr_splice_jnc_tuple)

wt_1 = pysam.AlignmentFile(m323k_wt_1_bam_path, "rb")

x = 0


#test_reads_dict = {}
#for read in wt_1.fetch(test_tr_splice_jnc_tuple[0],
#                    test_tr_splice_jnc_tuple[1],
#                    test_tr_splice_jnc_tuple[2]):
    
#    while x < 10:
#        if test_tr_splice_jnc_tuple in test_reads_dict.keys():
#            #add AlignedSegment object for read to end of list for given splice jnc
#            test_reads_dict[test_tr_splice_jnc_tuple].append(read.to_dict())
#        
#        else:
#            test_reads_dict[test_tr_splice_jnc_tuple] = [read.to_dict()]
#            
#        x += 1
#    else:
#        break
    
#print(x)
#print(test_reads_dict)

'''
cigartuples

    the cigar alignment. The alignment is returned as a list of tuples of (operation, length).

    If the alignment is not present, None is returned.

    The operations are:
    M 	BAM_CMATCH 	0
    I 	BAM_CINS 	1
    D 	BAM_CDEL 	2
    N 	BAM_CREF_SKIP 	3
    S 	BAM_CSOFT_CLIP 	4
    H 	BAM_CHARD_CLIP 	5
    P 	BAM_CPAD 	6
    = 	BAM_CEQUAL 	7
    X 	BAM_CDIFF 	8
    B 	BAM_CBACK 	9


I neeed to check for M followed by 

'''

#spliced_read_test_count_dict = {}


jnc_count = 0
for read in wt_1.fetch(test_tr_splice_jnc_tuple[0],
                    test_tr_splice_jnc_tuple[1],
                    test_tr_splice_jnc_tuple[2]):
    
    # check for match (M) | reference skip/intron (N) | match (M)
    parsed_cigar = read.cigartuples
    #print(parsed_cigar)
    #print(len(parsed_cigar))
        
    # Want to get indexes of Ms in a M | N | M sequence fro cigartuples
    # (can then slice parsed_cigar to check whether Ms are long enough)
    # M | N | M = [0, 3, 0] (in cigartuples looking to match to first element in each tuple)
    seq_to_check = [0, 3, 0]
        
    #print(parsed_cigar[0:0+len(seq_to_check)])
        
        
    cigar_check = [(i, i + len(seq_to_check) - 1) for i in range(len(parsed_cigar))
                       if [tup[0] for tup in parsed_cigar[i:i + len(seq_to_check)]] == seq_to_check]
        
    #print(len(cigar_check))
        
    if len(cigar_check) > 0:
        # Has a putative exon - exon alignment
        # Now check if length of matched sequences is sufficient
        for idx_start, idx_end in cigar_check:
                
            # Now check each string in matching sequence
            if parsed_cigar[idx_start][1] > 3 and parsed_cigar[idx_end][1] > 3:
                print("read has minimum overhang & is valid junction read")
                jnc_count += 1

            else:
                print("read does not have minimum exon-exon overhang")
                        
                    
    else:
        print("Read does not have a putative exon-exon junction alignment")            
       
        
print(jnc_count)    

print(test_tr_splice_jnc_tuple)
#for col in wt_1.pileup(test_tr_splice_jnc_tuple[0],
#                      test_tr_splice_jnc_tuple[1],
#                      test_tr_splice_jnc_tuple[2]):
#    print("coverage at position {0} : {1}".format(col.pos, col.n))
    
count = 0
for col in wt_1.fetch('chr2', 154588091, 154588092):
    count +=1

print(count)

read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is valid junction read
read has minimum overhang & is 