In [None]:
import pandas as pd
%load_ext autoreload
%autoreload 2

In [1]:
import holoviews as hv
import panel as pn
import plangym
from plangym.utils import process_frame
hv.extension('bokeh')

In [2]:
pn.extension("tabulator", theme="dark")

In [3]:
env = plangym.make('PlanMontezuma-v0', obs_type="coords", return_image=True, frameskip=3, check_death=True, episodic_life=False, n_workers=10, ray=True)

2024-10-18 07:37:45,319	INFO worker.py:1786 -- Started a local Ray instance.
[36m(RemoteEnv pid=553895)[0m A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[36m(RemoteEnv pid=553895)[0m [Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [4]:
import param
import time
import numpy as  np


class FaiRunner(param.Parameterized):
    is_running = param.Boolean(default=False)

    def __init__(self, fai, n_steps, plot=None, report_interval=100):
        super().__init__()
        self.reset_btn = pn.widgets.Button(icon="restore", button_type="primary")
        self.play_btn = pn.widgets.Button(icon="player-play", button_type="primary")
        self.pause_btn = pn.widgets.Button(icon="player-pause", button_type="primary")
        self.step_btn = pn.widgets.Button(name="Step", button_type="primary")
        self.progress = pn.indicators.Progress(name="Progress", value=0, width=600, max=n_steps, bar_color="primary")
        self.sleep_val = pn.widgets.FloatInput(value=0.0, width=60)
        self.report_interval = pn.widgets.IntInput(value=report_interval)
        self.table = pn.widgets.Tabulator()
        self.fai = fai
        self.n_steps = n_steps
        self.curr_step = 0
        self.plot = plot
        self.erase_coef_val = pn.widgets.FloatInput(value=0.05, width=60, name="erase")
        
    @param.depends("erase_coef_val.value")
    def update_erase_coef(self):
        self.fai.erase_coef = self.erase_coef_val.value

    @param.depends("reset_btn.value")
    def on_reset_click(self):
        self.fai.reset()
        self.curr_step = 0
        self.progress.value = 1
        self.curr_step = 0
        self.play_btn.disabled = False
        self.pause_btn.disabled = True
        self.step_btn.disabled = False
        self.is_running = False
        self.progress.bar_color = "primary"
        summary = pd.DataFrame(self.fai.summary(), index=[0])
        self.table.value = summary
        if self.plot is not None:
            self.plot.reset(self.fai)
            self.plot.send(self.fai)

    @param.depends("play_btn.value")
    def on_play_click(self):
        self.play_btn.disabled = True
        self.pause_btn.disabled = False
        self.is_running = True

    @param.depends("pause_btn.clicks")
    def on_pause_click(self):
        self.play_btn.disabled = False
        self.pause_btn.disabled = True
        self.is_running = False

    @param.depends("step_btn.value")
    def on_step_click(self):
        self.is_running = True
        self.run()
        self.is_running = False

    def run(self):
        if not self.is_running:
            return
        self.fai.step_tree()
        self.curr_step += 1
        self.progress.value = self.curr_step
        if self.curr_step >= self.n_steps:
            self.is_running = False
            self.progress.bar_color = "success"
            self.step_btn.disabled = True
            self.play_btn.disabled = True
            self.pause_btn.disabled = True
        
        if self.fai.oobs.sum().cpu().item() == self.fai.n_walkers - 1:
            self.is_running = False
            self.progress.bar_color = "danger"

        if self.fai.iteration % self.report_interval.value == 0:
            summary = pd.DataFrame(self.fai.summary(), index=[0])
            self.table.value = summary
            if self.plot is not None:
                self.plot.send(self.fai)
        time.sleep(self.sleep_val.value)

    def __panel__(self):
        pn.state.add_periodic_callback(self.run, period=20)
        return pn.Column(
            self.table,
            self.progress,
            pn.Row(
                self.play_btn,
                self.pause_btn,
                self.reset_btn,
                self.step_btn,
                pn.pane.Markdown("**Sleep**"),
                self.sleep_val,
                self.report_interval,
                self.erase_coef_val,
            ),
            self.on_play_click,
            self.on_pause_click,
            self.on_reset_click,
            self.on_step_click,
            self.update_erase_coef,
        )
    
        

In [5]:
from fragile.montezuma import aggregate_visits
from functools import partial
import param
import panel as pn
from plangym.utils import process_frame
import numpy as  np
from fragile.shaolin.stream_plots import RGB, Image
from  holoviews.streams import Pipe

PYRAMID = [
    [-1, -1, -1, 0, 1, 2, -1, -1, -1],
    [-1, -1, 3, 4, 5, 6, 7, -1, -1],
    [-1, 8, 9, 10, 11, 12, 13, 14, -1],
    [15, 16, 17, 18, 19, 20, 21, 22, 23],
]
EMPTY_ROOMS = [(0, 0), (0, 1), (0, 2), (0, 6), (0, 7), (0, 8), (1, 0), (1, 1), (1, 7), (1, 8), (2, 0), (2, 8)]

def get_rooms_xy(pyramid=None) -> np.ndarray:
    """Get the tuple that encodes the provided room."""
    pyramid = pyramid if pyramid is not None else PYRAMID
    n_rooms = max(max(row) for row in pyramid) + 1
    rooms_xy = []
    for room in range(n_rooms):
        for y, loc in enumerate(PYRAMID):
            if room in loc:
                room_xy = [loc.index(room), y]
                rooms_xy.append(room_xy)
                break
    return np.array(rooms_xy)


def get_pyramid_layout(room_h=160, room_w=160, channels=3, pyramid=None, empty_rooms=None):
    pyramid = pyramid if pyramid is not None else PYRAMID
    ph, pw = len(pyramid), len(pyramid[0])
    all_rooms = np.zeros((room_h * ph, room_w*pw, channels))
    return set_empty_rooms(all_rooms, empty_rooms, height=room_h, width=room_w)

def set_empty_rooms(all_rooms, empty_rooms=None, height=160, width=160):
    empty_rooms = empty_rooms if empty_rooms is not None else EMPTY_ROOMS
    val = np.array([255, 255, 255], dtype=np.uint8)
    for i, j in empty_rooms:
        all_rooms[i*height:(i+1)*height, j*width:(j+1)*width] = val
    return all_rooms

def draw_rooms(rooms, pyramid_layout=None, height=160, width=160):
    pyramid_layout = pyramid_layout if pyramid_layout is not None else get_pyramid_layout()
    for n_room, room in rooms.items():
        i, j = env.gym_env.get_room_xy(n_room)
        coord_x, coord_x1 = j * width, (j+1) * width
        coord_y, coord_y1 = i * height, (i+1) * height
        pyramid_layout[coord_x:coord_x1, coord_y:coord_y1, :] = room
    return pyramid_layout

def to_pyramid_coords(observ, room_xy, width=160, height=160):
    x, y, room = observ[:, 0].astype(np.int64), observ[:, 1].astype(np.int64), observ[:, 2].astype(np.int64)
    room_coords = room_xy[room]
    offset_coords = room_coords * np.array([width, height])
    abs_coords = np.array([x, y]).T + offset_coords
    return abs_coords

def to_plot_coords(room_coords, width=160, height=160):
    plot_x = (room_coords[:, 0]) / (width - 1) -0.5
    plot_y = ((height - 1) - room_coords[:, 1]) / (height - 1) -0.5
    return plot_x, plot_y

def draw_pyramid(data, pyramid_layout=None):
    return hv.RGB(draw_rooms(data, pyramid_layout)).opts(width=1440, height=640, xaxis=None, yaxis=None)
    
def draw_tree_pyramid(data, max_x: int = 1440, max_y: int = 640, room_xy=None):
    room_xy= room_xy if room_xy is not None else get_rooms_xy()
    if not data:
        return hv.Segments(bgcolor=None) * hv.Scatter(bgcolor=None)
    observ = data.observ.cpu().numpy().astype(np.int64)
    room_coords = to_pyramid_coords(observ, room_xy)
    parents = data.parent.cpu().numpy()
    room_coords[:, 0] = room_coords[:, 0] / data.env.gym_env._x_repeat
    plot_x, plot_y = to_plot_coords(room_coords, width=max_x, height=max_y)
    segs = plot_x[parents], plot_y[parents], plot_x, plot_y
    edges = hv.Segments(segs).opts(line_color="white", bgcolor=None)
    nodes = hv.Scatter((plot_x, plot_y)).opts(size=2, bgcolor=None, color="red", line_width=0.01, xaxis=None, yaxis=None)
    return edges * nodes

def draw_tree_best_room(data, width=160, height=160):
    if not data:
        return hv.Segments(bgcolor=None) * hv.Scatter(bgcolor=None)
    room_coords = data.observ.cpu().numpy().astype(np.int64)
    room = room_coords[:, 2][data.cum_reward.argmax().cpu().item()]
    room_ix = room_coords[:, 2] == room
    parents = data.parent.cpu().numpy()[room_ix]
    room_coords[:, 0] = room_coords[:, 0] / data.env.gym_env._x_repeat
    room_coords = room_coords[room_ix]
    plot_x, plot_y = to_plot_coords(room_coords, width=width, height=height)
    segs = plot_x[parents], plot_y[parents], plot_x, plot_y
    edges = hv.Segments(segs).opts(line_color="black", bgcolor=None)
    nodes = hv.Scatter((plot_x, plot_y)).opts(size=2, bgcolor=None, color="red")
    return edges * nodes

class MontezumaDisplay:

    def __init__(self, ):
        self.best_rgb = RGB()
        self.room_grey = Image(cmap="greys")
        self.visits_image = Image(alpha=0.7, xaxis=None, yaxis=None, cmap="fire", bgcolor=None)
        self.visits = np.zeros((24, 160, 160), dtype=np.int32) * np.nan
        self.rooms = np.zeros((24, 160, 160))
        self.visited_rooms = []
        self.pipe_tree = Pipe()
        self.room_pipe = Pipe()
        self._curr_best = -1
        #self.tree_best_room = hv.DynamicMap(draw_tree_best_room, streams=[self.pipe_tree])
        self.tree_pyramid = hv.DynamicMap(partial(draw_tree_pyramid, room_xy=get_rooms_xy()), streams=[self.pipe_tree])
        self.pyramid = hv.DynamicMap(draw_pyramid, streams=[self.room_pipe])


    def reset(self, fai):
        self.visited_rooms = []
        self.visits = np.zeros((24, 160, 160), dtype=np.int32) * np.nan
        self.rooms = np.zeros((24, 160, 160))

    def send(self, fai):
        best_ix = fai.cum_reward.argmax().cpu().item()
        best_rgb = fai.rgb[best_ix]
        if best_ix != self._curr_best:
            self.best_rgb.send(best_rgb)
            self._curr_best = best_ix

        observ = fai.observ.cpu().numpy().astype(np.int64)
        observ[:, 0] = observ[:, 0] / fai.env.gym_env._x_repeat
        room_ix = observ[:, 2]
        for ix in np.unique(room_ix):
            if ix not in self.visited_rooms:
                self.visited_rooms.append(ix)
                self.room_pipe.send(fai.env.gym_env.rooms)
                batch_ix = np.argmax(room_ix == ix)
                self.rooms[ix] = process_frame(fai.rgb[batch_ix][50:], mode="L").copy()
        best_room_ix = room_ix[best_ix]
        self.room_grey.send(self.rooms[best_room_ix])
        _visits = fai.visits[best_room_ix][None]
        _visits = aggregate_visits(_visits, block_size=8, upsample=True)[0]
        _visits[_visits == 0] = np.nan
        self.visits_image.send(_visits)
        self.pipe_tree.send(fai)
        

    def __panel__(self):
        return pn.Column(pn.Row(self.best_rgb.plot,
                                self.room_grey.plot * self.visits_image.plot,
                                #self.room_grey.plot * self.tree_best_room,
                                ),
                         self.pyramid * self.tree_pyramid,
                         )
                

In [6]:
hv.renderer("bokeh").theme = "carbon"
import pandas as pd

In [7]:
from fragile.utils import remove_notebook_margin

In [8]:
remove_notebook_margin()

In [9]:
plot = MontezumaDisplay()

In [10]:
from fragile.montezuma import FractalTree
n_walkers = 5000

fai = FractalTree(n_walkers=n_walkers, env=env, device="cpu", min_leafs=100, walkers_ix=100)

In [11]:
runner = FaiRunner(fai, 1000000, plot=plot)
pn.panel(pn.Column(runner, plot))

In [None]:
x = plot.visits.copy()
x[np.isnan(x)] = 0
x

In [None]:
fai.min_leafs = 100

In [None]:
import numpy as np

# Assume 'original_array' is your (batch_size, width, height) NumPy array
# Example data:
batch_size = x.shape[0]
width, height = 160, 160  # Original dimensions
original_array = x# np.random.randint(0, 10, size=(batch_size, width, height))

# Ensure width and height are divisible by the block size
block_size = 2
new_width = width // block_size
new_height = height // block_size

# Step 1: Reshape the array to separate the blocks
reshaped_array = original_array.reshape(
    batch_size,
    new_width, block_size,
    new_height, block_size
)

# Step 2: Sum over the block dimensions (axes 2 and 4)
aggregated_array = reshaped_array.sum(axis=(2, 4))

# 'aggregated_array' now has shape (batch_size, new_width, new_height)
print("Aggregated array shape:", aggregated_array.shape)

# 'aggregated_array' is now an 80x80 array


In [None]:
import numpy as np


# Example usage:
batch_size = 10
width, height = 160, 160
original_array = np.random.randint(0, 10, size=(batch_size, width, height))
aggregated_array = aggregate_array(original_array, block_size=2)
print("Aggregated array shape:", aggregated_array.shape)


In [None]:
hv.Image(aggregated_array[1]) + hv.Image(upsampled_array[1])

In [None]:

env.gym_env.rooms.keys()

In [None]:
h, w = 160, 160
ph, pw = 4, 9
all_rooms = np.zeros((h * ph, w*pw, 3))


In [None]:


get_rooms_xy()

In [None]:
room_xy

In [None]:

room_xy = np.array([list(env.gym_env.get_room_xy(i)) for i in range(24)], dtype=np.int64)

    

In [None]:
all_rooms = np.zeros((h * ph, w*pw, 3))
room_coords = to_pyramid_coords(fai.observ.cpu().numpy())
room_coords

In [None]:
nan_rooms = set_empty_rooms(all_rooms * np.nan)#.astype(np.uint8)
nan_rooms[room_coords[:, 1], room_coords[:, 0]] = np.array([255, 0, 0], dtype=np.float32)

In [None]:
nan_rooms[0:160, 640:800] = env.gym_env.rooms[1]

In [None]:
plt.imshow(nan_rooms.astype(np.uint8))

In [None]:
(hv.RGB(draw_rooms(env.gym_env.rooms))*
 draw_tree_pyramid(fai)).opts(width=1440, height=640, xaxis=None, yaxis=None)

In [None]:

(hv.RGB(draw_rooms(nan_rooms.copy(), env.gym_env.rooms))*
 hv.RGB(nan_rooms)).opts(width=1000, height=600)#, xaxis=None, yaxis=None)

In [None]:

room_xy

In [None]:
PYRAMID[]

In [None]:
import matplotlib.pyplot as plt
plt.imshow(env.gym_env.rooms[1])

In [None]:
import pickle 
with open("rooms.pkl", "wb") as f:
    pickle.dump(env.gym_env.rooms, f)

In [None]:
np.unique(fai.observ[:, 2])

In [None]:
batch_ix = 1
plot.rooms[5] = process_frame(fai.rgb[batch_ix][50:], mode="L").copy()

In [None]:
hv.Image(plot.rooms[5])

In [None]:
def draw_pyramid(rooms):
    height = 120
    width = 120
    PYRAMID = [
        [-1, -1, -1, 0, 1, 2, -1, -1, -1],
        [-1, -1, 3, 4, 5, 6, 7, -1, -1],
        [-1, 8, 9, 10, 11, 12, 13, 14, -1],
        [15, 16, 17, 18, 19, 20, 21, 22, 23],
    ]
    img_opts = dict(
                xaxis=None, yaxis=None, width=height, height=width, backend_opts={"plot.toolbar.autohide": True}, bgcolor=None, cmap="greys",
                normalize=True, framewise=True,
            )
    empty = np.ones_like(rooms[0]) * np.nan
    all_rows = []
    for rooms_row in PYRAMID:
        curr_row = []
        for ix in rooms_row:
            arr = empty if ix == -1 else rooms[ix]
            img = hv.Image(arr).opts(opts)
            curr_row.append(pn.pane.HoloViews(img, sizing_mode='fixed', margin=(-10, -15), min_height=height, min_width=width))
        curr_row = pn.Row(*curr_row, sizing_mode='fixed', margin=(0, 0))
        all_rows.append(curr_row)
    return pn.Column(*all_rows, sizing_mode='fixed', margin=(0, 0))
    
            

In [None]:
plot.rooms[1].max()

In [None]:
draw_pyramid(plot.rooms)

In [None]:
draw_pyramid(plot.rooms)

In [None]:
plot.room_grey.shape

In [None]:
plot.best_rgb.plot

In [None]:
fai.observ

In [None]:

room_ix = observ[:, 2]
(hv.Image(plot.rooms[1]).opts(cmap="greys") * 
 hv.Image(plot.visits[1]).opts(alpha=0.7, cmap="fire") * draw_tree(fai)
 
)

In [None]:
observ = fai.observ.cpu().numpy().astype(np.int64)
observ[:, 0] = observ[:, 0] / 2
obs_y = observ[:, 0] / 160 -0.5
obs_x = observ[:, 1] / 160 -0.5
room_ix = observ[:, 2]
hv.Scatter((obs_x, obs_y))

In [None]:
hv.Scatter((obs_x, obs_y))

In [None]:
from fragile.shaolin.stream_plots import RGB, Image

rgb = Image()

In [None]:
obs = fai.observ.cpu().numpy().astype(np.int64)
obs[:, 0] = obs[:, 0]/2
obs[:, :3]

In [None]:
visits = np.zeros((24, 160, 160), dtype=np.int32)


In [None]:
np.unique(obs[:, 2])

In [None]:
visits[obs[:, 2], obs[:, 1], obs[:, 0]] += 1

In [None]:
plt.imshow(visits[1, :, :])

In [None]:
x = process_frame(rgb, mode="L")
plt.imshow(x)

In [None]:
from plangym.utils import process_frame

In [None]:
fai.rgb[0][50:].shape

In [None]:
rgb.plot

In [None]:
data = self.env.step_batch(states=self.state, actions=self.action)
new_states, observ, reward, oobs, _truncateds, infos = data

In [None]:
oobs

In [None]:
_truncateds

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(infos[0]["rgb"])

In [None]:
runner.control.is_running

In [None]:
runner.progress.bar_color = "danger"