# [Optional] Generate expert trajectories

**Note : This repository already includes expert dataset and you don't then need to run this script.** (This script is for your customization.)

Before running algorithms in imitaiton learning, we need expert's trajectories, which are used to learn behaviors or to recover rewards.<br>
In this example, we generate optimal policy by applying reinforcement learning method (here I use PPO algorithm), and generate expert's trajectories with this trained agent.

It's also worth noting that we use reward's function to get the optimal agent in this notebook, but reward's function can never be used for optimization in all imitation learning exercises in this repository.<br>
That's why the methods in imitation learning matters.

See [here](https://github.com/tsmatz/reinforcement-learning-tutorials) for theoretical background behind reinforcement learning algorithms. (Explanation of RL is out of scope in this repository.)

*(back to [index](https://github.com/tsmatz/imitation-learning-tutorials/))*

Before we start, we need to install the required packages.

In [1]:
!pip install torch numpy matplotlib



## 1. Define GridWorld environment

First we define GridWorld environment to be used in all exercises, and save as Python script file, ```gridworld.py```.<br>
For details about this environment, see [Readme.md](https://github.com/tsmatz/imitation-learning-tutorials/blob/master/Readme.md).

**Note : This repository already has ```gridworld.py``` and you don't then need to run this cell.**

In [2]:
%%writefile gridworld.py
import numpy as np
import torch
from torch.nn import functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GridWorld:
    """
    This environment is motivated by the following paper.
    https://proceedings.mlr.press/v15/boularias11a/boularias11a.pdf

    - It has 50 x 50 grids (cells).
    - The agent has four actions for moving in one of the directions of the compass.
    - [Optional] If ```transition_prob``` = True, the actions succeed with probability 0.7,
      a failure results in a uniform random transition to one of the adjacent states.
    - A reward of 10 is given for reaching the goal state, located on the bottom-right corner.
    - For the remaining states,
      the reward function was randomly set to 0 with probability 2/3
      and to −1 with probability 1/3.
    - If the agent moves across the border, it's given the fail reward (i.e, reward=`-1`).
    - The initial state is sampled from a uniform distribution.
    """


    def __init__(self, reward_map=None, valid_states=None, seed=None, transition_prob=False, max_timestep=200, device="cuda"):
        """
        Initialize class.

        Parameters
        ----------
        reward_map : float[grid_size * grid_size]
            Reward for each state.
            Set this value when you load existing world definition (when seed=None).
        valid_states : list(int[2])
            List of states, in which the agent can reach to goal state without losing any rewards.
            Each state is a 2d vector, [row, column].
            When you call reset(), the initial state is picked up from these states.
            Set this value when you load existing world definition (when seed=None).
        seed : int
            Seed value to generate new grid (maze).
            Set this value when you create a new world.
            (Above ```reward_map``` and ```valid_states``` are newly generated.)
        transition_prob : bool
            True if transition probability (above) is enabled.
            False when we generate an expert agent without noise.
            (If transition_prob=True, it only returns next states in step() function.)
        max_timestep : int
            The maximum number of time-step (horizon).
            When it doesn't have finite horizon, set None as max_timestep.
            (If max_timestep=None, it doesn't return trunc flag in step() function.)
        device : string
            Device info ("cuda", "cpu", etc).
        """

        self.device = device
        self.transition_prob = transition_prob
        self.grid_size = 50
        self.action_size = 4
        self.max_timestep = max_timestep
        self.goal_reward = 10

        if seed is None:
            ############################
            ### Load from definition ###
            ############################
            self.reward_map = torch.tensor(reward_map).to(self.device)
            self.valid_states = torch.tensor(valid_states).to(self.device)
        else:
            ################################
            ### Generate a new GridWorld ###
            ################################
            # generate grid
            self.reward_map = torch.zeros(self.grid_size * self.grid_size, dtype=torch.int).to(self.device)
            # bottom-right is goal state
            self.reward_map[-1] = self.goal_reward
            # set reward=−1 with probability 1/3
            sample_n = np.floor((self.grid_size * self.grid_size - 1) / 3).astype(int)
            rng = np.random.default_rng(seed)
            sample_loc = rng.choice(self.grid_size * self.grid_size - 1, size=sample_n, replace=False)
            sample_loc = torch.from_numpy(sample_loc).to(self.device)
            self.reward_map[sample_loc] = -1
            # seek valid states
            valid_states_list = self._greedy_seek_valid_states([self.grid_size-1, self.grid_size-1], [])
            valid_states_list.remove([self.grid_size-1, self.grid_size-1])
            self.valid_states = torch.tensor(valid_states_list).to(self.device)

    def _greedy_seek_valid_states(self, state, old_state_list):
        """
        This method recursively seeks valid state.
        e.g, if some state is surrounded by the states with reward=-1,
        this state is invalid, because it cannot reach to the goal state
        without losing rewards.

        Parameters
        ----------
        state : int[2]
            State to start seeking. It then seeks this state and all child's states.
            This state must be the list of [row, column].
        old_state_list : int[N, 2]
            List of states already checked.
            Each state must be the list of [row, column].
            These items are then skipped for seeking.

        Returns
        ----------
        valid_states : int[N, 2]
            List of new valid states.
            Each state must be the list of [row, column].
        """
        # build new list
        new_state_list = []
        # if the state is already included in the list, do nothing
        if state in old_state_list:
            return new_state_list
        # if the state has reward=-1, do nothing
        if self.reward_map[state[0]*self.grid_size+state[1]] == -1:
            return new_state_list
        # else add the state into the list
        new_state_list.append(state)
        # move up
        if state[0] > 0:
            next_state = list(map(lambda i, j: i + j, state, [-1, 0]))
            new_state_list += self._greedy_seek_valid_states(
                next_state,
                old_state_list + new_state_list)
        # move down
        if state[0] < self.grid_size - 1:
            next_state = list(map(lambda i, j: i + j, state, [1, 0]))
            new_state_list += self._greedy_seek_valid_states(
                next_state,
                old_state_list + new_state_list)
        # move left
        if state[1] > 0:
            next_state = list(map(lambda i, j: i + j, state, [0, -1]))
            new_state_list += self._greedy_seek_valid_states(
                next_state,
                old_state_list + new_state_list)
        # move right
        if state[1] < self.grid_size - 1:
            next_state = list(map(lambda i, j: i + j, state, [0, 1]))
            new_state_list += self._greedy_seek_valid_states(
                next_state,
                old_state_list + new_state_list)
        # return result
        return new_state_list

    def reset(self, batch_size):
        """
        Randomly, get initial state (single state) from valid states.

        Parameters
        ----------
        batch_size : int
            The number of returned states.

        Returns
        ----------
        state : torch.tensor((batch_size), dtype=int)
            Return the picked-up state id.
        """
        # initialize step count
        self.step_count = 0
        # pick up sample of valid states
        indices = torch.multinomial(torch.ones(len(self.valid_states)).to(self.device), batch_size, replacement=True)
        state_2d = self.valid_states[indices]
        # convert 2d index to 1d index
        state_1d = state_2d[:,0] * self.grid_size + state_2d[:,1]
        # return result
        return state_1d

    def step(self, actions, states, trans_state_only=False, transition_prob=None):
        """
        Take action, proceed step, and return the result.

        Parameters
        ----------
        actions : torch.tensor((batch_size), dtype=int)
            Actions to take
            (0=UP 1=DOWN 2=LEFT 3=RIGHT)
        states : torch.tensor((batch_size), dtype=int)
            Current state id.
        trans_state_only : bool
            Set TRUE, when you call only for getting next state by stateless without reset()
            (If transition_prob=True, it only returns next states in step() function.)
        transition_prob : bool
            Set this property, if you overrite default ```transition_prob``` property.
            (For this property, see above in __init__() method.)

        Returns
        ----------
        new-states : torch.tensor((batch_size), dtype=int)
            New state id.
        rewards : torch.tensor((batch_size), dtype=float)
            The obtained reward.
        term : torch.tensor((batch_size), dtype=bool)
            Flag to check whether it reaches to the goal and terminates.
        trunc : torch.tensor((batch_size), dtype=bool)
            Flag to check whether it's truncated by reaching to max time-step.
            (When max_timestep is None, this is not returned.)
        """
        # get batch size
        batch_size = actions.shape[0]
        # if transition prob is enabled, apply stochastic transition
        if transition_prob is None:
            trans_prob = self.transition_prob # set default
        else:
            trans_prob = transition_prob      # overrite
        if trans_prob:
            # the action succeeds with probability 0.7
            prob = torch.ones(batch_size, self.action_size).to(self.device)
            mask = F.one_hot(actions, num_classes=self.action_size).bool()
            prob = torch.where(mask, 7.0, prob)
            selected_actions = torch.multinomial(prob, 1, replacement=True)
            selected_actions = selected_actions.squeeze(dim=1)
            action_onehot = F.one_hot(selected_actions, num_classes=self.action_size)
        else:
            # deterministic (probability=1.0 in one state)
            action_onehot = F.one_hot(actions, num_classes=self.action_size)
        # get 2d state
        mod = torch.div(states, self.grid_size, rounding_mode="floor")
        reminder = torch.remainder(states, self.grid_size)
        state_2d = torch.cat((mod.unsqueeze(dim=-1), reminder.unsqueeze(dim=-1)), dim=-1)
        # move state
        # (0=UP 1=DOWN 2=LEFT 3=RIGHT)
        up_and_down = action_onehot[:,1] - action_onehot[:,0]
        left_and_right = action_onehot[:,3] - action_onehot[:,2]
        move = torch.cat((up_and_down.unsqueeze(dim=-1), left_and_right.unsqueeze(dim=-1)), dim=-1)
        new_states = state_2d + move
        # set reward
        if not(trans_state_only):
            rewards = torch.zeros(batch_size).to(self.device)
            rewards = torch.where(new_states[:,0] < 0, -1.0, rewards)
            rewards = torch.where(new_states[:,0] >= self.grid_size, -1.0, rewards)
            rewards = torch.where(new_states[:,1] < 0, -1.0, rewards)
            rewards = torch.where(new_states[:,1] >= self.grid_size, -1.0, rewards)
        # correct location
        new_states = torch.clip(new_states, min=0, max=self.grid_size-1)
        # if succeed, add reward of current state
        states_1d = new_states[:,0] * self.grid_size + new_states[:,1]
        if not(trans_state_only):
            rewards = torch.where(rewards>=0.0, rewards+self.reward_map[states_1d], rewards)
            self.step_count += 1
        # return result
        if trans_state_only:
            return states_1d
        elif self.max_timestep is None:
            return states_1d, rewards, rewards==self.reward_map[self.grid_size * self.grid_size - 1]
        else:
            return states_1d, rewards, rewards==self.reward_map[self.grid_size * self.grid_size - 1], torch.tensor(self.step_count==self.max_timestep).to(self.device).unsqueeze(dim=0).expand(batch_size)

Writing gridworld.py


Create an environment with a fixed seed, ```1000```.<br>
I note that transition probability is disabled in this environment (i.e, ```transition_prob=False``` in constructor).

In [3]:
import torch
from torch.nn import functional as F
from gridworld import GridWorld

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = GridWorld(seed=1000, device=device)

Now I visualize our GridWorld environment.

The number in each cell indicates the reward score on this state.<br>
The goal state is on the right-bottom corner (in which the reward is ```10.0```), and the initial state is uniformly picked up from the gray-colored cells.<br>
If the agent can reach to goal state without losing any rewards, it will get ```10.0``` for total reward.

See [Readme.md](https://github.com/tsmatz/imitation-learning-tutorials/blob/master/Readme.md) for details about the game rule of this environment.

In [4]:
from IPython.display import HTML, display

valid_states_all = torch.cat((env.valid_states, torch.tensor([env.grid_size-1,env.grid_size-1]).to(device).unsqueeze(dim=0)))
valid_states_all = valid_states_all[:,0] * env.grid_size + valid_states_all[:,1]

html_text = "<table>"
for row in range(env.grid_size):
    html_text += "<tr>"
    for col in range(env.grid_size):
        if row*env.grid_size + col in valid_states_all:
            html_text += "<td bgcolor=\"gray\">"
        else:
            html_text += "<td>"
        html_text += str(env.reward_map[row*env.grid_size+col].tolist())
        html_text += "</td>"
    html_text += "</tr>"
html_text += "</table>"

display(HTML(html_text))

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
0,-1,0,0,0,-1,0,0,0,0,0,0,-1,0,-1,-1,-1,-1,0,-1,0,-1,0,0,0,-1,0,0,0,-1,-1,0,-1,0,0,0,0,-1,-1,0,0,0,0,0,0,0,0,0,0,0
-1,0,0,0,0,-1,0,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,-1,-1,-1,0,-1,0,0,0,0,0,-1,0,-1,0,-1,0,-1,0,-1,-1,0,0,0,-1,-1,0,0,-1,-1
0,0,-1,-1,0,-1,0,-1,0,-1,0,0,0,0,0,0,0,-1,0,0,-1,0,0,-1,0,0,0,0,0,0,0,-1,0,0,-1,0,0,-1,-1,0,0,-1,0,0,0,0,-1,0,0,0
0,0,-1,0,-1,-1,0,0,-1,0,-1,0,0,0,0,-1,0,-1,0,-1,0,-1,0,0,-1,-1,0,0,-1,0,0,0,0,0,0,0,0,0,0,0,0,-1,0,0,0,0,0,0,0,-1
-1,-1,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,-1,-1,0,-1,0,0,0,0,0,0,0,0,0,0,0,0,-1,0,-1,0,0,0,-1,0,0,-1,0,-1,-1,0,0
0,-1,-1,0,0,-1,0,-1,0,0,-1,0,0,-1,0,0,-1,-1,0,0,-1,0,0,0,-1,0,0,0,0,-1,-1,0,0,-1,0,-1,-1,-1,-1,0,0,-1,0,-1,0,0,0,0,0,-1
-1,-1,0,0,0,0,0,0,0,0,0,0,-1,0,-1,0,0,0,0,0,-1,-1,0,0,-1,-1,0,0,-1,-1,0,-1,0,0,0,0,0,0,0,-1,-1,0,0,0,-1,0,-1,0,-1,-1
0,0,0,0,-1,0,0,0,-1,0,0,0,0,-1,0,0,-1,-1,0,0,-1,0,0,-1,0,0,-1,0,0,0,-1,-1,0,0,-1,0,-1,0,-1,0,0,-1,0,-1,0,0,0,0,0,0
0,0,0,0,0,0,0,0,0,0,-1,-1,0,-1,0,0,-1,0,0,-1,-1,0,0,0,0,0,0,-1,0,0,0,-1,-1,-1,-1,-1,-1,0,-1,-1,0,-1,0,0,0,-1,0,0,0,0
0,0,0,-1,0,-1,0,0,-1,-1,0,0,-1,0,0,0,0,0,0,0,0,0,-1,-1,-1,0,-1,0,0,0,0,0,0,0,0,0,0,-1,0,0,0,0,0,0,-1,-1,0,-1,-1,-1


Save (serialize) this environment as JSON in order to recover the same environment in the following examples.

**Note : This repository already has this JSON file (```gridworld.json```) and you don't then need to run this cell.**

In [5]:
import json

def serialize_gridworld(env):
    env_dict = env.__dict__
    return {
        "reward_map": env_dict["reward_map"].tolist(),
        "valid_states": env_dict["valid_states"].tolist()
    }

with open("gridworld.json", "w") as f:
    json_str = json.dumps(env, default=serialize_gridworld)
    f.write(json_str)

[Optional] If you load the GridWorld environment with the existing JSON file (```gridworld.json```), please run as follows.<br>
(In the following exercises, it runs as follows to load an environment.)

In [6]:
import json

with open("gridworld.json", "r") as f:
    json_object = json.load(f)
    env = GridWorld(**json_object, device=device)

## 2. Define models

Now we define and instantiate the expert model (neural network) by using state-of-the-art reinforcement learning algorithm, PPO (Proximal Policy Optimization).<br>
Here I don't explain the algorithm about PPO, but please refer [reinforcement learning tutorials](https://github.com/tsmatz/reinforcement-learning-tutorials) for details.

In PPO, it uses 2 models - actor model and value model.<br>
Firstly, I then define these 2 models as follows. (Later these models will be trained.)

> Note : GridWorld is a primitive environment and I assume that both the action (logits) and value is linear to the feature of state, i.e, $\verb|estimated value| = \mathbf{w}^T \cdot \phi(s) $, where $\phi(s)$ is the feature of state $s$.<br>
> Thus, both networks have no hidden layers (see below) to speed up the training. (Depending on promblems, you should design other networks.)

In [7]:
import torch.nn as nn

STATE_SIZE = env.grid_size * env.grid_size  # 2500
ACTION_SIZE = env.action_size               # 4

class ActorNet(nn.Module):
    def __init__(self, hidden_dim=16):
        super().__init__()
        self.output = nn.Linear(STATE_SIZE, ACTION_SIZE, bias=False)

    def forward(self, state, mask=None):
        """
        Set mask (size (TIMESTEP_SIZE, BATCH_SIZE)), when the number of timestep differs.
        """
        logits = self.output(state)
        if mask is not None:
            mask = mask.unsqueeze(dim=-1)
            mask = mask.expand(-1, -1, ACTION_SIZE)
            logits = logits.masked_fill(mask, 0.0)
        return logits

class ValueNet(nn.Module):
    def __init__(self, hidden_dim=16):
        super().__init__()
        self.output = nn.Linear(STATE_SIZE, 1, bias=False)

    def forward(self, state, mask=None):
        """
        Set mask (size (TIMESTEP_SIZE, BATCH_SIZE)), when the number of timestep differs.
        """
        value = self.output(state)
        if mask is not None:
            mask = mask.unsqueeze(dim=-1)
            value = value.masked_fill(mask, 0.0)
        return value

#
# Generate model
#
actor_func = ActorNet().to(device)
value_func = ValueNet().to(device)

I also define the helper functions as follows.

In [8]:
# Pick up stochastic samples using previous actor model
def pick_sample_and_logp(policy, s):
    """
    Stochastically pick up action and logits with policy model.

    Parameters
    ----------
    policy : torch.nn.Module
        Policy network to use
    s : torch.tensor((BATCH_SIZE, STATE_SIZE), dtype=int)
        The feature (one-hot) of state.

    Returns
    ----------
    action : torch.tensor((BATCH_SIZE), dtype=int)
        The picked-up actions.
    logits : torch.tensor((BATCH_SIZE, ACTION_SIZE), dtype=float)
        Logits defining categorical distribution.
        This is needed to optimize model.
    """
    with torch.no_grad():
        # Get logits from state
        # --> size : (BATCH_SIZE, ACTION_SIZE)
        logits = policy(s.float())
        # From logits to probabilities
        # --> size : (BATCH_SIZE, ACTION_SIZE)
        probs = F.softmax(logits, dim=-1)
        # Pick up action's sample
        # --> size : (BATCH_SIZE, 1)
        a = torch.multinomial(probs, num_samples=1)
        # --> size : (BATCH_SIZE, )
        a = a.squeeze(dim=1)
        # Calculate log probability
        logprb = -F.cross_entropy(logits, a, reduction="none")

        # Return
        return a, logits, logprb

## 3. Train and generate expert model

**Note : This repository already has trained models (```expert_actor.pt```, ```expert_value.pt```) and you don't then need to run this cell.**

Now let's train models with PPO algorithm.<br>

In this notebook, I have just used source code in [reinforcement learning tutorials](https://github.com/tsmatz/reinforcement-learning-tutorials) without any explanation, but please refer the original notebook for theoretical background behind PPO.

In this training,

- To speed up training, it runs training as a batch.
- The environment doesn't have transition probability and the optimal total reward in a single episode always becomes ```10.0```. (If it's optimized, it'll be close to ```10.0```.)
- The goal for the agent is to reach the goal state without losing rewards. Thus, I have set no discount rate (```DISCOUNT = 1.0```) in this training.

> Here I set ```-100``` as unknown action, because PyTorch cross-entropy function (```torch.nn.functional.cross_entropy()```) has a property ```ignore_index``` which default value is ```-100```.

In [None]:
#
# Train model
# (see https://github.com/tsmatz/reinforcement-learning-tutorials/blob/master/04-ppo.ipynb)
#
# Operations are processed as a batch to speed up,
# and all working tensor has dimension: (step_count, batch_size, ...)
#
import numpy as np

DISCOUNT = 1.0            # No Discount
LEARNING_RATE = 0.001
BATCH_SIZE = 128

# [TODO] change this value if transition_prob=TRUE
THRESHOLD = 9.0  # if transition_prob=FALSE
# THRESHOLD = -0.95  # if transition_prob=TRUE

# These coefficients are experimentally determined in practice.
kl_coeff = 100000.0  # weight coefficient for KL-divergence loss
vf_coeff = 100.00  # weight coefficient for value loss

reward_records = []
all_params = list(actor_func.parameters()) + list(value_func.parameters())
opt = torch.optim.AdamW(all_params, lr=LEARNING_RATE)
for i in range(999999):

    # print("Epochs :", i+1)

    #
    # Run episode till done as a batch to generate tensors
    #
    done = torch.tensor([False]).to(device)
    # define working items
    # (tensor shape is (TIMESTEP_SIZE, BATCH_SIZE) or (TIMESTEP_SIZE, BATCH_SIZE, C) where C is the number of classes)
    states_work  = torch.empty((0,BATCH_SIZE,STATE_SIZE), dtype=torch.int).to(device)
    actions_work = torch.empty((0,BATCH_SIZE), dtype=torch.int).to(device)
    logits_work  = torch.empty((0,BATCH_SIZE,ACTION_SIZE), dtype=torch.float).to(device)
    logprbs_work = torch.empty((0,BATCH_SIZE), dtype=torch.float).to(device)
    rewards_work = torch.empty((0,BATCH_SIZE), dtype=torch.float).to(device)

    # define done items
    states_done  = []
    actions_done = []
    logits_done  = []
    logprbs_done = []
    rewards_done = []

    # start
    s = env.reset(BATCH_SIZE)
    while not (torch.prod(done) == 1):

        s_onehot = F.one_hot(s, num_classes=STATE_SIZE)
        states_work = torch.cat((states_work, s_onehot.unsqueeze(dim=0)), dim=0)
        a, l, p = pick_sample_and_logp(actor_func, s_onehot)
        s, r, term, trunc = env.step(a, s)
        done = torch.logical_or(term, trunc)

        actions_work = torch.cat((actions_work, a.unsqueeze(dim=0)), dim=0)
        logits_work  = torch.cat((logits_work,  l.unsqueeze(dim=0)), dim=0)
        logprbs_work = torch.cat((logprbs_work, p.unsqueeze(dim=0)), dim=0)
        rewards_work = torch.cat((rewards_work, r.unsqueeze(dim=0)), dim=0)

        # pick up batch to be done and append to done-list
        done_indices = done.nonzero().squeeze(dim=1)
        if done_indices.numel() > 0:
            states_done.append(states_work[:,done_indices,:])
            actions_done.append(actions_work[:,done_indices])
            logits_done.append(logits_work[:,done_indices,:])
            logprbs_done.append(logprbs_work[:,done_indices])
            rewards_done.append(rewards_work[:,done_indices])
        # filter batch to run (not to be done)
        work_indices = (done==False).nonzero().squeeze(dim=1)
        if work_indices.numel() > 0:
            states_work = states_work[:,work_indices,:]
            actions_work = actions_work[:,work_indices]
            logits_work = logits_work[:,work_indices,:]
            logprbs_work = logprbs_work[:,work_indices]
            rewards_work = rewards_work[:,work_indices]
        # also filter the current state
        if work_indices.numel() > 0:
            s = s[work_indices]

    #
    # Prepare tensors for training
    #

    # fill values (0 or -100) to fit to maximum timestep
    timestep_size = env.step_count
    states_done   = [torch.cat((s, torch.zeros((timestep_size-s.shape[0],s.shape[1],STATE_SIZE), dtype=torch.int).to(device)), dim=0) for s in states_done]
    actions_done  = [torch.cat((a, torch.ones((timestep_size-a.shape[0],a.shape[1]), dtype=torch.int).to(device)*-100), dim=0) for a in actions_done]
    logits_done   = [torch.cat((l, torch.zeros((timestep_size-l.shape[0],l.shape[1],ACTION_SIZE), dtype=torch.float).to(device)), dim=0) for l in logits_done]
    logprbs_done  = [torch.cat((p, torch.zeros((timestep_size-p.shape[0],p.shape[1]), dtype=torch.float).to(device)), dim=0) for p in logprbs_done]
    rewards_done  = [torch.cat((r, torch.zeros((timestep_size-r.shape[0],r.shape[1]), dtype=torch.float).to(device)), dim=0) for r in rewards_done]

    # generate tensor from the list of tensor
    # (tensor shape is (TIMESTEP_SIZE, BATCH_SIZE) or (TIMESTEP_SIZE, BATCH_SIZE, C) where C is the number of classes)
    states  = torch.cat(states_done, dim=1)
    actions = torch.cat(actions_done, dim=1)
    logits  = torch.cat(logits_done, dim=1)
    logprbs = torch.cat(logprbs_done, dim=1)
    rewards = torch.cat(rewards_done, dim=1)
    states  = states.float()

    #
    # Generate cumulative rewards
    #
    cum_rewards = torch.zeros_like(rewards).to(device)
    for j in reversed(range(timestep_size)):
        cum_rewards[j,:] = rewards[j,:] + (cum_rewards[j+1,:]*DISCOUNT if j+1 < timestep_size else 0)

    #
    # Train and optimize model parameters
    #
    opt.zero_grad()
    # get values and logits with new parameters
    values_new = value_func(states, mask=(actions==-100))
    logits_new = actor_func(states, mask=(actions==-100))
    # get advantages
    advantages = cum_rewards.unsqueeze(dim=-1) - values_new
    # calculate P_new / P_old
    logprbs_new = -F.cross_entropy(logits_new.transpose(1,2), actions, reduction="none")
    logprbs_new = logprbs_new.unsqueeze(dim=-1)
    prob_ratio = torch.exp(logprbs_new - logprbs.unsqueeze(dim=-1))
    # calculate KL-div for Categorical distribution
    l0 = logits - torch.amax(logits, dim=-1, keepdim=True) # to reduce quantity
    l1 = logits_new - torch.amax(logits_new, dim=-1, keepdim=True) # to reduce quantity
    e0 = torch.exp(l0)
    e1 = torch.exp(l1)
    e_sum0 = torch.sum(e0, dim=-1, keepdim=True)
    e_sum1 = torch.sum(e1, dim=-1, keepdim=True)
    p0 = e0 / e_sum0
    kl = torch.sum(
        p0 * (l0 - torch.log(e_sum0) - l1 + torch.log(e_sum1)),
        dim=-1,
        keepdim=True)
    # get value loss
    vf_loss = F.mse_loss(
        values_new,
        cum_rewards.unsqueeze(dim=-1),
        reduction="none")
    # get total loss
    loss = -advantages * prob_ratio + kl * kl_coeff + vf_loss * vf_coeff
    # optimize
    loss.sum().backward()
    opt.step()

    #
    # Output statistics (average in batch)
    #
    print("Run iteration{} with total reward {:6.1f}  episode length {:5.1f}".format(
        i+1,
        torch.mean(torch.sum(rewards, dim=0)).tolist(),
        torch.mean(torch.sum(actions!=-100, dim=0).float()).tolist()), end="\r")
    reward_records.append(torch.mean(torch.sum(rewards, dim=0)).tolist())

    #
    # Stop if reward mean is over a threshold
    #
    if np.average(reward_records[-100:]) > THRESHOLD:
        break

print("\nDone")



Save (serialize) the generated expert model.

In [None]:
torch.save(actor_func.state_dict(), "expert_actor.pt")
torch.save(value_func.state_dict(), "expert_value.pt")

Show how total reward in a single episode (the average in batch) has transitioned.

In [None]:
import matplotlib.pyplot as plt
# Generate 50 interval average
average_reward = []
for idx in range(len(reward_records)):
    avg_list = np.empty(shape=(1,), dtype=int)
    if idx < 50:
        avg_list = reward_records[:idx+1]
    else:
        avg_list = reward_records[idx-49:idx+1]
    average_reward.append(np.average(avg_list))
plt.plot(reward_records)
plt.plot(average_reward)

Show how the trained agent transits in GridWorld.

In [None]:
from IPython.display import HTML, display

# get all initial states
valid_states_all = torch.cat((env.valid_states, torch.tensor([env.grid_size-1,env.grid_size-1]).to(device).unsqueeze(dim=0)))
valid_states_all = valid_states_all[:,0] * env.grid_size + valid_states_all[:,1]

# create direction table
with torch.no_grad():
    s = torch.arange(STATE_SIZE).to(device)
    s_onehot = F.one_hot(s, num_classes=STATE_SIZE).float()
    logits = actor_func(s_onehot)
    direction = torch.argmax(logits, dim=-1)
    direction_table = torch.reshape(direction, (env.grid_size, env.grid_size))
    direction_table = direction_table.cpu().numpy()

# show table
html_text = "<table>"
for row in range(env.grid_size):
    html_text += "<tr>"
    for col in range(env.grid_size):
        if row*env.grid_size + col in valid_states_all:
            html_text += "<td bgcolor=\"gray\">"
            #
            # show direction
            #
            index = direction_table[row, col]
            if index == 0:
                html_text += "&#x2191;" # up
            elif index == 1:
                html_text += "&#x2193;" # down
            elif index == 2:
                html_text += "&#x2190;" # left
            elif index == 3:
                html_text += "&#x2192;" # right
        else:
            html_text += "<td>"
        html_text += "</td>"
    html_text += "</tr>"
html_text += "</table>"

display(HTML(html_text))

If you load the pre-trained model, please run as follows.

In [None]:
actor_func.load_state_dict(torch.load("expert_actor.pt"))
value_func.load_state_dict(torch.load("expert_value.pt"))
actor_func = actor_func.eval()
value_func = value_func.eval()

## 4. Generate expert trajectories (expert data)

Now we have a trained expert.<br>
Next generate expert's data with this trained agent.

In this example, I generate 100,000 trajectories (episodes) and save all data into 10 separated files (in which, each file has 10,000 trajectories).

In [None]:
from pathlib import Path
import pickle
import random

# total number of episodes to run
episode_num = 100000
# number of episodes to save in a single checkpoint file
episode_num_in_ckpt = 10000
# batch size to run inference
# (the number of episodes in each file)
inf_batch_size = 100
# directory name to save files
dest_dir = "./expert_data"

assert episode_num % episode_num_in_ckpt == 0
assert episode_num_in_ckpt % inf_batch_size == 0

total_iter_num = int(episode_num / inf_batch_size)

Path(dest_dir).mkdir(exist_ok=True)

with torch.no_grad():
    for i in range(total_iter_num):
        #
        # Run episode till done as a batch to generate tensors
        #
        done = torch.tensor([False]).to(device)
        # define working items
        states_work = torch.empty((inf_batch_size,0), dtype=torch.int).to(device)
        actions_work = torch.empty((inf_batch_size,0), dtype=torch.int).to(device)
        rewards_work = torch.empty((inf_batch_size,0), dtype=torch.float).to(device)
        # define done items
        states_done = []
        actions_done = []
        rewards_done = []
        # start
        s = env.reset(inf_batch_size)
        while not (torch.prod(done) == 1):
            s_onehot = F.one_hot(s, num_classes=STATE_SIZE)
            states_work = torch.cat((states_work, s.unsqueeze(dim=1)), dim=1)
            a, _, _ = pick_sample_and_logp(actor_func, s_onehot)
            s, r, term, trunc = env.step(a, s)
            done = torch.logical_or(term, trunc)
            actions_work = torch.cat((actions_work, a.unsqueeze(dim=1)), dim=1)
            rewards_work = torch.cat((rewards_work, r.unsqueeze(dim=1)), dim=1)
            # pick up batch to be done and append to done-list
            done_indices = done.nonzero().squeeze(dim=1)
            if done_indices.numel() > 0:
                states_done.append(states_work[done_indices,:])
                actions_done.append(actions_work[done_indices,:])
                rewards_done.append(rewards_work[done_indices,:])
            # filter batch to run (not to be done)
            work_indices = (done==False).nonzero().squeeze(dim=1)
            if work_indices.numel() > 0:
                states_work = states_work[work_indices,:]
                actions_work = actions_work[work_indices,:]
                rewards_work = rewards_work[work_indices,:]
            # also filter the current state
            if work_indices.numel() > 0:
                s = s[work_indices]

        #
        # Save results as numpy array
        #

        # e.g, [tensor([[1,3,1],[2,1,1]]), tensor([[1,2,1,3]]), ...]

        # split tensors into list of tensors
        # --> [[tensor([[1,3,1]]), tensor([[2,1,1]])], [tensor([[1,2,1,3]])], ...]
        states_done = [torch.split(s, 1, dim=0) for s in states_done]
        actions_done = [torch.split(a, 1, dim=0) for a in actions_done]
        rewards_done = [torch.split(r, 1, dim=0) for r in rewards_done]
        # flatten into 1-dimension list
        # --> [tensor([[1,3,1]]), tensor([[2,1,]]), tensor([[1,2,1,3]]), ...]
        states_done = [s2 for s1 in states_done for s2 in s1]
        actions_done = [a2 for a1 in actions_done for a2 in a1]
        rewards_done = [r2 for r1 in rewards_done for r2 in r1]
        # squeeze in each element
        # --> [tensor([1,3,1]), tensor([2,1,1]), tensor([1,2,1,3]), ...]
        states_done = [s.squeeze(dim=0) for s in states_done]
        actions_done = [s.squeeze(dim=0) for s in actions_done]
        rewards_done = [s.squeeze(dim=0) for s in rewards_done]
        # shuffle
        all_done = list(zip(states_done, actions_done, rewards_done))
        random.shuffle(all_done)
        states_done, actions_done, rewards_done = zip(*all_done)
        states_done, actions_done, rewards_done = list(states_done), list(actions_done), list(rewards_done)
        # get step length in each episode
        step_lens = [a.shape[0] for a in actions_done]
        # flatten
        states_done = torch.cat(states_done, dim=0)
        actions_done = torch.cat(actions_done, dim=0)
        rewards_done = torch.cat(rewards_done, dim=0)
        # to numpy
        states_done = states_done.cpu().numpy()
        actions_done = actions_done.cpu().numpy()
        rewards_done = rewards_done.cpu().numpy()
        step_lens = np.array(step_lens)
        # output progress
        print("Processed {:6d} / {:6d} episodes ...".format(inf_batch_size * (i + 1), episode_num), end="\r")
        # save in each episode_num_in_ckpt
        if (i + 1) % (episode_num_in_ckpt / inf_batch_size) == 1:
            # initialize
            states_store = states_done
            actions_store = actions_done
            rewards_store = rewards_done
            timestep_lens_store = step_lens
        else:
            # add to list
            states_store = np.concatenate((states_store, states_done), axis=0)
            actions_store = np.concatenate((actions_store, actions_done), axis=0)
            rewards_store = np.concatenate((rewards_store, rewards_done), axis=0)
            timestep_lens_store = np.concatenate((timestep_lens_store, step_lens), axis=0)
            # save
            if (i + 1) % (episode_num_in_ckpt / inf_batch_size) == 0:
                ckpt_num = int((i + 1) / (episode_num_in_ckpt / inf_batch_size) - 1)
                with open(f"{dest_dir}/ckpt{ckpt_num}.pkl","wb") as f:
                    pickle.dump({
                        "states": states_store,
                        "actions": actions_store,
                        "rewards": rewards_store,
                        "timestep_lens": timestep_lens_store,
                    }, f)
print("\nDone")