# RL environment setup
https://stable-baselines3.readthedocs.io/en/v1.0/guide/imitation.html

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gymnasium as gym

from PIL import Image, ImageOps
from gymnasium import Env, ActionWrapper, ObservationWrapper, RewardWrapper, Wrapper
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.policies import obs_as_tensor
from torch.nn import AdaptiveAvgPool2d

from IPython.display import display, clear_output
from pathlib import Path
from typing import Callable
from multiprocessing import cpu_count

In [None]:
from scripts.render import AgentView
from scripts.backbone import *
from scripts.dataset import Normalize

In [None]:
VIEW_SIZE = 128
CHANNELS = 64
DEPTH = 4
LATENT_DIM = 128

visual_backbone = CNNEncoder(out_channels=64, depth=4, residual=True).to(DEVICE)
visual_backbone.load_state_dict(torch.load('./models/visual-encoder-CNN-R-64-4-S.pt'))
visual_encoder = VisualEncoder(visual_backbone, reduce=AdaptiveAvgPool2d((1, 1)), frozen=True)
visual_projection = VisualProjection(visual_encoder, 512, LATENT_DIM).to(DEVICE)
visual_projection.load_state_dict(torch.load('./models/visual-projection-CNN.pt'))

In [None]:
# images with semantic segmentation masks available
images = [str(x).split('/').pop() for x in Path(f'./data/masks').glob('*.png')]

docs = {k:[] for k in set(['-'.join(x.split('-')[:-1]) for x in images])}
for source in sorted(images):
    docs['-'.join(source.split('-')[:-1])].append(source)

In [None]:
source = np.random.choice(images)
image = 255 - np.array(ImageOps.grayscale(Image.open(f'./data/images/{source}')))
nav = AgentView(image.astype(np.uint8), VIEW_SIZE, bias=0)
observation = nav.top()
plt.imshow(observation, 'gray')
plt.show()

## Base-Env

In [None]:
class DocNav(Env):    
    def __init__(self, pages: list, view_size: int,
                       max_episode_steps: int = None, render_mode: str = 'rgb_array'):
        super(DocNav, self).__init__()
        self.pages = pages
        self.dim = view_size
        self.max_episode_steps = max_episode_steps or float('inf')
        # renderer native spaces
        self.action_space = Box(low=-1, high=1, shape=(4,), dtype=np.float32)
        self.observation_space = Box(low=0, high=255, shape=(self.dim, self.dim, 1), dtype=np.uint8)
        self.nav = None
    
    def render(self):
        # original visual observation
        return self.observation.astype(np.uint8)
    
    def close(self):
        self.nav = None
        
    def info(self):
        state = self.nav.state
        return {'page':self.index, 'rotation':state[2], 'zoom':state[3], 'center':self.nav.loc}
    
    def reward(self):
        return 0. if self.nav.isin() else -1.

    def terminated(self):
        return self.done

    def truncated(self):
        if self.max_episode_steps and self.steps >= self.max_episode_steps:
            return True
        return False if self.nav.isin() else True
       
    def reset(self, seed: int = None, options: dict = None) -> np.array:
        super().reset(seed=seed, options=options)
        self.steps = 0
        self.done = False
        # load first page
        source = self.pages[0]
        self.index = 0
        # set renderer
        image = 255 - np.array(ImageOps.grayscale(Image.open(f'./data/images/{source}')))
        self.nav = AgentView(image.astype(np.uint8), self.dim, bias=0)
        # set viewport
        self.observation = self.nav.top()
        return self.observation, self.info()

    def step(self, action: np.array) -> tuple:
        self.steps += 1
        self.last_state = self.nav.state
        self.observation = self.nav.transform(action)
        self.last_action = action
        reward = self.reward()
        terminated, truncated = self.terminated(), self.truncated()
        return ( self.observation,
                 reward,
                 terminated or truncated, # for vector-env
                 truncated,               # which misses this
                 self.info() )


In [None]:
env = DocNav(images, VIEW_SIZE, max_episode_steps=10)
check_env(env, warn=True)

