In [1]:
import os, h5py
os.environ['CUDA_VISIBLE_DEVICES']='1'
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import logomaker
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from tqdm import tqdm

from cremerl import utils, model_zoo, shuffle

import shuffle_test

#import gymnasium as gym

import logging

# Set the logging level to WARNING
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

In [2]:
expt_name = 'DeepSTARR'

# load data
data_path = '../../data/'
filepath = os.path.join(data_path, expt_name+'_data.h5')
data_module = utils.H5DataModule(filepath, batch_size=100, lower_case=False, transpose=False)


In [3]:
deepstarr2 = model_zoo.deepstarr(2)
loss = torch.nn.MSELoss()
optimizer_dict = utils.configure_optimizer(deepstarr2, lr=0.001, weight_decay=1e-6, decay_factor=0.1, patience=5, monitor='val_loss')
standard_cnn = model_zoo.DeepSTARR(deepstarr2,
                                  criterion=loss,
                                  optimizer=optimizer_dict)

# load checkpoint for model with best validation performance
standard_cnn = utils.load_model_from_checkpoint(standard_cnn, 'DeepSTARR_standard.ckpt')

# evaluate best model
pred = utils.get_predictions(standard_cnn, data_module.x_test[np.newaxis,100], batch_size=100)

2023-08-13 09:04:17.652489: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [4]:

def get_swap_greedy(x, x_mut, tile_ranges):
    ori = x.copy()
    mut = x_mut.copy()
    for tile_range in tile_ranges:
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]

    return ori, mut

def get_score(pred):
    score1 = pred[0] - pred[2]
    score2 = pred[3] - pred[1]
    return (score1+score2)[0], score1+score2

def generate_tile_ranges(sequence_length, window_size, stride):
    ranges = []
    start = np.arange(0, sequence_length - window_size + stride, stride)

    for s in start:
        e = min(s + window_size, sequence_length)
        ranges.append([s, e])

    if start[-1] + window_size - stride < sequence_length:  # Adjust the last range
        ranges[-1][1] = sequence_length

    return ranges

In [10]:
def get_batch(x, tile_range, tile_ranges_ori, trials):
    test_batch = []
    for i in range(trials):
        test_batch.append(x)
        x_mut = shuffle.dinuc_shuffle(x.copy())
        test_batch.append(x_mut)

        ori = x.copy()
        mut = x_mut.copy()
        
        ori, mut = get_swap_greedy(ori, mut, tile_ranges_ori)
        
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]
        
        test_batch.append(ori)
        test_batch.append(mut)

    #print(np.array(test_batch).shape)
    return np.array(test_batch)


def get_batch_score(pred, trials):

    score = []
    score_sep = []
    for i in range(0, pred.shape[0], 2):
        # print(f"Viewing number {i}")
        score1 = pred[0] - pred[i]
        score2 = pred[i+1] - pred[1]
        score.append((np.sum((score1, score2)[0])).tolist()) #np.sum(score1+score2, keepdims=True)
        score_sep.append((score1+score2).tolist())
        
    # print(score)
        
    final = np.sum(np.array(score), axis=0)/trials

    #max_ind = np.argmax(final)
    #block_ind = np.argmax(np.array(score)[:, max_ind])
    #print(np.array(total_score)[:, max_ind])
    total_score_sep = np.sum(np.array(score_sep), axis=0)/trials

    #print(np.max(score))
    return final, total_score_sep

In [11]:
trials = 10
test2 = []

x = data_module.x_test[1].numpy()
tile_ranges = generate_tile_ranges(x.shape[1], 5, 5)
trainer = pl.Trainer(accelerator='gpu', devices='1', logger=None, enable_progress_bar=False)
for i in range(50):
    batch = get_batch(x, tile_ranges[i], [], trials=trials)
    #print(batch.shape)
    dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
    pred = np.concatenate(trainer.predict(standard_cnn, dataloaders=dataloader))
    #print(pred.shape)
    total_score, total_score_sep = get_batch_score(pred, trials=trials)
    test2.append(total_score_sep)

In [12]:
print(test2)

