<a href="https://colab.research.google.com/github/AndrewBoessen/CSCI3387_Notebooks/blob/main/Milestone_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CSCI3387 Project Milestone #2
_Andrew Boessen, Ian Bourgin, Theodore Grace_

# __Simulating Real-Time Game Environments with Neural Networks__

## Project Objective / Definition:

Video games are defined by a _game loop_ where the system (1) collect user inputs, (2) updates the game state, and (3) renders screen pixels.

An interactive game environment $\mathcal E$ consists of a latent space $\mathcal S$ of states, a set of actions $\mathcal A$, a space of observations $\mathcal O$, a projection function $V : \mathcal S → \mathcal O$, and a transition function $p(s|s',a)$ where $s',s \in \mathcal S, a \in \mathcal A$

Our project aims to simulate a game environment with a nueral network. This can be defined as an _Interactive World Simulation_ which given an environment $\mathcal E$ and an initial state $s_0 \in \mathcal S$ is a distribution function $q(o_n|o_{<n}, a_{<n}), o_i \in \mathcal O, a_i \in \mathcal A$. The Interactive World Simulation objective consists of minimizing $\mathbb{E}(D(o^i_q, o^i_p))$, where $D : \mathcal O \times \mathcal O \rightarrow \mathbb{R}$ is a distance metric between observations and $o^i_q \sim q$, $o^i_p \sim V(p)$.

## Atari Skiing

For our project, we have chosen to simulate _Skiing_, which is a game released for the Atari 2600. _Skiing_ is a simple game where the objective is to reach the bottom of the ski course in the least amount of time. To reach the bottom of the course, players must dodge obstacles and pass through a series of gates, indicated by flagpoles.

<div>
<img src="https://drive.google.com/uc?export=view&id=1aVJjOk12eTrPt7HU0Sc-5dMra27mSD8t" height="300">
</div>

We chose this game becaue of its simple graphics, small environment, and limited set of actions. In the case of _Skiing_, the environment $\mathcal E$ is defined where $\mathcal S$ is the programs dynamic memory content, i.e game variables, $\mathcal O$ is the rendered screen pixels, $V$ is the games rendering logic, $\mathcal A$ contains three actions: [move left, move right, no action], and $p$ is the games logic.

_Skiing_ contains ten levels of varying difficulty and length. These levels are split between those containing gates and those without. For more imformation about _Skiing_ see [the AtariAge page](https://atariage.com/manual_html_page.php?SoftwareLabelID=434).







# __Model Architecture__

In order to simulate the game environment, we use a spatiotemporal (ST)
 transformer [(Xu et al. 2021)](https://arxiv.org/pdf/2001.02908) to autoregressivly model the game's environment based on a single stream of action and image tokens. Following [(Ramesh et al., 2021)](https://arxiv.org/pdf/2102.12092), we use a Vector Quantized Variational Autoencoder (VQ-VAE) [ (Van Den Oord et al., 2017)](https://arxiv.org/pdf/1711.00937) to encode images into a discrete latent space, and a decoder-only transformer to autoregressively predict the next frame using MaskGIT [(Chang et al. 2022)](https://arxiv.org/pdf/2202.04200). We then use the decoder from the VQ-VAE to obtain a new image which is presented to the user as the next frame in the video game simulation. Our task is very similar to video generation and we implement a similar architecture to VideoGPT [(Yan et al. 2021)](https://arxiv.org/pdf/2104.10157) by using a VQ-VAE and transformer to generate sequences of images. Our model differs in that it is real-time video generation (~20fps), and because our model requires a sequence of actions which are only available throughout the generation, our model must generate video autoregressively.

We chose this approach becasue it allows us to generate long sequences of images while maintaining continuity between frames and respecting user inputs. For example, the transformer can attend to tokens from past images within it's context window. Compared to GameNGen [(Valevski et al., 2024)](https://arxiv.org/pdf/2408.14837) which ultizlies a computationaly expensive diffusion model to simulate the game environment, our approach will be able to generate more images-per-second which enables real-time interaction with the simulation. The choice of encoding images with a VQ-VAE allows the transformer to work in a lower dimensional latent space, which also reduces the computational complexity needed to generate an image. Additionaly, compared to the quadratic memory cost of a traditional transformer [(Vaswani et al. 2017)](https://arxiv.org/pdf/1706.03762), the ST-transfomer is memory efficient and balances model performace with computational constraints.

## Image Encoder (VQ-VAE)

The VQ-VAE network is an encoder-decoder network that encodes an image into a discrete latent space. This network is closely realted to a VAE network. A VAE consists of an encoder which paramterizes a posterior distribution $q(z|x)$ of the latent random variable $z$ given input $x$, a prior distribution $p(z)$, and a decoder with a distribution $p(x|z)$ over the input data. Typically, the posterior and prior distributions are normally distributed, but the VQ-VAE network uses vector quantization to extract prior and posterior distributions that are categorical. Samples are then drawn from these distributions and index an embedding table. These embeddings are then used as inputs to the decoder network.

<div>
<img src="https://drive.google.com/uc?export=view&id=1icVx2q_agOGYHNsDEDSQqZq5U5X6QPxi" width="600">
</div>

The latent embedding space $e$ is defined where $e \in R^{K \times D}$ and $K$ is the size of the discrete latent space (i.e. a $K$-way categorical distribution), and $D$ is the embedding dimension. To produce a latent encoding, the encoder takes and input $x$ and produces and output $z_e(x)$. A nearest neighbor look-up using the shared embedding space $e$ calculates the index $i$ of the embedding. An embedding table is used to obtain the embedding $e_i$ where $1 \leq i \leq K$.

The posterior categorical distribution $q(z|x)$ is defined as a one-hot encoding:

$$q(z=k|x) = \begin{cases}
1, & \text{for } k = \text{argmin}_j ||z_e(x)-e_j||_2 \\
0, & \text{otherwise}
\end{cases}$$

where $z_e(x)$ is the output of the encoder network. This output is passed through the discretisation bottleneck and it mapped to the nearest embedding. This function $z_q(x)$ is described as:

$$z_q(x) = e_k,
\text{where } k = \text{argmin}_j ||z_e(x) - e_j||_2$$

## Generative ST-Transformer Model

In our model, we use a transformer to generate new images by autorgressively sampling from the latent space $z$. Insipred by Genie, [(Brude et al. 2024)](https://arxiv.org/pdf/2402.15391) we use a ST-Transformer for this task.

<div>
<img src="https://drive.google.com/uc?export=view&id=1SJl-txgiz32S9ZTLlRPGq-AKL8U2hkpS" width="600">
</div>

Unlike a traditional transformer where every token attends to all others, an ST-transformer contains $L$ spatiotemporal blocks with interleaved spatial and temporal attention layers, followed by a feed-forward layer (FFW) as standard attention blocks. The self-attention in the spatial layer attends over the $1 \times H \times W$ tokens within each time step, and in the temporal layer attends over $T \times 1 \times 1$ tokens across the $T$ time steps. Crucially, the dominating factor of computation complexity (i.e. the spatial attention layer) in this architecture scales linearly with the number of frames rather than quadratically, making it much more efficient for video generation with consistent dynamics over extended interactions.



# __Data Collection via RL Agent__

The goal of our project is to have a simulation that human players can interact with in real-time. Becasue of this, our training data should reflect _human gameplay_. Unfortantely, there is not a large uniform dataset of huam gameplay of Atari games, so we will have to find a alternate source of data to approximate human gameplay. Similar to GameNGen, [(Valevski et al., 2024)](https://arxiv.org/pdf/2408.14837) we train a RL agent to collect sequences of gameplay data by interacting with a game environment. Unlike, a typical RL setup which attemps to maximize a game score, our goal is to generate training data. This means that the data the RL agent collects should be diverse and contain a variety of scenarios to maximize training data efficency.

Due to the simple nature of the game we chose, _Skiing_, we can use a very simple reward function to achieve this goal. The goal of the game is to reach the end of the course in the least amount of time possible, so the reward for our model is the in-game time. To build a complete dataset, we record all of the agent's training trajectories throughout the entire training process.

We use the [Gymnasium](https://gymnasium.farama.org/index.html) library for the game environment that the rl agent collects data from. This library include many Atari game environments including [Skiing](https://gymnasium.farama.org/environments/atari/skiing/). This environment is defined where the action space $\mathcal A$ contains 3 discrete actions, and the observation space $\mathcal O = [0, 255]^{H \times W \times C}$ where $H = 210, W = 160, C = 3$. As it trains, the agents collects sequences of observations, which we then use as training data for the generative model.

## RL Agent Method

# __Re: Milestone #1__

Based on feedback from Milestone #1, we have changed mutiple parts of our models architecture, but have kept the same objective for the project as a whole. In our last Milestone, we proposed to simulate the game DOOM using a diffusion model trained on data collected from the a RL agent. Due to concerns about the complexity of DOOM's game environment and the computationaly complexity of a diffusion model, we have decided to instead simulate a simpler Atari game, and have modified the architecture to fit within the computational constraints. Above, we have proposed a new architecture, that differs from GameNGen and we justify why this new architecture will be better fit for our objective. We have also chosen to train the RL agent using Deep Q Learning, rather than the Proximal Policy Optimization algorithm used in GameNGen.

# __Contribution__

## Andrew Boessen

- Researched new architectural changes from Milestone #1
- Wrote report on Project Definition, Model Architecture, VQ-VAE, ST-Transformer, and Data Collection

## Ian Bourgin

- Researched RL methods and data collection
- Create pytorch implementation of Deep-Q-Learning

## Theodore Grace

- Created pytorch implementation of VQ-VAE network

# _Dependencies_

In [None]:
!pip install umap
!pip install gymnasium
!pip install "gymnasium[atari, accept-rom-license]"



# _Imports_

In [None]:
from __future__ import print_function

%matplotlib inline
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

import numpy as np
from six.moves import xrange
import umap
from collections import deque
import random

import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from IPython import display
from IPython.display import HTML
from base64 import b64encode

# __VQ-VAE Implementation__

The pytorch implementation for the VQ-VAE network is given below

In [None]:
class VectorQuantizerEMA(nn.Module):
    '''
    Implements Exponential Moving Average (EMA) vector quantization for VQ-VAE models.
    '''
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._embedding_dim = embedding_dim # Dimension of the embedding, D
        self._num_embeddings = num_embeddings # Number of categories in distribution, K

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim), # Embedding table
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost # Constant used in loss function

        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()

        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

In [None]:
class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return x + self._block(x)


class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()

        self._conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_hiddens//2,
                                 kernel_size=4,
                                 stride=2, padding=1)

        self._conv2 = nn.Conv2d(in_channels=num_hiddens//2, out_channels=num_hiddens,
                                 kernel_size=4,
                                 stride=2, padding=1)

        self._conv3 = nn.Conv2d(in_channels=num_hiddens, out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)

        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

    def forward(self, inputs):
        x = self._conv1(inputs)
        x = F.relu(x)
        x = self._conv2(x)
        x = F.relu(x)
        x = self._conv3(x)
        return self._residual_stack(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()

        self._conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_hiddens,
                                kernel_size=3, stride=1, padding=1)
        self._residual_stack = ResidualStack(in_channels=num_hiddens, num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)
        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
                                                out_channels=num_hiddens // 2,
                                                kernel_size=4, stride=2, padding=1)
        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
                                                out_channels=3,
                                                kernel_size=4, stride=2, padding=1)

    def forward(self, inputs):
        x = self._conv1(inputs)
        x = self._residual_stack(x)
        x = self._conv_trans_1(x)
        x = F.relu(x)
        return self._conv_trans_2(x)

In [None]:
class VQVAE(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, decay):
        super(VQVAE, self).__init__()

        self._encoder = Encoder(3, num_hiddens, num_residual_layers, num_residual_hiddens)
        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, out_channels=embedding_dim,
                                      kernel_size=1, stride=1)

        self._vq = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
        self._decoder = Decoder(embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)

    def forward(self, x):
        z = self._encoder(x) # encode image to latent
        z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq(z) # quantize encoding to dicrete space
        x_recon = self._decoder(quantized) # reconstruction of input from decoder
        return loss, x_recon, perplexity

# RL Implementation (in progress)



In [None]:
# Deep Q Network
#
# This model is used to approximate the Q function

class DQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQN, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.features = nn.Sequential(
            # Extract feature vector from input observation
            # This uses multiple CNN layers to get a feature representation
            nn.Conv2d(input_shape[2], 32, kernel_size=8, stride=4), # Convert from 3 to 32 channels, downsample by 4x
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # Convert from 32 to 64 channels, downsample by 2x
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten() # Flatten 2d grid to vector
        )

        # Calculate the size of the output from the convolutional layers
        conv_out_size = self._get_conv_out_size(input_shape)

        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512), # Linear projection to feature space
            nn.ReLU(),
            nn.Linear(512, num_actions) # Lineat projection to action space
        )

    def _get_conv_out_size(self, shape):
        o = self.features(torch.zeros(1, shape[2], shape[0], shape[1]))
        return int(torch.prod(torch.tensor(o.size()))) # Get size of flattened vector

    def forward(self, x):
        x = x.permute(0, 3, 1, 2) # Permute to (batch_size, channels, height, width)
        features = self.features(x)
        return self.fc(features)

