In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

In [2]:
class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x, layerName = 'default'):
        # Do your print / debug stuff here
        print(layerName+': {}'.format(x.shape))
        return x

In [3]:
nc = 4

class CVAE(nn.Module):
    def __init__(self, nz, nBatch, nChannels, W, H, _channel):
        super(CVAE, self).__init__()

        self.nz = nz
        self.nBatch = nBatch
        self.nChannels = nChannels
        self.W = W
        self.H = H
        self._channel = _channel

        self.encoder = nn.Sequential(
            
#             [input -> (1) -> (2) -> (3) -> output]
            nn.Conv3d(nChannels, 16, (3, 3, 3), stride=(1, 1, 1), padding=(1,1,1),bias=False),
            PrintLayer(),
            nn.BatchNorm3d(16),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.decoder = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose3d(16, 8, (3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False),
            nn.BatchNorm3d(8),
            nn.Sigmoid()
        )

        self.fc1 = nn.Linear(16*nBatch*W*H*_channel, nChannels)
        self.fc21 = nn.Linear(8, nz)
        self.fc22 = nn.Linear(8, nz)

        self.fc3 = nn.Linear(nz, 8)
        self.fc4 = nn.Linear(8, 16*nBatch*W*H*_channel)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()
        # self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        conv = self.encoder(x);
        print("encode conv", conv.size())
        h1 = self.fc1(conv.view(-1, 16*self.nBatch*self.W*self.H*self._channel))
        print("encode h1", h1.size())
        return self.fc21(h1), self.fc22(h1)

    def decode(self, z):
        h3 = self.relu(self.fc3(z))
        deconv_input = self.fc4(h3)
        print("deconv_input", deconv_input.size())
        deconv_input = deconv_input.view(self.nBatch,16,self.W,self.H,self._channel)
        print("deconv_input", deconv_input.size())
        return self.decoder(deconv_input)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
#         if self.have_cuda:
#         eps = torch.cuda.FloatTensor(std.size()).normal_()
#         else:
        eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def forward(self, x):
        print("x", x.size())
        mu, logvar = self.encode(x)
        print("mu, logvar", mu.size(), logvar.size())
        z = self.reparametrize(mu, logvar)
        print("z", z.size())
        decoded = self.decode(z)
        print("decoded", decoded.size())
        return decoded, mu, logvar
      
# build model
model = CVAE(4, 3, 8, 240, 240, 1)


In [4]:
input = torch.randn(3, 8, 240, 240, 1)
output,_,_ = model(input)
model.modules[2].output:size()
print(output.shape)

x torch.Size([3, 8, 240, 240, 1])
default: torch.Size([3, 16, 240, 240, 1])
encode conv torch.Size([3, 16, 240, 240, 1])
encode h1 torch.Size([1, 8])
mu, logvar torch.Size([1, 4]) torch.Size([1, 4])
z torch.Size([1, 4])
deconv_input torch.Size([1, 2764800])
deconv_input torch.Size([3, 16, 240, 240, 1])
decoded torch.Size([3, 8, 240, 240, 1])


TypeError: 'method' object is not subscriptable