TENSORBOARD

In [1]:
# Launch TensorBoard
%load_ext tensorboard
%tensorboard --logdir "C:/Users/Jed/Desktop/8803DRL/option-critic-pytorch/runs" --host localhost --port 6008

Launching TensorBoard...

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import matplotlib.pyplot as plt
from matplotlib import animation, colors
from IPython.display import HTML
import numpy as np
import torch
from fourrooms import Fourrooms

def run_and_render_episode(env, model, max_steps=200):
    """
    Runs one episode and returns the rendered frames for animation.
    """
    # 1. Initialize
    obs = env.reset()
    state = model.get_state(torch.tensor(obs, dtype=torch.float32).unsqueeze(0))
    done = False
    steps = 0
    
    # Track frames for animation
    frames = []
    
    # Store the grid for visualization
    # We capture the initial state
    frames.append(env.render())

    # 2. Run Episode Loop
    greedy_option = model.greedy_option(state)
    current_option = greedy_option # Start with greedy option
    
    while not done and steps < max_steps:
        # Get action from model
        # We assume evaluation mode (epsilon=0 or very low)
        action, logp, entropy = model.get_action(state, current_option)
        
        # Take step
        next_obs, reward, done, _ = env.step(action)
        
        # Capture frame
        frames.append(env.render())
        
        # Update State and Options for next step
        state = model.get_state(torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0))
        option_termination, greedy_option = model.predict_option_termination(state, current_option)
        
        if option_termination:
            current_option = greedy_option
        
        steps += 1

    return frames

def display_animation(frames, interval=200):
    """
    Converts a list of grid arrays into a Matplotlib HTML5 animation.
    """
    # Create a custom colormap
    # Values in your env: Wall=1, Empty=0, Agent/Goal=-1
    # Map: -1: Red (Agent), 0: White (Empty), 1: Black (Wall)
    cmap = colors.ListedColormap(['red', 'white', 'black'])
    bounds = [-1.5, -0.5, 0.5, 1.5]
    norm = colors.BoundaryNorm(bounds, cmap.N)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.axis('off')
    
    # Initial plot
    img = ax.imshow(frames[0], cmap=cmap, norm=norm, interpolation='nearest')

    def animate(i):
        img.set_data(frames[i])
        return [img]

    anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=interval, blit=True)
    plt.close() # Prevent static plot from showing
    return HTML(anim.to_jshtml())

# --- EXECUTION ---

# Ensure model is in eval mode if it has distinct modes (optional based on your implementation)
# option_critic.eval() 

# Run the episode
env = Fourrooms()
option_critic = torch.load('path_to_your_trained_model.pth')  # Load your trained model
option_critic.eval()  # Set model to evaluation mode
episode_frames = run_and_render_episode(env, option_critic)

# Render the animation
display_animation(episode_frames)