# Training

In [None]:
# !cd ../bs-gym/ && pip install -e . -q
# !cd .
# !cd ../pytorch-a2c-ppo-acktr-gail/ && pip install -e . -q

In [None]:
# !which python
# !which pip
# !pip list

In [None]:
import numpy as np
import time
from datetime import datetime
import torch
from tqdm import tqdm

from a2c_ppo_acktr.algo import PPO
from a2c_ppo_acktr.storage import RolloutStorage
from bs_gym.gymbattlesnake import BattlesnakeEnv

In [None]:
from performance import check_performance
from policy import SnakePolicyBase, create_policy

from utils import n_opponents,  device
from utils import PathHelper, plot_graphs

u = PathHelper()

rollouts = None
tmp_env = None
n_envs = -1
n_steps = -1
CPU_THREADS = -1

def setup_rollouts(n_envs) -> None:
    global rollouts, tmp_env

    print('Setting up rollouts...')
    tmp_env = BattlesnakeEnv(n_threads=2, n_envs=n_envs)

    # rollout storage: game turns played and rewards
    rollouts = RolloutStorage(n_steps,
                            n_envs,
                            tmp_env.observation_space.shape,
                            tmp_env.action_space,
                            n_steps)
    tmp_env.close()

policy = None
best_old_policy = None
agent = None

def setup_agent(model_path=None,
                value_loss_coef=0.5,
                entropy_coef=0.01,
                max_grad_norm=0.5,
                clip_param=0.2,
                ppo_epoch=4,
                num_mini_batch=16,
                eps=1e-5,
                lr=5e-5,
                use_clipped_value_loss=True,
                use_gradient_accumulation=False,
                gradient_accumulation_steps=1
                ) -> None:
    global policy, best_old_policy, agent
    print('Setting up agent...')

    # policies
    policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
    # load latest model if found
    if model_path is not None:
        policy.load_state_dict(torch.load(model_path))

    best_old_policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
    best_old_policy.load_state_dict(policy.state_dict())

    # TODO: second_best model

    # agent
    agent = PPO(policy,
                value_loss_coef=value_loss_coef,
                entropy_coef=entropy_coef,
                max_grad_norm=max_grad_norm,
                clip_param=clip_param,
                ppo_epoch=ppo_epoch,
                num_mini_batch=num_mini_batch,
                eps=eps,
                lr=lr,
                use_clipped_value_loss=use_clipped_value_loss,
                use_gradient_accumulation=use_gradient_accumulation,
                gradient_accumulation_steps=gradient_accumulation_steps,
                )

env = None
obs = None

def setup_env() -> None:
    global env, obs
    print('Setting up environment...')
    
    # sanity check
    assert policy is not None
    assert best_old_policy is not None
    assert agent is not None

    # environment
    env = BattlesnakeEnv(n_threads=CPU_THREADS, n_envs=n_envs, opponents=[policy for _ in range(n_opponents)], device=device)
    # TODO: second_best model

    obs = env.reset()
    rollouts.obs[0].copy_(torch.tensor(obs))

    # send network/storage to gpu
    policy.to(device)
    best_old_policy.to(device)
    rollouts.to(device)

def count_parameters(model) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
last_iteration = -1

rewards = []
value_losses = []
action_losses = []
dist_entropies = []
lengths = []