In [None]:
def optimize(self,mini_batch,policy_network,target_network):

  current_q_list = []
  target_q_list = []

  for state,action,new_state,reward,terminated in mini_batch:##optimize using replay memory mini batch

    # This if/else block is to calculate the output of the target network
    if terminated:
      target = torch.FloatTensor([reward])
    else:
      with torch.no_grad():
        target = torch.FloatTensor(reward + self.discount_factor * torch.max(target_network(new_state))) #not sure about this

    #Get the Q-values
    current_q = policy_network(state)
    current_q_list.append(current_q)

    #Get the target values
    target_q = target_network(state) #this value should be the same as current_q as networks are the same
    target_q[action] = target  #replace
    target_q_list.append(target_q)


    ##compute loss for mini batch
    loss = self.loss(torch.stack(current_q_list),torch.stack(target_q_list))

    ##optimize policy network
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()



In [None]:
# Defining memory class for experience replay
from collections import deque
import random

class ReplayMemory(): # Deque containing some past experience where experience is a tuple (state,action,new_state,reward,terminated)
  def __init__(self, maxlen):
    self.memory = deque([],maxlen)

  def append(self, experience): #add experience to deque
    self.memory.append(experience)

  def sample(self, batch_size):  #return a random sample of experience, useful to ensure non correlation as experience is time-correlated
    return random.sample(self.memory, batch_size)

  def __len__(self): #returns length of deque
    return len(self.memory)

