In [1]:
import os, h5py
os.environ['CUDA_VISIBLE_DEVICES']='1'
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import logomaker
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from tqdm import tqdm
import copy, math
import collections

from cremerl import utils, model_zoo, shuffle

import shuffle_test

#import gymnasium as gym

import logging

# Set the logging level to WARNING
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

In [2]:
expt_name = 'DeepSTARR'

# load data
data_path = '../../data/'
filepath = os.path.join(data_path, expt_name+'_data.h5')
data_module = utils.H5DataModule(filepath, batch_size=100, lower_case=False, transpose=False)


In [3]:
deepstarr2 = model_zoo.deepstarr(2)
loss = torch.nn.MSELoss()
optimizer_dict = utils.configure_optimizer(deepstarr2, lr=0.001, weight_decay=1e-6, decay_factor=0.1, patience=5, monitor='val_loss')
standard_cnn = model_zoo.DeepSTARR(deepstarr2,
                                  criterion=loss,
                                  optimizer=optimizer_dict)

# load checkpoint for model with best validation performance
standard_cnn = utils.load_model_from_checkpoint(standard_cnn, 'DeepSTARR_standard.ckpt')

# evaluate best model
pred = utils.get_predictions(standard_cnn, data_module.x_test[np.newaxis,100], batch_size=100)

2023-08-11 17:30:58.878604: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [4]:
def get_swap_greedy(x, x_mut, tile_ranges):
    ori = x.copy()
    mut = x_mut.copy()
    for tile_range in tile_ranges:
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]

    return ori, mut

def get_score(pred):
    score1 = pred[0] - pred[2]
    score2 = pred[3] - pred[1]
    return (score1+score2)[0], score1+score2

def generate_tile_ranges(sequence_length, window_size, stride):
    ranges = []
    start = np.arange(0, sequence_length - window_size + stride, stride)

    for s in start:
        e = min(s + window_size, sequence_length)
        ranges.append([s, e])

    if start[-1] + window_size - stride < sequence_length:  # Adjust the last range
        ranges[-1][1] = sequence_length

    return ranges

In [5]:
def get_batch(x, tile_range, tile_ranges_ori, trials):
    test_batch = []
    for i in range(trials):
        test_batch.append(x)
        x_mut = shuffle.dinuc_shuffle(x.copy())
        test_batch.append(x_mut)

        ori = x.copy()
        mut = x_mut.copy()
        
        ori, mut = get_swap_greedy(ori, mut, tile_ranges_ori)
        
        ori[:, tile_range[0]:tile_range[1]] = x_mut[:, tile_range[0]:tile_range[1]]
        mut[:, tile_range[0]:tile_range[1]] = x[:, tile_range[0]:tile_range[1]]
        
        test_batch.append(ori)
        test_batch.append(mut)

    #print(np.array(test_batch).shape)
    return np.array(test_batch)


def get_batch_score(pred, trials):

    score = []
    score_sep = []
    for i in range(0, pred.shape[0], 2):
        # print(f"Viewing number {i}")
        score1 = pred[0] - pred[i]
        score2 = pred[i+1] - pred[1]
        score.append((np.sum((score1, score2)[0])).tolist()) #np.sum(score1+score2, keepdims=True)
        score_sep.append((score1+score2).tolist())
        
    # print(score)
        
    final = np.sum(np.array(score), axis=0)/trials

    #max_ind = np.argmax(final)
    #block_ind = np.argmax(np.array(score)[:, max_ind])
    #print(np.array(total_score)[:, max_ind])
    total_score_sep = np.sum(np.array(score_sep), axis=0)/trials

    #print(np.max(score))
    return final

In [6]:
def extend_sequence(one_hot_sequence):
    A, L = one_hot_sequence.shape

    # Create an all-ones row
    ones_row = np.zeros(L)

    # Add the all-ones row to the original sequence
    new_sequence = np.vstack((one_hot_sequence, ones_row))

    return np.array(new_sequence, dtype='float32')

