In [11]:
import os, h5py
os.environ['CUDA_VISIBLE_DEVICES']='1'
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import logomaker
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from tqdm import tqdm
import copy

from cremerl import utils, model_zoo, shuffle

import shuffle_test

#import gymnasium as gym

import logging

# Set the logging level to WARNING
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

In [2]:
expt_name = 'DeepSTARR'

# load data
data_path = '../../data/'
filepath = os.path.join(data_path, expt_name+'_data.h5')
data_module = utils.H5DataModule(filepath, batch_size=100, lower_case=False, transpose=False)


In [3]:
deepstarr2 = model_zoo.deepstarr(2)
loss = torch.nn.MSELoss()
optimizer_dict = utils.configure_optimizer(deepstarr2, lr=0.001, weight_decay=1e-6, decay_factor=0.1, patience=5, monitor='val_loss')
standard_cnn = model_zoo.DeepSTARR(deepstarr2,
                                  criterion=loss,
                                  optimizer=optimizer_dict)

# load checkpoint for model with best validation performance
standard_cnn = utils.load_model_from_checkpoint(standard_cnn, 'DeepSTARR_standard.ckpt')

# evaluate best model
pred = utils.get_predictions(standard_cnn, data_module.x_test[np.newaxis,100], batch_size=100)

2023-08-09 14:27:24.603366: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [4]:
def get_swap_greedy(x, x_mut, tile_ranges):
    ori = x.copy()
    mut = x_mut.copy()
    for tile_range in tile_ranges:
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]

    return ori, mut

def get_score(pred):
    score1 = pred[0] - pred[2]
    score2 = pred[3] - pred[1]
    return (score1+score2)[0], score1+score2

def generate_tile_ranges(sequence_length, window_size, stride):
    ranges = []
    start = np.arange(0, sequence_length - window_size + stride, stride)

    for s in start:
        e = min(s + window_size, sequence_length)
        ranges.append([s, e])

    if start[-1] + window_size - stride < sequence_length:  # Adjust the last range
        ranges[-1][1] = sequence_length

    return ranges

In [5]:
def get_batch(x, tile_range, tile_ranges_ori, trials):
    test_batch = []
    for i in range(trials):
        test_batch.append(x)
        x_mut = shuffle.dinuc_shuffle(x.copy())
        test_batch.append(x_mut)

        ori = x.copy()
        mut = x_mut.copy()
        
        ori, mut = get_swap_greedy(ori, mut, tile_ranges_ori)
        
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]
        
        test_batch.append(ori)
        test_batch.append(mut)

    #print(np.array(test_batch).shape)
    return np.array(test_batch)


def get_batch_score(pred, trials):

    score = []
    score_sep = []
    for i in range(0, pred.shape[0], 2):
        # print(f"Viewing number {i}")
        score1 = pred[0] - pred[i]
        score2 = pred[i+1] - pred[1]
        score.append((np.sum((score1, score2)[0])).tolist()) #np.sum(score1+score2, keepdims=True)
        score_sep.append((score1+score2).tolist())
        
    # print(score)
        
    final = np.sum(np.array(score), axis=0)/trials

    #max_ind = np.argmax(final)
    #block_ind = np.argmax(np.array(score)[:, max_ind])
    #print(np.array(total_score)[:, max_ind])
    total_score_sep = np.sum(np.array(score_sep), axis=0)/trials

    #print(np.max(score))
    return final

In [6]:
trials = 1000
test2 = []

x = data_module.x_test[1].numpy()
tile_ranges = generate_tile_ranges(x.shape[1], 5, 5)
trainer = pl.Trainer(accelerator='gpu', devices='1', logger=None, enable_progress_bar=False)
for i in range(50):
    batch = get_batch(x, tile_ranges[i], [], trials=trials)
    #print(batch.shape)
    dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
    pred = np.concatenate(trainer.predict(standard_cnn, dataloaders=dataloader))
    #print(pred.shape)
    total_score = get_batch_score(pred, trials=trials)
    test2.append(total_score)

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(1, 51)

artists = []
container = ax.bar(x, test2[:], color="orange")
artists.append(container)

