In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=1024, z_dim=32):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x).view(x.size(0), -1)
        mu, log_var = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, log_var)
        z = self.fc3(z).view(x.size(0), 256, 1, 1)
        return self.decoder(z), mu, log_var


class MDNRNN(nn.Module):
    def __init__(self, z_dim, a_dim, h_dim, n_gaussians=5):
        super(MDNRNN, self).__init__()
        self.rnn = nn.LSTM(z_dim + a_dim, h_dim, batch_first=True)
        self.fc = nn.Linear(h_dim, n_gaussians * z_dim * 3)

    def forward(self, z, a, h):
        x = torch.cat([z, a], dim=-1).unsqueeze(1)
        h, _ = self.rnn(x, h)
        h = h.squeeze(1)
        y = self.fc(h)
        return y, h

    def init_hidden(self, batch_size, h_dim):
        return (torch.zeros(1, batch_size, h_dim),
                torch.zeros(1, batch_size, h_dim))


class Controller(nn.Module):
    def __init__(self, z_dim, h_dim, a_dim):
        super(Controller, self).__init__()
        self.fc1 = nn.Linear(z_dim + h_dim, a_dim)

    def forward(self, z, h):
        x = torch.cat([z, h], dim=-1)
        return self.fc1(x)


class GATOAgent(nn.Module):
    def __init__(self, image_channels=3, z_dim=32, a_dim=3, h_dim=256, n_gaussians=5):
        super(GATOAgent, self).__init__()
        self.vae = VAE(image_channels, h_dim, z_dim)
        self.mdnrnn = MDNRNN(z_dim, a_dim, h_dim, n_gaussians)
        self.controller = Controller(z_dim, h_dim, a_dim)

    def forward(self, image, action, h):
        z, mu, log_var = self.vae(image)
        mdn_output, h = self.mdnrnn(z, action, h)
        action = self.controller(z, h)
        return action, h, mu, log_var, mdn_output

    def init_hidden(self, batch_size):
        return self.mdnrnn.init_hidden(batch_size)