def taking_action(sequence_with_ones, tile_range):
    start_idx, end_idx = tile_range

    # Ensure the start_idx and end_idx are within valid bounds
    #if start_idx < 0 or start_idx >= sequence_with_ones.shape[1] or end_idx < 0 or end_idx >= sequence_with_ones.shape[1]:
    #    raise ValueError("Invalid tile range indices.")

    # Copy the input sequence to avoid modifying the original sequence
    modified_sequence = sequence_with_ones.copy()

    # Modify the last row within the specified tile range
    modified_sequence[-1, start_idx:end_idx] = 1

    return np.array(modified_sequence, dtype='float32')

In [7]:
def convert_elements(input_list):
    input_list = input_list.tolist()
    num_columns = 5  # Number of elements to process in each group

    # Calculate the number of elements needed to pad the list
    padding_length = num_columns - (len(input_list) % num_columns)
    last_value = input_list[-1]
    padded_list = input_list + [last_value] * padding_length

    # Convert the padded list to a NumPy array for efficient operations
    input_array = np.array(padded_list)
    reshaped_array = input_array.reshape(-1, num_columns)

    # Check if each row has the same value (all 0s or all 1s)
    row_all_zeros = np.all(reshaped_array == 0, axis=1)
    row_all_ones = np.all(reshaped_array == 1, axis=1)

    # Replace all 0s with 0 and all 1s with 1 in the result array
    output_array = np.where(row_all_zeros, 0, np.where(row_all_ones, 1, reshaped_array[:, 0]))

    # Flatten the result array to get the final output list
    output_list = output_array.flatten()

    return output_list

In [8]:
class SeqGame:
    def __init__(self, sequence, model_func):
        self.seq = sequence
        self.ori_seq = sequence.copy()
        self.tile_ranges = generate_tile_ranges(sequence.shape[1], 5, 5)
        self.levels = 20
        self.num_trials = 10
        self.action_size = 50
        
        self.prev_score = -float("inf")
        self.current_score = 0
        
        self.trainer = pl.Trainer(accelerator='gpu', devices='1', logger=None, enable_progress_bar=False)
        self.model = model_func
        
        if self.seq.shape[0]!=5:
            self.seq = extend_sequence(self.seq)
            self.ori_seq = extend_sequence(self.ori_seq)
        
    
    def get_initial_state(self):
        self.seq = self.ori_seq.copy()
        
        return self.seq
    
    
    def get_next_state(self, action, tile_ranges_done):
        self.prev_score = self.current_score
        # self.current_level += 1
        
        self.seq = taking_action(self.seq, self.tile_ranges[action])
        
        batch = get_batch(self.seq[:4, :], self.tile_ranges[action], tile_ranges_done, self.num_trials)
        dataloader = torch.utils.data.DataLoader(batch, batch_size=100, shuffle=False)
        pred = np.concatenate(self.trainer.predict(self.model, dataloaders=dataloader))
        
        self.current_score = np.tanh(1 * get_batch_score(pred, self.num_trials)) #ADDED TANH
        
        return self.seq
    
    def get_valid_moves(self):
        return (convert_elements(self.seq[-1, :]) == 0).astype(np.uint8)
    
    def terminate(self, level, current_score, parent_score): #state
        # if self.current_level >= self.levels:
        #     return True
        # if self.current_score < self.prev_score:
        #     return True
        
        if level >= self.levels:
            return True
        if current_score < parent_score:
            return True
    
        return False
    
    def set_seq(self, seq):
        self.seq = seq
    
    def get_seq(self):
        return self.seq.copy()
    
    def get_score(self):
        return self.current_score

