In [144]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
LATENT_DIM = 2

In [145]:
# DCGAN pytorch tutorial
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom deconvolutional layer for simplicity."""
    layers = []
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)    

class Decoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, image_size=28, conv_dim=32):
        super(Decoder, self).__init__()
        self.fc = deconv(latent_dim, conv_dim*8, 2, stride=1, pad=0, bn=False)
        self.deconv1 = deconv(conv_dim*8, conv_dim*4, 4)
        self.deconv2 = deconv(conv_dim*4, conv_dim*2, 3) # hacky to change kernel size
        self.deconv3 = deconv(conv_dim*2, conv_dim, 4)
        self.deconv4 = deconv(conv_dim, 1, 4, bn=False)
        
    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        print(z.size())
        out = self.fc(z) # (?, 256, 2, 2)
        print(out.size())
        out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 128, 4, 4)
        print(out.size())
        out = F.leaky_relu(self.deconv2(out), 0.05) # (?, 64, 7, 7)
        print(out.size())
        out = F.leaky_relu(self.deconv3(out), 0.05) # (?, 32, 14, 14)
        print(out.size())
        out = F.tanh(self.deconv4(out)) # (?, 1, 28, 28)
        return out


In [148]:
# CNN pytorch tutorial
# but the tutorial doesn't have an inverse-CNN, so this is totally BS!!
class Decoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super(Decoder, self).__init__()
        self.fc = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 32, 7, stride=1, padding=0),
            nn.BatchNorm2d(32))
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU())
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1),
            nn.ReLU())
        
    def forward(self, x):
        x = x.view(x.size(0), x.size(1), 1, 1)
        print(x.size())
        out = self.fc(x)
        print(out.size())
        out = self.layer1(out)
        print(out.size())
        out = self.layer2(out)
        print(out.size())
        return out

In [149]:
decoder = Decoder()
decoder.cuda()

Decoder(
  (fc): Sequential(
    (0): ConvTranspose2d(2, 32, kernel_size=(7, 7), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer1): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
  )
  (layer2): Sequential(
    (0): ConvTranspose2d(16, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
  )
)

In [150]:
x = decoder(Variable(torch.rand(10,LATENT_DIM)).cuda())
x.size()

torch.Size([10, 2, 1, 1])
torch.Size([10, 32, 7, 7])
torch.Size([10, 16, 14, 14])
torch.Size([10, 1, 28, 28])


torch.Size([10, 1, 28, 28])

In [156]:
# DCGAN pytorch tutorial
def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom convolutional layer for simplicity."""
    layers = []
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

class Encoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, image_size=28, conv_dim=32):
        super(Encoder, self).__init__()
        self.conv1 = conv(1, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        self.conv3 = conv(conv_dim*2, conv_dim*4, 4)
        self.conv4 = conv(conv_dim*4, conv_dim*8, 4)
        # self.fc = conv(conv_dim*8, 1, int(image_size/16), 1, 0, False)
        self.linear1 = nn.Linear(conv_dim*8, latent_dim)
        self.linear2 = nn.Linear(conv_dim*8, latent_dim)
        
    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05) # (?, 32, 14, 14)
        print(out.size())
        out = F.leaky_relu(self.conv2(out), 0.05) # (?, 64, 7, 7)
        print(out.size())
        out = F.leaky_relu(self.conv3(out), 0.05) # (?, 128, 3, 3)
        print(out.size())
        out = F.leaky_relu(self.conv4(out), 0.05) # (?, 256, 1, 1)
        print(out.size())
        out = out.squeeze()
        mean, logvar = self.linear1(out), self.linear2(out)
        return mean, logvar

In [155]:
# CNN pytorch tutorial
class Encoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super(Encoder, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.linear1 = nn.Linear(7*7*32, latent_dim)
        self.linear2 = nn.Linear(7*7*32, latent_dim)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        mu, logvar = self.linear1(out), self.linear2(out)
        return mu, logvar

In [157]:
encoder = Encoder()
encoder.cuda()

Encoder(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv4): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (linear1): Linear(in_features=256, out_features=2, bias=True)
  (linear2): Linear(in_features=256, out_features=2, bias=True)
)

In [158]:
mu, logvar = encoder(Variable(torch.rand(10,1,28,28)).cuda())
mu.size(), logvar.size() # should be bs,LATENT_DIM

torch.Size([10, 32, 14, 14])
torch.Size([10, 64, 7, 7])
torch.Size([10, 128, 3, 3])
torch.Size([10, 256, 1, 1])


(torch.Size([10, 2]), torch.Size([10, 2]))