#ani = animation.ArtistAnimation(fig=fig, artists=artists, interval=400)
plt.show()
#ani

In [6]:
def extend_sequence(one_hot_sequence):
    A, L = one_hot_sequence.shape

    # Create an all-ones row
    ones_row = np.zeros(L)

    # Add the all-ones row to the original sequence
    new_sequence = np.vstack((one_hot_sequence, ones_row))

    return np.array(new_sequence, dtype='float32')

def taking_action(sequence_with_ones, tile_range):
    start_idx, end_idx = tile_range

    # Ensure the start_idx and end_idx are within valid bounds
    #if start_idx < 0 or start_idx >= sequence_with_ones.shape[1] or end_idx < 0 or end_idx >= sequence_with_ones.shape[1]:
    #    raise ValueError("Invalid tile range indices.")

    # Copy the input sequence to avoid modifying the original sequence
    modified_sequence = sequence_with_ones.copy()

    # Modify the last row within the specified tile range
    modified_sequence[-1, start_idx:end_idx] = 1

    return np.array(modified_sequence, dtype='float32')

In [18]:
def convert_elements(input_list):
    input_list = input_list.tolist()
    num_columns = 5  # Number of elements to process in each group

    # Calculate the number of elements needed to pad the list
    padding_length = num_columns - (len(input_list) % num_columns)
    last_value = input_list[-1]
    padded_list = input_list + [last_value] * padding_length

    # Convert the padded list to a NumPy array for efficient operations
    input_array = np.array(padded_list)
    reshaped_array = input_array.reshape(-1, num_columns)

    # Check if each row has the same value (all 0s or all 1s)
    row_all_zeros = np.all(reshaped_array == 0, axis=1)
    row_all_ones = np.all(reshaped_array == 1, axis=1)

    # Replace all 0s with 0 and all 1s with 1 in the result array
    output_array = np.where(row_all_zeros, 0, np.where(row_all_ones, 1, reshaped_array[:, 0]))

    # Flatten the result array to get the final output list
    output_list = output_array.flatten()

    return output_list

In [19]:
class SeqGame:
    def __init__(self, sequence, model_func):
        self.seq = sequence
        self.ori_seq = sequence.copy()
        self.tile_ranges = generate_tile_ranges(sequence.shape[1], 5, 5)
        self.tile_ranges_done = []
        self.levels = 20
        self.current_level = 0
        self.num_trials = 10
        self.action_size = 50
        
        self.prev_score = -float("inf")
        self.current_score = 0
        
        self.trainer = pl.Trainer(accelerator='gpu', devices='1', logger=None, enable_progress_bar=False)
        self.model = model_func
        
        if self.seq.shape[0]!=5:
            self.seq = extend_sequence(self.seq)
            self.ori_seq = extend_sequence(self.ori_seq)
        
    
    def get_initial_state(self):
        self.seq = self.ori_seq.copy()
        self.tile_ranges = generate_tile_ranges(self.seq.shape[1], 5, 5)
        self.tile_ranges_done = []
        self.current_level = 0
        
        return self.seq
    
    
    def get_next_state(self, action):
        self.prev_score = self.current_score
        self.current_level += 1
        
        self.seq = taking_action(self.seq, self.tile_ranges[action])
        
        batch = get_batch(self.seq[:4, :], self.tile_ranges[action], self.tile_ranges_done, self.num_trials)
        dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
        pred = np.concatenate(self.trainer.predict(self.model, dataloaders=dataloader))
        
        self.current_score = np.tanh(5 * get_batch_score(pred, self.num_trials)) #ADDED TANH
        
        return self.seq
    
    def get_valid_moves(self):
        return (convert_elements(self.seq[-1, :]) == 0).astype(np.uint8)
    
    def terminate(self): #state
        if self.current_level >= self.levels:
            return True
        if self.current_score < self.prev_score:
            return True
    
        return False
    
    def get_score(self):
        return self.current_score

In [None]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)
state = seqgame.get_initial_state()

