# Memory Model: MDN-RNN
This notebook explains the architecture, training pipeline, and visualization of our MDN-RNN memory model.

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import torchvision.transforms as transforms
import glob
import sys
from tqdm import tqdm
from IPython.display import display, HTML, clear_output


In [2]:
import sys
import os
from pathlib import Path


# Get the absolute path to the parent directory of the notebook
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

path = str(Path().cwd().parent)  

## 1. Data Preparation: CarRacingDataset

The `CarRacingDataset` class is responsible for loading and preprocessing the data for training our RNN-MDN model. This dataset handles sequences of latent vectors that were previously extracted using a Variational Autoencoder (VAE).

Key features:
- Loads sequences of latent vectors from files
- Creates input-output pairs for training sequence prediction
- Handles time series data with sequence length parameter
- Prepares batched data for efficient training

In [3]:
from src.utils.dataset import CarRacingDataset_RNN

pygame 2.6.1 (SDL 2.28.4, Python 3.12.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


## 2. RNN-MDN Model Architecture

The RNN-MDN (Recurrent Neural Network with Mixture Density Network) is a powerful model that combines the sequence modeling capabilities of LSTMs with the distributional output of Mixture Density Networks. This architecture is ideal for our world model's memory component because:

1. The LSTM component captures temporal dependencies in sequences of latent vectors
2. The MDN component models the uncertainty in predicting future latent states

<img src='imgs/rnn_mdn.png' width=800>


### Mixture Density Network (MDN)

The MDN outputs parameters for a mixture of Gaussian distributions:

- **π**: The mixture weights (which Gaussian to pick)
- **μ**: The means of each Gaussian
- **σ**: The standard deviations of each Gaussian

For a mixture with K components and D-dimensional output:
- π has shape [K]
- μ has shape [K, D]
- σ has shape [K, D]

The probability density function is:

$$p(y|x) = \sum_{k=1}^{K} \pi_k(x) \mathcal{N}(y|\mu_k(x), \sigma_k^2(x))$$

Where:
- $\pi_k(x)$ is the mixture weight for component k
- $\mathcal{N}(y|\mu_k(x), \sigma_k^2(x))$ is the Gaussian probability density for component k

<img src='imgs/mdn.png' width=800>


### Dream System Architecture

The RNN-MDN model serves as the "memory" component in our World Model architecture. It predicts the next latent state based on the current latent state, action, and reward.

In the dreaming mode, the model can generate sequences of latent states without real input, allowing the agent to "imagine" trajectories through the environment.

<img src='imgs/dream_diagram.png' width=800>


## 3. Training the RNN-MDN Model

Training the RNN-MDN model involves minimizing the negative log-likelihood of the target latent vectors given the predicted mixture parameters. 

The loss function is derived from the probability density function of the mixture model:

$$\text{Loss} = -\log\left(\sum_{k=1}^{K} \pi_k \mathcal{N}(y|\mu_k, \sigma_k^2)\right)$$

To prevent numerical underflow, we use the log-sum-exp trick:

$$\log\sum_i e^{x_i} = a + \log\sum_i e^{x_i - a}$$

where $a = \max_i(x_i)$

In [4]:
from src.models.dds_vae import Vision

# Instantiate model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vision = Vision(n_features_to_select=0.03, 
                in_ch=3, 
                out_ch=3, 
                base_ch=16, 
                alpha=1.0, 
                delta=0.1
).to(device)
vision.load_state_dict(torch.load(path+'/src/trained_models/vision_03_miniVAE.pth', map_location=device, weights_only=True))
vision.eval()


