In [None]:
import gymnasium as gym
import numpy as np
import torch
from utils import *
from dreamer import *
import random
torch.set_printoptions(threshold=2000, linewidth=200, sci_mode=False)
np.set_printoptions(threshold=2000, linewidth=200)

environmentName = "CarRacing-v3"
renderMode = None
numUpdates = 40000
episodesBeforeStart = 10
playInterval = 10
stepCountLimit = 256
bufferSize = 20
resume = True
saveMetrics = True
saveCheckpoints = True
runName = f"{environmentName}_MINUS_ADV_SAMPLE"
checkpointToLoad = f"checkpoints/{runName}_5500"
metricsFilename = f"metrics/{runName}"
plotFilename = f"plots/{runName}"
videoFilename = f"videos/{runName}"
saveMetricsInterval = 10
checkpointInterval = 500
numNewEpisodePlay = 1
seed = 1

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gym.make(environmentName, render_mode=renderMode)
observationShape = torch.tensor(env.observation_space.shape)
actionSize = torch.tensor(env.action_space.shape) if hasattr(env.action_space, 'shape') else np.array([env.action_space.n])
print(f"Env {environmentName} with observations {observationShape} and actions {actionSize}\n###\n")
dreamer = Dreamer()

episodeBuffer = EpisodeBuffer(size=bufferSize)

if resume:
    dreamer.loadCheckpoint(checkpointToLoad)
    start = dreamer.totalUpdates
else:
    start = 0

for i in range(start - episodesBeforeStart, start + numUpdates + 1):
    for _ in range(numNewEpisodePlay):
        if i % playInterval == 0 or i < start:
            observation, info = env.reset(seed=seed + abs(i))
            observation = torch.from_numpy(np.transpose(observation, (2, 0, 1))).unsqueeze(0).to(device).float()/255.0
            observations, actions, rewards, dones = [observation], [], [], []
            stepCount, totalReward, done = 1, 0, False
            while not done:
                action = dreamer.act(observation, reset=(stepCount == 1))
                observation, reward, terminated, truncated, info = env.step(action.cpu().numpy())
                observation = torch.from_numpy(np.transpose(observation, (2, 0, 1))).unsqueeze(0).to(device).float()/255.0
                stepCount += 1
                done = terminated or truncated or stepCount >= stepCountLimit
                totalReward += reward
                
                observations.append(observation)
                actions.append(action)
                rewards.append(reward)
                # dones.append(done)

            episodeBuffer.addEpisode(torch.stack(observations).squeeze(1),
                                    torch.stack(actions).to(device),
                                    torch.tensor(rewards).view(-1).to(device))

        selectedEpisodeObservations, selectedEpisodeActions, selectedEpisodeRewards = episodeBuffer.sampleEpisode() 
        sampledFullState, worldModelLoss, reconstructionLoss, rewardPredictionLoss, klLoss = dreamer.trainWorldModel(selectedEpisodeObservations, selectedEpisodeActions, selectedEpisodeRewards)
        criticLoss, actorLoss, valueEstimate = dreamer.trainActorCritic(sampledFullState)

    if i % saveMetricsInterval == 0 and i > start and saveMetrics:
        saveLossesToCSV(metricsFilename, {
            "i": i,
            "worldModelLoss": worldModelLoss,
            "reconstructionLoss": reconstructionLoss,
            "rewardPredictionLoss": rewardPredictionLoss,
            "klLoss": klLoss,
            "criticLoss": criticLoss,
            "actorLoss": actorLoss,
            "valueEstimate": valueEstimate,
            "totalReward": totalReward})
        
        print(f"\nnewest actions:\n{episodeBuffer.getNewestEpisode()[1][:5]}")

    if i % checkpointInterval == 0 and i > start and saveCheckpoints:
        print(f"i {i:6}: worldModelLoss, criticLoss, actorLoss, reward = {worldModelLoss:8.4f}, {criticLoss:8.4f}, {actorLoss:8.4f}, {totalReward:.2f}")
        dreamer.totalUpdates = i
        dreamer.saveCheckpoint(f"checkpoints/{runName}_{i}")
        plotMetrics(metricsFilename, show=False, save=True, savePath=f"{plotFilename}_{i}") # TODO: plot should replace the unnecessary previous file
        saveVideoFromGymEnv(dreamer, environmentName, f"{videoFilename}_{i}", frameLimit=stepCountLimit)

env.close()

Env CarRacing-v3 with observations tensor([96, 96,  3]) and actions tensor([3])
###



  checkpoint = torch.load(checkpointPath)


Loaded checkpoint from: checkpoints/CarRacing-v3_MINUS_ADV_SAMPLE_5500.pth
advantages tensor([-0.0663, -0.0638, -0.0602, -0.0575, -0.0528, -0.0488, -0.0442, -0.0404, -0.0360, -0.0315, -0.0262, -0.0217, -0.0170, -0.0116, -0.0065], device='cuda:0')
logprobs tensor([ 8.5534, 15.5636, 13.4584, 18.6046, 20.7004, 17.3613,  6.5820, 14.7392, 18.3163, 11.1700,  9.9984,  8.5417, 18.1395, 13.3020, 20.3292], device='cuda:0', grad_fn=<StackBackward0>)
entropies tensor([-3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420, -3.7420], device='cuda:0', grad_fn=<StackBackward0>)
advantages tensor([0.0122, 0.0120, 0.0114, 0.0111, 0.0119, 0.0098, 0.0105, 0.0094, 0.0099, 0.0085, 0.0069, 0.0063, 0.0044, 0.0047, 0.0010], device='cuda:0')
logprobs tensor([10.7749, 10.0968, 21.5406, 12.1486, 20.7378, 13.8009, 10.8133, 20.9681, 17.2069, 19.9627, 20.9928, 19.0710, 20.8915, 15.9634, 20.2766], device='cuda:0', grad_fn=<StackBackward0>)
entropi

