# A2C with PPO training in the unity environment with the car

In [None]:
from io import BytesIO

from functools import reduce
from itertools import zip_longest

import typing
from typing import NamedTuple
from dataclasses import dataclass, field

import numpy as np

from tqdm.notebook import trange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dist

from PIL import Image

from mlagents.trainers.demo_loader import load_demonstration

from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import ActionTuple, DecisionSteps, TerminalSteps

import matplotlib.pyplot as plt
%matplotlib inline

## Unity environment

### Environment initialization

In [None]:
unity_env = UnityEnvironment()
unity_env.reset()
name = list(unity_env.behavior_specs.keys())[0]
print(name)

### Environment wrapper

In [None]:
class Observation(NamedTuple):
    """
    Environment observation tuple
    """
    camera: np.ndarray
    ifr_r: np.ndarray
    ifr_l: np.ndarray

In [None]:
class UnityEnvWrapper:
    """
    Wrapper of unity environment for gym like interface
    """

    def __init__(self, unity_env: UnityEnvironment):
        self._unity_env = unity_env

        self._unity_env.reset()

        self._name = list(self._unity_env.behavior_specs.keys())[0]
        self._group_spec = self._unity_env.behavior_specs[self._name]

        self._done = False

    def _get_step(self) -> typing.Union[DecisionSteps, TerminalSteps]:
        decision_step, terminal_step = self._unity_env.get_steps(self._name)

        if len(terminal_step) != 0:
            self._done = True
            return terminal_step
        else:
            return decision_step

    def _get_step_observation(self, step: typing.Union[DecisionSteps, TerminalSteps]) -> Observation:
        return Observation(
            camera=step.obs[0],
            ifr_l=step.obs[1],
            ifr_r=step.obs[2]
        )

    def _process_step(self, step: typing.Union[DecisionSteps, TerminalSteps]) -> tuple[Observation, float, bool, str]:
        return self._get_step_observation(step), step.reward[0], isinstance(step, TerminalSteps), ""
    
    def reset(self) -> Observation:
        self._unity_env.reset()
        self._done = False

        step = self._get_step()
        return self._get_step_observation(step)

    def step(self, actions: ActionTuple) -> tuple[Observation, float, bool, str]:
        if self._done:
            raise ValueError("Actions passed to the done env")

        self._unity_env.set_actions(self._name, actions)
        self._unity_env.step()

        step = self._get_step()

        return self._process_step(step)


### Environment test

In [None]:
env = UnityEnvWrapper(unity_env)
env.reset()

Testing that environment responds to the actions and provides the rewards

In [None]:
done = False
rewards = []

np_action = np.array([[1, 0]])
action = ActionTuple(np_action)

env.reset()

while not done:
    _, reward, done, _ = env.step(action)
    rewards.append(reward)

In [None]:
plt.figure(figsize=(20, 10))
plt.plot(rewards)

## Model definition

### Encoders

In [None]:
def size_after_conv(in_size: int, kernel_size: int, stride: int):
    return (in_size - kernel_size) // stride + 1


class ImageEncoder(nn.Module):
    """
    Camera observation encoder
    """

    def __init__(self, im_size: tuple[int, int, int], out_features: int):
        super().__init__()

        self.im_size = im_size

        self.conv1 = nn.Sequential(
            nn.Conv2d(im_size[0], 32, 3, stride=2),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=2),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=2),
            nn.ReLU()
        )

        # Calculates width and height of output image
        flat_size_w = reduce(lambda x, _: size_after_conv(x, 3, 2), range(3), im_size[1])
        flat_size_h = reduce(lambda x, _: size_after_conv(x, 3, 2), range(3), im_size[2])
        self.flat_size = flat_size_w * flat_size_h * 64

        self.linear1 = nn.Sequential(
            nn.Linear(self.flat_size, 128),
            nn.ReLU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU()
        )

        self.out = nn.Linear(64, out_features)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        X = X.view(-1, *self.im_size)

        X = self.conv1(X)
        X = self.conv2(X)
        X = self.conv3(X)

        X = X.view(-1, self.flat_size)

        X = self.linear1(X)
        X = self.linear2(X)

        return self.out(X)
        

In [None]:
class StateEncoder(nn.Module):
    """
    Encodes output of several encodings into one 
    """

    def __init__(self, *in_sizes: int, out_features: int, hidden_size: int = 512):
        super().__init__()

        in_features = sum(in_sizes)

        self.hidden = nn.Sequential(
            nn.Linear(in_features, hidden_size),
            nn.ReLU()
        )

        self.linear = nn.Linear(hidden_size, out_features)

    def forward(self, *embeds: torch.Tensor) -> torch.Tensor:

        X = torch.cat(embeds, dim=1)

        return self.linear(self.hidden(X))

