In [1]:
import warnings
import os
warnings.filterwarnings("ignore") 

# Suppress TensorFlow logging messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import json
import os
import pandas as pd
import pysam
import numpy as np
import tensorflow as tf
from basenji import seqnn, stream
from akita_utils.numpy_utils import ut_dense

from helper import simple_seqs_gen, calculate_min_insulation

In [2]:
# read data parameters
mouse_dir = "/project/fudenber_735/tensorflow_models/akita/v2/data/mm10/"
data_stats_file = '%s/statistics.json' % mouse_dir

with open(data_stats_file) as data_stats_open:
    data_stats = json.load(data_stats_open)
seq_length = data_stats['seq_length']
target_length = data_stats['target_length']
hic_diags =  data_stats['diagonal_offset']
target_crop = data_stats['crop_bp'] // data_stats['pool_width']
target_length1 = data_stats['seq_length'] // data_stats['pool_width']

In [3]:
# model's parameters
batch_size = 8 
head_index = 1 # mouse!

shifts = "0"
rc = False
shifts = [int(shift) for shift in shifts.split(",")]

# directory with models
models_dir = "/project/fudenber_735/tensorflow_models/akita/v2/models"

In [4]:
# open genome FASTA
genome_fasta = "/project/fudenber_735/genomes/mm10/mm10.fa"
genome_open = pysam.Fastafile(genome_fasta)

# picking the model
model_index = 0

In [5]:
params_file = models_dir + f"/f{model_index}c0" + "/train" + "/params.json"
model_file = models_dir + f"/f{model_index}c0" + "/train" + f"/model{head_index}_best.h5"

# read model parameters
with open(params_file) as params_open:
    params = json.load(params_open)
params_train = params["train"]
params_model = params["model"]

if batch_size is None:
    batch_size = params_train["batch_size"]
else:
    batch_size = batch_size

# load model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, head_i=head_index)
seqnn_model.build_ensemble(rc, shifts)
seq_length = int(params_model["seq_length"])

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 sequence (InputLayer)       [(None, 1310720, 4)]         0         []                            
                                                                                                  
 stochastic_reverse_complem  ((None, 1310720, 4),         0         ['sequence[0][0]']            
 ent (StochasticReverseComp   ())                                                                 
 lement)                                                                                          
                                                                                                  
 stochastic_shift (Stochast  (None, 1310720, 4)           0         ['stochastic_reverse_complemen
 icShift)                                                           t[0][0]']               

In [6]:
sensitive_path = "./../results/disruption_sensitive_boundaries.tsv"
# resilient_path = "./../results/disruption_resilient_boundaries.tsv"

sensitive_tads = pd.read_csv(sensitive_path, sep="\t", index_col=None)
# resilient_tads = pd.read_csv(resilient_path, sep="\t", index_col=None)

In [7]:
preds_stream = stream.PredStreamGen(
        seqnn_model,
        simple_seqs_gen(sensitive_tads, genome_open),
        batch_size,
    )

In [8]:
num_preds = len(sensitive_tads)

In [9]:
ins_min = []

for pred_index in range(num_preds):

    print("pred_index= ", pred_index)
    permuted_preds_matrix = preds_stream[pred_index]
    map = ut_dense(permuted_preds_matrix, diagonal_offset=2)
    
    average_map = map.mean(axis=2)
    ins = calculate_min_insulation(average_map)
    ins_min.append(ins)
                        

pred_index=  0
pred_index=  1
pred_index=  2
pred_index=  3
pred_index=  4
pred_index=  5
pred_index=  6
pred_index=  7
pred_index=  8
pred_index=  9
pred_index=  10
pred_index=  11
pred_index=  12
pred_index=  13
pred_index=  14
pred_index=  15
pred_index=  16
pred_index=  17
pred_index=  18
pred_index=  19
pred_index=  20
pred_index=  21
pred_index=  22
pred_index=  23
pred_index=  24
pred_index=  25
pred_index=  26
pred_index=  27
pred_index=  28
pred_index=  29
pred_index=  30
pred_index=  31
pred_index=  32
pred_index=  33
pred_index=  34
pred_index=  35
pred_index=  36
pred_index=  37
pred_index=  38
pred_index=  39
pred_index=  40
pred_index=  41
pred_index=  42
pred_index=  43
pred_index=  44
pred_index=  45
pred_index=  46
pred_index=  47
pred_index=  48
pred_index=  49
pred_index=  50
pred_index=  51
pred_index=  52
pred_index=  53
pred_index=  54
pred_index=  55
pred_index=  56
pred_index=  57
pred_index=  58
pred_index=  59
pred_index=  60
pred_index=  61
pred_index=  62
pr

In [None]:
sensitive_tads["ins16_min"] = ins_min
# resilient_tads["ins16_min"] = ins_min

In [None]:
# sensitive_tads.to_csv("./../results/disruption_sensitive_boundaries.tsv", sep="\t", index=False)
# resilient_tads.to_csv("./../results/disruption_resilient_boundaries.tsv", sep="\t", index=False)

In [None]:
genome_open.close()