In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch
import glob
import pickle
import matplotlib.pyplot as plt

from interpretation.interpret import compute_importance_score_c_type, compute_importance_score_bias, visualize_sequence_imp
from models.models import CATAC2, CATAC_w_bias

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Sample from peak sequences

In [None]:
with open('../results/peaks_seq.pkl', 'rb') as file:
    seq = pickle.load(file)

seq

Unnamed: 0_level_0,chr,start,end,middle_peak,sequence,GC_cont
peakID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1:14154-15100,1,14154,15100,14627,GCCCACCGGCCCCAGGCTCCTGTCTCCCCCCAGGTGTGTGGTGATG...,0.585938
1:15698-16677,1,15698,16677,16188,GGCCATTAGGCTCTCAGCATGACTATTTTTAGAGACCCCGTGTCTG...,0.599854
1:17116-17963,1,17116,17963,17540,GGCTGACCATTGCCTTGGACCGCTCTTGGCAGTCGAAGAAGATTCT...,0.597168
1:28903-29812,1,28903,29812,29358,CACATGCAGCCACTGAGCACTTGAAATGTGGATAGTCTGAATTGAG...,0.489746
1:180637-181553,1,180637,181553,181095,ATGCTGATGGATTGTCAGCTTCCCAGATGTGCAAGAATCTCTCCTC...,0.537109
...,...,...,...,...,...,...
Y:56724332-56725266,Y,56724332,56725266,56724799,ACGGAATGGAATGGAATCCAAAGGAATGGAATAGAATGGAATGGAA...,0.389404
Y:56727619-56728608,Y,56727619,56728608,56728114,TGGAATGCACTCGAATGCAATGGAGTCGAAACAAATGGACTGGAAT...,0.391602
Y:56763067-56763979,Y,56763067,56763979,56763523,TGGAAGGGAGTGTAATGCAAGGTTCTCGAAAATAATGGAATCGAAT...,0.393555
Y:56829118-56830024,Y,56829118,56830024,56829571,TTCCTTTCCTTTCCATTCCATTCCTTTCCATTCCATTCCATGCGAG...,0.435059


# Add sequence overlapping with CUT&Tag peaks

In [None]:
#Open peak pandas dataframe and write to .bed file
with open('../results/peaks_seq.pkl', 'rb') as file:
    peaks = pickle.load(file)

peaks.chr = ["chr" + str(c) for c in peaks.chr]
peaks[["chr","start","end"]].to_csv('../results/all_peaks.bed', sep="\t", index=False, header=False)

#Compute intersection between genomic regions
#-----------------------------------------------

#MyoD1
!bedtools intersect -a ../results/all_peaks.bed -b ../data/MYOD1_25m_s450.bed -wa > ../results/MyoD1_intersect.bed
MYOD1 = pd.read_csv("../results/MyoD1_intersect.bed", header=None, sep= "\t")
MYOD1["ID"] = MYOD1[0] + ":" +  MYOD1[1].astype('str') + "-" + MYOD1[2].astype("str")
MYOD1["ID"] =  [x[3:] for x in MYOD1.ID]; MYOD1 = MYOD1.ID.tolist()

#HES1
!bedtools intersect -a ../results/all_peaks.bed -b ../data/HES1_D8_REP1_filt_TFBS_HES1_s400.bed -wa > ../results/HES1_intersect.bed

HES1 = pd.read_csv("../results/HES1_intersect.bed", header=None, sep= "\t")
HES1["ID"] = HES1[0] + ":" +  HES1[1].astype('str') + "-" + HES1[2].astype("str")
HES1["ID"] =  [x[3:] for x in HES1.ID]; HES1 = HES1.ID.tolist()

#PAX3
!bedtools intersect -a ../results/all_peaks.bed -b ../data/PAX3_D8_REP1_s500.bed  -wa > ../results/PAX3_intersect.bed

PAX3 = pd.read_csv("../results/PAX3_intersect.bed", header=None, sep= "\t")
PAX3["ID"] = PAX3[0] + ":" +  PAX3[1].astype('str') + "-" + PAX3[2].astype("str")
PAX3["ID"] =  [x[3:] for x in PAX3.ID]; PAX3 = PAX3.ID.tolist()