def train(num_updates, start_iteration=0, test_every=5, required_winrate=0.3, force_save=-1,
          gamma=0.99, lambd=0.95, datafile=None):
    print("Starting training.")
    global last_iteration, rewards, value_losses, lengths

    data = u.load_data(datafile=datafile, start_iteration=start_iteration)
    rewards = data['rewards']
    value_losses = data['value_losses']
    action_losses = data['action_losses']
    dist_entropies = data['dist_entropies']
    lengths = data['lengths']

    start = time.time()
    for i in range(num_updates):
        j = start_iteration + i
        episode_rewards = []
        episode_lengths = []
        
        policy.eval()
        print(f"Iteration {j}: Generating rollouts")
        for step in tqdm(range(n_steps)):
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = policy.act(rollouts.obs[step],
                                                                rollouts.recurrent_hidden_states[step],
                                                                rollouts.masks[step])
            obs, reward, done, infos = env.step(action.cpu().squeeze())
            obs = torch.tensor(obs)
            reward = torch.tensor(reward).unsqueeze(1)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    episode_lengths.append(info['episode']['l'])

            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor([[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = policy.get_value(
                rollouts.obs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]
            ).detach()
            
        policy.train()

        print("Training policy on rollouts...")
        rollouts.compute_returns(next_value, True, gamma, lambd, False)
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()

        policy.eval()
        
        total_num_steps = (j + 1) * n_envs * n_steps
        end = time.time()
        
        lengths.append(np.mean(episode_lengths))
        rewards.append(np.mean(episode_rewards))
        value_losses.append(value_loss)
        action_losses.append(action_loss)
        dist_entropies.append(dist_entropy)
        
        print(f"Time taken: {end - start:.2f} seconds")
        print(f"Completed {total_num_steps} steps")

        u.save_data(rewards=rewards,
                    value_losses= value_losses,
                    action_losses=action_losses,
                    dist_entropies=dist_entropies, 
                    lengths=lengths, 
                    iteration=j, datafile=datafile)

        saved_model = False
        if j % test_every == 0 and j != 0:
            print("\n")
            print("=" * 80)
            print("Iteration", j, "Results")
            # TODO: do in parallel
            # Check the performance of the current policy against the prior best
            winrate = check_performance(policy, best_old_policy, device=torch.device(device))#device=device)
            print(f"Winrate vs prior best: {winrate*100:.2f}%")
            print(f"Median Length: {np.median(episode_lengths)}")
            print(f"Max Length: {np.max(episode_lengths)}")
            print(f"Min Length: {np.min(episode_lengths)}")

            if winrate > required_winrate:
                print("Policy winrate is > 30%. Updating prior best model")
                best_old_policy.load_state_dict(policy.state_dict())
                print("Saving latest best model.")
                saved_model = True
                torch.save(best_old_policy.state_dict(), u.get_modelpath(iteration=j))
            else:
                print("Policy has not learned enough yet... keep training!")
            print("-" * 80)
        if not saved_model and (force_save > 0 and j % force_save == 0):
            print("Force saving latest model.")
            torch.save(policy.state_dict(), u.get_modelpath(custom=f'tmp_iter', iteration=j))
            
    print("Saving final model.")
    last_iteration = start_iteration + num_updates - 1
    torch.save(policy.state_dict(), u.get_modelpath(last_iteration))
    torch.save(policy.state_dict(), u.get_modelpath_latest())

    return rewards, value_losses, lengths

In [None]:
# Do matmul at TF32 mode.
torch.backends.cuda.matmul.allow_tf32 = False

# FORMULA: The total training set size per iteration is n_envs * n_steps
n_envs = 400     # parallel environments
n_steps = 600    # steps per environment to simulate
CPU_THREADS = 6

num_updates = 100 # iterations to run
test_every = 10 # iterations per test
required_winrate = 0.3
force_save = 5 # iterations per force save
# TODO: flexible test_every

MODEL_GROUP = 'test7'

u.set_modelgroup(MODEL_GROUP, read_tmp=True)



# =====================================
latest_model_path, iteration = u.get_latest_model()
print(f"Latest model: {latest_model_path}")
print(f"Latest iteration: {iteration}")

In [None]:

setup_rollouts(n_envs)

setup_agent(model_path=latest_model_path,
            value_loss_coef=0.5,
            entropy_coef=0.01,
            max_grad_norm=0.5,
            clip_param=0.2,
            num_mini_batch=64,
            ppo_epoch=4,
            eps=1e-5,
            # lr=3e-4,
            lr=5e-5,
            use_clipped_value_loss=True,
            use_gradient_accumulation=False,
            gradient_accumulation_steps=1
            )
print(f"Trainable Parameters: {count_parameters(policy)}")


setup_env()

rewards2, value_losses2, lengths2 = train(num_updates, start_iteration=iteration+1,
                                          test_every=test_every,
                                          required_winrate=required_winrate,
                                          force_save=force_save,
                                          gamma=0.99,
                                          lambd=0.95,
                                          )



# TODO: send notification when complete

In [None]:
plot_graphs(rewards, value_losses, action_losses, dist_entropies, lengths)