[array([-2.42212183, -2.05458423]), array([0.25875406, 0.07878317]), array([-0.11537886,  0.05786403]), array([-0.81958828, -0.27240764]), array([-0.06512202, -0.35848022]), array([1.52791846, 0.45657713]), array([1.98820824, 0.86091409]), array([0.43547419, 0.43903367]), array([0.50364161, 0.31939501]), array([3.15768599, 0.5239851 ]), array([1.12237315, 0.06437841]), array([-0.58647064,  0.26434966]), array([ 0.07335519, -0.29134213]), array([-0.60392353, -0.45809508]), array([-1.99475831, -0.46328246]), array([-3.9830762 , -1.02939824]), array([-0.57097154, -0.22275513]), array([2.20200795, 1.28154684]), array([1.59409314, 0.17606212]), array([0.85593808, 0.54308369]), array([2.62548453, 1.63662804]), array([-0.10549095, -0.10093793]), array([ 0.83723722, -0.02134376]), array([1.82707246, 1.31184287]), array([-5.02991436, -2.37946165]), array([ 0.09967744, -0.19268004]), array([0.94160058, 0.37086544]), array([ 0.30530639, -0.63256736]), array([-2.61288682, -0.83662164]), array([-0.

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(1, 51)

artists = []
container = ax.bar(x, test2[:], color="orange")
artists.append(container)

#ani = animation.ArtistAnimation(fig=fig, artists=artists, interval=400)
plt.show()
#ani

In [8]:
def extend_sequence(one_hot_sequence):
    A, L = one_hot_sequence.shape

    # Create an all-ones row
    ones_row = np.zeros(L)

    # Add the all-ones row to the original sequence
    new_sequence = np.vstack((one_hot_sequence, ones_row))

    return np.array(new_sequence, dtype='float32')

def taking_action(sequence_with_ones, tile_range):
    start_idx, end_idx = tile_range

    # Ensure the start_idx and end_idx are within valid bounds
    #if start_idx < 0 or start_idx >= sequence_with_ones.shape[1] or end_idx < 0 or end_idx >= sequence_with_ones.shape[1]:
    #    raise ValueError("Invalid tile range indices.")

    # Copy the input sequence to avoid modifying the original sequence
    modified_sequence = sequence_with_ones.copy()

    # Modify the last row within the specified tile range
    modified_sequence[-1, start_idx:end_idx] = 1

    return np.array(modified_sequence, dtype='float32')

In [12]:
def convert_elements(input_list):
    input_list = input_list.tolist()
    num_columns = 5  # Number of elements to process in each group

    # Calculate the number of elements needed to pad the list
    padding_length = num_columns - (len(input_list) % num_columns)
    last_value = input_list[-1]
    padded_list = input_list + [last_value] * padding_length

    # Convert the padded list to a NumPy array for efficient operations
    print(len(padded_list))
    input_array = np.array(padded_list)
    reshaped_array = input_array.reshape(-1, num_columns)

    # Check if each row has the same value (all 0s or all 1s)
    row_all_zeros = np.all(reshaped_array == 0, axis=1)
    row_all_ones = np.all(reshaped_array == 1, axis=1)

    # Replace all 0s with 0 and all 1s with 1 in the result array
    output_array = np.where(row_all_zeros, 0, np.where(row_all_ones, 1, reshaped_array[:, 0]))

    # Flatten the result array to get the final output list
    output_list = output_array.flatten()

    return output_list

In [19]:
class SeqGame:
    def __init__(self, sequence, model_func):
        self.seq = sequence
        self.ori_seq = sequence.copy()
        self.tile_ranges = generate_tile_ranges(sequence.shape[1], 5, 5)
        self.tile_ranges_done = []
        self.levels = 20
        self.current_level = 0
        self.num_trials = 10
        self.action_size = 50
        
        self.prev_score = -float("inf")
        self.current_score = 0
        
        self.trainer = pl.Trainer(accelerator='gpu', devices='1', logger=None, enable_progress_bar=False)
        self.model = model_func
        
        if self.seq.shape[0]!=5:
            self.seq = extend_sequence(self.seq)
            self.ori_seq = extend_sequence(self.ori_seq)
        
    
    def get_initial_state(self):
        self.seq = self.ori_seq.copy()
        self.tile_ranges = generate_tile_ranges(self.seq.shape[1], 5, 5)
        self.tile_ranges_done = []
        self.current_level = 0
        
        return self.seq
    
    
    def get_next_state(self, action):
        self.prev_score = self.current_score
        self.current_level += 1
        
        self.seq = taking_action(self.seq, self.tile_ranges[action])
        
        batch = get_batch(self.seq[:4, :], self.tile_ranges[action], self.tile_ranges_done, self.num_trials)
        dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
        pred = np.concatenate(self.trainer.predict(self.model, dataloaders=dataloader))
        
        self.current_score = get_batch_score(pred, self.num_trials)
        
        #return self.seq
    
    def get_valid_moves(self):
        return (convert_elements(self.seq[-1, :]) == 0).astype(np.uint8)
    
    def terminate(self): #state
        if self.current_level >= self.levels:
            return True
        if self.current_score < self.prev_score:
            return True
    
        return False
    
    def get_score(self):
        return self.current_score

In [None]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)
state = seqgame.get_initial_state()

while True:
    print(state)
    valid_moves = seqgame.get_valid_moves()
    print("Valid moves", [i for i in range(seqgame.action_size) if valid_moves[i]==1])
    #action = int(input("Take action: "))
    action = np.random.randint(0, 50)
    print(action)
    
    if valid_moves[action] == 0:
        print("Invalid action")
        continue
    
    state = seqgame.get_next_state(action)
    
    print(seqgame.get_score())
    is_terminal = seqgame.terminate()
    
    if is_terminal:
        print(state)
        print("Game ended")
        break