In [3]:
import os, h5py
os.environ['CUDA_VISIBLE_DEVICES']='1'
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import logomaker
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from tqdm import tqdm

from cremerl import utils, model_zoo, shuffle

import shuffle_test

#import gymnasium as gym

import logging

# Set the logging level to WARNING
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

In [4]:
expt_name = 'DeepSTARR'

# load data
data_path = '../../data/'
filepath = os.path.join(data_path, expt_name+'_data.h5')
data_module = utils.H5DataModule(filepath, batch_size=100, lower_case=False, transpose=False)


In [5]:
deepstarr2 = model_zoo.deepstarr(2)
loss = torch.nn.MSELoss()
optimizer_dict = utils.configure_optimizer(deepstarr2, lr=0.001, weight_decay=1e-6, decay_factor=0.1, patience=5, monitor='val_loss')
standard_cnn = model_zoo.DeepSTARR(deepstarr2,
                                  criterion=loss,
                                  optimizer=optimizer_dict)

# load checkpoint for model with best validation performance
standard_cnn = utils.load_model_from_checkpoint(standard_cnn, 'DeepSTARR_standard.ckpt')

# evaluate best model
pred = utils.get_predictions(standard_cnn, data_module.x_test[np.newaxis,100], batch_size=100)

2023-08-07 10:46:00.927655: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [44]:
def get_batch(x, tile_ranges, trials):
    test_batch = []
    for i in range(trials):
        test_batch.append(x)
        x_mut = (shuffle_test.dinuc_shuffle(x.copy().transpose())).transpose()
        test_batch.append(x_mut)

        for tile_range in tile_ranges:
            ori = x.copy()
            ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
            test_batch.append(ori)

            mut = x_mut.copy()
            mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]
            test_batch.append(mut)

    #print(np.array(test_batch).shape)
    return np.array(test_batch)


def find_max(pred, trials):
    b_size = int(pred.shape[0]/trials)
    loop_range = generate_tile_ranges(pred.shape[0], b_size, b_size)
    total_score = []
    total_score_sep = []
    for rang in loop_range:
        score = []
        score_sep = []
        p_pred = pred[rang[0]:rang[1]]
        for i in range(0, p_pred.shape[0]-2, 2):
            score1 = p_pred[0] - p_pred[i]
            score2 = p_pred[i+1] - p_pred[1]
            score.append((np.sum((score1, score2)[0])).tolist()) #np.sum(score1+score2, keepdims=True)
            score_sep.append((score1+score2).tolist())

        #print(score)

        total_score.append(score)
        total_score_sep.append(score_sep)

    final = np.sum(np.array(total_score), axis=0)/trials

    max_ind = np.argmax(final)
    block_ind = np.argmax(np.array(total_score)[:, max_ind])
    #print(np.array(total_score)[:, max_ind])

    total_score = final
    total_score_sep = np.sum(np.array(total_score_sep), axis=0)/trials


    #print(np.max(score))
    return np.max(total_score), np.argmax(total_score), block_ind


def get_swap_greedy(x, x_mut, tile_ranges):
    ori = x.copy()
    mut = x_mut.copy()
    for tile_range in tile_ranges:
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]

    return ori, mut

def get_score(pred):
    score1 = pred[0] - pred[2]
    score2 = pred[3] - pred[1]
    return (score1+score2)[0], score1+score2

def generate_tile_ranges(sequence_length, window_size, stride):
    ranges = []
    start = np.arange(0, sequence_length - window_size + stride, stride)

    for s in start:
        e = min(s + window_size, sequence_length)
        ranges.append([s, e])

    if start[-1] + window_size - stride < sequence_length:  # Adjust the last range
        ranges[-1][1] = sequence_length

    return ranges