In [None]:
class Encoder(nn.Module):
    """
    State encoder for actor and critic models
    """

    def __init__(self, *encoders: nn.Module, state_encoder: nn.Module, state_size: int):
        super().__init__()

        self.encoders = nn.ModuleList(encoders)
        self.state_encoder = state_encoder

        self.linear1 = nn.Sequential(
            nn.Linear(state_size, 256),
            nn.ReLU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU()
        )

        self.lstm = nn.LSTMCell(input_size=256, hidden_size=128)
        self.hid = torch.zeros(1, 128)
        self.cell = torch.zeros(1, 128)

        self.head = nn.Linear(128, state_size)

    @property
    def memory(self) -> tuple[torch.Tensor, torch.Tensor]:
        return (self.hid.detach(), self.cell.detach())
    
    def reset(self, memory: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None):
        self.hid, self.cell = memory if memory else (torch.zeros(1, 128), torch.zeros(1, 128))
        
    def forward(self, X: torch.Tensor, memory: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:

        X = torch.cat(
            [self.state_encoder(
                *(f(x) for f, x in zip_longest(self.encoders, X_, fillvalue=lambda x: x))
            )
            for X_ in X]
        )

        X = self.linear1(X)
        X = self.linear2(X)

        if memory is not None:
            self.hid, self.cell = memory

        X, _ = self.hid, self.cell = self.lstm(X, self.memory)

        return self.head(X)


### PPO implementation

In [None]:
class Actor(nn.Module):
    
    def __init__(self, encoder: nn.Module, state_size: int, n_actions: int):
        super().__init__()

        self.encoder = encoder
        
        self.linear1 = nn.Sequential(
            nn.Linear(state_size, 256),
            nn.ReLU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU()
        )

        self.alpha_head = nn.Sequential(
            nn.Linear(128, n_actions),
            nn.Softplus()
        )
        self.beta_head = nn.Sequential(
            nn.Linear(128, n_actions),
            nn.Softplus()
        )


    def forward(self, X: torch.Tensor, memory: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:

        X = self.encoder(X, memory)

        X = self.linear1(X)
        X = self.linear2(X)

        return self.alpha_head(X), self.beta_head(X)

    def get_policy(self, state, memory: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None) -> dist.Beta:
        alpha, beta = self(state, memory)

        # Beta policy is used to sample values from 0 to 1
        return dist.Beta(alpha, beta)


In [None]:
class Critic(nn.Module):

    def __init__(self, encoder: nn.Module, state_size: int):
        super().__init__()

        self.encoder = encoder

        self.linear1 = nn.Sequential(
            nn.Linear(state_size, 256),
            nn.ReLU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU()
        )
        self.linear3 = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU()
        )

        self.value = nn.Linear(128, 1)

    def forward(self, X: torch.Tensor, memory: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:

        X = self.encoder(X, memory)

        X = self.linear1(X)
        X = self.linear2(X)
        X = self.linear3(X)

        return self.value(X)

### Initialization

In [None]:
def make_actor(n_action: int) -> Actor:
    image_encoder = ImageEncoder((3, 256, 256), 128)
    state_encoder = StateEncoder(128, 1, 1, out_features=256)
    encoder = Encoder(image_encoder, state_encoder=state_encoder, state_size=256)

    actor = Actor(encoder=encoder, state_size=256, n_actions=n_action)

    return actor

In [None]:
def make_critic() -> Critic:
    image_encoder = ImageEncoder((3, 256, 256), 128)
    state_encoder = StateEncoder(128, 1, 1, out_features=256)
    encoder = Encoder(image_encoder, state_encoder=state_encoder, state_size=256)

    critic = Critic(encoder=encoder, state_size=256)

    return critic

In [None]:
def load_model(model: nn.Module, model_name: str, episode: int):
    model.load_state_dict(torch.load(f'models/{model_name}/{model_name}-{episode:04}.torch'))


def save_model(model: nn.Module, model_name: str, episode: int):
    torch.save(model.state_dict(), f'models/{model_name}/{model_name}-{episode:>04}.torch')


In [None]:
actor = make_actor(2)
critic = make_critic()

load_model(actor, "actor", 1090)
load_model(critic, "critic", 1090)

actor_old = make_actor(2)
actor_old.load_state_dict(actor.state_dict())

In [None]:
def count_parameters(model: nn.Module) -> int:
    return sum(reduce(lambda x, a: x * a, parameter.size(), 1) for parameter in model.parameters())

count_parameters(actor)

## Learning

In [None]:
Self = typing.TypeVar("Self", bound="Buffer")


class Buffer:
    """
    Base class for buffer dataclasses
    """

    def __len__(self) -> int:
        return len(self.reward)

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    def __getitem__(self, index: typing.Union[int, slice]) -> Self:
        return type(self)(**{field: values[index] for field, values in self.__dict__.items()})

    def add(self, **kwargs) -> None:
        for key, value in kwargs.items():
            self.__dict__[key].append(value)

    def clear(self) -> None:
        for l in self.__dict__.values():
            del l[:]

    def batches(self, batch_size) -> typing.Generator[Self, None, None]:
        for i in range(0, len(self), batch_size):
            yield self[i : i + batch_size]

In [None]:
def prepare_obs(obs: typing.Iterable[np.ndarray]) -> typing.Iterable[torch.Tensor]:
    return [torch.from_numpy(o).float() for o in obs]

### Imitation learning

In [None]:
@dataclass
class ImitationBuffer(Buffer):
    
    observation: list[list[torch.Tensor]] = field(default_factory=list)
    reward: list[float] = field(default_factory=list)
    state_value: list[float] = field(default_factory=list)

    action: list[torch.Tensor] = field(default_factory=list)


In [None]:
LR = 3e-4
GAMMA = 0.99
N_EPOCHS = 3000
MINIBATCH_SIZE = 20

ENTROPY_K = 0.01

actor_bc_optim = optim.Adam(actor.parameters(), LR)
critic_bc_optim = optim.Adam(critic.parameters(), LR)

In [None]:
spec, info_list, total = load_demonstration("demos/record.demo")

In [None]:
dataset: list[ImitationBuffer] = []

buffer = ImitationBuffer()

for record in info_list:

    observations = record.agent_info.observations
    observation = Observation(
        camera=np.array(Image.open(BytesIO(observations[0].compressed_data))),
        ifr_l=np.array([observations[1].float_data.data]),
        ifr_r=np.array([observations[2].float_data.data])
    )

    action = torch.FloatTensor(record.action_info.continuous_actions)
    action[1] = (action[1] + 1) / 2
    buffer.add(
        observation=prepare_obs(observation),
        reward=record.agent_info.reward,
        action=action
    )

    if record.agent_info.done:

        r = 0
        for reward in buffer.reward:
            r = reward + GAMMA * r
            buffer.add(state_value=r)

        buffer.state_value.reverse()

        dataset.append(buffer)
        buffer = ImitationBuffer()

In [None]:
def train_bc(batch: ImitationBuffer, values: torch.Tensor, logprobs: torch.Tensor) -> None:
    true_values = torch.tensor(batch.state_value).unsqueeze(1)

    critic_loss = F.mse_loss(values, true_values)
    actor_loss = -logprobs.mean()

    critic_bc_optim.zero_grad()
    critic_loss.backward()
    critic_bc_optim.step()

    actor_bc_optim.zero_grad()
    actor_loss.backward()
    actor_bc_optim.step()

In [None]:
def sample_value_logprob(batch: ImitationBuffer) -> tuple[torch.Tensor, torch.Tensor]:
    actor_memory = None
    critic_memory = None

    values = []
    logprobs = []
    
    for item in batch:
        obs = item.observation

        policy = actor.get_policy([obs], actor_memory)
        action = item.action
        action[1] = (action[1] + 1) / 2
        logprobs.append(policy.log_prob(item.action) + ENTROPY_K * policy.entropy())

        values.append(critic([obs], critic_memory))

        critic_memory = critic.encoder.memory
        actor_memory = actor.encoder.memory

    values = torch.stack(values).squeeze(1)
    logprobs = torch.stack(logprobs).squeeze(1)

    return values, logprobs

In [None]:
tepoch = trange(N_EPOCHS)
for epoch in tepoch:

    np.random.shuffle(dataset)

    for trajectory in dataset:

        actor.encoder.reset()
        critic.encoder.reset()    
        
        for batch in trajectory.batches(MINIBATCH_SIZE):

            train_bc(batch, *sample_value_logprob(batch))


In [None]:
save_model(actor, "actor", 1090)
save_model(critic, "critic", 1090)

### Reinforcement learning

In [None]:
@dataclass
class ReplayBuffer(Buffer):    

    observation: list[list[torch.Tensor]] = field(default_factory=list)

    actor_memory: list[tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
    critic_memory: list[tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)

    action: list[torch.Tensor] = field(default_factory=list)
    old_logprob: list[torch.Tensor] = field(default_factory=list)
    
    is_terminal: list[bool] = field(default_factory=list)
    reward: list[float] = field(default_factory=list)


In [None]:
LR = 3e-4
GAMMA = 0.99
LAMBDA = 0.95
N_EPISODES = 3000
N_STEPS = 400
EPS_CLIP = 0.2
K_EPOCHS = 10
T_HORIZON = 5

ENTROPY_K = 0.01

buffer = ReplayBuffer()

actor_optim = optim.Adam(actor.parameters(), LR)
critic_optim = optim.Adam(critic.parameters(), LR)

In [None]:
def prepare_memory(memory: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
    hid, cell = zip(*memory)
    hid = torch.stack(hid).squeeze(1)
    cell = torch.stack(cell).squeeze(1)

    return hid, cell


In [None]:
def train_ppo(buffer: ReplayBuffer, Qval):

    Qvals = np.zeros((len(buffer), 1))

    for i, item in enumerate(reversed(buffer)):
        Qval = item.reward + (1 - item.is_terminal) * LAMBDA * GAMMA * Qval
        Qvals[len(buffer) - i - 1] = Qval

    Qvals = torch.tensor(Qvals)
    old_logprobs = torch.stack(buffer.old_logprob)

    for _ in range(K_EPOCHS):

        values = critic(buffer.observation, prepare_memory(buffer.critic_memory))

        advantage = Qvals - values

        policy = actor.get_policy(buffer.observation, prepare_memory(buffer.actor_memory))
        logprobs = policy.log_prob(torch.stack(buffer.action).squeeze())

        ratios = torch.exp(logprobs - old_logprobs)

        surr1 = ratios * advantage.detach()
        surr2 = torch.clamp(ratios, 1 - EPS_CLIP, 1 + EPS_CLIP) * advantage.detach()

        critic_loss = advantage.pow(2).mean()
        actor_loss = -torch.min(surr1, surr2).mean() + ENTROPY_K * policy.entropy().mean()

        critic_optim.zero_grad()
        critic_loss.backward()
        critic_optim.step()

        actor_optim.zero_grad()
        actor_loss.backward()
        actor_optim.step()
    
    actor_old.load_state_dict(actor.state_dict())


In [None]:
all_lengths = []
all_rewards = []

tepisodes = trange(N_EPISODES)
for episode in tepisodes:

    actor_memory = critic_memory = None

    rewards = []

    steps = 0
    done = False
    obs = env.reset()

    while not done:
        buffer.clear()

        actor_old.encoder.reset(actor_memory)
        critic.encoder.reset(critic_memory)

        for t in range(T_HORIZON):
            obs = prepare_obs(obs)

            buffer.add(
                observation=obs,
                actor_memory=actor_old.encoder.memory,
                critic_memory=critic.encoder.memory
            )
            if critic.encoder.memory[0].shape[0] == 15:
                critic.encoder.reset()

            policy = actor_old.get_policy([obs])
            action = policy.sample()

            buffer.add(action=action, old_logprob=policy.log_prob(action).detach())
            
            action = action.clone().numpy()
            action[:, 1] = (2 * action[:, 1]) - 1 # Move steering action from space [0, 1] to [-1, 1]
            action = ActionTuple(action)

            new_obs, reward, done, _ = env.step(action)

            rewards.append(reward)

            buffer.add(reward=reward, is_terminal=done)
            
            steps += 1

            obs = new_obs

            if done or steps == N_STEPS: break

        actor_memory = actor_old.encoder.memory
        critic_memory = critic.encoder.memory

        Qval = critic([prepare_obs(obs)]).detach().numpy()
        train_ppo(buffer, Qval)

        if steps == N_STEPS: break

 
    s = np.sum(rewards)
    all_rewards.append(s)
    tepisodes.set_postfix(reward=s, len=steps)


In [None]:
save_model(actor, "actor", 2080)
save_model(critic, "critic", 2080)

In [None]:
plt.figure(figsize=(20, 10))
plt.plot(all_rewards)

## Evaluation

In [None]:
import time

done = False

actor.eval()

obs = env.reset()
actor.encoder.reset()

time.sleep(1)

while not done:
    obs = prepare_obs(obs)

    with torch.no_grad():
        action = actor_old.get_policy([obs]).sample().numpy()
        action[0, 1] = 2 * action[0, 1] - 1

        action = ActionTuple(action)

        obs, _, done, _ = env.step(action)

        time.sleep(1 / 60)

## Close environment

In [None]:
unity_env.close()

## Tests