In [9]:
class Node:
    def __init__(self, action, state, done, reward, mcts, level, tile_ranges_done, parent=None):
        self.env = parent.env
        self.action = action
        
        self.is_expanded = False
        self.parent = parent
        self.children = {}
        
        self.action_space_size = self.env.action_size
        self.child_total_value = np.zeros(
            [self.action_space_size], dtype=np.float32
        ) # Q
        self.child_priors = np.zeros([self.action_space_size], dtype=np.float32) # P
        self.child_number_visits = np.zeros(
            [self.action_space_size], dtype=np.float32
        ) # N
        self.valid_actions = (convert_elements(state[-1, :]) == 0).astype(np.uint8)
        
        self.reward = reward
        self.done = done
        self.state = state
        self.level = level
        
        self.tile_ranges_done = tile_ranges_done
        
        self.mcts = mcts
    
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.action]
    
    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.action] = value
        
    @property
    def total_value(self):
        return self.parent.child_total_value[self.action]
    
    @total_value.setter
    def total_value(self, value):
        self.parent.child_total_value[self.action] = value
        
    def child_Q(self):
        return self.child_total_value / (1 + self.child_number_visits)
    
    def child_U(self):
        return (
            math.sqrt(self.number_visits)
            * self.child_priors
            / (1 + self.child_number_visits)
        )
    
    def best_action(self):
        child_score = self.child_Q() + self.mcts.c_puct * self.child_U()
        masked_child_score = child_score
        # masked_child_score[~self.valid_actions] = -np.inf
        masked_child_score = masked_child_score * self.valid_actions
        return np.argmax(masked_child_score)
    
    def select(self):
        current_node = self
        while current_node.is_expanded:
            best_action = current_node.best_action()
            current_node = current_node.get_child(best_action)
        return current_node
    
    def expand(self, child_priors):
        self.is_expanded = True
        self.child_priors = child_priors
        
    def set_state(self, state):
        self.state = state
        self.valid_actions = (convert_elements(state[-1, :]) == 0).astype(np.uint8)
    
    def get_child(self, action):
        if action not in self.children:

            self.env.set_seq(self.state.copy())
            next_state = self.env.get_next_state(action, self.tile_ranges_done)
            # self.tile_ranges_done.append(self.tile_ranges.pop(action))
            new_tile_ranges_done = copy.deepcopy(self.tile_ranges_done)
            # print(new_tile_ranges_done)
            new_tile_ranges_done.append(self.env.tile_ranges[action])
            # swap tile_ranges
            reward = self.env.get_score()
            terminated = self.env.terminate(self.level, reward, self.parent.reward)
            self.children[action] = Node(
                state=next_state, 
                action=action, 
                parent=self, 
                reward=reward,
                done=terminated,
                mcts=self.mcts, 
                level=self.level+1, 
                tile_ranges_done=new_tile_ranges_done
            )
        return self.children[action]
    
    def backup(self, value):
        current = self
        while current.parent is not None:
            current.number_visits += 1
            current.total_value += value
            current = current.parent

class RootParentNode:
    def __init__(self, env):
        self.parent = None
        self.child_total_value = collections.defaultdict(float)
        self.child_number_visits = collections.defaultdict(float)
        self.env = env
        self.reward = -np.inf

class MCTS:
    def __init__(self, model, mcts_param):
        self.model = model
        self.temperature = mcts_param["temperature"]
        self.dir_epsilon = mcts_param["dirichlet_epsilon"]
        self.dir_noise = mcts_param["dirichlet_noise"]
        self.num_sims = mcts_param["num_simulations"]
        self.exploit = mcts_param["argmax_tree_policy"]
        self.add_dirichlet_noise = mcts_param["add_dirichlet_noise"]
        self.c_puct = mcts_param["puct_coefficient"]
    
    def compute_action(self, node):
        for _ in range(self.num_sims):
            leaf = node.select()
            if leaf.done:
                value = leaf.reward
            else:
                child_priors, value = self.model(torch.tensor(leaf.state).unsqueeze(0))
                child_priors = torch.softmax(child_priors, axis=1).squeeze(0).cpu().detach().numpy()
                if self.add_dirichlet_noise:
                    child_priors = (1 - self.dir_epsilon) * child_priors
                    child_priors += self.dir_epsilon * np.random.dirichlet(
                        [self.dir_noise] * child_priors.size
                    )
                
                leaf.expand(child_priors)
            leaf.backup(value)
            
        tree_policy = node.child_number_visits / node.number_visits
        tree_policy = tree_policy / np.max(tree_policy)
        tree_policy = np.power(tree_policy, self.temperature)
        tree_policy = tree_policy / np.sum(tree_policy)
        if self.exploit:
            action = np.argmax(tree_policy)
        else:
            action = np.random.choice(np.arange(node.action_space_size), p=tree_policy)
        return tree_policy, action, node.children[action]