In [46]:
def find_max(pred, trials):
    b_size = int(pred.shape[0]/trials)
    loop_range = generate_tile_ranges(pred.shape[0], b_size, b_size)
    total_score = []
    total_score_sep = []
    for rang in loop_range:
        score = []
        score_sep = []
        p_pred = pred[rang[0]:rang[1]]
        for i in range(0, p_pred.shape[0]-2, 2):
            score1 = p_pred[0] - p_pred[i]
            score2 = p_pred[i+1] - p_pred[1]
            score.append((np.sum((score1, score2)[0])).tolist()) #np.sum(score1+score2, keepdims=True)
            score_sep.append((score1+score2).tolist())

        #print(score)

        total_score.append(score)
        total_score_sep.append(score_sep)

    final = np.sum(np.array(total_score), axis=0)/trials

    max_ind = np.argmax(final)
    block_ind = np.argmax(np.array(total_score)[:, max_ind])
    #print(np.array(total_score)[:, max_ind])

    total_score = final
    total_score_sep = np.sum(np.array(total_score_sep), axis=0)/trials


    #print(np.max(score))
    return total_score_sep, np.max(total_score), np.argmax(total_score), block_ind

In [45]:
def get_batch(x, tile_ranges_new, tile_ranges_ori, trials):
    test_batch = []
    for i in range(trials):
        test_batch.append(x)
        x_mut = shuffle.dinuc_shuffle(x.copy())
        test_batch.append(x_mut)

        for tile_range in tile_ranges_new:
            ori = x.copy()
            mut = x_mut.copy()
            
            ori, mut = get_swap_greedy(ori, mut, tile_ranges_ori)
            
            ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
            mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]
            
            test_batch.append(ori)
            test_batch.append(mut)

    #print(np.array(test_batch).shape)
    return np.array(test_batch)


def find_max(pred, trials):
    b_size = int(pred.shape[0]/trials)
    loop_range = generate_tile_ranges(pred.shape[0], b_size, b_size)
    total_score = []
    total_score_sep = []
    for rang in loop_range:
        score = []
        score_sep = []
        p_pred = pred[rang[0]:rang[1]]
        for i in range(0, p_pred.shape[0]-2, 2):
            score1 = p_pred[0] - p_pred[i]
            score2 = p_pred[i+1] - p_pred[1]
            score.append((np.sum((score1, score2)[0])).tolist()) #np.sum(score1+score2, keepdims=True)
            score_sep.append((score1+score2).tolist())

        #print(score)

        total_score.append(score)
        total_score_sep.append(score_sep)

    final = np.sum(np.array(total_score), axis=0)/trials

    max_ind = np.argmax(final)
    block_ind = np.argmax(np.array(total_score)[:, max_ind])
    #print(np.array(total_score)[:, max_ind])

    total_score = final
    total_score_sep = np.sum(np.array(total_score_sep), axis=0)/trials


    #print(np.max(score))
    return np.max(total_score), np.argmax(total_score), block_ind


def get_swap_greedy(x, x_mut, tile_ranges):
    ori = x.copy()
    mut = x_mut.copy()
    for tile_range in tile_ranges:
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]

    return ori, mut

def get_score(pred):
    score1 = pred[0] - pred[2]
    score2 = pred[3] - pred[1]
    return (score1+score2)[0], score1+score2

def generate_tile_ranges(sequence_length, window_size, stride):
    ranges = []
    start = np.arange(0, sequence_length - window_size + stride, stride)

    for s in start:
        e = min(s + window_size, sequence_length)
        ranges.append([s, e])

    if start[-1] + window_size - stride < sequence_length:  # Adjust the last range
        ranges[-1][1] = sequence_length

    return ranges

In [31]:
def greedy_search_ori(seq, threshold=0, trials=1):
    a = seq
    b = (shuffle_test.dinuc_shuffle(a.copy().transpose())).transpose()
    max_score = 0
    x = a
    x_mut = b
    tile_ranges = generate_tile_ranges(x.shape[1], 5, 5)
    trainer = pl.Trainer(accelerator='gpu', devices='1', logger=None, enable_progress_bar=False)
    comb_best = []
    tile_ranges_ori = []
    indices_list = list(range(0, 50))
    for i in range(40):
        batch = get_batch(x, tile_ranges, tile_ranges_ori, trials=trials)
        #print(batch.shape)
        dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
        pred = np.concatenate(trainer.predict(standard_cnn, dataloaders=dataloader))
        #print(pred.shape)
        total_score, score, best, block_ind = find_max(pred, trials=trials)
        #print(f"Score: {score}")
        #print(f"Score: {score} \t Best Swap: {best}")
        if score > max_score:
            max_score = score
            comb_best.append(indices_list.pop(best-1))
            tile_ranges_ori.append(tile_ranges.pop(best-1))
            #print(block_ind*(int(pred.shape[0]/trials)))
            # x = batch[block_ind*(int(pred.shape[0]/trials))+best*2]
            # x_mut = batch[block_ind*(int(pred.shape[0]/trials))+best*2+1]
        else:
            break
        

    comb_best.sort()

    tile_ranges2 = generate_tile_ranges(x.shape[1],5,5)
    act = []
    for i in comb_best:
        act.append(tile_ranges2[i])

    ori, mut = get_swap_greedy(a, b, act)

    batch = np.array([a, b, ori, mut])
    dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
    pred2 = np.concatenate(trainer.predict(standard_cnn, dataloaders=dataloader))

    final_score, score_list = get_score(pred2)

    #print(score)
    #print(comb_best, final_score)
    return b, comb_best, final_score, score_list, total_score

