In [19]:
import torch
import torch.nn as nn
import numpy as np
import torchsummary

import tml

class AEBase(nn.Module):
    
    def __init__(self, input_shape=(3,84,84)):
        super().__init__()
        self.input_shape = input_shape
        self.encoder = self._get_encoder()
        self.decoder = self._get_decoder()
        self.latent_shape = tuple(self.encoder(torch.zeros(2,*input_shape)).shape)[1:]
        
    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y

    def _get_encoder(self):
        raise NotImplementedError()

    def _get_decoder(self):
        raise NotImplementedError()
         
class AEConv(AEBase):
    
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(16,32,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(32,64,kernel_size=5,stride=2),  nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size=5,stride=1),
        ) 
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(64,32,kernel_size=5,stride=2,output_padding=1),  nn.LeakyReLU(),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2,output_padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(16,3,kernel_size=5,stride=2,output_padding=1),
        )
    
class AEAlex(AEBase):
    
    def __init__(self, input_shape=(3,84,84), latent_shape=(3*3*128,), dropout=0.5):
        self.dropout = dropout
        self.latent_shape = latent_shape
        super().__init__(input_shape=input_shape)
       
        
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(16,32,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(32,64,kernel_size=5,stride=2),  nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size=5,stride=1), nn.LeakyReLU(),
            tml.View(3*3*128, -1),
            nn.Linear(3*3*128, self.latent_shape[0]), 
            nn.Dropout(self.dropout) if self.dropout > 0 else nn.Identity(),
        )
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.Linear(self.latent_shape[0], 3*3*128),
            tml.View(-1, (128,3,3)),
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(64,32,kernel_size=5,stride=2,output_padding=1),  nn.LeakyReLU(),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2,output_padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(16,3,kernel_size=5,stride=2,output_padding=1),
        )
    
class AEGDN(AEBase):
    
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=5,stride=2), tml.GDN(16),
            nn.Conv2d(16,32,kernel_size=5,stride=2), tml.GDN(32),
            nn.Conv2d(32,64,kernel_size=5,stride=2),  tml.GDN(64),
            nn.Conv2d(64,128,kernel_size=5,stride=1), tml.GDN(128),
        ) 
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=1), tml.GDN(64),
            nn.ConvTranspose2d(64,32,kernel_size=5,stride=2,output_padding=1),  tml.GDN(32),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2,output_padding=1), tml.GDN(16),
            nn.ConvTranspose2d(16,3,kernel_size=5,stride=2,output_padding=1),
        )
    
model = AEConv()
model = AEAlex()
#model = AEGDN()


torchsummary.summary(model, input_size=(3,84,84), device="cpu")
print(model.latent_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 40, 40]           1,216
         LeakyReLU-2           [-1, 16, 40, 40]               0
            Conv2d-3           [-1, 32, 18, 18]          12,832
         LeakyReLU-4           [-1, 32, 18, 18]               0
            Conv2d-5             [-1, 64, 7, 7]          51,264
         LeakyReLU-6             [-1, 64, 7, 7]               0
            Conv2d-7            [-1, 128, 3, 3]         204,928
         LeakyReLU-8            [-1, 128, 3, 3]               0
              View-9                 [-1, 1152]               0
           Linear-10                 [-1, 1152]       1,328,256
          Dropout-11                 [-1, 1152]               0
           Linear-12                 [-1, 1152]       1,328,256
             View-13            [-1, 128, 3, 3]               0
  ConvTranspose2d-14             [-1, 6