In [10]:
class CNN_v0(nn.Module):
    def __init__(self, action_dim):
        super(CNN_v0, self).__init__()
        
        self.convblock1 = nn.Sequential(
            nn.Conv1d(5, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32), 
            nn.ReLU()
        )
        
        self.convblock2 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64), 
            nn.ReLU()
        )
        
        self.convblock3 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128), 
            nn.ReLU()
        )
        
        self.policyHead = nn.Sequential(
            nn.Conv1d(128, 50, kernel_size=3, padding=1), 
            nn.BatchNorm1d(50), 
            nn.Flatten(), 
            nn.Linear(50 * 249, action_dim) # 4 * action_dim
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv1d(128, 50, kernel_size=3, padding=1), 
            nn.BatchNorm1d(50), 
            nn.Flatten(), 
            nn.Linear(50 * 249, 128), 
            nn.Linear(128, 1), 
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.convblock1(x)
        x = self.convblock2(x)
        x = self.convblock3(x)
        
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value

In [None]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)


mcts_config = {
    "puct_coefficient": 2.0,
    "num_simulations": 10000,
    "temperature": 1.5,
    "dirichlet_epsilon": 0.25,
    "dirichlet_noise": 0.03,
    "argmax_tree_policy": False,
    "add_dirichlet_noise": True,
}

model = CNN_v0(seqgame.action_size)
model.eval()

mcts = MCTS(model, mcts_config)

# Get the initial sequence and create the root node
initial_sequence = seqgame.get_seq()
    
root_node = Node(
    state=initial_sequence,
    reward=0,
    done=False,
    action=None,
    parent=RootParentNode(env=seqgame),
    mcts=mcts, 
    level=0, 
    tile_ranges_done=[]
)

while not root_node.done:  # Loop until the root node indicates the game is done
    print("Current sequence:", convert_elements(root_node.state[-1,:]))
    valid_moves = root_node.valid_actions
    print("Valid moves:", [i for i in range(seqgame.action_size) if valid_moves[i] == 1])

    # Perform simulations and select an action using MCTS
    mcts_probs, action, next_node = mcts.compute_action(root_node)
    print("MCTS probabilities:", mcts_probs)
    print("Selected action:", action)
    

    if valid_moves[action] == 0:
        print("Invalid action, skipping.")
        continue
    
    print(root_node.tile_ranges_done)
    print(root_node.reward)

    root_node = next_node

print("Game ended.")