In [None]:
def random_walk(env, scale: float = 1., limit: int = 100):
    """
    take random walk to test the env
    """
    observation, info = env.reset()
    for step in range(limit):
        action = env.action_space.sample() * scale
        observation, reward, terminated, truncated, info = env.step(action)
        center, rotation, zoom = info['center'], info['rotation'], info['zoom']
        plt.title((f'Action: {np.round(action, 4)}\nReward: {reward:.4f}\n'
                   f'Steps: {step + 1:<3}  Done: {terminated}   Lost: {truncated}\n'
                   f'Center: {np.round(center, 2)}\nRotation: {rotation:.2f}\nZoom: {zoom:.2f}'),
                  ha='left', x=0, fontdict={'family':'monospace','size':10})
        if terminated:
            return
        img = plt.imshow(env.render(), cmap='gray')
        display(plt.gcf())
        clear_output(wait=True)


In [None]:
random_walk(env, scale=0.1)

In [None]:
with torch.no_grad():
    embedding = visual_projection(Normalize(observation).unsqueeze(1).to(DEVICE)).unsqueeze(1)
    embedding = embedding.cpu().numpy()
    plt.scatter(range(len(embedding)), embedding, s=10, c=embedding/2.5, cmap='rainbow')
    plt.title(f'Embedding value: [{np.min(embedding):.0f},{np.max(embedding):.0f}] shape: {embedding.shape}')
    plt.show()

In [None]:
class EncodedVisual(ObservationWrapper):
    """
    use embedding space instead of original visual observation
    """
    def __init__(self, env, dim, encoder, device):
        super().__init__(env)
        self.observation_space = Box(low=-1, high=1, shape=(dim,), dtype=np.float32)
        self.encoder = encoder
        self.device = device
        
    def observation(self, observation) -> np.array:
        """
        transform the native renderer observation to
        operational observation: embedding vector in this case
        """
        with torch.no_grad():
            embedding = self.encoder(Normalize(observation).unsqueeze(1).to(self.device))
            # normalize to fit in [-1, 1]
            return embedding.cpu().numpy().squeeze() * 0.25
    

In [None]:
env = EncodedVisual(DocNav(images, VIEW_SIZE, max_episode_steps=10), LATENT_DIM, visual_projection, DEVICE)
check_env(env, warn=True)

random_walk(env, scale=0.1)

## Learning environment
The renderer native action space maybe too complex for standard RL to be useful.

In [None]:
class DiscreteRotate(ActionWrapper):
    """
    use only 3 basic actions: go CCW one degree, hold, go CW one degree
    """
    def __init__(self, env):
        super().__init__(env)
        self.action_value = list(np.array([-1., 0., 1.])/180.)
        self.action_space = Discrete(3)

    def action(self, value):
        """
        translate operational to the native action format
        """
        return np.array([0, 0, self.action_value[int(value)], 0])
    

In [None]:
env = DiscreteRotate(EncodedVisual(DocNav(images, VIEW_SIZE), LATENT_DIM, visual_projection, DEVICE))
check_env(env, warn=True)

random_walk(env, limit=10)

In [None]:
class PageAlign(DocNav):
    """
    agent can only rotate, and should rotate in the nearest of (0, 90, 180, 270) direction
    one degree at a time: means total steps should be under 45
    """
    def __init__(self, pages: list, view_size: int):
        super().__init__(pages=pages, view_size=view_size, max_episode_steps=200)
    
    def reward(self):
        """
        This simplified scenario considers alignment only:
        success is any of (0, 90, 180, 270) states
        """
        if not self.nav.isin():
            return -1000.
        # evaluate current state
        curr = self.nav.state[2] % 90
        curr = min(curr, 90 - curr)
        if curr == 0:
            self.done = True
            return 1000. / self.steps
        # compare current and previous states
        prev = self.last_state[2] % 90
        prev = min(prev, 90 - prev)
        if prev > curr: # move in the right direction
            return -0.01
        return -0.1 * self.steps
       
    def reset(self, seed: int = None, options: dict = None) -> np.array:
        super().reset(seed=seed, options=options)
        self.steps = 0
        # load random page
        self.index = np.random.choice(len(self.pages))
        source = self.pages[self.index]
        # set rendering
        image = 255 - np.array(ImageOps.grayscale(Image.open(f'./data/images/{source}')))
        self.nav = AgentView(image.astype(np.uint8), self.dim, bias=0)        
        # set random viewport
        std = 0 # make sure there's something to see ( some pages are half-empty )
        while std < 10.:
            center = (np.array(self.nav.space.center) * (0.25 + np.random.rand() * 1.5)).astype(int)
            rotation = np.random.choice(360)
            zoom = -1 - np.random.rand() * 2.5
            observation = self.nav.set_state(center, rotation, zoom)
            std = np.std(observation)
        self.observation = observation
        return self.observation, self.info()
    