while True:
    print(state)
    valid_moves = seqgame.get_valid_moves()
    print("Valid moves", [i for i in range(seqgame.action_size) if valid_moves[i]==1])
    #action = int(input("Take action: "))
    action = np.random.randint(0, 50)
    print(action)
    
    if valid_moves[action] == 0:
        print("Invalid action")
        continue
    
    state = seqgame.get_next_state(action)
    
    print(seqgame.get_score())
    is_terminal = seqgame.terminate()
    
    if is_terminal:
        print(state)
        print("Game ended")
        break

In [9]:
class Node:
    def __init__(self, parent, prior_p):
        self.parent = parent
        self.children = {}
        self.n_visits = 0
        self.Q = 0
        self.u = 0
        self.P = prior_p
    
    def expand(self, action_priors):
        for action, prob in action_priors:
            if action not in self.children:
                self.children[action] = Node(self, prob)
    
    def select(self, c_puct):
        return max(self.children.items(),
                   key=lambda act_node: act_node[1].get_value(c_puct))
    
    def update(self, leaf_value):
        self.n_visits += 1
        self.Q += 1.0 * (leaf_value - self.Q) / self.n_visits
    
    def update_recursive(self, leaf_value):
        if self.parent:
            self.parent.update_recursive(-leaf_value)
        self.update(leaf_value)
    
    def get_value(self, c_puct):
        self.u = (c_puct * self.P *
                  np.sqrt(self.parent.n_visits) / (1 + self.n_visits))
        return self.Q + self.u
    
    def is_leaf(self):
        return self.children == {}
    
    def is_root(self):
        return self.parent is None

In [10]:
import copy

class MCTS:
    def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
        self.root = Node(None, 1.0)
        self.policy = policy_value_fn
        self.c_puct = c_puct
        self.n_playout = n_playout
    
    def playout(self, state):
        node = self.root
        while True:
            if node.is_leaf():
                break
            
            action, node = node.select(self.c_puct)
            state.get_next_state(action)
        
        action_probs, leaf_value = self.policy(state)
        end = state.terminate()
        if not end:
            node.expand(action_probs)
        else:
            leaf_value = state.current_score
        
        node.update_recursive(-leaf_value)
    
    def get_move_probs(self, state, temp=1e-3):
        for n in range(self.n_playout):
            state_copy = copy.deepcopy(state)
            state_copy.playout = 1
            
            self.playout(state_copy)
        
        act_visits = [(act, node.v_visits)
                      for act, node in self.root.children.items()]
        
        # Debug...? Need to confirm
        if not act_visits:
            return [], []
        acts, visits = zip(*act_visits)
        act_probs = F.softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
        
        # To prevent from selecting the same move, and then normalize the act_probs
        valid_moves = state.get_valid_moves()
        act_probs = act_probs * valid_moves / np.sum(act_probs)
        
        return acts, act_probs
    
    def update_with_move(self, last_move):
        if last_move in self.root.children:
            self.root = self.root_children[last_move]
            self.root.parent = None
        else:
            self.root = Node(None, 1.0)
    
    def __str__(self):
        return "MCTS Algorithm"

In [13]:
class MCTSMutater:
    def __init__(self, policy_value_function, 
                 c_puct=5, n_playout=2000, is_selfplay=0):
        self.mcts = MCTS(policy_value_function, c_puct, n_playout)
        self.is_selfplay = is_selfplay
    
    def set_player_ind(self, p):
        self.player = p
    
    def reset_Mutater(self):
        self.mcts.update_with_move(-1)
    
    def get_action(self, Env, temp=1e-3, return_prob=0):
        move_probs = np.zeros(Env.action_size) #SUBJECT TO CHANGE
        
        acts, probs = self.mcts.get_move_probs(Env, temp)
        
        if acts:
            move_probs[list(acts)] = probs
            if self.is_selfplay:
                move = np.random.choice(
                    acts, 
                    p=0.75*probs + 0.25 * np.random.dirichlet(0.3*np.ones(len(probs)))
                )
                self.mcts.update_with_move(move)
                print("MCTS Mutater moved: %d\n" % (move))
            else:
                move = np.random.choice(acts, p=probs)
                self.mcts.update_with_move(-1)
                print("MCTS Mutater moved: %d\n" % (move))
            
            if return_prob:
                return move, move_probs
            else:
                return move
        else:
            return [], []
    
    def __str__(self):
        return "MCTS {}".format(self.player)

