In [2]:
import os, h5py
os.environ['CUDA_VISIBLE_DEVICES']='1'
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import logomaker
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from tqdm import tqdm
import copy, math
import collections

from cremerl import utils, model_zoo, shuffle

import shuffle_test

#import gymnasium as gym

import logging

# Set the logging level to WARNING
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

In [3]:
expt_name = 'DeepSTARR'

# load data
data_path = '../../data/'
filepath = os.path.join(data_path, expt_name+'_data.h5')
data_module = utils.H5DataModule(filepath, batch_size=100, lower_case=False, transpose=False)


In [4]:
deepstarr2 = model_zoo.deepstarr(2)
loss = torch.nn.MSELoss()
optimizer_dict = utils.configure_optimizer(deepstarr2, lr=0.001, weight_decay=1e-6, decay_factor=0.1, patience=5, monitor='val_loss')
standard_cnn = model_zoo.DeepSTARR(deepstarr2,
                                  criterion=loss,
                                  optimizer=optimizer_dict)

# load checkpoint for model with best validation performance
standard_cnn = utils.load_model_from_checkpoint(standard_cnn, 'DeepSTARR_standard.ckpt')

# evaluate best model
pred = utils.get_predictions(standard_cnn, data_module.x_test[np.newaxis,100], batch_size=100)

2023-08-11 15:06:38.700069: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [5]:
def get_swap_greedy(x, x_mut, tile_ranges):
    ori = x.copy()
    mut = x_mut.copy()
    for tile_range in tile_ranges:
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]

    return ori, mut

def get_score(pred):
    score1 = pred[0] - pred[2]
    score2 = pred[3] - pred[1]
    return (score1+score2)[0], score1+score2

def generate_tile_ranges(sequence_length, window_size, stride):
    ranges = []
    start = np.arange(0, sequence_length - window_size + stride, stride)

    for s in start:
        e = min(s + window_size, sequence_length)
        ranges.append([s, e])

    if start[-1] + window_size - stride < sequence_length:  # Adjust the last range
        ranges[-1][1] = sequence_length

    return ranges

In [6]:
def get_batch(x, tile_range, tile_ranges_ori, trials):
    test_batch = []
    for i in range(trials):
        test_batch.append(x)
        x_mut = shuffle.dinuc_shuffle(x.copy())
        test_batch.append(x_mut)

        ori = x.copy()
        mut = x_mut.copy()
        
        ori, mut = get_swap_greedy(ori, mut, tile_ranges_ori)
        
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]
        
        test_batch.append(ori)
        test_batch.append(mut)

    #print(np.array(test_batch).shape)
    return np.array(test_batch)


def get_batch_score(pred, trials):

    score = []
    score_sep = []
    for i in range(0, pred.shape[0], 2):
        # print(f"Viewing number {i}")
        score1 = pred[0] - pred[i]
        score2 = pred[i+1] - pred[1]
        score.append((np.sum((score1, score2)[0])).tolist()) #np.sum(score1+score2, keepdims=True)
        score_sep.append((score1+score2).tolist())
        
    # print(score)
        
    final = np.sum(np.array(score), axis=0)/trials

    #max_ind = np.argmax(final)
    #block_ind = np.argmax(np.array(score)[:, max_ind])
    #print(np.array(total_score)[:, max_ind])
    total_score_sep = np.sum(np.array(score_sep), axis=0)/trials

    #print(np.max(score))
    return final

In [7]:
def extend_sequence(one_hot_sequence):
    A, L = one_hot_sequence.shape

    # Create an all-ones row
    ones_row = np.zeros(L)

    # Add the all-ones row to the original sequence
    new_sequence = np.vstack((one_hot_sequence, ones_row))

    return np.array(new_sequence, dtype='float32')

def taking_action(sequence_with_ones, tile_range):
    start_idx, end_idx = tile_range

    # Ensure the start_idx and end_idx are within valid bounds
    #if start_idx < 0 or start_idx >= sequence_with_ones.shape[1] or end_idx < 0 or end_idx >= sequence_with_ones.shape[1]:
    #    raise ValueError("Invalid tile range indices.")

    # Copy the input sequence to avoid modifying the original sequence
    modified_sequence = sequence_with_ones.copy()

    # Modify the last row within the specified tile range
    modified_sequence[-1, start_idx:end_idx] = 1

    return np.array(modified_sequence, dtype='float32')

