In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import imageio
import time
import torch
from MazeEnvironment import MazeEnv

In [None]:
def visualize_cnn(env, cnn, cmap = 'gray', save_path = None): 
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    observations, _ = env.reset()
    img = observations['image']/ 255.0 
    
    plt.imshow(img)
    if save_path is not None:
        plt.savefig(save_path + "_inp.png")
    plt.show()
    
    x = torch.FloatTensor(img).unsqueeze(0)
    x = torch.permute(x, (0, 3, 1, 2)).to(device)
    for i, layer in enumerate(cnn):
        with torch.no_grad():
            x = layer(x)
        if isinstance(layer, nn.Conv2d):
            num_channels = x.size()[1]
            fig, axs = plt.subplots(int(round(num_channels**0.5)), int(round(num_channels**0.5)))
            ax = axs.ravel()
            for i in range(num_channels):
                ax[i].set_title(f'Channel {i}')
                ax[i].imshow(x.cpu().detach().numpy()[0, i], cmap = cmap)
                ax[i].axis('off')
            if save_path is not None:
                plt.savefig(save_path + str(i) + '.png')
            plt.show()

In [None]:
def animate_policy(env, model, FPS: int = 12, do_truncate: bool = True, goal_dist=None):
    figure_size = (5, 5)

    s, _ = env.reset(options = {'goal_dist':goal_dist})
    
    env_info = {
        'actions': lambda a: ["↑","→","←","↓"][a[0]],
        'state_interpreter': lambda s: str(s['telemetry']),
    }

    step = 0
    
    while True:
        start_time = time.time()
        
        action = model.predict(s)
        
        step += 1
        
        clear_output(wait=True)
        
        plt.figure(figsize=figure_size)
        plt.imshow(s['image'])
        plt.axis('off')
        
        interp = env_info['state_interpreter'](s)
        action_str = env_info['actions'](action)
        
        # Add information below the image
        plt.text(0.5, -0.15, f"State: {interp}\nAction: {action_str}\nTime Step: {step}", 
         transform=plt.gca().transAxes, fontsize=12, 
         verticalalignment='bottom', horizontalalignment='center')
        
        
        plt.show()
        
        s, r, terminated, truncated, _ = env.step(action[0])
        r = float(r)
        
        end_time = time.time()
        if FPS:
            time.sleep(max(0,1 / FPS - (end_time - start_time)))
            
        if terminated or (truncated and do_truncate):
            break
    
    # Show final frame
    clear_output(wait=True)
    img, telemetry = s
    frame = img
    
    plt.figure(figsize=figure_size)
    plt.imshow(s['image'])
    plt.axis('off')
    
    interp = env_info['state_interpreter'](s)
    
    plt.text(0.5, -0.15, f"Final State: {interp}\nAction: {action_str}\nTime Step: {step}", 
         transform=plt.gca().transAxes, fontsize=12, 
         verticalalignment='bottom', horizontalalignment='center')
    
    plt.show()

In [None]:
def animate_and_save(env, model, video_path, FPS: int = 12, do_truncate: bool = True, goal_dist=None):
    figure_size = (5, 5)
    s, _ = env.reset(options={'goal_dist': goal_dist})

    env_info = {
        'actions': lambda a: ["↑", "→", "←", "↓"][a[0]],
        'state_interpreter': lambda s: str(s['telemetry']),
    }

    step = 0
    frames = []

    while True:
        start_time = time.time()
        action = model.predict(s)
        step += 1

        # Create the figure
        fig, ax = plt.subplots(figsize=figure_size)
        ax.imshow(s['image'])
        ax.axis('off')

        interp = env_info['state_interpreter'](s)
        action_str = env_info['actions'](action)

        # Text overlay
        ax.text(0.5, -0.15, f"State: {interp}\nAction: {action_str}\nTime Step: {step}",
                transform=ax.transAxes, fontsize=12,
                verticalalignment='bottom', horizontalalignment='center')

        # Convert the matplotlib figure to a NumPy array
        canvas = FigureCanvas(fig)
        canvas.draw()
        frame = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
        frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        frames.append(frame)

        plt.close(fig)

        s, r, terminated, truncated, _ = env.step(action[0])
        if terminated or (truncated and do_truncate):
            break

        end_time = time.time()
        if FPS:
            time.sleep(max(0, 1 / FPS - (end_time - start_time)))

    # Final frame
    fig, ax = plt.subplots(figsize=figure_size)
    ax.imshow(s['image'])
    ax.axis('off')
    interp = env_info['state_interpreter'](s)
    ax.text(0.5, -0.15, f"Final State: {interp}\nAction: {action_str}\nTime Step: {step}",
            transform=ax.transAxes, fontsize=12,
            verticalalignment='bottom', horizontalalignment='center')

    canvas = FigureCanvas(fig)
    canvas.draw()
    frame = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    frames.append(frame)
    plt.close(fig)

    # Save the frames as a video
    print(f"Saving video to {video_path}...")
    imageio.mimsave(video_path, frames, fps=FPS)
    print("Done.")


In [2]:
!jupyter nbconvert --to python plots.ipynb

[NbConvertApp] Converting notebook plots.ipynb to python
[NbConvertApp] Writing 2853 bytes to plots.py
