In [None]:
class VAE(nn.Module):
    
    def __init__(self, encoded_space_dim, lambda_):
        super().__init__()
        self.encoded_space_dim = encoded_space_dim
        
        ### Encoder
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=1, padding=0),
            nn.ReLU(True)
        )
        self.encoder_lin = nn.Sequential(
            nn.Linear(5 * 5 * 32, 150),
            nn.ReLU(True)
        )
        
        self.z_mean = nn.Linear(150, encoded_space_dim)
        self.z_var = nn.Linear(150, encoded_space_dim)
        self.lambda_ = lambda_
        
        ### Decoder
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 150),
            nn.ReLU(True),
            nn.Linear(150, 5 * 5 * 32),
            nn.ReLU(True)
        )
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=1, output_padding=0),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)
        )


    def forward(self, x):
        x = self.encode(x)
        x = self.decode(x)
        return x

    def encode(self, x):
        # Apply convolutions
        x = self.encoder_cnn(x)
        # Flatten
        x = x.view([x.size(0), -1])
        # Apply linear layers
        x = self.encoder_lin(x)
        z_mean = self.z_mean(x)
        z_var = self.z_var(x)
        if z_mean.is_cuda:
            z = z_mean + torch.exp(z_var)*torch.cuda.FloatTensor(z_mean.shape).normal_()*self.lambda_
        else:
            z = z_mean + torch.exp(z_var)*torch.randn(z_mean.shape)*self.lambda_
        return z
        
    def decode(self, x):
        # Apply linear layers
        x = self.decoder_lin(x)
        # Reshape
        x = x.view([-1, 32, 5, 5])
        # Apply transposed convolutions
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x