In [8]:
def convert_elements(input_list):
    input_list = input_list.tolist()
    num_columns = 5  # Number of elements to process in each group

    # Calculate the number of elements needed to pad the list
    padding_length = num_columns - (len(input_list) % num_columns)
    last_value = input_list[-1]
    padded_list = input_list + [last_value] * padding_length

    # Convert the padded list to a NumPy array for efficient operations
    input_array = np.array(padded_list)
    reshaped_array = input_array.reshape(-1, num_columns)

    # Check if each row has the same value (all 0s or all 1s)
    row_all_zeros = np.all(reshaped_array == 0, axis=1)
    row_all_ones = np.all(reshaped_array == 1, axis=1)

    # Replace all 0s with 0 and all 1s with 1 in the result array
    output_array = np.where(row_all_zeros, 0, np.where(row_all_ones, 1, reshaped_array[:, 0]))

    # Flatten the result array to get the final output list
    output_list = output_array.flatten()

    return output_list

In [32]:
class SeqGame:
    def __init__(self, sequence, model_func):
        self.seq = sequence
        self.ori_seq = sequence.copy()
        self.tile_ranges = generate_tile_ranges(sequence.shape[1], 5, 5)
        # self.tile_ranges_done = []
        self.levels = 20
        self.num_trials = 10
        self.action_size = 50
        
        self.prev_score = -float("inf")
        self.current_score = 0
        
        self.trainer = pl.Trainer(accelerator='gpu', devices='1', logger=None, enable_progress_bar=False)
        self.model = model_func
        
        if self.seq.shape[0]!=5:
            self.seq = extend_sequence(self.seq)
            self.ori_seq = extend_sequence(self.ori_seq)
        
    
    def get_initial_state(self):
        self.seq = self.ori_seq.copy()
        self.tile_ranges = generate_tile_ranges(self.seq.shape[1], 5, 5)
        self.tile_ranges_done = []
        self.current_level = 0
        
        return self.seq
    
    
    def get_next_state(self, action, tile_ranges_done):
        self.prev_score = self.current_score
        # self.current_level += 1
        
        self.seq = taking_action(self.seq, self.tile_ranges[action])
        
        batch = get_batch(self.seq[:4, :], self.tile_ranges[action], tile_ranges_done, self.num_trials)
        dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
        pred = np.concatenate(self.trainer.predict(self.model, dataloaders=dataloader))
        
        self.current_score = np.tanh(1 * get_batch_score(pred, self.num_trials)) #ADDED TANH
        
        return self.seq
    
    def get_valid_moves(self):
        return (convert_elements(self.seq[-1, :]) == 0).astype(np.uint8)
    
    def terminate(self, level, current_score, parent_score): #state
        # if self.current_level >= self.levels:
        #     return True
        # if self.current_score < self.prev_score:
        #     return True
        
        if level >= self.levels:
            return True
        if current_score < parent_score:
            return True
    
        return False
    
    def set_seq(self, seq):
        self.seq = seq
    
    def get_seq(self):
        return self.seq.copy()
    
    def get_score(self):
        return self.current_score