In [None]:
def random_env():
    return DiscreteRotate(EncodedVisual(PageAlign(images, VIEW_SIZE), LATENT_DIM, visual_projection, DEVICE))

env = random_env()
check_env(env, warn=True)

random_walk(env, limit=10)

In [None]:
num_cores = cpu_count()
# vector-env
vec = make_vec_env(random_env, n_envs=num_cores)

In [None]:
#!rm -rf runs/rl-align
#agent = PPO('MlpPolicy', vec, verbose=1, learning_rate=1e-4, tensorboard_log='runs/rl-align/')
#agent.learn(total_timesteps=1e6)
#agent.save('./models/align-PPO')

    ------------------------------------------
    | rollout/                |              |
    |    ep_len_mean          | 34.5         |
    |    ep_rew_mean          | 29.1         |
    | time/                   |              |
    |    fps                  | 61           |
    |    iterations           | 62           |
    |    time_elapsed         | 16582        |
    |    total_timesteps      | 1015808      |
    | train/                  |              |
    |    approx_kl            | 0.0024247682 |
    |    clip_fraction        | 0.0209       |
    |    clip_range           | 0.2          |
    |    entropy_loss         | -0.33        |
    |    explained_variance   | 0.419        |
    |    learning_rate        | 0.0001       |
    |    loss                 | 673          |
    |    n_updates            | 610          |
    |    policy_gradient_loss | -0.00194     |
    |    value_loss           | 1.94e+03     |
    ------------------------------------------
    
    
    tensorboard --bind_all --logdir ./runs

In [None]:
agent = PPO.load('./models/align-PPO', env=env, print_system_info=True)

In [None]:
def get_action_proba(agent, embedding):
    with torch.no_grad():
        x = obs_as_tensor(embedding[np.newaxis, :], agent.policy.device)
        # get actions probabilities given observation
        return agent.policy.get_distribution(x).distribution.probs.cpu().numpy()


def run_episode(env, agent, limit=100):
    """
    let the agent drive
    """
    venv = DummyVecEnv([lambda: env])
    venv.render_mode = 'rgb_array'
    observation = venv.reset()
    for step in range(limit):
        dist = get_action_proba(agent, observation)[0]
        proba = ', '.join([str(x) for x in np.round(dist, 2)])
        action,__ = agent.predict(observation, deterministic=True)
        observation, reward, terminated, info = venv.step(action)
        center, rotation, zoom = info[0]['center'], info[0]['rotation'], info[0]['zoom']
        plt.title((f'Action: {np.round(action[0], 4)}   Proba: [{proba}]\nReward: {reward[0]:.4f}\n'
                   f'Steps: {step + 1:<3}  Done: {terminated[0]}\n'
                   f'Center: {np.round(center, 2)}\nRotation: {rotation:.2f}\nZoom: {zoom:.2f}'),
                  ha='left', x=0, fontdict={'family':'monospace','size':10})
        if terminated[0]:
            return
        img = plt.imshow(env.render(), cmap='gray')
        display(plt.gcf())
        clear_output(wait=True)


In [None]:
env = random_env()
run_episode(env, agent, limit=50)

In [None]:
class TestAlign(PageAlign):
    """
    This reward is the same as PageAlign except it never signals `done` state
    """
    def reward(self):
        if not self.nav.isin():
            return -1000.
        curr = self.nav.state[2] % 90
        curr = min(curr, 90 - curr)
        if curr == 0:
            ### self.done = True ###
            return 1000. / self.steps
        prev = self.last_state[2] % 90
        prev = min(prev, 90 - prev)
        if prev > curr:
            return -0.01
        return -0.1 * self.steps
    

In [None]:
env = DiscreteRotate(EncodedVisual(TestAlign(images, VIEW_SIZE), LATENT_DIM, visual_projection, DEVICE))
run_episode(env, agent, limit=50)