In [15]:
from einops import rearrange
import numpy as np
import gymnasium as gym
import ale_py
import torch
from torch import nn

from math import prod

In [16]:
env = gym.make("ALE/Breakout-v5")
env = gym.wrappers.TransformObservation(env, lambda x: rearrange(x, "h w c -> c h w"))

In [17]:
ob_shape, n_actions = env.observation_space.shape, env.action_space.n
ob_shape = (ob_shape[2], ob_shape[0], ob_shape[1])
ob_shape, n_actions

((3, 210, 160), 4)

In [18]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, first_channels, second_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, first_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(first_channels, second_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(second_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

    def forward(self, x):
        return self.encoder(x)

In [19]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, first_channels, second_channels, out_channels, out_size=None):
        super().__init__()
        if out_size:
            upsampler = nn.UpsamplingBilinear2d(size=out_size)
        else:
            upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
        self.decoder = nn.Sequential(
            upsampler,
            nn.Conv2d(in_channels, first_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(first_channels, second_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(second_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.decoder(x)

In [26]:
class Encoder(nn.Module):
    def __init__(self, ob_shape):
        super().__init__()
        c, h, w = ob_shape
        self.encoder = nn.Sequential(
            EncoderBlock(c, 16, 32, 32),
            EncoderBlock(32, 48, 48, 64),
            EncoderBlock(64, 128, 128, 64),
            EncoderBlock(64, 48, 32, 32),
        )

        dummy_in = torch.zeros((1,) + tuple(ob_shape))
        self.out_shape = self.encoder(dummy_in).shape[1:]

    def forward(self, x):
        return self.encoder(x)

In [27]:
enc = Encoder(ob_shape)
enc.out_shape

torch.Size([32, 13, 10])

In [28]:
class Decoder(nn.Module):
    def __init__(self, ob_shape):
        super().__init__()
        c, h, w = ob_shape
        self.decoder = nn.Sequential(
            DecoderBlock(32, 32, 48, 64),
            DecoderBlock(64, 32, 16, c, out_size=[h // 2, w // 2]),
        )

    def forward(self, x):
        return self.decoder(x)

In [23]:
class PolicyNetwork(nn.Module):
    def __init__(self, in_shape, n_hiddens, n_actions):
        super().__init__()
        self.input = nn.Sequential(
            nn.Flatten(),
            nn.Linear(prod(in_shape), n_hiddens),
        )
        self.step = nn.Linear(n_hiddens, n_hiddens)
        self.out = nn.Linear(n_hiddens, n_actions)
        
    def forward(self, x, state):
        state = self.step(state) + self.input(x)
        out = self.out(state)

        return out, state