In [41]:
class Node:
    def __init__(self, action, state, done, reward, mcts, level, tile_ranges_done, parent=None):
        self.env = parent.env
        self.action = action
        
        self.is_expanded = False
        self.parent = parent
        self.children = {}
        
        self.action_space_size = self.env.action_size
        self.child_total_value = np.zeros(
            [self.action_space_size], dtype=np.float32
        ) # Q
        self.child_priors = np.zeros([self.action_space_size], dtype=np.float32) # P
        self.child_number_visits = np.zeros(
            [self.action_space_size], dtype=np.float32
        ) # N
        self.valid_actions = (convert_elements(state[-1, :]) == 0).astype(np.uint8)
        
        self.reward = reward
        self.done = done
        self.state = state
        self.level = level
        
        self.tile_ranges_done = tile_ranges_done
        
        self.mcts = mcts
    
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.action]
    
    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.action] = value
        
    @property
    def total_value(self):
        return self.parent.child_total_value[self.action]
    
    @total_value.setter
    def total_value(self, value):
        self.parent.child_total_value[self.action] = value
        
    def child_Q(self):
        return self.child_total_value / (1 + self.child_number_visits)
    
    def child_U(self):
        return (
            math.sqrt(self.number_visits)
            * self.child_priors
            / (1 + self.child_number_visits)
        )
    
    def best_action(self):
        child_score = self.child_Q() + self.mcts.c_puct * self.child_U()
        masked_child_score = child_score
        # masked_child_score[~self.valid_actions] = -np.inf
        masked_child_score = masked_child_score * self.valid_actions
        return np.argmax(masked_child_score)
    
    def select(self):
        current_node = self
        while current_node.is_expanded:
            best_action = current_node.best_action()
            current_node = current_node.get_child(best_action)
        return current_node
    
    def expand(self, child_priors):
        self.is_expanded = True
        self.child_priors = child_priors
        
    def set_state(self, state):
        self.state = state
        self.valid_actions = (convert_elements(state[-1, :]) == 0).astype(np.uint8)
    
    def get_child(self, action):
        if action not in self.children:

            self.env.set_seq(self.state.copy())
            next_state = self.env.get_next_state(action, self.tile_ranges_done)
            # self.tile_ranges_done.append(self.tile_ranges.pop(action))
            new_tile_ranges_done = copy.deepcopy(self.tile_ranges_done)
            # print(new_tile_ranges_done)
            new_tile_ranges_done.append(self.env.tile_ranges[action])
            # swap tile_ranges
            reward = self.env.get_score()
            terminated = self.env.terminate(self.level, reward, self.parent.reward)
            self.children[action] = Node(
                state=next_state, 
                action=action, 
                parent=self, 
                reward=reward,
                done=terminated,
                mcts=self.mcts, 
                level=self.level+1, 
                tile_ranges_done=new_tile_ranges_done
            )
        return self.children[action]
    
    def backup(self, value):
        current = self
        while current.parent is not None:
            current.number_visits += 1
            current.total_value += value
            current = current.parent

class RootParentNode:
    def __init__(self, env):
        self.parent = None
        self.child_total_value = collections.defaultdict(float)
        self.child_number_visits = collections.defaultdict(float)
        self.env = env
        self.reward = -np.inf

class MCTS:
    def __init__(self, model, mcts_param):
        self.model = model
        self.temperature = mcts_param["temperature"]
        self.dir_epsilon = mcts_param["dirichlet_epsilon"]
        self.dir_noise = mcts_param["dirichlet_noise"]
        self.num_sims = mcts_param["num_simulations"]
        self.exploit = mcts_param["argmax_tree_policy"]
        self.add_dirichlet_noise = mcts_param["add_dirichlet_noise"]
        self.c_puct = mcts_param["puct_coefficient"]
    
    def compute_action(self, node):
        for _ in range(self.num_sims):
            leaf = node.select()
            if leaf.done:
                value = leaf.reward
            else:
                child_priors, value = self.model(torch.tensor(leaf.state).unsqueeze(0))
                child_priors = torch.softmax(child_priors, axis=1).squeeze(0).cpu().detach().numpy()
                if self.add_dirichlet_noise:
                    child_priors = (1 - self.dir_epsilon) * child_priors
                    child_priors += self.dir_epsilon * np.random.dirichlet(
                        [self.dir_noise] * child_priors.size
                    )
                
                leaf.expand(child_priors)
            leaf.backup(value)
            
        tree_policy = node.child_number_visits / node.number_visits
        tree_policy = tree_policy / np.max(tree_policy)
        tree_policy = np.power(tree_policy, self.temperature)
        tree_policy = tree_policy / np.sum(tree_policy)
        if self.exploit:
            action = np.argmax(tree_policy)
        else:
            action = np.random.choice(np.arange(node.action_space_size), p=tree_policy)
        return tree_policy, action, node.children[action]