TF_peaks = MYOD1 + HES1 + PAX3

  pid, fd = os.forkpty()
  pid, fd = os.forkpty()
  pid, fd = os.forkpty()


In [None]:
extra = seq.loc[TF_peaks].sequence.tolist()
seq = seq.sample(10000).sequence.tolist() + extra
len(seq)

12216

In [None]:
seq = pd.Series(seq)
seq

0        GAGGGACTTAGAACATGAGGGACCATCATCTCTGTTCAAATTCACT...
1        ACCCTGGGTGGGGATCCTCGGGGCTTCCGGGTGCAGACCTCCCCAC...
2        TGTTTTCTCCTCTGGAAAGGAGCATGCAGGTGTGTCTGGCTGAGAC...
3        AAAAAAAAAAAGGAACAGTGCTAGAGACAAGTTCAGATAACATCTT...
4        TTCTCCTGCCTCAGCCTCCCAAGTAGCAGGAATTACAGACCTGCAC...
                               ...                        
12211    GGAACTGTTATCATGTTAGAGTAAATTAGATTTCTTGAGGGAAGTG...
12212    GGAACTGTTATCATGTTAGAGTAAATTAGATTTCTTGAGGGAAGTG...
12213    GGAACTGTTATCATGTTAGAGTAAATTAGATTTCTTGAGGGAAGTG...
12214    GGAACTGTTATCATGTTAGAGTAAATTAGATTTCTTGAGGGAAGTG...
12215    TGGAAATATTTCATTGCTTGATAGTGGTACAAGTTAATGATTATGT...
Length: 12216, dtype: object

# Compute importance scores

In [None]:
path_model = '../results/train_res/128_model.pkl'

all_c_type = ['Immature', 'Mesenchymal', 'Myoblast', 'Myogenic', 'Neuroblast',
       'Neuronal', 'Somite']
time_point = ["D8", "D12", "D20", "D22"]

#Load the model
model = CATAC_w_bias(nb_conv=8, nb_filters=128, first_kernel=21, 
                      rest_kernel=3, out_pred_len=1024, 
                      nb_pred=4)
        
model.load_state_dict(torch.load(path_model, map_location=torch.device('cpu')))

path_model_bias = "../data/Tn5_NN_model.h5"

#Compute attribution scores
seq, shap_scores, proj_scores = compute_importance_score_bias(model, path_model_bias, seq, device, "Myogenic", all_c_type, 1)

In [None]:
#Save encoded seq + scores
np.savez('../results/encod_seq.npz', seq[:,:4,:])
np.savez('../results/seq_scores.npz', shap_scores[:,:4,:], proj_scores[:,:4,:])

print("Shap scores saved!")

In [None]:
seq = np.load('../results/encod_seq.npz')["arr_0"]
shap_scores = np.load('../results/seq_scores.npz')
proj_scores = shap_scores['arr_1']; shap_scores = shap_scores['arr_0']

# Visualize few examples

In [None]:
visualize_sequence_imp(proj_scores[[73],:4,:] ,0, 4096)
visualize_sequence_imp(proj_scores[[1266],:4,:] ,0, 4096)
visualize_sequence_imp(proj_scores[[563],:4,:] ,0, 4096)

# Use TF-modisco to find TFBS
Following tutorial at: https://github.com/jmschrei/tfmodisco-lite/blob/main/examples/ModiscoDemonstration.ipynb

In [None]:
!modisco motifs -s  ../results/encod_seq.npz -a  ../results/seq_scores.npz -n 2000 -o modisco_results.h5

In [None]:
!modisco report -i modisco_results.h5 -o report/

In [None]:
from IPython.display import HTML
HTML('report/motifs.html')

# Run TOMTOM on modisco results

In [None]:
!modisco report -i modisco_results.h5 -o report/TOMTOM/ -s report/TOMTOM/ -m ../data/JASPAR_motif.txt

In [None]:
from IPython.display import HTML

HTML('report/motifs.html')