In [1]:
import numpy as np
import pandas as pd
import torch
import os
from fastprogress import progress_bar
from torch.utils.data import DataLoader
import seaborn as sns
from prettytable import PrettyTable
sns.set(font_scale=1.8)

In [2]:
global data_folder
global device
data_folder = 'data'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
results_folder = 'results_FSei'

# Utils

In [3]:
from src.attribution_utils import IGDataset,extract_seq,get_motif,mat_product
from src.utils import hot_encode_sequence

# Jaspar database

In [4]:
tfs_data = open('jaspar.meme.txt').readlines()
motifs = dict()
matrix_data = False
key = None
binding_site_matrix = []
for line in tfs_data:
    if line.startswith('MOTIF'):
        key = ' '.join(line.strip().split(' ')[1:])
        binding_site_matrix = []
    if line.startswith('URL'):
        matrix_data = False
        if len(binding_site_matrix)!=0:
            motifs[key] = np.stack(binding_site_matrix,axis=1)
            binding_site_matrix = []
    if matrix_data:
        vector = np.array([n.strip() for n in line.strip().split('  ')],dtype=float)
        binding_site_matrix.append(vector)
        #motif += bp_dict[np.argmax(vector)]
    if line.startswith('letter'):
        matrix_data = True

In [5]:
df_path = 'data/Labelled_Data_IR_iDiffIR_corrected'
fa_file = 'data/data.fa'
df_path = df_path.split('.')[0] #just in case the user provide extension
df_all = pd.read_csv(df_path+'.txt',delimiter='\t',header=None)
df_seq = pd.read_csv(fa_file,header=None)
strand = df_seq[0][0][-3:] #can be (+) or (.) 
df_all['header'] = df_all.apply(lambda x: '>'+x[0]+':'+str(x[1])+'-'+str(x[2])+strand, axis=1)

df_seq_all = pd.concat([df_seq[::2].reset_index(drop=True), df_seq[1::2].reset_index(drop=True)], axis=1, sort=False)
df_seq_all.columns = ["header","sequence"]
df_seq_all['sequence'] = df_seq_all['sequence'].apply(lambda x: x.upper())

df_all.rename(columns={7: "label"},inplace=True)

df_final = pd.merge(df_seq_all[['header',"sequence"]],df_all[['header','label']],on='header',how='inner')
df_final.drop_duplicates(inplace=True)

DNAalphabet = {'A':0, 'C':0, 'G':0, 'T':0}
for idx,row in df_final.iterrows():
    for nt in DNAalphabet.keys():
        DNAalphabet[nt] = DNAalphabet[nt] + row.sequence.count(nt)
sum_nt = sum(DNAalphabet.values()) 
for nt in DNAalphabet.keys():
    DNAalphabet[nt] = DNAalphabet[nt]/sum_nt
back_freq = np.array(list(DNAalphabet.values())).reshape(4,1)

In [6]:
pd.DataFrame([motifs[key].shape[1] for key in motifs]).describe(percentiles=[0.5,0.75,0.8,0.9,0.95,0.99])

Unnamed: 0,0
count,949.0
mean,12.287671
std,3.399625
min,5.0
50%,12.0
75%,14.0
80%,15.0
90%,16.0
95%,18.0
99%,21.0


# Main

In [7]:
Udir = '239916'
Udir_path = os.path.join(results_folder,Udir)
model_path = 'model_23_10_04:10:59.pkl'
model = torch.load(os.path.join(Udir_path,model_path)).to(device)
window_size = 24
IG_threshhold = 0.7

In [9]:
relevant_targets = [0]
test_sampler = np.loadtxt('data/test_indices.txt', dtype=int)
valid_sampler = np.loadtxt('data/valid_indices.txt', dtype=int)
sampler = np.concatenate((test_sampler,valid_sampler))
df_path = 'Labelled_Data_IR_iDiffIR_corrected'
fa_file = 'data.fa'
IG_dataset_0 = IGDataset(df_path=df_path,fa_file=fa_file,sampler = sampler, relevant_targets=relevant_targets)
IG_loader_0 = DataLoader(IG_dataset_0,batch_size=32)
seqs_0,IGs_0,Scores_0,n_unique_sequences_0 = extract_seq(model,IG_loader_0,window_size,IG_threshhold,0,device)

In [12]:
IG_window_size = 24

In [13]:
binding_sites_0 = {}
i = 0
for tf in progress_bar(motifs.keys()):
    if IG_window_size >= motifs[tf].shape[1]:
        motif = motifs[tf]
        pseudo_motif = np.where(motif==0,0.001,motif)
        log_odd_motif = np.log(np.divide(pseudo_motif, back_freq))
        max_motif_score = np.sum(np.max(log_odd_motif),axis=0)
        for seq in seqs_0:
            hot_encode = hot_encode_sequence(seq)
            if mat_product(log_odd_motif,hot_encode,threshold=0.6):
                binding_sites_0[tf] = binding_sites_0.get(tf,0) + 1