class SkiingDQL():
  #Hyperparameter: turn this into function arg later
  lr = 0.001
  discount_factor = 0.9
  network_sync_rate = 10 #number of steps agent takes before we sync the policy and target networks
  replay_memeory_size = 1000
  mini_batch_size = 32 #refers to size of batch sampled from replay memory

  loss = nn.MSELoss()
  optimizer = torch.optim.Adam

  ACTIONS = ["Left","Right","No_Action"]

  def train(self, episodes):
    env = gym.make('ALE/Skiing-v5', render_mode="rgb_array")
    num_states = env.observation_space.n
    num_actions =3
    input_space = env.observation_space.shape

    epsilon = 1 # fully random actions
    memory = ReplayMemory(self.replay_memeory_size) #creating memory of experience

    #Create the two networks
    policy_network = DQN(input_space, num_actions)
    target_network = DQN(input_space, num_actions)

    target_network.load_state_dict(policy_network.state_dict()) #Copying the weights so that the networks are the same

    # Keep track of rewards collected per episode --> Only sync networks if you have a reward, otherwise no learning will occur
    rewards_per_episode = np.zeros(episodes)

    #Keep track of epsilon decay  #The point of e-greedy algorithm is to choose the best action but also allow for exploration
    epsilon_decays = []

    #track number of steps taken to know when to sync networks
    step_count = 0

    for i in range(episodes):
      state =env.reset()
      terminated = False  #True when agent wins i.e goes through 20 poles
      truncated = False #True when agent takes more than 200 actions (prevent infinite loop, unlikely)

    #Picking action based on greedy algo -- Starts with full exploration --> choosing best action towards the end
      if random.random() < epsilon:
        action = env.action_space.sample()  #random action
      else:
        with torch.no_grad():
          action = policy_network(state).argmax().item() #best action

      #Perform the action
      new_state,reward,terminated,truncated,_ = env.step(action)

      # Save to Memory
      memory.append((state,action,new_state,reward,terminated))

      #Move to new state
      state = new_state

      # Increment step
      step_count += 1

      #Adding reward structure
      if reward == 1:
        rewards_per_episode[i]=1

      #If enough experience has been collectd, and at least 1 reward has been collected, train

      if len(memory)>self.mini_batch_size and np.sum(rewards_per_episode)>0:
        mini_batch = memory.sample(self.mini_batch_size)
        self.optimize(mini_batch,policy_network,target_network)  ##Define Optimize function

        #epsilon decay
        epsilon = max(epsilon-1/episodes,0)
        epsilon_decays.append(epsilon)

        if step_count > self.network_sync_rate:
          target_network.load_state_dict(policy_network.state_dict())
          step_count = 0

