# Setup

## Create Filesystem
This notebook is primarily meant to be executed in Colab as a computational backend. If you want to run on your own hardware with data, you need to set `data_dir` and `ALLOW_IO`

This notebook viewable directly on Colab from [https://colab.research.google.com/github/rcharan/phutball/blob/rl/pytorch-implementation/model-training.ipynb](https://colab.research.google.com/github/rcharan/phutball/blob/rl/pytorch-implementation/model-training.ipynb) (it is a mirror of github). But if it has moved branches or you are looking at a past commit, look at the [Google instructions](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb) on where to find this file.

The workflow is:
 - Data stored in (my personal/private) Google Drive
 - Utilities/library files (for importing) on github, edited on local hardware and pushed to github.
 - Notebook hosted on github, edited both in Colab or locally (depending on the relative value of having a GPU attached versus being able to use regular Jupyter keyboard shortcuts/a superior interface)

In [1]:
# Attempt Colab setup if on Colab
try:
  import google.colab
except:
  ALLOW_IO = False
else:
  # Mount Google Drive at data_dir
  #  (for data)
  from google.colab import drive
  from os.path import join
  ROOT = '/content/drive'
  DATA = 'My Drive/phutball'
  drive.mount(ROOT)
  ALLOW_IO = True
  data_dir = join(ROOT, DATA)
  !mkdir "{data_dir}"     # in case we haven't created it already   

  # Pull in code from github
  %cd /content
  github_repo = 'https://github.com/rcharan/phutball'
  !git clone -b rl {github_repo}
  %cd /content/phutball
  
  # Point python to code base
  import sys
  sys.path.append('/content/phutball/pytorch-implementation')

  # Updater for library functions changed on local hardware and pushed to github
  #  (circuitous, I know)
  def update_repo():
    !git pull

## Imports

In [32]:
%%capture

%load_ext autoreload
%autoreload 2

import os
import gc

# Codebase
from lib.model_v1          import TDConway
from lib.moves             import create_placement_getter, get_jumps
from lib.utilities         import config
from lib.testing_utilities import create_state, visualize_state, boards
from lib.timer             import Timer
from lib.moves             import END_LOC, COL, CHAIN

# Graphics for visualization
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
%matplotlib inline
plt.ioff()

# PyTorch
import torch
from torch.optim import Optimizer

# Tensorflow solely for the Progress Bar (totally worth it)
from tensorflow.keras.utils import Progbar as ProgressBar

## Device Management Utilities
Setup for GPU, CPU, or (not working well/fully implemented) TPU

In [33]:
use_tpu = False

if use_tpu:
  # Install PyTorch/XLA
  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version $VERSION
  import torch_xla
  import torch_xla.core.xla_model as xm
  
  # Set the device
  device = xm.xla_device()
  
  # Memory inspection
  def print_memory_usage():
    print('TPU memory inspection not implemented')
  def print_max_memory_usage():
    print('TPU memory inspection not implemented')
  def garbage_collect():
    gc.collect() # No TPU specific implementation yet
    
elif torch.cuda.is_available():
  # Set the device
  device = torch.device('cuda')
  
  # Echo GPU info
  gpu_info = !nvidia-smi
  gpu_info = '\n'.join(gpu_info)
  print(gpu_info)
  
  # Memory inspection and management
  from lib.memory import (
    print_memory_usage_cuda     as print_memory_usage,
    print_max_memory_usage_cuda as print_max_memory_usage,
    garbage_collect_cuda        as garbage_collect
  )

else:
  # Set the device to CPU
  device = torch.device('cpu')
  
  # Echo RAM info
  from psutil import virtual_memory
  from lib.memory import format_bytes
  ram = virtual_memory().total
  print(format_bytes(ram), 'available memory on CPU-based runtime')
  
  # Memory inspection and management
  from lib.memory import (
    print_memory_usage, 
    print_max_memory_usage,
    garbage_collect
  )

8.00GiB available memory on CPU-based runtime


# Configuration

In [None]:
BATCH_SIZE = 64
MAX_JUMPS = 300

# PyTorch Tools
These tools will be moved to .py files in the lib folder when stable

## Model Persistence

In [None]:
class ModelState:

  def __init__(self, model_name, model = None,
               target_ensemble = None, 
               optimizer = None, optimizer_class = Adam,
               epochs_done = 0, 
               best_val_loss = float('inf'),
               train_tensor = None, valid_tensor = None, valid_df = None,
               other_data = None):
    self.model_name      = model_name
    self.model           = model
    self.target_ensemble = target_ensemble
    self.epochs_done     = epochs_done
    self.optimizer       = optimizer
    self.optimizer_class = optimizer_class
    self.best_val_loss   = best_val_loss
    self.other_data      = other_data
    self.train_tensor    = train_tensor
    self.valid_tensor    = valid_tensor
    self.valid_df        = valid_df
    
  def update(self, epochs_done, best_val_loss, other_data):
    self.epochs_done   = epochs_done
    self.best_val_loss = best_val_loss
    self.other_data    = other_data

  @property
  def fname(self):
    return  f'penn/checkpoints/{self.model_name}.pt'

  def temp_fname(self, subdir, model_name, info):
    return f'penn/checkpoints/{subdir}/{model_name}/{info}.pt'

  @property
  def trainable_params(self):
    return [
      name 
      for name, param in self.model.named_parameters()
      if  param.requires_grad
    ]

  def state_dict(self):
    return {
        'model'            : self.model.state_dict(),
        'games_done'       : self.epochs_done,
        'optimizer'        : self.optimizer.state_dict(),
        'game_lengths'     : self.game_lengths
        'dropout'          : self.model.dropout,
        'trainable_params' : self.trainable_params,
    }

  def save_head(self, target, epoch):
    param = self.model.heads[target.col_name]
    fname = self.temp_fname('heads', self.model_name, f'{target.col_name}-{epoch}')
    with open_safely(fname, 'wb') as f:
      torch.save(param.state_dict(), f)

  def register_parent_model(self, names):
    self.parent_models = names

  def load_head(self, target):
    param = self.model.heads[target.col_name]
    model_names = self.parent_models.copy()
    def get_fname(model_base):
      return self.temp_fname('heads', model_base,
                              f'{target.col_name}-{target.best_val_epoch}')
    fname = get_fname(self.model_name)
    while not os.path.exists(fname):
      # print(f"Can't find {fname}")
      try:
        model_name = model_names.pop()
      except IndexError:
        print(f'Unable to load {target.col_name} at epoch {target.best_val_epoch}')
        return
      fname = get_fname(model_name)
      # print(f"Trying {fname}")


    with open_safely(fname, 'rb') as f:
      sd = torch.load(f)
    
    param.load_state_dict(sd)

  def save(self):
    with open_safely(self.fname, 'wb') as f:
      torch.save(self.state_dict(), f)

  def load(self, initial_state = lambda _ : None):
    path = self.fname
    if os.path.exists(path):
      print(f'Restoring from {path}')
    else:
      print(f'No saved model found, proceeding with a fresh initialization')
      return initial_state(self.model_name)

    with open_safely(self.fname, 'rb') as f:
      sd = torch.load(f, map_location = torch.device('cpu'))

    self.target_ensemble = TargetEnsemble([])
    self.target_ensemble.load_state_dict(sd['targets'])

    self.model = BERT(sd['dropout'], self.target_ensemble)
    self.model.load_state_dict(sd['model'])
    self.model = self.model.to(device)

    for name, param in self.model.named_parameters():
      if name in sd['trainable_params']:
        param.requires_grad_(True)
      else:
        param.requires_grad_(False)

    trainable_params = lfilter(lambda param : param.requires_grad, self.model.parameters())

    self.optimizer = self.optimizer_class(trainable_params)
    self.optimizer.load_state_dict(sd['optimizer'])

    self.best_val_loss = sd['best_val_loss']
    self.epochs_done   = sd['epochs_done']
    self.train_tensor  = sd['train_tensor']
    self.valid_tensor  = sd['valid_tensor']
    self.other_data    = sd['other_data']
    self.valid_df      = sd['valid_df']

    self.train_tensor = lmap(lambda t : t.to(device), self.train_tensor)
    self.valid_tensor = lmap(lambda t : t.to(device), self.valid_tensor)

    print(f'{self.epochs_done} epochs run with'
          f' best validation loss {self.best_val_loss:.2f}')
    
    return self
    
  def inspect(self):
    if not os.path.exists(self.fname):
      print('No such model')
      return

    state_dict = torch.load(self.fname, map_location = torch.device('cpu'))
    print(f'''Epochs run: {state_dict['epochs_done']}\n'''
          f'''Valid loss: {state_dict['best_val_loss']:.3f}''')

  def size_on_disk(self):
    if not os.path.exists(self.fname):
      return '0'
    else:
      return format_bytes(os.path.getsize(self.fname))
  
  @classmethod
  def _dataset(self, tensors, batch_size, shuffle = True):
    # Construct a Torch Dataset, then a DataLoader that random samples and batches
    dataset     = TensorDataset(*tensors)
    dataset     = DataLoader(dataset, batch_size, shuffle = shuffle)

    return dataset

  def train_dataset(self, batch_size):
    return self._dataset(self.train_tensor, batch_size, shuffle = True)

  def valid_dataset(self, batch_size):
    return self._dataset(self.valid_tensor, batch_size, shuffle = False)

## Optimizer

In [None]:
class AlternatingTDLambda(Optimizer):
  '''Implements tracing and updates for the TD(λ) algorithm.
  
  For details see Sutton and Barto, Reinforcement Learning 2ed,
  Chaper 12 Section 2.
  
  Modifications:
    (1) the algorithm is modified to compute traces
        one step early (when the graph is available). Hence, trace must be
        initialized with update_trace before calling step.
        
    (2) Because the board switches sides each turn, eligibility trace updates
        must be of the form z_t+1 <-  -λz_t + 𝝯v
  
  You may find the following greeks a helpful reference:
  alpha : Learning Rate
  lambda: Exponential Decay parameter for the eligibility trace
          (which is essentially momentum but with different
           interpertation and is occasionally zeroed out).
  delta : Temporal difference (TD) i.e. the difference between the estimated
          value of a step and the realized value upon the best move (based
          on further estimation of course).
  '''
  def __init__(self, parameters, alpha, lamda):
    self.alpha = alpha
    self.lamda = lamda # Note alternate spelling (I didn't make it up!)
    defaults   = dict(alpha = alpha, lamda = lamda)

    super(TDLambda, self).__init__(parameters, defaults)
    
  @torch.no_grad()
  def step(self, delta, update_trace = True):
    '''Performs a single optimization step updating the trace *afterwards*
    
    update_trace must be called before first step to initialize the trace.
    
    Arguments
    ---------
    
    delta: Difference between estimated value of move at time t and realized
           value after moving and going to time t+1
    '''
    
    for group in self.param_groups:
      alpha = group['alpha']
      
      for p in group['params']:
        if p.grad is None:
          continue
        
        state = self.state[p]
        
        if len(state) == 0 or 'trace' not in state:
          raise RuntimeError('Traces must be initialized before calling step')
          
        trace = state['trace']
        
        p.add_(alpha * delta * trace) # Note gradient *ascent* in reinforcement learning
    
    if update_trace:
      self.update_trace()

  @torch.no_grad()
  def update_trace(self):
    '''Updates the trace based on the gradients.
    
    This also is the only way to initialize the traces.
    It must be called after evaluating the starting position
    and backprop at the beginning of the game.
    '''
    for group in self.param_groups:
      lamda = group['lamda']
      
      for p in group['params']:
        if p.grad is None:
          continue
        
        state = self.state[p]
        
        # Initialize to zero if necessary
        if len(state) == 0 or 'trace' not in state:
          state['trace'] = torch.zeros_like(p)
          
        trace = state['trace']
        trace.mul_(-lamda).add_(p.grad)
      
  @torch.no_grad()
  def zero_trace(self):
    for group in self.param_groups:
      for p in group['params']:
        state = self.state['p']
        if 'trace' not in state:
          continue
          
        trace = state['trace']
        trace.zero_()
      

## Training Loop

In [None]:
def training_loop(optimizer, num_games, off_policy = lambda _ : None):
  
  initial_state = create_state('H10').to(device)
  for i in range(num_games):
    print(f'\nPlaying game {i+1} of {num_games}:')
    game_loop(initial_state, model, optimizer, off_policy)

In [None]:
def restart(optimizer, score):
  optimizer.zero_trace()
  score.backwards()
  optimizer.update_trace()
  optimizer.zero_grad()
  return score.item()

In [None]:
def game_loop(initial_state, model, optimizer, off_policy):
  '''Training loop that plays one game'''
  # Just in case
  optimizer.zero_grad()
  
  # Initialization
  state    = initial_state
  score, _ = model(state.unsqueeze(0))
  v_t      = restart(optimizer, score)
  
  # Progress Bar
  bar      = ProgressBar(284)
  move_num = 1
  
  while True:
    # Determine the next move
    game_over, moved_off_policy, new_state, score = \
      get_next_move_training(curr_state, off_policy = off_policy)
    
    if game_over:      
      delta = 1 - v_t
      optimizer.step(delta, update_trace = False)
      
      # Terminate the progress bar
      bar.target = move_num
      bar.update(move_num)

      break
    
    elif moved_off_policy:
      # Equivalent to starting a new game
      v_t = restart(optimizer, score)
      
    else:
      score.backwards()
      delta = (1 - score) - v_t
      optimizer.step(delta)
      optimizer.zero_grad()
      
      v_t   = score.item()
      state = new_state
      
    # Progress bar
    move_num += 1
    if move_num >= bar.target * 0.9:
      bar.target += bar.target // 10 + 1
    
    bar.update(move_num)
    


In [None]:
def get_next_move_training(curr_state, off_policy = lambda _ : None):
  '''Get the next move for the bot
  
  Gets the next move for the bot with computations and
  return value suitable for training only
  (i.e. gradients are taken)
  
  If off_policy is not None, the off_policy move is
  selected instead.
  
  Gradients with respect to the value-function applied
  at the best move are accumlated and availabe to the caller
  
  Inputs
  ------
  
  curr_state: binary tensor of shape (channels, rows, cols)
              reprenting the game state
              
  off_policy: callable with signature
              off_policy(num_available_moves: int) returning
              either None or the index of the move desired.
              If the return value is not None AND the bot 
              cannot otherwise win on that move, then that
              move is made.
              
  Outputs
  -------
  game_over  : boolean. Whether the bot can (and does) win
             on this move. The bot *always* plays a
             win-in-one move when it is available, regardless
             of the off-policy argument.
             
  off_policy : boolean. Whether an off-policy move was made.
             OR: value is None if game is over
  
  new_state  : a binary tensor of same shape as curr_input
             representing the new state of the game after
             the bot moves AND the board is flipped around
             to present it from opponents view. OR: value
             is None if game is over.
               
  value      : value of the value-function applied to new_state.
             OR: value is None if the game is over.
  '''
  # Compute the placements
  placements = get_placements(curr_state)

  # Compute the jumps
  jumps = get_jumps(curr_state, MAX_JUMPS)

  # Deal with special cases/win condition for the jump
  
  # No jumps to worry about
  if len(jumps) == 0:
    moves = placements
  
  # Win condition
  elif (
    len(jumps) == 1 and
    jumps[0][CHAIN][END_LOC][COL] in [config.cols, config.cols-1]):
    return True, None, None, None # The game is over!
  
  # Regular jump evaluation
  else:
    # Retain only the final state
    jumps = [jump_data[0] for jump_data in jumps]
    jumps = torch.tensor(jumps, dtype = torch.bool)
    moves = torch.cat([placements, jumps])
    
  # Turn the board around to represent the opponent's view
  moves = torch.flip(moves, [-1])
  
  # Either make an off policy move, or evaluate the value-function
  #  to determine the policy
  if off_policy(moves.shape[0]) is not None:
    move     = moves[move].unsqueeze()
    score, _ = model(move)
    return False, True, move, score
  
  # Batch the moves
  batches = torch.split(moves, BATCH_SIZE)
  
  # We only need to differentiate the best score
  #  track which one that is.
  best_score = None
  best_index = None
  curr_index = 0

  
  timer = Timer()
  for batch in batches:

    
    # Run the model
    score, index = model(batch)

    # Update running tally of best score
    #  Old best scores should have their graphs
    #  destroyed

    if best_score is None or score > best_score:
      best_score = score
      best_index = curr_index + index

    # Keep track of how many indices we've traversed to 
    #  get best_index correct
    curr_index += batch.shape[0]
    
    gc.collect()
    
    print(timer)
    # print(f'{process.memory_info().rss:,d}')
    
  # Return
  return False, False, moves[best_index], best_score

# Execution

In [None]:
get_placements = create_placement_getter(device)

In [None]:
model = TDConway(config).to(device)

In [None]:
initial_state = create_state('H10').to(device)
visualize_state(initial_state.cpu())

In [None]:
curr_state = initial_state

In [None]:
get_next_move_training(initial_state)