In [None]:
import time
import logging
import sys
import gzip

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

import scipy.stats as stats
from scipy.signal import savgol_filter

from collections import Counter

%matplotlib inline

pd.set_option('display.max_colwidth', None)

samples = ['data_2021-08-27/ky08', 'data_2021-08-27/ky10', 'data_2021-08-27/ky17', 'data_2021-08-27/ky18']
#samples = ['data_2021-08-27/ky11', 'data_2021-08-27/ky13']

In [None]:
#load in preprocessed dfs
dfs = []

for i, x in enumerate(samples):
    df = pd.read_csv(x+'/'+x[-4:]+'_info.new.clean.gtf')
    print(i)
    dfs.append(df)

In [None]:
lorax_sigs = []

for df_num, d in enumerate(dfs):
    
    print("working on: " + str(df_num))

    pos = d[d.read_strand_1 == '+']
    neg = d[d.read_strand_1 == '-']

    searchsize = 100
    pauses_chrom_pos = []
    pos_chrs = pos.chr_1.unique()

    for c in pos_chrs:
        pauses = []

        x = pos[pos.chr_1 == c].sort_values('3prime')

        x = x.groupby(['3prime']).size().reset_index()
        x = x.rename(columns = {0: 'val'})
        x = list(zip(x['3prime'], x['val']))

        last_pause = 0   
        for i, [prime, count] in enumerate(x):
            if count > 4:
                avg = []

                left = prime - searchsize
                #search left
                for j in range(i-1, -1, -1):
                    if x[j][0] > left:
                        avg.append(x[j][1])
                    else:
                        break

                left = prime + searchsize
                for j in range(i, len(x)):
                    if x[j][0] < left:
                        avg.append(x[j][1])
                    else:
                        break
                        
                #pad list
                avg += [0] * (searchsize*2 - len(avg))
                
                std = np.std(avg)
                avg = sum(avg) / (searchsize*2)

                if count > avg + 3*std:
                    pauses.append([prime, count])
                    last_pause = prime

        pauses_chrom_pos.append([c, pauses])

        
    pauses_chrom_neg = []
    neg_chrs = neg.chr_1.unique()

    for c in neg_chrs:
        pauses = []

        x = neg[neg.chr_1 == c].sort_values('3prime')

        x = x.groupby(['3prime']).size().reset_index()
        x = x.rename(columns = {0: 'val'})
        x = list(zip(x['3prime'], x['val']))

        last_pause = 0   
        for i, [prime, count] in enumerate(x):
            if count > 4:
                avg = []

                left = prime - searchsize
                #search left
                for j in range(i-1, -1, -1):
                    if x[j][0] > left:
                        avg.append(x[j][1])
                    else:
                        break

                left = prime + searchsize
                for j in range(i, len(x)):
                    if x[j][0] < left:
                        avg.append(x[j][1])
                    else:
                        break
                #pad list
                avg += [0] * (searchsize*2 - len(avg))

                std = np.std(avg)
                avg = sum(avg) / (searchsize*2)

                if count > avg + 3*std:
                    pauses.append([prime, count])
                    last_pause = prime

        pauses_chrom_neg.append([c, pauses])
        
    p_temp = []

    for c in pauses_chrom_pos:
        for c1 in c[1]:
            p_temp.append([c[0], c1[0], c1[1]])

    hotspots_pos = pd.DataFrame(p_temp, columns=['chr_1', '3prime', 'val'])
    hotspots_pos['read_strand_1'] = '+'

    p_temp = []

    for c in pauses_chrom_neg:
        for c1 in c[1]:
            p_temp.append([c[0], c1[0], c1[1]])

    hotspots_neg = pd.DataFrame(p_temp, columns=['chr_1', '3prime', 'val'])
    hotspots_neg['read_strand_1'] = '-'
    
    hotspots = pd.concat([hotspots_pos, hotspots_neg])

    top_hotspots = hotspots.sort_values('val', ascending=False)

    lorax_sig = d.merge(top_hotspots, left_on=['chr_1', '3prime', 'read_strand_1'], right_on=['chr_1', '3prime', 'read_strand_1'], how='right')

    lorax_sig = lorax_sig.drop_duplicates(['chr_1', '3prime']).sort_values('val', ascending=False)
    
    print(len(lorax_sig))
    
    lorax_sigs.append(lorax_sig)
    
    location = samples[df_num]
    lorax_sig.to_csv(location+'/hotspots.new.clean.gtf', index=False)

In [None]:
#save wt netseq pauses as bed for plotting

l = lorax_sigs[2]

a = l.loc[l.read_strand_1 == '+', ['chr_1', '3prime', '3prime']]
a.iloc[:, 2] = a.iloc[:, 2] + 1
a.to_csv('tmm/plus_pause.bed', sep= '\t', index=False, header=False)

a = l.loc[l.read_strand_1 == '-', ['chr_1', '3prime', '3prime']]
a.iloc[:, 2] = a.iloc[:, 2] + 1
a.to_csv('tmm/minus_pause.bed', sep= '\t', index=False, header=False)