In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0
%env WANDB_NOTEBOOK_NAME train.ipynb
%env WANDB_SILENT true

from collections import defaultdict
import os

import numpy as np
import pandas as pd
import torch
import wandb

import data
import inept

# Set params
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BASE_FOLDER = os.path.abspath('')
DATA_FOLDER = os.path.join(BASE_FOLDER, '../data/')
MODEL_FOLDER = os.path.join(BASE_FOLDER, 'models/')

- VERIFY
  - Check that rewards are normalized after (?) advantage

- HIGH PRIORITY
  - Add parallel envs of different sizes
  - Fix reconstruction speed of memories, will result in 10x training speedup (likely the cause for low GPU utilization)

- LOW PRIORITY
  - Add parallel envs of different sizes
  - Make backward (MAX_NODES, MAX_BATCH) batching work
  - Add multithreading to forward and distributed to backward

- LINKS
  - [Original paper (pg 24)](https://arxiv.org/pdf/1909.07528.pdf)
  - [Original blog](https://openai.com/research/emergent-tool-use)
  - [Gym](https://gymnasium.farama.org/)
  - [Slides](https://glouppe.github.io/info8004-advanced-machine-learning/pdf/pleroy-hide-and-seek.pdf)
  - [PPO implementation](https://github.com/nikhilbarhate99/PPO-PyTorch/blob/master/PPO.py#L38)
  - [Residual SA](https://github.com/openai/multi-agent-emergence-environments/blob/bafaf1e11e6398624116761f91ae7c93b136f395/ma_policy/layers.py#L89)

### Parameters

In [2]:
# Notebook kwargs
note_kwargs = {'seed': 42}

# Data parameters
data_kwargs = {
    'dataset': 'MouseVisual',
    'standardize': True,
    # 'pca_dim': [min(512, *M.shape) for M in modalities],
    'num_nodes': 10,
}

# Environment parameters
env_kwargs = {
    'dim': 3,  # 2 = (x, y, vx, vy), 3 = (x, y, z, vx, vy, vz)
    'reward_distance_target': 1,  # None tries to emulate all modalities, ints or lists of ints only consider those modalities as targets
}

# Environment reward weights
stages_kwargs = {
    'env': (
        # Boundary, +origin, +vel+act, -origin+dist
        {'penalty_bound': 1},
        {'reward_origin': 1},
        {'penalty_velocity': 1, 'penalty_action': 1},
        {'reward_origin': 0, 'reward_distance': 1},
    ),
}

# Policy parameters
policy_kwargs = {
    # Main arguments
    'num_features_per_node': 2*env_kwargs['dim'],
    # 'modal_sizes': None,  # Determined in the running script
    'output_dim': env_kwargs['dim'],
    # Backpropagation
    'update_maxbatch': None,  # Total memory to sample from during backprop
    'update_batch': int(1e4),  # Memory to sample from during each backprop epoch
    'update_minibatch': int(1e4),  # Max memories to backprop at a time
    'update_load_level': 'minibatch',  # What stage to reconstruct memories from compressed form
    'update_cast_level': 'minibatch',  # What stage to cast to GPU memory
    # Internal arguments
    'embed_dim': 64,
    'feature_embed_dim': 32,
    # Training arguments
    'action_std_init': .6,
    'action_std_min': .1,
}

# Training parameters
train_kwargs = {
    'max_ep_timesteps': 1e3,
    'max_timesteps': 5e6,
    'update_timesteps': 5e3,
    'max_batch': None,  # Max number of nodes to calculate actions for at a time
    'max_nodes': None,  # Max number of nodes to use as neighbors in action calculation
    'episode_random_samples': True,  # Random nodes each epoch
    'use_wandb': True,  # Record performance to wandb
}

# Early stopping parameters
es_kwargs = {
    # Global parameters
    'buffer': 6 * int(train_kwargs['update_timesteps'] / train_kwargs['max_ep_timesteps']),  # 6 training cycles
    # `average` method parameters
    'window_size': 3 * int(train_kwargs['update_timesteps'] / train_kwargs['max_ep_timesteps']),  # 3 training cycles
}

### Load Data

In [3]:
# Reproducibility
torch.manual_seed(note_kwargs['seed'])
if torch.cuda.is_available(): torch.cuda.manual_seed(note_kwargs['seed'])
np.random.seed(note_kwargs['seed'])

# Load data
modalities, types, features = data.load_data(data_kwargs['dataset'], DATA_FOLDER)
ppc = inept.utilities.Preprocessing(**data_kwargs, device=DEVICE)
processed_modalities = ppc.fit_transform(modalities)

# Fixed samples
if not train_kwargs['episode_random_samples']:
    modalities, keys = ppc.subsample(processed_modalities, return_idx=True)
    modalities = ppc.cast(modalities)

### Train Policy

In [None]:
# Tracking parameters
# Use `watch -d -n 0.5 nvidia-smi` to watch CUDA memory usage
# Use `top` to watch system memory usage
# Run script and put following above function to profile
#    from memory_profiler import profile
#    @profile
# Use cProfiler to profile timing:
#    python -m cProfile -s time -o profile.prof train.py
#    snakeviz profile.prof

# Initialize classes
env = inept.environments.trajectory(*modalities, **env_kwargs, **stages_kwargs['env'][0], device=DEVICE)  # Set to first stage
policy_kwargs['modal_sizes'] = [m.shape[1] for m in env.get_return_modalities()]
policy = inept.models.PPO(**policy_kwargs, device=DEVICE).train()
early_stopping = inept.utilities.EarlyStopping(**es_kwargs)

# Initialize wandb
if train_kwargs['use_wandb']: wandb.init(
    project='INEPT',
    config={
        **{'note/'+k:v for k, v in note_kwargs.items()},
        **{'data/'+k:v for k, v in data_kwargs.items()},
        **{'env/'+k:v for k, v in env_kwargs.items()},
        **{'stages/'+k:v for k, v in stages_kwargs.items()},
        **{'policy/'+k:v for k, v in policy_kwargs.items()},
        **{'train/'+k:v for k, v in train_kwargs.items()},
        **{'es/'+k:v for k, v in es_kwargs.items()},
    },
)

# Initialize logging vars
torch.cuda.reset_peak_memory_stats()
timer = inept.utilities.time_logger(discard_first_sample=True)
timestep = 0; episode = 1; stage = 0

# CLI
print('Beginning training')
num_train_nodes = data_kwargs['num_nodes'] if train_kwargs['max_nodes'] is None else min(data_kwargs['num_nodes'], train_kwargs['max_nodes'])
num_train_batch = data_kwargs['num_nodes'] if train_kwargs['max_batch'] is None else min(data_kwargs['num_nodes'], train_kwargs['max_nodes'])
print(
    f'Training using {num_train_nodes} nodes out of a'
    f' total {data_kwargs["num_nodes"]} with forward batches of'
    f' size {num_train_batch}.'
)
update_maxbatch_print = (
    policy_kwargs["update_maxbatch"]
    if policy_kwargs["update_maxbatch"] is not None else 
    'all'
)
print(
    f'Training on {update_maxbatch_print} states'
    f' with batches of size {policy_kwargs["update_batch"]}'
    f' and minibatches of size {policy_kwargs["update_minibatch"]}'
    f' from {int(train_kwargs["update_timesteps"] * data_kwargs["num_nodes"])} total.')

# Simulation loop
while timestep < train_kwargs['max_timesteps']:
    # Sample new data
    if train_kwargs['episode_random_samples']:
        modalities, keys = ppc.subsample(processed_modalities, return_idx=True)
        modalities = ppc.cast(modalities)
        env.set_modalities(modalities)

    # Reset environment
    env.reset()
    timer.log('Reset Environment')

    # Start episode
    ep_timestep = 0; ep_reward = 0; ep_itemized_reward = defaultdict(lambda: 0)
    while ep_timestep < train_kwargs['max_ep_timesteps']:
        with torch.no_grad():
            # Get current state
            state = env.get_state(include_modalities=True)
            timer.log('Environment Setup')

            # Get actions from policy
            actions = policy.act_macro(
                state,
                keys=keys,
                max_batch=train_kwargs['max_batch'],
                max_nodes=train_kwargs['max_nodes'],
            ).detach()
            timer.log('Calculate Actions')

            # Step environment and get reward
            rewards, finished, itemized_rewards = env.step(actions, return_itemized_rewards=True)
            finished = finished or (ep_timestep == train_kwargs['max_ep_timesteps']-1)  # Maybe move logic inside env?
            timer.log('Step Environment')

            # Record rewards for policy
            policy.memory.record(
                rewards=rewards.cpu().tolist(),
                is_terminals=finished,
            )

            # Record rewards for logging
            ep_reward = ep_reward + rewards.cpu().mean()
            for k, v in itemized_rewards.items():
                ep_itemized_reward[k] += v.cpu().mean()
            timer.log('Record Rewards')

        # Iterate
        timestep += 1
        ep_timestep += 1

        # Update model
        if timestep % train_kwargs['update_timesteps'] == 0:
            # assert False
            print(f'Updating model with average reward {np.mean(policy.memory.storage["rewards"])} on episode {episode} and timestep {timestep}', end='')
            policy.update()
            print(f' ({torch.cuda.max_memory_allocated() / 1024**3:.2f} GB CUDA)')
            torch.cuda.reset_peak_memory_stats()
            timer.log('Update Policy')

        # Escape if finished
        if finished: break

    # Upload stats
    ep_reward = (ep_reward / ep_timestep).item()
    update = int(timestep / train_kwargs['update_timesteps'])
    if train_kwargs['use_wandb']:
        wandb.log({
            **{
            # Measurements
            'end_timestep': timestep,
            'episode': episode,
            'update': update,
            'stage': stage,
            # Parameters
            'action_std': policy.action_std,
            # Outputs
            'average_reward': ep_reward,
            },
            **{'rewards/'+k: (v / ep_timestep).item() for k, v in ep_itemized_reward.items()},
        })
    timer.log('Record Stats')

    # Decay model std
    if early_stopping(ep_reward) or timestep >= train_kwargs['max_timesteps']:
        # Save model
        wgt_file = os.path.join(MODEL_FOLDER, f'policy_{stage:02}.wgt')
        torch.save(policy.state_dict(), wgt_file)  # Save just weights
        if train_kwargs['use_wandb']: wandb.save(wgt_file)
        mdl_file = os.path.join(MODEL_FOLDER, f'policy_{stage:02}.mdl')
        torch.save(policy, mdl_file)  # Save whole model
        if train_kwargs['use_wandb']: wandb.save(mdl_file)

        # End if maximum timesteps reached
        if timestep >= train_kwargs['max_timesteps']:
            print('Maximal timesteps reached')

        # End if at minimum `action_std`
        if policy.action_std <= policy.action_std_min:
            print(f'Ending early on episode {episode} and timestep {timestep}')
            break

        # Activate next stage or decay
        stage += 1
        # CLI
        print(f'Advancing training to stage {stage}')
        if stage < len(stages_kwargs['env']):
            # Activate next stage
            env.set_rewards(stages_kwargs['env'][stage])
        else:
            # Decay policy randomness
            policy.decay_action_std()
            # CLI
            print(f'Decaying std to {policy.action_std} on episode {episode} and timestep {timestep}')

        # Reset early stopping
        early_stopping.reset()
    timer.log('Early Stopping')

    # Iterate
    episode += 1

# CLI Timer
print()
timer.aggregate('sum')

# Finish wandb
if train_kwargs['use_wandb']: wandb.finish()