In [1]:
import warnings
import os
warnings.filterwarnings("ignore") 

# Suppress TensorFlow logging messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import json
import os
import pandas as pd
import pysam
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf
from basenji import seqnn, stream, dataset, dna_io

from akita_utils.numpy_utils import ut_dense
from akita_utils.dna_utils import dna_1hot

In [2]:
# read data parameters
mouse_dir = "/project/fudenber_735/tensorflow_models/akita/v2/data/mm10/"
data_stats_file = '%s/statistics.json' % mouse_dir

# human_dir = "/project/fudenber_735/tensorflow_models/akita/v2/data/hg38/"
# data_stats_file = '%s/statistics.json' % human_dir

with open(data_stats_file) as data_stats_open:
    data_stats = json.load(data_stats_open)
seq_length = data_stats['seq_length']
target_length = data_stats['target_length']
hic_diags =  data_stats['diagonal_offset']
target_crop = data_stats['crop_bp'] // data_stats['pool_width']
target_length1 = data_stats['seq_length'] // data_stats['pool_width']

In [3]:
# model's parameters
batch_size=8 
head_index = 1 # mouse!
# head_index = 0 # human!
shifts = "0"
rc = False
shifts = [int(shift) for shift in shifts.split(",")]

# directory with models
models_dir = "/project/fudenber_735/tensorflow_models/akita/v2/models"

In [4]:
# open genome FASTA
genome_fasta = "/project/fudenber_735/genomes/mm10/mm10.fa"
genome_open = pysam.Fastafile(genome_fasta)

# picking the model
model_index = 0
# model_index = 5

In [5]:
params_file = models_dir + f"/f{model_index}c0" + "/train" + "/params.json"
model_file = models_dir + f"/f{model_index}c0" + "/train" + f"/model{head_index}_best.h5"

# read model parameters
with open(params_file) as params_open:
    params = json.load(params_open)
params_train = params["train"]
params_model = params["model"]

if batch_size is None:
    batch_size = params_train["batch_size"]
else:
    batch_size = batch_size

