# Training

In [None]:
# !cd ../bs-gym/ && pip install -e . -q
# !cd .
# !cd ../pytorch-a2c-ppo-acktr-gail/ && pip install -e . -q

In [None]:
# !which python
# !which pip
# !pip list

In [None]:
import numpy as np
import time
import logging
import sys
import os

import torch
from tqdm import tqdm

from a2c_ppo_acktr.algo import PPO
from a2c_ppo_acktr.storage import RolloutStorage
from bs_gym.gymbattlesnake import BattlesnakeEnv

In [None]:
logger = logging.getLogger('default')

logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s') # %(msecs)d %(name)s 

stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
stdout_handler.setFormatter(formatter)

logger.addHandler(stdout_handler)

In [None]:
from performance import check_performance
from policy import SnakePolicyBase, create_policy

from utils import n_opponents,  device
from utils import PathHelper, plot_graphs

u = PathHelper()

rollouts = None
tmp_env = None
n_envs = -1
n_steps = -1
CPU_THREADS = -1

def setup_rollouts(n_envs) -> None:
    global rollouts, tmp_env

    logger.debug('Setting up rollouts...')
    tmp_env = BattlesnakeEnv(n_threads=2, n_envs=n_envs)

    # rollout storage: game turns played and rewards
    rollouts = RolloutStorage(n_steps,
                            n_envs,
                            tmp_env.observation_space.shape,
                            tmp_env.action_space,
                            n_steps)
    tmp_env.close()

policy = None
next_best_policy = None
second_best_policy = None
agent = None

def setup_agent(model_path=None,
                value_loss_coef=0.5,
                entropy_coef=0.01,
                max_grad_norm=0.5,
                clip_param=0.2,
                ppo_epoch=4,
                num_mini_batch=16,
                eps=1e-5,
                lr=5e-5,
                use_clipped_value_loss=True,
                use_gradient_accumulation=False,
                gradient_accumulation_steps=1
                ) -> None:
    global policy, next_best_policy, second_best_policy, agent
    logger.debug('Setting up agent...')

    # policies
    policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
    # load latest model if found
    if model_path is not None:
        policy.load_state_dict(torch.load(model_path))

    next_best_policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
    second_best_policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
    next_best_policy.load_state_dict(policy.state_dict())
    second_best_policy.load_state_dict(policy.state_dict())

    # TODO: second_best model

    # agent
    agent = PPO(policy,
                value_loss_coef=value_loss_coef,
                entropy_coef=entropy_coef,
                max_grad_norm=max_grad_norm,
                clip_param=clip_param,
                ppo_epoch=ppo_epoch,
                num_mini_batch=num_mini_batch,
                eps=eps,
                lr=lr,
                use_clipped_value_loss=use_clipped_value_loss,
                use_gradient_accumulation=use_gradient_accumulation,
                gradient_accumulation_steps=gradient_accumulation_steps,
                )

env = None
obs = None

def setup_env() -> None:
    global env, obs
    logger.debug('Setting up environment...')
    
    # sanity check
    assert policy is not None
    assert next_best_policy is not None
    assert second_best_policy is not None
    assert agent is not None

    # environment
    opponent_policies = [policy for _ in range(n_opponents - 1)]
    opponent_policies.append(next_best_policy)

    env = BattlesnakeEnv(n_threads=CPU_THREADS, n_envs=n_envs, opponents=opponent_policies, device=device)

    obs = env.reset()
    rollouts.obs[0].copy_(torch.tensor(obs))

    # send network/storage to gpu
    policy.to(device)
    next_best_policy.to(device)
    second_best_policy.to(device)

    rollouts.to(device)

def count_parameters(model) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
last_iteration = -1

rewards = []
value_losses = []
action_losses = []
dist_entropies = []
lengths = []