In [27]:
class AlphaDNA:
    def __init__(self, model, optimizer, env, args, mcts_config = {
        "puct_coefficient": 2.0,
        "num_simulations": 10000,
        "temperature": 1.5,
        "dirichlet_epsilon": 0.25,
        "dirichlet_noise": 0.03,
        "argmax_tree_policy": False,
        "add_dirichlet_noise": True,}
                 ):
        self.model = model
        self.optimizer = optimizer
        self.env = env
        self.args = args
        self.mcts = MCTS(model, mcts_config)
        
        self.initial_sequence = env.get_seq()
    
    def selfPlay(self):
        memory = []
        state = self.env.get_seq()

        root_node = Node(
            state=self.initial_sequence,
            reward=0,
            done=False,
            action=None,
            parent=RootParentNode(env=seqgame),
            mcts=self.mcts, 
            level=0, 
            tile_ranges_done=[]
        )

        while True:  # Loop until the root node indicates the game is done
            valid_moves = root_node.valid_actions
            mcts_probs, action, next_node = self.mcts.compute_action(root_node)
            
            memory.append((root_node.state, mcts_probs, next_node.reward))

            if valid_moves[action] == 0:
                print("Invalid action, skipping.")
                continue
            
            if next_node.done:
                return memory

            root_node = next_node
    
    def train(self, memory):
        np.random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])]
            state, policy_targets, value_targets = zip(*sample)
            
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets)
            
            state = torch.tensor(state, dtype=torch.float32)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32)
            value_targets = torch.tensor(value_targets, dtype=torch.float32)
            
            out_policy, out_value = self.model(state)
            
            priors = nn.Softmax(dim=-1)(out_policy)
            
            policy_loss = torch.mean(
                -torch.sum(policy_targets * torch.log(priors), dim=-1)
            )
            value_loss = torch.mean(torch.pow(value_targets - out_value, 2))
            
            total_loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            
            return total_loss
        
    def learn(self):
        for iteration in tqdm(range(self.args['num_iterations'])):
            print(f"Iteration {iteration}:")
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in tqdm(range(self.args['num_selfPlay_iterations'])):
                print(f"SelfPlay Iteration {selfPlay_iteration}")
                memory += self.selfPlay()
            
            self.model.train()
            for epoch in tqdm(range(self.args['num_epochs'])):
                print(f"Training Epoch {epoch}")
                loss = self.train(memory)
                print(f"Total Loss: {loss}")
            
            torch.save(self.model.state_dict(), f"model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")
            
            
            

In [30]:
sequence = data_module.x_test[1].numpy()
seqgame = SeqGame(sequence, standard_cnn)

model = CNN_v0(seqgame.action_size)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'num_iterations': 20, # 3 
    'num_selfPlay_iterations': 2, # 500
    'num_epochs': 4, 
    'batch_size': 2, # 64
}

mcts_config = {
    "puct_coefficient": 2.0,
    "num_simulations": 1000,
    "temperature": 1.5,
    "dirichlet_epsilon": 0.25,
    "dirichlet_noise": 0.03,
    "argmax_tree_policy": False,
    "add_dirichlet_noise": True,
}

alphadna = AlphaDNA(model, optimizer, seqgame, args, mcts_config)
alphadna.learn()

  0%|          | 0/20 [00:00<?, ?it/s]

Iteration 0:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:05<00:00,  2.97s/it]
100%|██████████| 4/4 [00:00<00:00, 134.91it/s]
  5%|▌         | 1/20 [00:05<01:53,  6.00s/it]

Training Epoch 0
Total Loss: 5.298501014709473
Training Epoch 1
Total Loss: 22.32219696044922
Training Epoch 2
Total Loss: 28.530973434448242
Training Epoch 3
Total Loss: 3.477583885192871
Iteration 1:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:17<00:00,  8.86s/it]
100%|██████████| 4/4 [00:00<00:00, 138.49it/s]
 10%|█         | 2/20 [00:23<03:52, 12.93s/it]

Training Epoch 0
Total Loss: 22.160442352294922
Training Epoch 1
Total Loss: 22.214080810546875
Training Epoch 2
Total Loss: 21.115890502929688
Training Epoch 3
Total Loss: 20.44860076904297
Iteration 2:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:17<00:00,  8.88s/it]
100%|██████████| 4/4 [00:00<00:00, 131.88it/s]
 15%|█▌        | 3/20 [00:41<04:17, 15.16s/it]

Training Epoch 0
Total Loss: 18.718368530273438
Training Epoch 1
Total Loss: 11.52180004119873
Training Epoch 2
Total Loss: 14.892873764038086
Training Epoch 3
Total Loss: 16.590232849121094
Iteration 3:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:26<00:00, 13.14s/it]
100%|██████████| 4/4 [00:00<00:00, 118.77it/s]
 20%|██        | 4/20 [01:07<05:13, 19.57s/it]

