In [None]:
import pandas as pd
%load_ext autoreload
%autoreload 2

In [None]:
import holoviews as hv
import panel as pn
import plangym
from plangym.utils import process_frame
hv.extension('bokeh')

In [None]:
pn.extension()

In [None]:
env = plangym.make('PlanMontezuma-v0', obs_type="coords", return_image=True, frameskip=3, check_death=False, n_workers=18, ray=True, episodic_life=True)

In [None]:
import param
import time
import numpy as  np


class FaiRunner(param.Parameterized):
    is_running = param.Boolean(default=False)

    def __init__(self, fai, n_steps, plot=None, report_interval=100):
        super().__init__()
        self.reset_btn = pn.widgets.Button(icon="restore", button_type="primary")
        self.play_btn = pn.widgets.Button(icon="player-play", button_type="primary")
        self.pause_btn = pn.widgets.Button(icon="player-pause", button_type="primary")
        self.step_btn = pn.widgets.Button(name="Step", button_type="primary")
        self.progress = pn.indicators.Progress(name="Progress", value=0, width=600, max=n_steps, bar_color="primary")
        self.sleep_val = pn.widgets.FloatInput(value=0.0, width=60)
        self.report_interval = pn.widgets.IntInput(value=report_interval)
        self.table = pn.widgets.DataFrame(height=50, width=600)
        self.fai = fai
        self.n_steps = n_steps
        self.curr_step = 0
        self.plot = plot

    @param.depends("reset_btn.value")
    def on_reset_click(self):
        self.fai.reset()
        self.curr_step = 0
        self.progress.value = 1
        self.curr_step = 0
        self.play_btn.disabled = False
        self.pause_btn.disabled = True
        self.step_btn.disabled = False
        self.is_running = False
        self.progress.bar_color = "primary"
        summary = pd.DataFrame(self.fai.summary(), index=[0])
        self.table.value = summary
        if self.plot is not None:
            self.plot.reset(self.fai)
            self.plot.send(self.fai)

    @param.depends("play_btn.value")
    def on_play_click(self):
        self.play_btn.disabled = True
        self.pause_btn.disabled = False
        self.is_running = True

    @param.depends("pause_btn.clicks")
    def on_pause_click(self):
        self.play_btn.disabled = False
        self.pause_btn.disabled = True
        self.is_running = False

    @param.depends("step_btn.value")
    def on_step_click(self):
        self.is_running = True
        self.run()
        self.is_running = False

    def run(self):
        if not self.is_running:
            return
        self.fai.step_tree()
        self.curr_step += 1
        self.progress.value = self.curr_step
        if self.curr_step >= self.n_steps:
            self.is_running = False
            self.progress.bar_color = "success"
            self.step_btn.disabled = True
            self.play_btn.disabled = True
            self.pause_btn.disabled = True
        
        if self.fai.oobs.sum().cpu().item() == self.fai.n_walkers - 1:
            self.is_running = False
            self.progress.bar_color = "danger"

        if self.fai.iteration % self.report_interval.value == 0:
            summary = pd.DataFrame(self.fai.summary(), index=[0])
            self.table.value = summary
            if self.plot is not None:
                self.plot.send(self.fai)
        time.sleep(self.sleep_val.value)

    def __panel__(self):
        pn.state.add_periodic_callback(self.run, period=1)
        return pn.Column(
            self.table,
            self.progress,
            pn.Row(
                self.play_btn,
                self.pause_btn,
                self.reset_btn,
                self.step_btn,
                pn.pane.Markdown("**Sleep**"),
                self.sleep_val,
                self.report_interval,
            ),
            self.on_play_click,
            self.on_pause_click,
            self.on_reset_click,
            self.on_step_click,
        )
    
        

In [None]:
import param
import panel as pn
from plangym.utils import process_frame
import numpy as  np
from fragile.shaolin.stream_plots import RGB, Image
from  holoviews.streams import Pipe

def draw_tree(data):
    if not data:
        return hv.Segments(bgcolor=None) * hv.Scatter(bgcolor=None)
    observ = data.observ.cpu().numpy().astype(np.int64)
    parents = data.parent.cpu().numpy()
    observ[:, 0] = observ[:, 0] / 2
    obs_x = (observ[:, 0]) / 160 -0.5
    obs_y = (159 - observ[:, 1]) / 159 -0.5
    segs = obs_x[parents], obs_y[parents], obs_x, obs_y
    edges = hv.Segments(segs).opts(line_color="black", bgcolor=None)
    nodes = hv.Scatter((obs_x, obs_y)).opts(size=2, bgcolor=None, color="red")
    return edges * nodes