In [13]:
class CNN_v0(nn.Module):
    def __init__(self, action_dim):
        super(CNN_v0, self).__init__()
        
        self.convblock1 = nn.Sequential(
            nn.Conv1d(5, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32), 
            nn.ReLU()
        )
        
        self.convblock2 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64), 
            nn.ReLU()
        )
        
        self.convblock3 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128), 
            nn.ReLU()
        )
        
        self.policyHead = nn.Sequential(
            nn.Conv1d(128, 50, kernel_size=3, padding=1), 
            nn.BatchNorm1d(50), 
            nn.Flatten(), 
            nn.Linear(50 * 249, action_dim) # 4 * action_dim
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv1d(128, 50, kernel_size=3, padding=1), 
            nn.BatchNorm1d(50), 
            nn.Flatten(), 
            nn.Linear(50 * 249, 128), 
            nn.Linear(128, 1), 
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.convblock1(x)
        x = self.convblock2(x)
        x = self.convblock3(x)
        
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value

In [None]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)


mcts_config = {
    "puct_coefficient": 2.0,
    "num_simulations": 10000,
    "temperature": 1.5,
    "dirichlet_epsilon": 0.25,
    "dirichlet_noise": 0.03,
    "argmax_tree_policy": False,
    "add_dirichlet_noise": True,
}

model = CNN_v0(seqgame.action_size)
model.eval()

mcts = MCTS(model, mcts_config)
state = seqgame.get_initial_state()
tree_node = None
while True:
    print(state)
    valid_moves = seqgame.get_valid_moves()
    print("Valid moves", [i for i in range(seqgame.action_size) if valid_moves[i]==1])
    #action = int(input("Take action: "))
    #action = np.random.randint(0, 50)
    if tree_node == None:
        tree_node = Node(
            state=seqgame.get_seq(), 
            reward=0, 
            done=False,
            action=None,
            parent=RootParentNode(env=seqgame),
            mcts=mcts, 
            level=0, 
            tile_ranges=generate_tile_ranges(sequence.shape[1], 5, 5), 
            tile_ranges_done=[]
        )
    mcts_probs, action, tree_node = mcts.compute_action(tree_node)
    print(mcts_probs)
    #action = np.argmax(mcts_probs)
    print(f"This is the action: {action}")
    
    if valid_moves[action] == 0:
        print("Invalid action")
        continue
    
    state = seqgame.get_next_state(action)
    print(f"The current level is: {seqgame.current_level}")
    
    print(seqgame.get_score())
    #is_terminal = seqgame.terminate()
    is_terminal = tree_node.done
    
    if is_terminal:
        print("Game ended")
        break

In [None]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)


mcts_config = {
    "puct_coefficient": 2.0,
    "num_simulations": 10000,
    "temperature": 1.5,
    "dirichlet_epsilon": 0.25,
    "dirichlet_noise": 0.03,
    "argmax_tree_policy": False,
    "add_dirichlet_noise": True,
}

model = CNN_v0(seqgame.action_size)
model.eval()

mcts = MCTS(model, mcts_config)

# Get the initial sequence and create the root node
initial_sequence = seqgame.get_seq()
tile_ranges=generate_tile_ranges(sequence.shape[1], 5, 5)
tile_ranges_done=[]
prev_score = 0
test_sequence = seqgame.get_seq()
    
root_node = Node(
    state=initial_sequence,
    reward=0,
    done=False,
    action=None,
    parent=RootParentNode(env=seqgame),
    mcts=mcts, 
    level=0, 
    tile_ranges_done=[]
)