binding_sites_sorted_0 = sorted(binding_sites_0.items(), key=lambda x:x[1], reverse=True)
converted_dict_0 = dict(binding_sites_sorted_0)

In [14]:
relevant_targets = [1]
test_sampler = np.loadtxt('data/test_indices.txt', dtype=int)
valid_sampler = np.loadtxt('data/valid_indices.txt', dtype=int)
sampler = np.concatenate((test_sampler,valid_sampler))
input_prefix = 'Labelled_Data_IR_iDiffIR_corrected'
fa_file = 'final_data_IR/data.fa'
IG_dataset_1 = IGDataset(df_path=input_prefix,fa_file=fa_file,sampler = sampler, relevant_targets=relevant_targets)
IG_loader_1 = DataLoader(IG_dataset_1,batch_size=32)
seqs_1,IGs_1,Scores_1,n_unique_sequences_1 = extract_seq(model,IG_loader_0,window_size,IG_threshhold,1,device)

In [15]:
binding_sites_1 = {}
i = 0
for tf in progress_bar(motifs.keys()):
    if IG_window_size >= motifs[tf].shape[1]:
        motif = motifs[tf]
        pseudo_motif = np.where(motif==0,0.001,motif)
        log_odd_motif = np.log(np.divide(pseudo_motif, back_freq))
        max_motif_score = np.sum(np.max(log_odd_motif),axis=0)
        for seq in seqs_1:
            hot_encode = hot_encode_sequence(seq)
            if mat_product(log_odd_motif,hot_encode,threshold=0.8*max_motif_score):
                binding_sites_1[tf] = binding_sites_1.get(tf,0) + 1

binding_sites_sorted_1 = sorted(binding_sites_1.items(), key=lambda x:x[1], reverse=True)
converted_dict_1 = dict(binding_sites_sorted_1)

In [16]:
prob_diff = dict()
for tf in list(set(converted_dict_0.keys()) | set(converted_dict_1.keys())):
    prob_diff[tf] = int(converted_dict_1.get(tf,0))/len(seqs_1) - int(converted_dict_0.get(tf,0))/len(seqs_0)
prob_diff=dict(sorted(prob_diff.items(), key=lambda item: item[1], reverse=True))

In [69]:
#import statsmodels.api as sm
from statsmodels.stats.proportion import proportions_ztest
from statsmodels.stats.multitest import multipletests
pvalue_dict= dict()
for idx,tf in enumerate(list(prob_diff.keys())): 
    # Define the counts and sample sizes for two groups
    count_non_IR = int(converted_dict_0.get(tf,0))  # Number of successes in group 1
    count_IR = int(converted_dict_1.get(tf,0))  # Number of successes in group 2
    z_score, p_value = proportions_ztest([count_non_IR, count_IR], [len(seqs_0),len(seqs_1)])
    pvalue_dict[tf] = p_value
adjusted_pvalues = multipletests(list(pvalue_dict.values()),method='bonferroni')[1]

In [70]:
table = PrettyTable(["rank","TF", "motif", "non-IR", "IR","diff prob","pvalue"])
for idx,tf in enumerate(list(prob_diff.keys())):
    table.add_row([idx+1,tf,get_motif(motifs[tf]),converted_dict_0.get(tf,0),converted_dict_1.get(tf,0),
                  round(prob_diff.get(tf),3),adjusted_pvalues[idx]])
results_file = os.path.join(Udir_path,'nsites.txt')
print(table.get_string(),file=results_file)
print(table.get_string())

+------+--------------------------------+--------------------------+--------+-----+-----------+-------------------------+
| rank |               TF               |          motif           | non-IR |  IR | diff prob |          pvalue         |
+------+--------------------------------+--------------------------+--------+-----+-----------+-------------------------+
|  1   |    MA1648.1 MA1648.1.TCF12     |       CGCACCTGCCG        |  9334  | 625 |   0.263   |  7.598230103939114e-165 |
|  2   |     MA0522.3 MA0522.3.TCF3     |       CGCACCTGCCC        |  8067  | 565 |   0.244   | 2.1572453987821423e-159 |
|  3   |    MA0003.1 MA0003.1.TFAP2A    |        GCCCGGGGG         |  7490  | 527 |   0.228   | 2.6820652811093807e-148 |
|  4   |     MA0830.2 MA0830.2.TCF4     |      CGGCACCTGCCCC       |  5525  | 474 |   0.225   |  4.358577861161846e-186 |
|  5   |    MA0820.1 MA0820.1.FIGLA     |        ACCACCTGTT        |  7643  | 489 |   0.201   |  8.566996566935598e-113 |
|  6   |     MA1515.1 MA

In [18]:
IG_dataset_0.df_final.shape[0],IG_dataset_1.df_final.shape[0],n_unique_sequences_0,n_unique_sequences_1,len(seqs_0),len(seqs_1)

(11102, 3182, tensor(9152), tensor(313), 61244, 1503)

In [19]:
DNAalphabet #Background proba

{'A': 0.2683035884612275,
 'C': 0.2310727183862361,
 'G': 0.23145329313601504,
 'T': 0.26917040001652137}