In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
x = torch.zeros(1, 2, 64)

In [30]:
class VAE(nn.Module):
    def __init__(self, n_params=32):
        super(VAE, self).__init__()

        # Encoder
        self.conv1 = nn.Conv1d(2, 3, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(3, 32, kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(1024, 128)

        # Latent space
        self.fc21 = nn.Linear(128, 2)
        self.fc22 = nn.Linear(128, 2)

        # Decoder
        self.fc3 = nn.Linear(2, 128)
        self.fc4 = nn.Linear(128, 1024)
        self.deconv1 = nn.ConvTranspose1d(32, 32, kernel_size=3, stride=1, padding=1)
        self.deconv2 = nn.ConvTranspose1d(32, 32, kernel_size=3, stride=1, padding=1)
        self.deconv3 = nn.ConvTranspose1d(32, 32, kernel_size=2, stride=2, padding=0)
        self.conv5 = nn.Conv1d(32, 2, kernel_size=3, stride=1, padding=1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.relu(self.conv3(out))
        out = self.relu(self.conv4(out))
        print(out.size())
        out = out.view(out.size(0), -1)
        print(out.size())
        h1 = self.relu(self.fc1(out))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = self.relu(self.fc3(z))
        print(h3.size())
        out = self.relu(self.fc4(h3))
        print(out.size())
        # import pdb; pdb.set_trace()
        out = out.view(out.size(0), 32, 32)
        out = self.relu(self.deconv1(out))
        out = self.relu(self.deconv2(out))
        out = self.relu(self.deconv3(out))
        out = self.sigmoid(self.conv5(out))
        return out

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [31]:
model = VAE()
x_, mu, logvar = model(x)
print(x_.size(), mu.size(), logvar.size())

torch.Size([1, 32, 32])
torch.Size([1, 1024])
torch.Size([1, 128])
torch.Size([1, 1024])
torch.Size([1, 2, 64]) torch.Size([1, 2]) torch.Size([1, 2])
