In [1]:
from einops import rearrange
import numpy as np
import gymnasium as gym
import ale_py
import torch
from torch import nn, optim

from math import prod

  File "/home/fitti/.conda/envs/puffer/lib/python3.11/site-packages/gymnasium/envs/registration.py", line 594, in load_plugin_envs
    fn()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/site-packages/shimmy/registration.py", line 304, in register_gymnasium_envs
    _register_atari_envs()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/site-packages/shimmy/registration.py", line 205, in _register_atari_envs
    import ale_py
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/site-packages/ale_py/__init__.py", line 66, in <module>
    register_v0_v4_envs()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/site-packages/ale_py/registration.py", line 176, in register_v0_v4_envs
    _register_rom_configs(legacy_games, obs_types, versions)
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/site-packages/ale_py/registration.py", line 62, in _register_rom_configs
    gymnasium.register(
    ^^^^^^^^^^^^^^^^^^
AttributeError: partially initialized module 'gymnasium' has 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
env = gym.make("ALE/Breakout-v5")
env = gym.wrappers.TransformObservation(env, lambda x: rearrange(x, "h w c -> c h w"))

A.L.E: Arcade Learning Environment (version 0.9.0+750d7f9)
[Powered by Stella]


In [4]:
ob_shape, n_actions = env.observation_space.shape, env.action_space.n
ob_shape = (ob_shape[2], ob_shape[0], ob_shape[1])
ob_shape, n_actions

((3, 210, 160), 4)

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, first_channels, second_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, first_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(first_channels, second_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(second_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        return self.encoder(x)

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, first_channels, second_channels, out_channels, out_size=None):
        super().__init__()
        if out_size:
            upsampler = nn.UpsamplingBilinear2d(size=out_size)
        else:
            upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
        self.decoder = nn.Sequential(
            upsampler,
            nn.Conv2d(in_channels, first_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(first_channels, second_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(second_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.decoder(x)

In [7]:
class Encoder(nn.Module):
    def __init__(self, ob_shape):
        super().__init__()
        c, h, w = ob_shape
        self.encoder = nn.Sequential(
            EncoderBlock(c, 16, 32, 32),
            EncoderBlock(32, 48, 48, 64),
            EncoderBlock(64, 128, 128, 64),
            EncoderBlock(64, 48, 32, 32),
        )

        dummy_in = torch.zeros((1,) + tuple(ob_shape))
        self.out_shape = self.encoder(dummy_in).shape[1:]

    def forward(self, x):
        return self.encoder(x)

In [8]:
enc = Encoder(ob_shape)
enc.out_shape

torch.Size([32, 13, 10])

In [9]:
class Decoder(nn.Module):
    def __init__(self, ob_shape):
        super().__init__()
        c, h, w = ob_shape
        self.decoder = nn.Sequential(
            DecoderBlock(32, 32, 48, 64),
            DecoderBlock(64, 32, 16, c, out_size=[h, w]),
        )

    def forward(self, x):
        return self.decoder(x)

In [10]:
class RecurrentNetwork(nn.Module):
    def __init__(self, in_shape, n_hiddens):
        super().__init__()
        self.n_hiddens = n_hiddens
        self.step = nn.Linear(prod(in_shape) + n_hiddens, n_hiddens)

    def init_memory(self):
        return torch.zeros(self.n_hiddens, dtype=torch.float32).unsqueeze(0)

    def forward(self, state, memory):
        x = torch.cat((state.flatten(1), memory), dim=1)
        return self.step(x)

In [11]:
class PredictionNetwork(nn.Module):
    def __init__(self, n_hiddens, out_shape):
        super().__init__()
        self.input = nn.Linear(1, prod(out_shape))
        self.step = nn.Linear(n_hiddens, prod(out_shape))

    def forward(self, memory, action):
        return self.step(memory) + self.input(action)

In [12]:
class PolicyNetwork(nn.Module):
    def __init__(self, n_hiddens, n_actions):
        super().__init__()
        self.head = nn.Linear(n_hiddens, n_actions)

    def forward(self, x):
        return self.head(x)

In [13]:
class ValueNetwork(nn.Module):
    def __init__(self, n_hiddens):
        super().__init__()
        self.head = nn.Linear(n_hiddens, 1)

    def forward(self, x):
        return self.head(x)

In [14]:
import random

In [15]:
def select_action(logits):
    dist = torch.distributions.Categorical(logits=logits) # shape [batch_size, n_actions]
    
    action = dist.sample() # Shape: [n_actions]
    log_prob = dist.log_prob(action)

    return action, log_prob

In [16]:
n_hiddens = 128
gamma = 0.99
enc_lr = 6e-5
dec_lr = 5e-5
rec_lr = 4e-5
prd_lr = 3e-5
val_lr = 2e-5
pol_lr = 1e-5

enc_net = Encoder(ob_shape).to(device=device)
dec_net = Decoder(ob_shape).to(device=device)
rec_net = RecurrentNetwork(enc_net.out_shape, n_hiddens).to(device=device)
prd_net = PredictionNetwork(n_hiddens, enc_net.out_shape).to(device=device)
pol_net = PolicyNetwork(n_hiddens, n_actions).to(device=device)
val_net = ValueNetwork(n_hiddens).to(device=device)

In [17]:
mse_loss_fn = nn.MSELoss()

enc_optimizer = optim.Adam(enc_net.parameters(), lr=enc_lr)
dec_optimizer = optim.Adam(dec_net.parameters(), lr=dec_lr)
rec_optimizer = optim.Adam(rec_net.parameters(), lr=rec_lr)
prd_optimizer = optim.Adam(prd_net.parameters(), lr=prd_lr)
val_optimizer = optim.Adam(val_net.parameters(), lr=val_lr)
pol_optimizer = optim.Adam(pol_net.parameters(), lr=pol_lr)

In [18]:
ob, _ = env.reset()
ob = torch.from_numpy(ob).to(device=device, dtype=torch.float32, non_blocking=True).unsqueeze(0)
memory = rec_net.init_memory().to(device=device, non_blocking=True)

ret = 0
best_ret = 0
episodes = 0
while True:
    state = enc_net(ob)
    memory = rec_net(state.detach(), memory.detach())
    
    decoded = dec_net(state)
    value = val_net(memory)
    logits = pol_net(memory)
    action, log_prob = select_action(logits)
    pred_state = prd_net(memory.detach(), action.detach().to(dtype=torch.float32))
    
    next_ob, reward, done, truncated, _ = env.step(action)
    
    if done or truncated:
        next_ob, _ = env.reset()
        
    next_ob = torch.from_numpy(next_ob).to(device=device, dtype=torch.float32, non_blocking=True).unsqueeze(0)
    ret += reward

    if done or truncated:
        episodes += 1
        if ret > best_ret:
            best_ret = ret
        print("Episode", episodes, "Best return:", best_ret, end="\t\t\r")
        ret = 0

    with torch.no_grad():
        next_state = enc_net(next_ob)
        next_memory = rec_net(next_state, memory)
        next_value = val_net(next_memory)
        
    surprise = mse_loss_fn(next_state.flatten(1), pred_state)
    td_target = (reward + surprise.detach()) + gamma * next_value * ~(done | truncated)
    
    with torch.no_grad():
        advantage = (td_target - value)
    
    dec_loss = mse_loss_fn(next_ob, decoded)
    prd_loss = surprise
    val_loss = mse_loss_fn(td_target, value)
    pol_loss = -log_prob * advantage

    loss = dec_loss + 0.5 * prd_loss + 0.25 * val_loss + 0.125 * pol_loss

    enc_optimizer.zero_grad()
    dec_optimizer.zero_grad()
    rec_optimizer.zero_grad()
    prd_optimizer.zero_grad()
    val_optimizer.zero_grad()
    pol_optimizer.zero_grad()
    loss.backward()
    enc_optimizer.step()
    dec_optimizer.step()
    rec_optimizer.step()
    prd_optimizer.step()
    val_optimizer.step()
    pol_optimizer.step()

    ob = next_ob

Episode 2 Best return: 7.0		

ValueError: Expected parameter logits (Tensor of shape (1, 4)) of distribution Categorical(logits: torch.Size([1, 4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[nan, nan, nan, nan]], device='cuda:0', grad_fn=<SubBackward0>)