while not root_node.done:  # Loop until the root node indicates the game is done
    print("Current sequence:", convert_elements(root_node.state[-1,:]))
    valid_moves = root_node.valid_actions
    print("Valid moves:", [i for i in range(seqgame.action_size) if valid_moves[i] == 1])

    # Perform simulations and select an action using MCTS
    mcts_probs, action, next_node = mcts.compute_action(root_node)
    print("MCTS probabilities:", mcts_probs)
    print("Selected action:", action)
    
    print(valid_moves)

    if valid_moves[action] == 0:
        print("Invalid action, skipping.")
        continue
    
    print(root_node.tile_ranges_done)
    print(root_node.reward)

    # next_state = seqgame.get_next_state(action)
    # seqgame.set_seq(test_sequence)
    # next_state = seqgame.get_next_state(action, tile_ranges, tile_ranges_done)
    # tile_ranges_done.append(tile_ranges[action])
    # tile_ranges[action] = [0,0]
    # next_reward = seqgame.get_score()
    # print(next_reward)
    # print(next_node.level)
    # next_done = seqgame.terminate(next_node.level, next_reward, prev_score)  # Update termination status based on your logic
    # prev_score = next_reward

    # # Create the new child node and update the root node
    # for i in range(root_node.action_space_size):
    #     print("Children's states")
    #     print(convert_elements(root_node.get_child(i).state[-1,:]))
    # next_child_node = root_node.get_child(action)
    # print(f"This is the chosen action's child's state: {convert_elements(root_node.get_child(action).state[-1,:])}")
    # next_child_node.set_state(next_state)
    # next_child_node.reward = next_reward
    # next_child_node.done = next_done
    # next_child_node.is_expanded = False
    # root_node = next_child_node
    root_node = next_node

print("Game ended.")


In [47]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)


mcts_config = {
    "puct_coefficient": 2.0,
    "num_simulations": 10000,
    "temperature": 1.5,
    "dirichlet_epsilon": 0.25,
    "dirichlet_noise": 0.03,
    "argmax_tree_policy": False,
    "add_dirichlet_noise": True,
}

model = CNN_v0(seqgame.action_size)
model.eval()

mcts = MCTS(model, mcts_config)

# Get the initial sequence and create the root node
initial_sequence = seqgame.get_seq()
tile_ranges=generate_tile_ranges(sequence.shape[1], 5, 5)
tile_ranges_done=[]
prev_score = 0
test_sequence = seqgame.get_seq()
    
root_node = Node(
    state=initial_sequence,
    reward=0,
    done=False,
    action=None,
    parent=RootParentNode(env=seqgame),
    mcts=mcts, 
    level=0, 
    tile_ranges_done=[]
)

while not root_node.done:  # Loop until the root node indicates the game is done
    print("Current sequence:", convert_elements(root_node.state[-1,:]))
    valid_moves = root_node.valid_actions
    print("Valid moves:", [i for i in range(seqgame.action_size) if valid_moves[i] == 1])

    # Perform simulations and select an action using MCTS
    mcts_probs, action, next_node = mcts.compute_action(root_node)
    print("MCTS probabilities:", mcts_probs)
    print("Selected action:", action)
    

    if valid_moves[action] == 0:
        print("Invalid action, skipping.")
        continue
    
    print(root_node.tile_ranges_done)
    print(root_node.reward)

    root_node = next_node

print("Game ended.")


Current sequence: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]
Valid moves: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
MCTS probabilities: [8.2636916e-06 7.2403163e-05 5.3674253e-06 2.9216562e-06 5.3674253e-06
 5.3674253e-06 5.3674253e-06 5.3674253e-06 5.3674253e-06 5.3674253e-06
 5.3674253e-06 5.3674253e-06 2.9216562e-06 2.9216562e-06 5.3674253e-06
 1.1548863e-05 5.3674253e-06 9.9936849e-01 5.3674253e-06 5.3674253e-06
 4.8417147e-05 5.3674253e-06 5.3674253e-06 5.3674253e-06 5.3674253e-06
 2.9216562e-06 5.3674253e-06 5.3674253e-06 5.3674253e-06 5.3674253e-06
 1.8698600e-04 5.3674253e-06 5.3674253e-06 5.3674253e-06 5.3674253e-06
 5.3674253e-06 5.3674253e-06 5.3674253e-06 5.3674253e-06 3.2665117e-05
 5.4109831e-05 5.3674253e-06 5.36742