class MontezumaDisplay:

    def __init__(self, ):
        self.best_rgb = RGB()
        self.room_images = [Image(cmap="greys") for _ in range(24)]
        self.visits_images = [Image(alpha=0.7, xaxis=None, yaxis=None, cmap="fire") for _ in range(24)]
        self.visits = np.zeros((24, 160, 160), dtype=np.int32) * np.nan
        self.rooms = np.zeros((24, 160, 160))
        self.visited_rooms = []
        self.pipe = Pipe()
        self._curr_best = -1
        self.tree = hv.DynamicMap(draw_tree, streams=[self.pipe])


    def reset(self, fai):
        self.visited_rooms = []
        self.visits = np.zeros((24, 160, 160), dtype=np.int32) * np.nan
        self.rooms = np.zeros((24, 160, 160))

    def send(self, fai):
        best_ix = fai.cum_reward.argmax().cpu().item()
        best_rgb = fai.rgb[best_ix]
        if best_ix != self._curr_best:
            self.best_rgb.send(best_rgb)
            self._curr_best = best_ix

        observ = fai.observ.cpu().numpy().astype(np.int64)
        observ[:, 0] = observ[:, 0] / 2
        room_ix = observ[:, 2]
        self.visits[observ[:, 2], observ[:, 1], observ[:, 0]] = np.where(
            np.isnan(self.visits[observ[:, 2], observ[:, 1], observ[:, 0]]),
            1,
            self.visits[observ[:, 2], observ[:, 1], observ[:, 0]] + 1
        )
        for ix in np.unique(room_ix):
            if ix not in self.visited_rooms:
                self.visited_rooms.append(ix)
                batch_ix = np.argmax(room_ix == ix)
                self.rooms[ix] = process_frame(fai.rgb[batch_ix][50:], mode="L").copy()
                self.room_images[ix].send(self.rooms[ix])
            self.visits_images[ix].send(self.visits[ix])
        self.pipe.send(fai)

    def __panel__(self):
        return pn.Row(self.best_rgb.plot,
                      self.room_images[1].plot * self.visits_images[1].plot,
                      self.room_images[1].plot * self.tree)
                

In [None]:
from fragile.montezuma import FractalTree
n_walkers = 250
plot = MontezumaDisplay()
fai = FractalTree(n_walkers=n_walkers, env=env, device="cpu")

In [None]:
runner = FaiRunner(fai, 1000000, plot=plot)
pn.panel(pn.Column(runner, plot))

In [None]:
hv.RGB(fai.rgb[2])

In [None]:
np.argmax(fai.observ[:, 2] == 5)

In [None]:
np.unique(fai.observ[:, 2])

In [None]:
batch_ix = 1
plot.rooms[5] = process_frame(fai.rgb[batch_ix][50:], mode="L").copy()

In [None]:
hv.Image(plot.rooms[5])

In [None]:
def draw_pyramid(rooms):
    height = 120
    width = 120
    PYRAMID = [
        [-1, -1, -1, 0, 1, 2, -1, -1, -1],
        [-1, -1, 3, 4, 5, 6, 7, -1, -1],
        [-1, 8, 9, 10, 11, 12, 13, 14, -1],
        [15, 16, 17, 18, 19, 20, 21, 22, 23],
    ]
    img_opts = dict(
                xaxis=None, yaxis=None, width=height, height=width, backend_opts={"plot.toolbar.autohide": True}, bgcolor=None, cmap="greys",
                normalize=True, framewise=True,
            )
    empty = np.ones_like(rooms[0]) * np.nan
    all_rows = []
    for rooms_row in PYRAMID:
        curr_row = []
        for ix in rooms_row:
            arr = empty if ix == -1 else rooms[ix]
            img = hv.Image(arr).opts(opts)
            curr_row.append(pn.pane.HoloViews(img, sizing_mode='fixed', margin=(-10, -15), min_height=height, min_width=width))
        curr_row = pn.Row(*curr_row, sizing_mode='fixed', margin=(0, 0))
        all_rows.append(curr_row)
    return pn.Column(*all_rows, sizing_mode='fixed', margin=(0, 0))
    
            

In [None]:
plot.rooms[1].max()

In [None]:
draw_pyramid(plot.rooms)

In [None]:
draw_pyramid(plot.rooms)

In [None]:
plot.room_images.shape

In [None]:
plot.best_rgb.plot

In [None]:
fai.observ

In [None]:

room_ix = observ[:, 2]
(hv.Image(plot.rooms[1]).opts(cmap="greys") * 
 hv.Image(plot.visits[1]).opts(alpha=0.7, cmap="fire") * draw_tree(fai)
 
)

In [None]:
observ = fai.observ.cpu().numpy().astype(np.int64)
observ[:, 0] = observ[:, 0] / 2
obs_y = observ[:, 0] / 160 -0.5
obs_x = observ[:, 1] / 160 -0.5
room_ix = observ[:, 2]
hv.Scatter((obs_x, obs_y))

In [None]:
hv.Scatter((obs_x, obs_y))

In [None]:
from fragile.shaolin.stream_plots import RGB, Image

rgb = Image()

In [None]:
obs = fai.observ.cpu().numpy().astype(np.int64)
obs[:, 0] = obs[:, 0]/2
obs[:, :3]

In [None]:
visits = np.zeros((24, 160, 160), dtype=np.int32)


In [None]:
np.unique(obs[:, 2])

In [None]:
visits[obs[:, 2], obs[:, 1], obs[:, 0]] += 1

In [None]:
plt.imshow(visits[1, :, :])

In [None]:
x = process_frame(rgb, mode="L")
plt.imshow(x)

In [None]:
from plangym.utils import process_frame

In [None]:
fai.rgb[0][50:].shape

In [None]:
rgb.plot

In [None]:
data = self.env.step_batch(states=self.state, actions=self.action)
new_states, observ, reward, oobs, _truncateds, infos = data

In [None]:
oobs

In [None]:
_truncateds

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(infos[0]["rgb"])

In [None]:
runner.control.is_running

In [None]:
runner.progress.bar_color = "danger"