In [20]:
class Node:
    def __init__(self, env, args, parent=None, action_taken=None):
        self.env = env
        self.args = args
        self.parent = parent
        self.action_taken = action_taken
        
        self.children = []
        self.expandable_moves = env.get_valid_moves()
        
        self.visit_count = 0
        self.value_sum = 0
    
    def is_fully_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        
        return best_child
    
    def get_ucb(self, child):
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * np.sqrt(np.log(self.visit_count) / child.visit_count) #could use math.sqrt or math.log
    
    def expand(self):
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0
        
        child_env = copy.deepcopy(self.env)
        child_state = child_env.get_next_state(action)
        child = Node(child_env, self.args, self, action)
        self.children.append(child)
        return child
    
    def simulate(self):
        value = self.env.current_score
        is_terminal = self.env.terminate()
        print(f"It's terminating with the value of: {value}")
        if is_terminal:
            return value
        
        rollout_env = copy.deepcopy(self.env)
        while True:
            valid_moves = rollout_env.get_valid_moves()
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = rollout_env.get_next_state(action)
            value = rollout_env.current_score
            is_terminal = rollout_env.terminate()
            if is_terminal:
                return value
    
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1
        
        value = self.env.current_score
        if self.parent is not None:
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self, env, args):
        self.env = env
        self.args = args
    
    def search(self):
        root = Node(self.env, self.args)
        
        for search in range(self.args['num_searches']):
            print(f"Conducting search no. {search}")
            node = root
            
            while node.is_fully_expanded():
                node = node.select()
            
            value = self.env.current_score
            is_terminal = self.env.terminate()
            
            if not is_terminal:
                node = node.expand()
                value = node.simulate()

            node.backpropagate(value)
        
        action_probs = np.zeros(self.env.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs

In [23]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)


args = {
    'C': 1.41, 
    'num_searches': 1000
}

mcts = MCTS(seqgame, args)
state = seqgame.get_initial_state()

while True:
    print(state)
    valid_moves = seqgame.get_valid_moves()
    print("Valid moves", [i for i in range(seqgame.action_size) if valid_moves[i]==1])
    #action = int(input("Take action: "))
    #action = np.random.randint(0, 50)
    mcts_probs = mcts.search()
    print(mcts_probs)
    action = np.argmax(mcts_probs)
    print(f"This is the action: {action}")
    
    if valid_moves[action] == 0:
        print("Invalid action")
        continue
    
    state = seqgame.get_next_state(action)
    
    print(seqgame.get_score())
    is_terminal = seqgame.terminate()
    
    if is_terminal:
        print(state)
        print("Game ended")
        break

[[1. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 1. 1. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Valid moves [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
Conducting search no. 0
It's terminating with the value of: 1.0175574719905853
Conducting search no. 1
It's terminating with the value of: 0.10970532596111297
Conducting search no. 2


  rank_zero_warn(


It's terminating with the value of: -0.02848142385482788
Conducting search no. 3
It's terminating with the value of: 1.083763748407364
Conducting search no. 4
It's terminating with the value of: 0.233143213391304
Conducting search no. 5
It's terminating with the value of: 0.07494834661483765
Conducting search no. 6
It's terminating with the value of: -0.1829577222466469
Conducting search no. 7
It's terminating with the value of: -0.07001760303974151
Conducting search no. 8
It's terminating with the value of: 0.6581978395581245
Conducting search no. 9
It's terminating with the value of: 0.5906098663806916
Conducting search no. 10
It's terminating with the value of: 1.5126946330070496
Conducting search no. 11
It's terminating with the value of: 0.0947249785065651
Conducting search no. 12
It's terminating with the value of: 0.09443687200546265
Conducting search no. 13
It's terminating with the value of: 0.573985481262207
Conducting search no. 14
It's terminating with the value of: 0.12395