Training Epoch 0
Total Loss: 11.645698547363281
Training Epoch 1
Total Loss: 11.100095748901367
Training Epoch 2
Total Loss: 12.235994338989258
Training Epoch 3
Total Loss: 16.988523483276367
Iteration 4:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:24<00:00, 12.23s/it]
100%|██████████| 4/4 [00:00<00:00, 142.80it/s]
 25%|██▌       | 5/20 [01:32<05:20, 21.35s/it]

Training Epoch 0
Total Loss: 9.83155345916748
Training Epoch 1
Total Loss: 8.710453033447266
Training Epoch 2
Total Loss: 5.074153423309326
Training Epoch 3
Total Loss: 7.747463703155518
Iteration 5:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:26<00:00, 13.20s/it]
100%|██████████| 4/4 [00:00<00:00, 131.61it/s]
 30%|███       | 6/20 [01:58<05:23, 23.09s/it]

Training Epoch 0
Total Loss: 6.70822286605835
Training Epoch 1
Total Loss: 7.176239490509033
Training Epoch 2
Total Loss: 3.908320426940918
Training Epoch 3
Total Loss: 6.363714218139648
Iteration 6:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:27<00:00, 13.54s/it]
100%|██████████| 4/4 [00:00<00:00, 133.70it/s]
 35%|███▌      | 7/20 [02:26<05:17, 24.41s/it]

Training Epoch 0
Total Loss: 3.6277413368225098
Training Epoch 1
Total Loss: 7.015326499938965
Training Epoch 2
Total Loss: 4.756281852722168
Training Epoch 3
Total Loss: 3.5805766582489014
Iteration 7:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:33<00:00, 16.76s/it]
100%|██████████| 4/4 [00:00<00:00, 128.31it/s]
 40%|████      | 8/20 [02:59<05:27, 27.33s/it]

Training Epoch 0
Total Loss: 2.1308603286743164
Training Epoch 1
Total Loss: 3.171949863433838
Training Epoch 2
Total Loss: 9.619722366333008
Training Epoch 3
Total Loss: 4.922009468078613
Iteration 8:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:16<00:00,  8.43s/it]
100%|██████████| 4/4 [00:00<00:00, 130.24it/s]
 45%|████▌     | 9/20 [03:16<04:24, 24.07s/it]

Training Epoch 0
Total Loss: 13.017902374267578
Training Epoch 1
Total Loss: 10.377541542053223
Training Epoch 2
Total Loss: 3.364511728286743
Training Epoch 3
Total Loss: 3.805150032043457
Iteration 9:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:35<00:00, 17.51s/it]
100%|██████████| 4/4 [00:00<00:00, 132.21it/s]
 50%|█████     | 10/20 [03:51<04:34, 27.46s/it]

Training Epoch 0
Total Loss: 1.742547869682312
Training Epoch 1
Total Loss: 2.7243292331695557
Training Epoch 2
Total Loss: 2.248629331588745
Training Epoch 3
Total Loss: 2.0029866695404053
Iteration 10:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:21<00:00, 10.50s/it]
100%|██████████| 4/4 [00:00<00:00, 128.21it/s]
 55%|█████▌    | 11/20 [04:12<03:49, 25.51s/it]

Training Epoch 0
Total Loss: 2.198237180709839
Training Epoch 1
Total Loss: 0.6289008259773254
Training Epoch 2
Total Loss: 1.0931742191314697
Training Epoch 3
Total Loss: 4.035010814666748
Iteration 11:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:25<00:00, 12.98s/it]
100%|██████████| 4/4 [00:00<00:00, 134.66it/s]
 60%|██████    | 12/20 [04:38<03:25, 25.66s/it]