In [47]:
for i in range(10):
    _, out1, out2, _, total_score= greedy_search_ori(data_module.x_test[1].numpy(), threshold=1.4, trials=100)
    print(f"{out1} \t {out2}")
    #print(total_score)

[7, 8, 17, 18, 22, 24, 27, 29, 30, 36, 38, 39, 40, 46] 	 5.159029483795166
[6, 7, 8, 16, 17, 18, 22, 24, 27, 28, 29, 30, 36, 38, 39, 40, 46] 	 5.143600940704346
[19, 24, 27, 30, 36, 40] 	 1.1142568588256836
[17, 19, 22, 24, 27, 29, 30, 36, 38, 39, 40, 46] 	 6.155328750610352
[7, 8, 17, 18, 22, 24, 27, 29, 30, 36, 38, 39, 46] 	 3.663485527038574
[7, 8, 17, 18, 22, 24, 27, 29, 30, 36, 38, 39, 40, 46] 	 7.915139675140381
[8, 17, 18, 22, 24, 27, 29, 30, 36, 39, 46] 	 3.7481472492218018
[6, 7, 8, 17, 18, 19, 22, 24, 27, 29, 30, 36, 38, 39, 40, 46] 	 5.147916793823242
[19, 22, 24, 27, 30, 36, 40, 46] 	 3.187816619873047
[18, 22, 24, 27, 29, 30, 36, 38, 39, 40, 46] 	 5.910625457763672


In [38]:
for j in range(3, 6):
    print(f"Sequence {j}")
    for i in range(10):
        _, out1, out2, _ = greedy_search_ori(data_module.x_test[j].numpy(), threshold=0.1, trials=200)
        print(f"{out1} \t {out2}")