Vision(
  (unet1): UNet(
    (inc): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
    )
    (down1): Down(
      (down): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): ConvBlock(
          (conv): Sequential(
            (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
      )
    )
    (down2): Down(
      (down): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): ConvBlock(
          (conv): Sequential(
            (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(

In [5]:
from src.models.mdn_rnn import MDNRNN

memory = MDNRNN(latent_dim=32, 
                action_dim=3, 
                hidden_dim=256, 
                num_gaussians=5
).to(device)

memory.load_state_dict(torch.load(path+'/src/trained_models/memory.pth', map_location=device, weights_only=True))

<All keys matched successfully>

## 4. Dream Visualization with RNN-MDN

In this section, we'll use the trained RNN-MDN model to generate "dreams" - sequences of latent vectors that are decoded back into images using the VAE decoder. This demonstrates how the model can imagine possible future states without real input.

The process involves:
1. Starting with an initial latent vector
2. Using the RNN-MDN to predict the next latent vector
3. Decoding the predicted latent vector into an image using the VAE
4. Repeating the process to generate a sequence of images

### **Predict next frame**

In [7]:
import gymnasium as gym
import pygame
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
import cv2
from src.utils.utils import TRANSFORM as transform
from src.models.mdn_rnn import sample_mdn

def run_car_racing_rnn_mda(env_name, vision, mdnrnn, transform, device, scale=1, resolution=(150, 150), tau =1.0):
    
    # Initialize pygame for rendering
    pygame.init()
    resolution = (resolution[0] * 2 * scale, resolution[1] * scale)
    screen = pygame.display.set_mode(resolution)
    clock = pygame.time.Clock() 
    
    action = np.zeros(3)  # Initialize action array
    
    def get_action(keys):
        """ Map keyboard input to actions """
        action[0] = -1.0 if keys[pygame.K_LEFT] else 1.0 if keys[pygame.K_RIGHT] else 0.0  # Steering
        action[1] = 1.0 if keys[pygame.K_UP] else 0.0  # Accelerate
        action[2] = 1.0 if keys[pygame.K_DOWN] else 0.0  # Brake
        return action
    
    
    # Initialize the environment
    env = gym.make(env_name, render_mode='rgb_array')
    obs, _ = env.reset()
    
    running = True
    h = mdnrnn.rnn.init_hidden(1) 
    h = (h[0].to(device), h[1].to(device))
    
    cnt = 0
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
        
        keys = pygame.key.get_pressed()  # Get current key states
        action = get_action(keys)        # Update action based on key presses

        # Enviroment step 
        obs, reward, done, info, _ = env.step(action)
        
        # Render and process the frame
        x = transform(obs).unsqueeze(0).to(device)  # Transform frame to tensor

        with torch.no_grad():
            mask, mini_mask, z = vision.encode(x)           
                
            # Generate predicted next image  using RNN-MDA
            action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

            pi, mu, sigma, h = mdnrnn(z.unsqueeze(0),action_tensor, h=h, tau=tau)
            z_next = sample_mdn(pi, mu, sigma)
            
            # Decode MDN sampled latent vector
            x_hat, mask_hat, mini_mask_hat =  vision.decode(z_next.squeeze(0))
            reconstructed = x_hat

        cnt+=1
        
        # Prepare images for display
        reconstructed = (reconstructed.squeeze(0).permute(2, 1, 0).cpu().numpy() * 255).astype(np.uint8)
        obs = (x.squeeze(0).permute(2, 1, 0).cpu().numpy() * 255).astype(np.uint8)
        
        # Concatenate original and reconstructed images
        full_image = np.concatenate((obs, reconstructed), axis=0)
        full_image_resized = cv2.resize(full_image, (resolution[1], resolution[0]), interpolation=cv2.INTER_LINEAR)
        
        # Display the combined image
        clock.tick(30)
        pygame.surfarray.blit_array(screen, full_image_resized)
        pygame.display.flip()
        
        if done:
            obs, _ = env.reset()  # Reset environment if done
            h = mdnrnn.rnn.init_hidden(1) 
            h = (h[0].to(device), h[1].to(device))
            
    pygame.quit()
    env.close()
# 

run_car_racing_rnn_mda(env_name="CarRacing-v3", vision=vision, mdnrnn=memory, transform=transform, device=device, scale=4, tau=0.1)


## **Dream**

In [8]:
import gymnasium as gym
import pygame
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from src.utils.utils import setup_video_writer

def run_car_racing_rnn_mda(env_name, vision, mdnrnn, transform, device, scale=1, resolution=(150, 150), tau=1.0, video_filepath='renders/dream.mp4', save_video=False):
    import pygame
    import numpy as np
    import torch
    import cv2

    # Initialize the environment
    env = gym.make(env_name, render_mode='rgb_array')
    obs, _ = env.reset()
    for _ in range(0):
        env.step(np.array([0, 0, 0]))

    # Initialize pygame for rendering
    pygame.init()
    resolution = (resolution[0] * scale, resolution[1] * scale)
    screen = pygame.display.set_mode(resolution)
    clock = pygame.time.Clock()

    action = np.zeros(3)  # Initialize action array
    video_writer = setup_video_writer(video_filepath, resolution[::-1]) if save_video else None

    def get_action(keys):
        """ Map keyboard input to actions """
        action[0] = -1.0 if keys[pygame.K_LEFT] else 1.0 if keys[pygame.K_RIGHT] else 0.0  # Steering
        action[1] = 1.0 if keys[pygame.K_UP] else 0.0  # Accelerate
        action[2] = 1.0 if keys[pygame.K_DOWN] else 0.0  # Brake
        return action

    running = True
    h = mdnrnn.rnn.init_hidden(1)
    h = (h[0].to(device), h[1].to(device))

    cnt = 0
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False

        keys = pygame.key.get_pressed()  # Get current key states
        action = get_action(keys)  # Update action based on key presses

        # Environment step
        obs, reward, done, info, _ = env.step(action)
        obs_tensor = transform(obs).unsqueeze(0).to(device)  # Transform frame to tensor

        with torch.no_grad():
            if cnt == 0:
                mask, mini_mask, z = vision.encode(obs_tensor)
                z = z.unsqueeze(0)
            else:
                z = z_next

            action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
            pi, mu, sigma, h = mdnrnn(z, action_tensor, h=h, tau=tau)
            z_next = sample_mdn(pi, mu, sigma)

            x_hat, mask_hat, mini_mask_hat = vision.decode(z_next.squeeze(0))
            reconstructed = x_hat

        cnt += 1

        # Prepare the reconstructed image for display
        reconstructed = (reconstructed.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
        reconstructed_resized = cv2.resize(reconstructed, (resolution[0], resolution[1]))

        # Convert reconstructed image to pygame surface
        frame_surface = pygame.surfarray.make_surface(reconstructed_resized.swapaxes(0, 1))

        # Blit frame to the screen
        screen.blit(frame_surface, (0, 0))

        # Draw key indicators using pygame shapes
        indicator_color = (255, 0, 0)  # Red color for indicators

        # Left arrow (proportional size)
        if keys[pygame.K_LEFT]:
            pygame.draw.polygon(screen, indicator_color, [(40, resolution[1] - 40), (20, resolution[1] - 30), (40, resolution[1] - 20)])

        # Right arrow (proportional size)
        if keys[pygame.K_RIGHT]:
            pygame.draw.polygon(screen, indicator_color, [(100, resolution[1] - 40), (120, resolution[1] - 30), (100, resolution[1] - 20)])

        # Up arrow (centered and proportional)
        if keys[pygame.K_UP]:
            pygame.draw.polygon(screen, indicator_color, [(70, resolution[1] - 70), (50, resolution[1] - 50), (90, resolution[1] - 50)])

        # Down arrow (proportional rectangle)
        if keys[pygame.K_DOWN]:
            pygame.draw.rect(screen, indicator_color, (140, resolution[1] - 60, 40, 40))

        # Update the display
        clock.tick(30)
        pygame.display.flip()

        if save_video and video_writer is not None:
            video_writer.write(cv2.cvtColor(reconstructed_resized, cv2.COLOR_RGB2BGR))

        if done:
            obs = env.reset()
            h = mdnrnn.rnn.init_hidden(1)
            h = (h[0].to(device), h[1].to(device))

    pygame.quit()
    env.close()

    if save_video and video_writer is not None:
        video_writer.release()


run_car_racing_rnn_mda(env_name="CarRacing-v3", vision=vision, mdnrnn=memory, transform=transform, device=device, scale=4, tau=.001, save_video=False)