In [63]:
import torch
import torch.nn as nn
import numpy as np
import torchsummary

import tml

class AEBase(nn.Module):
    
    def __init__(self, input_shape=(3,84,84)):
        super().__init__()
        self.input_shape = input_shape
        self.encoder = self._get_encoder()
        self.decoder = self._get_decoder()
        self.latent_shape = tuple(self.encoder(torch.zeros(2,*input_shape)).shape)[1:]
        
    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y

    def _get_encoder(self):
        raise NotImplementedError()

    def _get_decoder(self):
        raise NotImplementedError()
   
    
class AEAlex(AEBase):
    
    def __init__(self, input_shape=(3,84,84), latent_shape=(2*2*128,), dropout=0.5):
        self.dropout = dropout
        self.latent_shape = latent_shape
        super().__init__(input_shape=input_shape)
       
        
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=6,stride=1), nn.LeakyReLU(),
            nn.Conv2d(16,32,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(32,64,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=1),  
            tml.View(2*2*128, -1),
            nn.Linear(2*2*128, self.latent_shape[0]), 
            nn.Dropout(self.dropout) if self.dropout > 0 else nn.Identity(),
        )
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.Linear(self.latent_shape[0], 2*2*128),
            tml.View(-1, (128,2,2)),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(64,32,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(16,3,kernel_size=6, stride=1),
        )
    
class AEGDN(AEBase):
    
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=6,stride=1), tml.GDN(16),
            nn.Conv2d(16,32,kernel_size=5,stride=2), tml.GDN(32),
            nn.Conv2d(32,64,kernel_size=6,stride=1),  tml.GDN(64),
            nn.Conv2d(64,128,kernel_size=5,stride=2), tml.GDN(128),
            nn.Conv2d(128,128,kernel_size=5,stride=2),tml.GDN(128),
            nn.Conv2d(128,128,kernel_size=5,stride=1), 
        ) 
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=1), tml.GDN(128),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=2), tml.GDN(128),
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=2),  tml.GDN(64),
            nn.ConvTranspose2d(64,32,kernel_size=6,stride=1),  tml.GDN(32),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2),  tml.GDN(16),
            nn.ConvTranspose2d(16,3,kernel_size=6, stride=1),
        )
    
      
class AEConv(AEBase):
    
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=6,stride=1), nn.LeakyReLU(),
            nn.Conv2d(16,32,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(32,64,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=1), 
        ) 
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(64,32,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(16,3,kernel_size=6, stride=1),
        )
    
model = AEConv()
model = AEAlex()
model = AEGDN()


torchsummary.summary(model, input_size=(3,84,84), device="cpu")
print(model.latent_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 79, 79]           1,744
               GDN-2           [-1, 16, 79, 79]               0
            Conv2d-3           [-1, 32, 38, 38]          12,832
               GDN-4           [-1, 32, 38, 38]               0
            Conv2d-5           [-1, 64, 33, 33]          73,792
               GDN-6           [-1, 64, 33, 33]               0
            Conv2d-7          [-1, 128, 15, 15]         204,928
               GDN-8          [-1, 128, 15, 15]               0
            Conv2d-9            [-1, 128, 6, 6]         409,728
              GDN-10            [-1, 128, 6, 6]               0
           Conv2d-11            [-1, 128, 2, 2]         409,728
  ConvTranspose2d-12            [-1, 128, 6, 6]         409,728
              GDN-13            [-1, 128, 6, 6]               0
  ConvTranspose2d-14          [-1, 128,