In [None]:
import pickle
import random
import time
from pathlib import Path

import numpy as np
import pandas as pd
import scipy
from scipy.sparse import csr_matrix

import pyranges as pr
from pyranges.pyranges_main import PyRanges

from utils import *
from peak_parser import *


In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

# with tf.device('/CPU:0'): 
      # pass
import tensorflow as tf
import data_io

In [None]:
## Construct dataset

In [None]:
#
path_peak='data/predicting'
cell='MEF'
path_cell=path_peak+'/'+cell

#MEF
atac='SRR5077744_peaks.narrowPeak'
h3k27ac='SRR5077641_peaks.narrowPeak'
h3k27me3='SRR5077645_peaks.broadPeak'
h3k4me1='SRR5077633_peaks.broadPeak'
h3k4me3='SRR5077625_peaks.narrowPeak'
h3k36me3='SRR5077653_peaks.broadPeak'
h3k9me3='SRR5077657_peaks.broadPeak'

oh_atac='ATAC.pickle'
oh_h3k27ac='H3K27ac.pickle'
oh_h3k27me3='H3K27me3.pickle'
oh_h3k4me1='H3K4me1.pickle'
oh_h3k4me3='H3K4me3.pickle'
oh_h3k36me3='H3K36me3.pickle'
oh_h3k9me3='H3K9me3.pickle'

#
peak_2onehot_chrom_whole(path_cell+'/'+atac, path_cell+'/'+oh_atac)
peak_2onehot_chrom_whole(path_cell+'/'+h3k27ac, path_cell+'/'+oh_h3k27ac)
peak_2onehot_chrom_whole(path_cell+'/'+h3k27me3, path_cell+'/'+oh_h3k27me3)
peak_2onehot_chrom_whole(path_cell+'/'+h3k4me1, path_cell+'/'+oh_h3k4me1)
peak_2onehot_chrom_whole(path_cell+'/'+h3k4me3, path_cell+'/'+oh_h3k4me3)
peak_2onehot_chrom_whole(path_cell+'/'+h3k36me3, path_cell+'/'+oh_h3k36me3)
peak_2onehot_chrom_whole(path_cell+'/'+h3k9me3, path_cell+'/'+oh_h3k9me3)


In [None]:
#Extract promoter regions (TSS: up-2000bp, down-1000bp)
#gtf file was downloaded from https://www.gencodegenes.org/
#
#zcat gencode.vM25.annotation.gtf.gz | \
#awk 'BEGIN{OFS=FS="\t"}{if($3=="gene") {if($7=="+") {start=$4-2000; end=$4+1000;} else {if($7=="-") start=$5-1000; end=$5+2000;} if(start<0) start=0; print $1,start,end,$3,$6,$7,$2,$8,$9}}'| \
#grep protein_coding |cut -f 1,2,3|sort|uniq> gencode_vM25_gene_promoter_protein_coding_uniq.bed


In [None]:
#Extract promoter sequence
genome_fa_file='data/genome/GRCm38.primary_assembly.genome.fa'  #

path_bins='data/genome/windows_mm10/promoter'
bed_file=path_bins+'/'+'gencode_vM25_gene_promoter_protein_coding_bins.bed'
reg_file=path_bins+'/'+'gencode_vM25_gene_promoter_protein_coding_regs.bed'
seq_file=path_bins+'/'+'gene_promoter_bins.fa'

extract_fasta(bed_file=bed_file,fa_file=genome_fa_file,seq_file=seq_file)

In [None]:
#
epi_infos=['ATAC', 'H3K27ac', 'H3K27me3', 'H3K36me3', 'H3K4me1', 'H3K4me3', 'H3K9me3']
epi_targets=['ATAC', ]

tmpdir=path_peak+'/'+'tmp'
Path(tmpdir).mkdir(exist_ok=True)  #

print('###cell line:', cell)

ohpeak_files={'ATAC':path_cell+'/'+'ATAC.pickle',
              'H3K27ac':path_cell+'/'+'H3K27ac.pickle',
              'H3K27me3':path_cell+'/'+'H3K27me3.pickle',
              'H3K36me3':path_cell+'/'+'H3K36me3.pickle',
              'H3K4me1':path_cell+'/'+'H3K4me1.pickle',
              'H3K4me3':path_cell+'/'+'H3K4me3.pickle',
              'H3K9me3':path_cell+'/'+'H3K9me3.pickle'
             }

bed_file=path_bins+'/'+'gencode_vM25_gene_promoter_protein_coding_bins.bed'
reg_file=path_bins+'/'+'gencode_vM25_gene_promoter_protein_coding_regs.bed'
seq_file=path_bins+'/'+'gene_promoter_bins.fa'

sent_file=path_cell+'/'+'gene_promoter_bins.pickle'
n=wc(reg_file)