def train(num_updates, start_iteration=0, check_perf=True, test_every=5, required_winrate=0.3, force_save=-1,
          gamma=0.99, lambd=0.95, datafile=None):
    logger.debug("Starting training.")
    global last_iteration, rewards, value_losses, lengths

    data = u.load_data(datafile=datafile, start_iteration=start_iteration)
    rewards = data['rewards']
    value_losses = data['value_losses']
    action_losses = data['action_losses']
    dist_entropies = data['dist_entropies']
    lengths = data['lengths']

    training_start = time.time()
    for i in range(num_updates):
        iteration_start = time.time()
        j = start_iteration + i
        episode_rewards = []
        episode_lengths = []
        
        policy.eval()
        logger.info(f"Iteration {j}: Generating rollouts")
        for step in tqdm(range(n_steps)):
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = policy.act(rollouts.obs[step],
                                                                rollouts.recurrent_hidden_states[step],
                                                                rollouts.masks[step])
            obs, reward, done, infos = env.step(action.cpu().squeeze())
            obs = torch.tensor(obs)
            reward = torch.tensor(reward).unsqueeze(1)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    episode_lengths.append(info['episode']['l'])

            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor([[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = policy.get_value(
                rollouts.obs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]
            ).detach()
            
        policy.train()

        logger.debug("Training policy on rollouts...")
        rollouts.compute_returns(next_value, True, gamma, lambd, False)
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()

        policy.eval()
        
        total_num_steps = (j + 1) * n_envs * n_steps
        
        lengths.append(np.mean(episode_lengths))
        rewards.append(np.mean(episode_rewards))
        value_losses.append(value_loss)
        action_losses.append(action_loss)
        dist_entropies.append(dist_entropy)
        
        time_before_save = time.time()
        u.save_data(rewards=rewards,
                    value_losses= value_losses,
                    action_losses=action_losses,
                    dist_entropies=dist_entropies, 
                    lengths=lengths, 
                    iteration=j, datafile=datafile)

        updated_after_check = False
        if check_perf and (j % test_every == 0 and j != 0):
            test_start = time.time()

            logger.info("=" * 80)
            logger.info(f"Iteration {j} Results")
            
            # TODO: do in parallel
            # Check the performance of the current policy against the prior best
            winrate = check_performance(policy, next_best_policy, device=torch.device(device))
            
            logger.info(f"Winrate vs prior best: {winrate*100:.2f}%")
            logger.info(f"Median Length: {np.median(episode_lengths)}")
            logger.info(f"Max Length: {np.max(episode_lengths)}")
            logger.info(f"Min Length: {np.min(episode_lengths)}")

            if winrate > required_winrate:
                logger.info("Policy winrate is > 30%. Updating prior best model")
                second_best_policy.load_state_dict(next_best_policy.state_dict()) # shift the prior best to second best
                next_best_policy.load_state_dict(policy.state_dict())
                logger.info("Saving latest best model.")
                updated_after_check = True
                torch.save(next_best_policy.state_dict(), u.get_modelpath(iteration=j))
            else:
                logger.info("Policy has not learned enough yet... keep training!")

            logger.info(f"Time taken (performance test total): {time.time() - test_start:.2f} seconds")
            logger.info("-" * 80)


        if not updated_after_check and (force_save > 0 and j % force_save == 0):
            logger.info("Force updated next best policy.")
            next_best_policy.load_state_dict(policy.state_dict())
            logger.info("Force saving latest model.")
            torch.save(policy.state_dict(), u.get_modelpath(custom=f'tmp_iter', iteration=j))

        iteration_end = time.time()
        
        logger.info(f"Time taken (model save total): {iteration_end - time_before_save:.2f} seconds")
        logger.info(f"Time taken (iteration total): {iteration_end - iteration_start:.2f} seconds")
        logger.info(f"Time taken (since start): {iteration_end - training_start:.2f} seconds")
        logger.info(f"Completed {total_num_steps} steps")
            
    logger.info("Saving final model.")
    last_iteration = start_iteration + num_updates - 1
    torch.save(policy.state_dict(), u.get_modelpath(last_iteration))
    torch.save(policy.state_dict(), u.get_modelpath_latest())

    return rewards, value_losses, lengths

In [None]:

torch.backends.cuda.matmul.allow_tf32 = False # Do matmul at TF32 mode.
CPU_THREADS = os.cpu_count()

# FORMULA: The total training set size per iteration is n_envs * n_steps
n_envs = 210     # parallel environments
n_steps = 600    # steps per environment to simulate

num_updates = 6000 # target iterations to run
test_every = 50 # iterations per test
required_winrate = 0.3
check_perf = True
force_save = 5 # iterations per force save
# TODO: flexible test_every

MODEL_GROUP = 'test10'

u.set_modelgroup(MODEL_GROUP, read_tmp=True)



# =====================================
latest_model_path, iteration = u.get_latest_model()

# configure logger
file_handler = logging.FileHandler(u.get_modelpath(custom='training', ext='log'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

print(f"Latest model: {latest_model_path}")
print(f"Latest iteration: {iteration}")

In [None]:

setup_rollouts(n_envs)

setup_agent(model_path=latest_model_path,
            value_loss_coef=0.5,
            entropy_coef=0.01,
            max_grad_norm=0.5,
            clip_param=0.2,
            num_mini_batch=32,
            ppo_epoch=4,
            eps=1e-5,
            lr=5e-5,
            use_clipped_value_loss=True,
            use_gradient_accumulation=False,
            gradient_accumulation_steps=1
            )
logger.debug(f"Trainable Parameters: {count_parameters(policy)}")


setup_env()

rewards2, value_losses2, lengths2 = train(num_updates, start_iteration=iteration+1,
                                          test_every=test_every,
                                          check_perf=check_perf,
                                          required_winrate=required_winrate,
                                          force_save=force_save,
                                          gamma=0.998,
                                          lambd=0.95,
                                          )



# TODO: send notification when complete

In [None]:
plot_graphs(rewards, value_losses, action_losses, dist_entropies, lengths)