Training Epoch 0
Total Loss: 10.132330894470215
Training Epoch 1
Total Loss: 9.091532707214355
Training Epoch 2
Total Loss: 1.5453221797943115
Training Epoch 3
Total Loss: 3.4052395820617676
Iteration 12:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:28<00:00, 14.07s/it]
100%|██████████| 4/4 [00:00<00:00, 136.21it/s]
 65%|██████▌   | 13/20 [05:06<03:04, 26.43s/it]

Training Epoch 0
Total Loss: 1.5085400342941284
Training Epoch 1
Total Loss: 1.0531786680221558
Training Epoch 2
Total Loss: 11.770639419555664
Training Epoch 3
Total Loss: 3.066715717315674
Iteration 13:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:27<00:00, 13.62s/it]
100%|██████████| 4/4 [00:00<00:00, 127.75it/s]
 70%|███████   | 14/20 [05:34<02:40, 26.69s/it]

Training Epoch 0
Total Loss: 5.6675920486450195
Training Epoch 1
Total Loss: 1.8001728057861328
Training Epoch 2
Total Loss: 4.00806188583374
Training Epoch 3
Total Loss: 7.66372013092041
Iteration 14:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:22<00:00, 11.17s/it]
100%|██████████| 4/4 [00:00<00:00, 123.84it/s]
 75%|███████▌  | 15/20 [05:56<02:06, 25.39s/it]

Training Epoch 0
Total Loss: 0.7887231111526489
Training Epoch 1
Total Loss: 17.20461654663086
Training Epoch 2
Total Loss: 1.1866786479949951
Training Epoch 3
Total Loss: 2.547724962234497
Iteration 15:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:28<00:00, 14.01s/it]
100%|██████████| 4/4 [00:00<00:00, 131.93it/s]
 80%|████████  | 16/20 [06:24<01:44, 26.20s/it]

Training Epoch 0
Total Loss: 1.1030526161193848
Training Epoch 1
Total Loss: 8.244341850280762
Training Epoch 2
Total Loss: 4.607093334197998
Training Epoch 3
Total Loss: 1.0634911060333252
Iteration 16:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:30<00:00, 15.41s/it]
100%|██████████| 4/4 [00:00<00:00, 136.12it/s]
 85%|████████▌ | 17/20 [06:55<01:22, 27.60s/it]

Training Epoch 0
Total Loss: 2.801349401473999
Training Epoch 1
Total Loss: 4.642921447753906
Training Epoch 2
Total Loss: 1.8931376934051514
Training Epoch 3
Total Loss: 5.27190637588501
Iteration 17:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:27<00:00, 14.00s/it]
100%|██████████| 4/4 [00:00<00:00, 118.71it/s]
 90%|█████████ | 18/20 [07:23<00:55, 27.74s/it]

Training Epoch 0
Total Loss: 2.2421681880950928
Training Epoch 1
Total Loss: 1.308841347694397
Training Epoch 2
Total Loss: 8.206446647644043
Training Epoch 3
Total Loss: 0.6455740928649902
Iteration 18:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:19<00:00,  9.78s/it]
100%|██████████| 4/4 [00:00<00:00, 135.31it/s]
 95%|█████████▌| 19/20 [07:43<00:25, 25.30s/it]

Training Epoch 0
Total Loss: 3.396185874938965
Training Epoch 1
Total Loss: 0.6861117482185364
Training Epoch 2
Total Loss: 1.7199504375457764
Training Epoch 3
Total Loss: 0.9397327303886414
Iteration 19:




SelfPlay Iteration 0




SelfPlay Iteration 1


100%|██████████| 2/2 [00:19<00:00,  9.61s/it]
100%|██████████| 4/4 [00:00<00:00, 135.75it/s]
100%|██████████| 20/20 [08:02<00:00, 24.12s/it]

Training Epoch 0
Total Loss: 2.9753501415252686
Training Epoch 1
Total Loss: 14.490436553955078
Training Epoch 2
Total Loss: 0.8867294192314148
Training Epoch 3
Total Loss: 16.788740158081055