#split file
if n>100000:
    ns=n//100000+1  #
    with open(bed_file,'r') as f:
        beds=f.readlines()
    with open(reg_file,'r') as f:
        regs=f.readlines()
    with open(seq_file,'r') as f:
        seqs=f.readlines()
    for i in range(ns):
        sp=i*100000
        ep=(i+1)*100000  #
        if ep>n:
            ep=n
        sp_=i*200000
        ep_=(i+1)*200000
        if ep_>n*2:
            ep_=n*2
        bed_file_sub=tmpdir+'/'+'gene_promoter_bins_'+str(i)+'.bed'
        reg_file_sub=tmpdir+'/'+'gene_promoter_regs_'+str(i)+'.bed'
        seq_file_sub=tmpdir+'/'+'gene_promoter_bins_'+str(i)+'.fa'
        sent_file_sub=tmpdir+'/'+'gene_promoter_bins_'+str(i)+'.pickle'
        with open(bed_file_sub,'w') as f:
            f.writelines(beds[sp:ep])
        with open(reg_file_sub,'w') as f:
            f.writelines(regs[sp:ep])
        with open(seq_file_sub,'w') as f:
            f.writelines(seqs[sp_:ep_])
        
        generate_peak_context(seq_file=seq_file_sub,reg_file=reg_file_sub,label=0,targets=epi_infos,targets_files=ohpeak_files,out_file=sent_file_sub,tmpdir=tmpdir)


In [None]:
#
file_pickles=['gene_promoter_bins_0.pickle',
             'gene_promoter_bins_1.pickle',
             'gene_promoter_bins_2.pickle',
             'gene_promoter_bins_3.pickle',
             'gene_promoter_bins_4.pickle',
             'gene_promoter_bins_5.pickle',
             'gene_promoter_bins_6.pickle',]
sent_file=path_cell+'/'+'gene_promoter_bins.pickle'
cont_pickles=[]

for f in file_pickles:
    f2=tmpdir+'/'+f
    with open(f2, 'rb') as f:
        samples = pickle.load(f)  # [((dna_seq,epi_seq),label),...]
    print(len(samples))
    cont_pickles.extend(samples)

with open(sent_file, 'wb') as f:
    pickle.dump(cont_pickles, f, pickle.HIGHEST_PROTOCOL)

In [None]:
#
target_epis=['ATAC',]

sent_file_pos=path_cell+'/'+'gene_promoter_bins.pickle'
file_set=path_cell+'/'+'gene_promoter_bins.tfrecord'
file_set_atac=path_cell+'/'+'gene_promoter_bins_atac.tfrecord'

binsset=data_io.select_sample(sample_file=sent_file)
print(len(binsset))
data_io.write_tfrecord(binsset,file_set)
binsset=[mask_peak_context(x,target_epis,keep_dna=True) for x in binsset]
data_io.write_tfrecord(binsset,file_set_atac)


In [None]:
### Make prediction

In [None]:
#
import logging
import time
import datetime
from pathlib import Path

import tensorflow as tf
import tensorflow_text as text

import data_io
from data_io import _parse_function

from utils import *
from model import *

In [None]:
vocab_file = 'vocab.txt'
tokenizer = text.BertTokenizer(vocab_file, token_out_type=tf.int64)

def prepare_batch(example, label):
    dna = example[0]
    epi = tokenizer.tokenize(example[1])
    epi = epi.merge_dims(-2, -1).to_tensor()  #
    return (dna, epi), label

def make_batches(ds, batch_size=32,buffer_size=20000,shuffle=False,):
    if shuffle:
        ds = ds.shuffle(buffer_size)
    return (
        ds
        .batch(batch_size)
        .map(prepare_batch, tf.data.AUTOTUNE)
        .prefetch(buffer_size=tf.data.AUTOTUNE))

In [None]:
#
BUFFER_SIZE = 20000
BATCH_SIZE = 32

cell='MEF'
tranf='CTCF'

path_peak='data/predicting'
path_cell=path_peak+'/'+cell
path_tf=path_cell+'/'+tranf

file_set=path_cell+'/'+'gene_promoter_bins.tfrecord'
peak_ds = tf.data.TFRecordDataset([file_set]).map(_parse_function)
peak_batches=make_batches(peak_ds,batch_size=128,shuffle=False)  #


In [None]:
#
LEARNING_RATE=0.001

vocab_size=3**7
d_model = 32
len_motif=12
dff = 128
num_heads = 1
num_layers = 1
dropout_rate = 0.1

epiformer = EIformer(num_layers=num_layers,d_model=d_model,num_heads=num_heads,dff=dff,vocab_size=vocab_size,len_motif=len_motif,dropout_rate=dropout_rate)

dna_in=tf.keras.Input((200,4))
epi_in=tf.keras.Input((200,))
_=epiformer((dna_in,epi_in))

epiformer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                  loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),  #from_logits=True,
                  metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.),]  #
                 )

In [None]:
#
file_weight='models/weights_CTCF.h5'

epiformer.load_weights(file_weight)
pred_proba = epiformer.predict(x=peak_batches)

In [None]:
#
path_bins='data/genome/windows_mm10/promoter'
file_peak_ori=path_bins+'/'+'gencode_vM25_gene_promoter_protein_coding_bins.bed'
print(file_peak_ori)
peak_ori=pd.read_csv(file_peak_ori,sep='\t',header=None)
peak_ori[3]=pred_proba

file_peak_pred=path_cell+'/'+'pred_sites_'+tranf+'.bed'
peak_ori.to_csv(file_peak_pred,sep='\t',index=False,header=False)