# load model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, head_i=head_index)
seqnn_model.build_ensemble(rc, shifts)
seq_length = int(params_model["seq_length"])

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 sequence (InputLayer)       [(None, 1310720, 4)]         0         []                            
                                                                                                  
 stochastic_reverse_complem  ((None, 1310720, 4),         0         ['sequence[0][0]']            
 ent (StochasticReverseComp   ())                                                                 
 lement)                                                                                          
                                                                                                  
 stochastic_shift (Stochast  (None, 1310720, 4)           0         ['stochastic_reverse_complemen
 icShift)                                                           t[0][0]']               

In [6]:
def simple_seqs_gen(
    seq_coords_df,
    genome_open
):
    for index, row in seq_coords_df.iterrows():

        wt_seq_1hot = dna_1hot(
            genome_open.fetch(row.chrom	, row.window_start, row.window_end).upper()
        )

        yield wt_seq_1hot

In [7]:
# explained_path = "./all_explained_boundaries.tsv"
unexplained_path = "./all_unexplained_boundaries.tsv"

In [8]:
# explained_tads = pd.read_csv(explained_path, sep="\t", index_col=None)
unexplained_tads = pd.read_csv(unexplained_path, sep="\t", index_col=None)

In [None]:
# explained_tads

In [10]:
preds_stream = stream.PredStreamGen(
        seqnn_model,
        simple_seqs_gen(unexplained_tads, genome_open),
        batch_size,
    )

In [11]:
num_preds = len(unexplained_tads)

In [12]:
from akita_utils.stats_utils import slide_diagonal_insulation, _extract_centered_window

In [None]:
# OFF-16

In [None]:
# crop_width=29 since TAD boundary is 5 bins + 12 bins at each side

In [13]:
16 * 2 * 2048

65536

In [14]:
def calculate_min_insulation(target_map, window=16, crop_around_center=True, crop_width=2):
    map_size = target_map.shape[0]
    insulation_scores = slide_diagonal_insulation(target_map, window)
    
    if crop_around_center:
        insulation_scores = _extract_centered_window(
            insulation_scores, window=window, width=3
        )
    
    min_score = np.nanmin(insulation_scores)
    
    return min_score

In [15]:
ins_min = []

for pred_index in range(num_preds):

    print("pred_index= ", pred_index)
    permuted_preds_matrix = preds_stream[pred_index]
    map = ut_dense(permuted_preds_matrix, diagonal_offset=2)
    # scd = calculate_SCD(permuted_map, ref_map)
    # scd_mean = np.mean(scd)
    
    average_map = map.mean(axis=2)
    ins = calculate_min_insulation(average_map)
    ins_min.append(ins)
                            
    # all_data.append(permuted_map)
    
    # if (pred_index % 100 == 0) and (pred_index != 0):
    #     all_data = np.array(all_data)
    #     np.save(f"/scratch1/smaruj/test_human_fold0_AkitaV2/pred_matrices_{pred_index}.npy", all_data)
    #     all_data = []
        
    # SCD_list.append(scd_mean)
    # print(scd)
    # print("mean: ", scd_mean)
    # print(scd)
    # plot_map(permuted_map[:,:,1])
    # print("difference map")
    # plot_map(permuted_map[:,:,1]-ref_map[:,:,1])
    
    # write_stat_metrics_to_h5(
    #         permuted_preds_matrix,
    #         ref_preds_matrix,
    #         stats_out,
    #         exp_index,
    #         head_index,
    #         model_index,
    #         diagonal_offset=2,
    #         stat_metrics=stats,
    #     )
    # exp_index += 1

# all_data = np.array(all_data)   
# np.save(f"/scratch1/smaruj/test_mouse_fold0_AkitaV2/pred_matrices_800.npy", all_data)

pred_index=  0
pred_index=  1
pred_index=  2
pred_index=  3
pred_index=  4
pred_index=  5
pred_index=  6
pred_index=  7
pred_index=  8
pred_index=  9
pred_index=  10
pred_index=  11
pred_index=  12
pred_index=  13
pred_index=  14
pred_index=  15
pred_index=  16
pred_index=  17
pred_index=  18
pred_index=  19
pred_index=  20
pred_index=  21
pred_index=  22
pred_index=  23
pred_index=  24
pred_index=  25
pred_index=  26
pred_index=  27
pred_index=  28
pred_index=  29
pred_index=  30
pred_index=  31
pred_index=  32
pred_index=  33
pred_index=  34
pred_index=  35
pred_index=  36
pred_index=  37
pred_index=  38
pred_index=  39
pred_index=  40
pred_index=  41
pred_index=  42
pred_index=  43
pred_index=  44
pred_index=  45
pred_index=  46
pred_index=  47
pred_index=  48
pred_index=  49
pred_index=  50
pred_index=  51
pred_index=  52
pred_index=  53
pred_index=  54
pred_index=  55
pred_index=  56
pred_index=  57
pred_index=  58
pred_index=  59
pred_index=  60
pred_index=  61
pred_index=  62
pr

In [16]:
unexplained_tads["ins16_min"] = ins_min

In [17]:
unexplained_tads

Unnamed: 0,SCD_h1_m0_t0,SCD_h1_m0_t1,SCD_h1_m0_t2,SCD_h1_m0_t3,SCD_h1_m0_t4,SCD_h1_m0_t5,chrom,end,rel_disruption_end,rel_disruption_start,start,type,window_end,window_start,miss_per,ins16_min
0,12.625,13.9800,9.8400,9.760,10.5600,8.8000,chr1,4410000,667648,665600,4400000,up4,5061504,3750784,9.879583,-0.204339
1,0.717,0.8345,0.9307,0.885,0.8260,0.6606,chr1,4410000,620544,618496,4400000,down15,5061504,3750784,9.879583,-0.204339
2,1.633,1.6190,1.6640,1.480,1.4375,1.1750,chr1,5160000,622592,620544,5150000,down14,5811504,4500784,14.626723,-0.233104
3,1.366,1.4660,1.3620,1.321,1.3210,1.1160,chr18,73180000,620544,618496,73170000,down15,73831504,72520784,16.061542,-0.285407
4,12.480,14.4100,14.8400,15.160,13.9500,11.8360,chr1,6200000,651264,649216,6190000,tad1,6851504,5540784,17.484202,-0.343384
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2258,1.087,0.9897,1.3480,1.361,1.1460,1.0230,chr16,21450000,618496,616448,21440000,down16,22101504,20790784,7.270233,-0.294114
2259,9.890,9.0900,10.0400,10.080,9.6600,8.3100,chr9,121710000,647168,645120,121700000,down2,122361504,121050784,9.879583,-0.271511
2260,8.640,8.5500,8.8800,8.586,8.4700,7.2800,chr9,122360000,651264,649216,122350000,tad1,123011504,121700784,10.986476,-0.177083
2261,1.611,1.3910,2.0600,2.105,1.7330,1.5070,chr7,122190000,618496,616448,122180000,down16,122841504,121530784,2.325502,-0.129382


In [18]:
unexplained_tads.to_csv("all_unexplained_boundaries.tsv", sep="\t", index=False)

In [None]:
# all_data.shape

In [None]:
# np.save("/scratch1/smaruj/test_mouse_fold0_AkitaV2/combined_pred_matrices.npy", all_data)

In [None]:
genome_open.close()