Sequence 3
[2, 10, 23, 28, 32, 38, 46] 	 1.753103494644165
[2, 10, 23, 28, 32, 38, 46] 	 1.8625035285949707
[2, 10, 23, 28, 32, 38, 42] 	 2.41713285446167
[6, 10, 23, 28, 32, 47] 	 -0.7116690278053284
[10, 23, 26, 28, 32] 	 0.45512598752975464
[2, 10, 23, 28, 32, 38, 42, 47] 	 1.5824801921844482
[2, 10, 23, 28, 32, 38, 47] 	 0.9569301009178162
[2, 10, 20, 28, 32, 34] 	 -0.8081310987472534
[6, 10, 23, 28, 32, 47] 	 0.22446399927139282
[10, 23, 26, 28, 32, 47] 	 0.9564542770385742
Sequence 4
[17, 21, 33, 37, 41, 47] 	 1.9058430194854736
[10, 17, 37, 47] 	 0.7334312796592712
[17, 21, 33, 41, 47] 	 0.81021648645401
[17, 21, 33, 37, 41, 47] 	 1.7233858108520508
[13, 17, 21, 33, 37, 47] 	 1.164968490600586
[17, 21, 33, 41, 47] 	 1.4922833442687988
[10, 17, 37] 	 1.4800578355789185
[9, 17, 23, 33] 	 1.508580207824707
[10, 17, 37, 41, 47] 	 0.7804054021835327
[9, 10, 17, 37, 47] 	 1.347149133682251
Sequence 5
[20, 28, 34] 	 2.5182693004608154
[20, 28, 31] 	 1.7822425365447998
[10, 20, 28, 31, 

In [35]:
def find_common_elements(input_list):
    if not input_list:
        return []  # Return an empty list if the input is empty

    # Use set intersection to find common elements in all lists
    common_elements = set(input_list[0])
    for sublist in input_list[1:]:
        common_elements.intersection_update(sublist)

    return list(common_elements)

In [None]:
for j in range(10):
    #print(f"Sequence {j}")
    com = []
    for i in range(10):
        _, out1, out2, _ = greedy_search_ori(data_module.x_test[j].numpy(), threshold=0.05, trials=2000)
        com.append(out1)
        print(f"{out1} \t {out2}")
    print(f"The common indices among the 10 are: {find_common_elements(com)}")

In [48]:
for j in range(10):
    #print(f"Sequence {j}")
    com = []
    for i in range(10):
        _, out1, out2, _, total_score = greedy_search_ori(data_module.x_test[j].numpy(), threshold=0.05, trials=200)
        com.append(out1)
        print(f"{out1} \t {out2}")
    print(f"The common indices among the 10 are: {find_common_elements(com)}")

[14, 19, 23, 25, 26, 33, 38, 41, 42] 	 3.2555878162384033
[12, 14, 19, 20, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 2.8549304008483887
[14, 19, 23, 25, 26, 34, 38, 41, 42] 	 3.373605966567993
[12, 14, 19, 21, 23, 25, 26, 33, 38, 41, 42, 46] 	 2.719654083251953
[14, 19, 23, 25, 26, 33, 37, 41, 42] 	 5.284907341003418
[12, 14, 19, 20, 23, 25, 26, 33, 37, 38, 39, 41, 42, 45, 46] 	 4.650846481323242
[12, 14, 19, 21, 22, 23, 24, 25, 26, 31, 33, 38, 39, 41, 42, 46, 48] 	 5.464040756225586
[12, 14, 19, 21, 23, 24, 25, 26, 33, 38, 41, 42, 46] 	 4.385184288024902
[12, 14, 19, 21, 23, 25, 26, 33, 38, 39, 41, 42, 46] 	 4.349886894226074
[19, 23, 25, 26, 33, 38, 41, 42] 	 2.7852509021759033
The common indices among the 10 are: [41, 42, 19, 23, 25, 26]
[7, 8, 16, 17, 18, 22, 24, 27, 29, 30, 36, 38, 39, 40, 46] 	 5.458824157714844
[7, 8, 17, 19, 22, 24, 27, 29, 30, 36, 38, 39, 40, 46] 	 5.713505744934082
[19, 24, 27, 29, 30, 36, 40, 46] 	 3.5190682411193848
[7, 8, 16, 17, 18, 22, 24, 27, 29, 30, 36

In [None]:
for j in range(10):
    #print(f"Sequence {j}")
    com = []
    for i in range(10):
        _, out1, out2, _, total_score = greedy_search_ori(data_module.x_test[j].numpy(), threshold=0.05, trials=200)
        com.append(out1)
        print(f"{out1} \t {out2}")
    print(f"The common indices among the 10 are: {find_common_elements(com)}")

In [43]:
for j in range(10):
    print(f"Sequence {j}")
    com = []
    for i in range(10):
        _, out1, out2, _, total_score = greedy_search_ori(data_module.x_test[j].numpy(), threshold=0.05, trials=2000)
        com.append(out1)
        print(f"{out1} \t {out2}")
    print(f"The common indices among the 10 are: {find_common_elements(com)}")

Sequence 0
[12, 14, 19, 21, 22, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 4.902645111083984
[12, 14, 19, 20, 21, 23, 24, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 3.920459032058716
[12, 14, 19, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 3.6395156383514404
[12, 14, 19, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46, 48] 	 3.931053638458252
[12, 14, 19, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 2.5523033142089844
[12, 14, 19, 20, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 3.558357000350952
[12, 14, 19, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 4.958101272583008
[12, 14, 19, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46, 48] 	 1.9736292362213135
[12, 14, 19, 20, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46] 	 5.810445308685303
[2, 12, 14, 19, 21, 23, 25, 26, 33, 38, 39, 41, 42, 45, 46, 48] 	 3.291626214981079
The common indices among the 10 are: [33, 38, 39, 41, 42, 12, 45, 14, 46, 19, 21, 23, 25, 26]
Sequence 1
[6, 7, 8, 16, 17, 18, 20, 22, 24, 27, 28, 29, 30, 36, 38, 39, 40, 46] 	 6.36