In [None]:
import pandas as pd
import glob
from collections import defaultdict
from statistics import median
from statistics import mean
from kneed import KneeLocator
import pymannkendall as mk
from statistics import stdev
import matplotlib.pyplot as plt
import numpy as np
from natsort import natsorted
import importlib

In [None]:
def format_func_K(value, tick_number):
    # find number of multiples of pi/2
    if value % 1000 == 0 and value != 0:
        value = '{:.0f}'.format(value / 1000)
        return str(value) + "K"

In [None]:
def Counting(filename, tpm_thr_min, tpm_thr_max, psi_thr_min, psi_thr_max):
    summary = open(path + filename)
    tpm_thr_min = tpm_thr_min
    tpm_thr_max = tpm_thr_max
    psi_thr_min = psi_thr_min
    psi_thr_max = psi_thr_max
    
    as_gene_dict = defaultdict(set)
    as_junction_dict = defaultdict(set)
    
    number_gene_dict = defaultdict(set)
    number_junction_dict = defaultdict(set)
    
    depth_dict = {}
    
    mean_gene_dict = {}
    mean_junction_dict = {}
    
    stdev_gene_dict = {}
    stdev_junction_dict = {}
    
    for line in summary:
        phenotype, sample, gene, junction, psi, depth, tpm = line.split()
        junction_id = gene + '_' + junction
        depth_dict[sample] = depth
        if float(psi) >= psi_thr_min and float(psi) < psi_thr_max:
            if float(tpm) >= tpm_thr_min and float(tpm) < tpm_thr_max:
                as_gene_dict[sample].add(gene)#all genes with AS for each sample
                as_junction_dict[sample].add(junction_id) #all junctions for each sample 
    
    for k,v in as_gene_dict.items():
        number_gene_dict[depth_dict[k]].add(len(v))#number of genes with AS for each sample
    for k,v in as_junction_dict.items():
        number_junction_dict[depth_dict[k]].add(len(v))#number of genes with AS for each sample
        
    for k,v in number_gene_dict.items():
        mean_gene_dict[k] = mean(v)
        stdev_gene_dict[k] = stdev(v)
        
    for k,v in number_junction_dict.items():
        mean_junction_dict[k] = mean(v)
        stdev_junction_dict[k] = stdev(v)
        
    return(mean_gene_dict, stdev_gene_dict, mean_junction_dict, stdev_junction_dict)


In [None]:
#Visualisation
def Plotting(phenotype, average_as, std_as, tpm_thr_min, tpm_thr_max, max_y):
   
    fig, ax = plt.subplots(figsize=(15, 10))
    barWidth = 0.5
    ind1 = np.arange(10, step=1)
    gene_y = []
    c = []
    for k,v in dict(natsorted(average_as.items())).items():
        gene_y.append(v)

    for k,v in dict(natsorted(std_as.items())).items():
        c.append(v)
    
    
    p1_1 = plt.bar(ind1, gene_y, yerr = c, width=barWidth-0.1, color="#004D40")
    
    plt.xticks([r for r in [ 0, 1, 2, 3, 4, 5, 6 , 7, 8, 9]], 
           ['60M (4)', '70M (10)', '80M (23)', '90M (18)', 
            '100M (8)', '110M (4)', '120M (15)', '130M (7)', '140M (8)', '150M (2)'])
    
    plt.yticks(fontsize=50)
    plt.xticks(fontsize=40, rotation = 90)
    ax.yaxis.set_major_formatter(plt.FuncFormatter(format_func_K))
    plt.ylabel(str(tpm_thr_min) + '<=TPM<' + str(tpm_thr_max), fontsize=60)

    plt.ylim(0, max_y)
    plt.show()

In [None]:
path = ""#path to covid_summary.tsv

In [None]:
filename = 'covid_summary.tsv'

one_gene, one_std_gene, one_junction, one_std_junction = Counting(filename, 0, 0.1, 0.05, 0.95)
two_gene, two_std_gene, two_junction, two_std_junction = Counting(filename, 0.1, 0.5, 0.05, 0.95)
three_gene, three_std_gene, three_junction, three_std_junction = Counting(filename, 0.5, 1, 0.05, 0.95)
four_gene, four_std_gene, four_junction, four_std_junction = Counting(filename, 1, 10, 0.05, 0.95)
five_gene, five_std_gene, five_junction, five_std_junction = Counting(filename, 10, float('inf'), 0.05, 0.95)
six_gene, six_std_gene, six_junction, six_std_junction = Counting(filename, 1, float('inf'), 0.05, 0.95)



In [None]:
y_max = max(max(list(one_gene.values())), max(list(two_gene.values())), 
            max(list(three_gene.values())), max(list(four_gene.values())),
            max(list(five_gene.values())), max(list(six_gene.values())))

In [None]:
#Mann-Kendall Trend test
result = [one_gene[key] for key in natsorted(one_gene.keys(), reverse=False)]
print(mk.original_test(result))#for each list

In [None]:
#Genes
tpm_thr_min = 10 #0, 0.1, 0.5, 1, 10
tpm_thr_max = float('inf')#0.1, 0.5, 1, 10, float('inf'), float('inf')

if tpm_thr_min == 0:
    mean_value = one_gene
    std_value = one_std_gene
    
if tpm_thr_min == 0.1:
    mean_value = two_gene
    std_value = two_std_gene
    
if tpm_thr_min == 0.5:
    mean_value = three_gene
    std_value = three_std_gene
    
if tpm_thr_min == 1 and tpm_thr_max == 10:
    mean_value = four_gene
    std_value = four_std_gene
    
if tpm_thr_min == 10:
    mean_value = five_gene
    std_value = five_std_gene
    
if tpm_thr_min == 1 and tpm_thr_max == float('inf'):
    mean_value = six_gene
    std_value = six_std_gene
    

In [None]:
Plotting('Covid-19', mean_value, std_value, tpm_thr_min, tpm_thr_max,  y_max+ 1000)

In [None]:
#Junctions
tpm_thr_min = 10 #0, 0.1, 0.5, 1, 10
tpm_thr_max = float('inf')#0.1, 0.5, 1, 10, float('inf'), float('inf') 

if tpm_thr_min == 0:
    mean_value = one_junction
    std_value = one_std_junction
    
if tpm_thr_min == 0.1:
    mean_value = two_junction
    std_value = two_std_junction
    
if tpm_thr_min == 0.5:
    mean_value = three_junction
    std_value = three_std_junction
    
if tpm_thr_min == 1 and tpm_thr_max == 10:
    mean_value = four_junction
    std_value = four_std_junction
    
if tpm_thr_min == 10:
    mean_value = five_junction
    std_value = five_std_junction
    
if tpm_thr_min == 1 and tpm_thr_max == float('inf'):
    mean_value = six_junction
    std_value = six_std_junction
    

In [None]:
y_max = max(max(list(one_junction.values())), max(list(two_junction.values())), 
            max(list(three_junction.values())), max(list(four_junction.values())),
            max(list(five_junction.values())), max(list(six_junction.values())))

In [None]:
Plotting('Covid-19', mean_value, std_value, tpm_thr_min, tpm_thr_max,  y_max+10000)

In [None]:
result = [one_junction[key] for key in natsorted(one_junction.keys(), reverse=False)]
print(mk.original_test(result))#for each list