In [None]:
selectedEpisodeActions

In [None]:
for name, param in dreamer.actor.named_parameters():
    print(f"Name: {name}")
    print(f"Shape: {param.shape}")
    print(f"Values: {param.data}\n")

In [None]:
# Try out rollout of the world model
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageTk
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

# Initialize your Dreamer model and device here
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the start image
start_image = np.array(Image.open("startImage2.png"))  # Replace with uploaded image path
start_image_tensor = torch.from_numpy(np.transpose(start_image, (2, 0, 1))).unsqueeze(0).to(device).float() / 255.0

# Initialize the rollout
recurrent_state, latent_state = dreamer.rolloutInitialize(start_image_tensor)

# Define dark mode colors
BG_COLOR = "#333333"
FG_COLOR = "#DDDDDD"
SLIDER_COLOR = "#555555"
SLIDER_THUMB_COLOR = "#AAAAAA"
BUTTON_COLOR = "#444444"
BUTTON_HOVER_COLOR = "#666666"

# GUI setup
root = tk.Tk()
root.title("Dreamer Rollout Interface")
root.configure(bg=BG_COLOR)
root.attributes('-fullscreen', True)  # Fullscreen mode
root.bind("<Escape>", lambda event: root.attributes("-fullscreen", False))  # Exit fullscreen with ESC

# Position window on primary monitor (top left corner)
root.geometry(f"{root.winfo_screenwidth()}x{root.winfo_screenheight()}+0+0")

# Styling configuration
style = ttk.Style()
style.theme_use('clam')
style.configure("TFrame", background=BG_COLOR)
style.configure("TLabel", background=BG_COLOR, foreground=FG_COLOR)
style.configure("TButton", background=BUTTON_COLOR, foreground=FG_COLOR, font=("Arial", 12), relief="flat", padding=8)
style.map("TButton", background=[("active", BUTTON_HOVER_COLOR)])

# Display for rollout images
fig, ax = plt.subplots(figsize=(7, 7))
fig.patch.set_facecolor(BG_COLOR)
ax.set_facecolor(BG_COLOR)
canvas = FigureCanvasTkAgg(fig, master=root)
canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True, pady=(20, 10))

def update_observation_image(obs_image):
    ax.clear()
    ax.imshow(obs_image)
    ax.axis('off')
    canvas.draw()

# Frame for sliders positioned to the right and centered below the image
slider_frame = ttk.Frame(root)
slider_frame.pack(side=tk.TOP, pady=10)

action_labels = ["Steer", "Acceleration", "Brake"]
action_ranges = [(-1, 1), (0, 1), (0, 1)]
action = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32, device=device)
sliders = []

# Spacer to push sliders to the right
spacer = ttk.Frame(slider_frame, width=200, style="TFrame")
spacer.pack(side=tk.LEFT)

# Action sliders with custom ranges and names
for i in range(3):
    label = ttk.Label(slider_frame, text=action_labels[i], font=("Arial", 12, "bold"))
    label.pack(side=tk.LEFT, padx=(20, 10))

    slider = tk.Scale(slider_frame, from_=action_ranges[i][0], to=action_ranges[i][1], resolution=0.01, orient=tk.HORIZONTAL,
                      length=300, bg=BG_COLOR, fg=FG_COLOR, troughcolor=SLIDER_COLOR, sliderrelief="flat",
                      highlightthickness=0, activebackground=SLIDER_THUMB_COLOR)
    slider.set(action[i].item())
    slider.pack(side=tk.LEFT, padx=(0, 20))
    sliders.append(slider)

# Step function
def step():
    global recurrent_state, latent_state, action
    action_values = [slider.get() for slider in sliders]
    action = torch.tensor(action_values, dtype=torch.float32, device=device)
    
    # Rollout step
    next_recurrent_state, next_latent_state, next_observation, next_reward = dreamer.rolloutStep(
        recurrent_state, latent_state, action
    )
    recurrent_state, latent_state = next_recurrent_state, next_latent_state

    # Convert observation to image and display
    obs_image = next_observation.squeeze().permute(1, 2, 0).cpu().numpy()
    obs_image = np.clip(obs_image * 255, 0, 255).astype(np.uint8)
    update_observation_image(obs_image)

# Close (X) button in the top right corner
close_button = ttk.Button(root, text="X", command=root.destroy, style="TButton")
close_button.place(relx=0.98, rely=0.02, anchor="ne")  # Position in top-right corner

# Step button below sliders
step_button = ttk.Button(root, text="Step", command=step, style="TButton")
step_button.pack(side=tk.TOP, pady=20)

# Initial display
update_observation_image(start_image)

# Run GUI
root.mainloop()


In [None]:

original = selectedEpisodeObservations[1:].cpu()
reconstructed = dreamer.reconstructObservations(selectedEpisodeObservations, selectedEpisodeActions).cpu()
sideBySide = F.interpolate(torch.cat([original, reconstructed], dim=-1), size=(512, 1024), mode='bilinear')
saveVideoFrom4DTensor(sideBySide, f"results/sideBySideRepresentation_{runName}.mp4", fps=30)