Code taken from: https://github.com/pytorch/examples/tree/master/vae

In [1278]:
from __future__ import print_function
import argparse
import os
import h5py
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
%matplotlib inline

In [1279]:
# modify to accept hard coded arguments
batch_size = 1
epochs = 20
no_cuda = False
log_interval = 1

cuda = not no_cuda and torch.cuda.is_available()

seed = 1
torch.manual_seed(seed)


# device = torch.device("cuda" if args.cuda else "cpu")
device = torch.device("cuda" if cuda else "cpu")

# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [1280]:
class HydrogenDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, h5_file, root_dir):
        """
        Args:
            h5_file (string): name of the h5 file with 32 sampled cubes.
            root_dir (string): Directory with the .h5 file.
        """
        file_size = os.path.getsize(root_dir + h5_file) / 1e6 # in MBs
        print("The file size is " + str(int(file_size)) + " MBs")
        
        # self.subcubes = h5py.File('../data/sample_32.h5', 'r')
        self.subcubes = h5py.File(root_dir + h5_file, 'r')['sample32']
        self.h5_file = h5_file
        self.root_dir = root_dir

    def __len__(self):
        # Function called when len(self) is executed
        
        #print(len(self.subcubes))
        return len(self.subcubes)

    def __getitem__(self, idx):
        """
        This can be implemented in such a way that the whole h5 file read 
        using h5py.File() and get_sample() function is called to return
        a random subcube. This won't increase memory usage because the
        subcubes will be read in the same way and only the batch will
        be read into memory.
        
        Here we have implemented it so that it can be used with data
        generated by get_sample() function.
        
        The output of this function is one subcube with the dimensions
        specified by get_sample() implementation.
        """
        
        # default version -> error in training because of dimensions
        #sample = self.subcubes[idx]
        
        # reshaped version to add another dimension
        sample = self.subcubes[idx].reshape((1,128,128,128))

        return sample

In [1281]:
sampled_subcubes = HydrogenDataset(h5_file="sample_32.h5",
                                    root_dir = "../data/")
sampled_subcubes

The file size is 268 MBs


<__main__.HydrogenDataset at 0x11b8b16a0>

In [None]:
# Data Loaders
train_loader = DataLoader(
        dataset=sampled_subcubes,
        #batch_size=args.batch_size, 
        batch_size=batch_size,
        shuffle=True, 
        **kwargs)

test_loader = DataLoader(
        dataset=sampled_subcubes,
        #batch_size=args.batch_size, 
        batch_size=batch_size,
        shuffle=True, 
        **kwargs)

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        """
        The Encoding Layers
        nn.Conv3d 
        nn.MaxPool3d 
        
        out_channels is the number of different filters we convolute 
        over the whole sampled subcube.
        
        So the first convolutional layer's in_channel should be 0 (?)
        
        In addition, the next layer's in_channel should be equal to
        the previous layer's out_channels (all examples show that
        this is the case)
        """
        
        # Convolutional Layer 1
        self.encode_conv1 = nn.Conv3d(in_channels=1, 
                                      out_channels=8, 
                                      kernel_size=(4,4,4), # == 4
                                      stride = (2,2,2), # == 2
                                      padding=(1,1,1)) # == 1
        nn.init.xavier_uniform_(self.encode_conv1.weight) #Xaviers Initialisation
        
        self.encode_relu1 = nn.ReLU()
        self.encode_maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), 
                                             stride=(2, 2, 2),
                                            return_indices = True)
        
        # Convolutional Layer 2
        self.encode_conv2 = nn.Conv3d(in_channels=8, 
                                      out_channels=16, 
                                      kernel_size=(4,4,4), # == 4 
                                      stride = (2,2,2),
                                      padding=(1,1,1))
        nn.init.xavier_uniform_(self.encode_conv2.weight) #Xaviers Initialisation
        
        self.encode_relu2 = nn.ReLU()
        self.encode_maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), 
                                             stride=(2, 2, 2),
                                            return_indices = True)

        # Convolutional Layer 3
        self.encode_conv3 = nn.Conv3d(in_channels=16, 
                                      out_channels=32, 
                                      kernel_size=(4,4,4), # == 4 
                                      stride = (2,2,2),
                                      padding=(1,1,1))
        nn.init.xavier_uniform_(self.encode_conv3.weight) #Xaviers Initialisation
        
        self.encode_relu3 = nn.ReLU()
#         self.encode_maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), 
#                                              stride=(2, 2, 2),
#                                             return_indices = True)

        
        
        """
        Fully Connected Layers after 3D Convolutional Layers
        First FC layer's input should be equal to 
        last convolutional layer's output
        8192 = 8^3 * 16 
            8^3 = (output of 2nd convolutional layer)
            16 = number of out_channels
        """
        
        self.encode_fc1 = nn.Sequential(
            nn.Linear(in_features=2048, 
                      out_features=5096), 
            nn.ReLU(),
            nn.Dropout(0.5))
        #init.xavier_normal(self.fc1.state_dict()['weight'])
        
        self.encode_fc2 = nn.Sequential(
            nn.Linear(in_features = 5096,
                      out_features = 5096),
            nn.ReLU(),
            nn.Dropout(0.5))
        #init.xavier_normal(self.fc2.state_dict()['weight'])
        
        """
        The last fully connected layer's output is the dimensions
        of the embeddings?
        """
        self.encode_fc31 = nn.Sequential(
            nn.Linear(in_features=5096,
                      out_features=32))
        self.encode_fc32 = nn.Sequential(
            nn.Linear(in_features=5096,
                      out_features=32))
        
        
        
        """
        The Decoding Layers
        nn.Conv3d -> nn.ConvTranspose3d
        nn.MaxPool3d -> nn.MaxUnpool3d
        """
        
        self.decode_fc1 = nn.Sequential(
            nn.Linear(in_features=32,
                      out_features=5096))
        
        self.decode_fc2 = nn.Sequential(
            nn.Linear(in_features=5096, 
                      out_features=5096), 
            nn.ReLU(),
            nn.Dropout(0.5))
        #init.xavier_normal(self.fc1.state_dict()['weight'])
        
        self.decode_fc3 = nn.Sequential(
            nn.Linear(in_features = 5096,
                      out_features = 2048),
            nn.ReLU(),
            nn.Dropout(0.5))
        #init.xavier_normal(self.fc2.state_dict()['weight'])
        
        
        self.decode_conv1 = nn.ConvTranspose3d(in_channels=32, 
                                              out_channels=16, 
                                              kernel_size=(4,4,4),
                                              stride = (2,2,2),
                                              padding=(1,1,1))
        self.decode_relu1 = nn.ReLU()
        self.decode_maxunpool1 = nn.MaxUnpool3d(kernel_size=(2, 2, 2), 
                                                     stride=(2, 2, 2))
        #init.xavier_normal(self.group1.state_dict()['weight'])
        
        self.decode_conv2 = nn.ConvTranspose3d(in_channels=16, 
                                              out_channels=8, 
                                              kernel_size=(4,4,4),
                                              stride = (2,2,2),
                                              padding=(1,1,1))
        self.decode_relu2 = nn.ReLU()
        self.decode_maxunpool2 = nn.MaxUnpool3d(kernel_size=(2, 2, 2), 
                                                     stride=(2, 2, 2))
        
        self.decode_conv3 = nn.ConvTranspose3d(in_channels=8, 
                                              out_channels=1, 
                                              kernel_size=(4,4,4),
                                              stride = (2,2,2),
                                              padding=(1,1,1))
        self.decode_relu3 = nn.ReLU()
        self.decode_maxunpool3 = nn.MaxUnpool3d(kernel_size=(2, 2, 2), 
                                                     stride=(2, 2, 2))
        
        
    # Encoding part of VAE
    def encode(self, x):
#         h1 = F.relu(self.fc1(x))
#         return self.fc21(h1), self.fc22(h1)

        print("Starting Encoding")
        print("----------------------------")
        
        out = self.encode_conv1(x)
#         print("First Conv output shape = " + str(out.shape))
        #print(out.shape)
        out = self.encode_relu1(out)
#         print("First ReLU Layer output shape = " + str(out.shape))
        size1 = out.size()
        out, ind1 = self.encode_maxpool1(out)
#         print("First MaxPooling output shape = " + str(out.shape))
#         print("Ind1 shape = " + str(ind1.shape))
#         #print("Size1 = " + str(size1))
#         print("----------------------------")
        
        out = self.encode_conv2(out)
#         print("Second Conv output shape = " + str(out.shape))
        out = self.encode_relu2(out)
#         print("Second ReLU Layer output shape = " + str(out.shape))
        size2 = out.size()
        out, ind2 = self.encode_maxpool2(out)
#         print("Second MaxPooling output shape = " + str(out.shape))
#         print("Ind2 shape = " + str(ind2.shape))
        #print("Size2 = " + str(size2))
#          print("----------------------------")
        
        out = self.encode_conv3(out)
#         print("Last Conv output shape = " + str(out.shape))
        out = self.encode_relu3(out)
#         print("Last ReLU output shape = " + str(out.shape))
        size3 = out.size()
#         out, ind3 = self.encode_maxpool3(out)
#         print("Last Conv Layer output shape = " + str(out.shape))
#         print("Ind3 shape = " + str(ind3.shape))
        #print("Size3 = " + str(size3))
#         print("----------------------------")

        #out = out.view(out.size(0), -1)
        out = out.view(1, -1)
#         print("Last Conv Layer output shape after reshaping \n \
#                 (Input to first FC layer) = " + str(out.shape))
        
        out = self.encode_fc1(out)
        out = self.encode_fc2(out)
        out_mu = self.encode_fc31(out)
        out_logvar = self.encode_fc32(out)
        
        print("Encode - Forward Pass Finished")
        print(out_mu.shape)
        print(out_logvar.shape)
        print("----------------------------")
        
#         return out_mu, out_logvar, [ind1,ind2,ind3], [size1,size2,size3]
        return out_mu, out_logvar, [ind1,ind2], [size1,size2]
    

    # Reparametrization Trick
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
    
    
    # Decoding part of VAE
    def decode(self, z, indices_list, size_list):
#         h3 = F.relu(self.fc3(z))
#         return torch.sigmoid(self.fc4(h3))
        print("----------------------------")
        print("Starting Decoding")
#         print("z shape = " + str(z.shape))
        
        out = self.decode_fc1(z)
#         print("1st FC output shape = " + str(out.shape))
        out = self.decode_fc2(out)
#         print("2nd FC output shape = " + str(out.shape))
        out = self.decode_fc3(out)
#         print("Last FC output shape = " + str(out.shape))
        
        out = out.view(1, 32, 4, 4, 4)
#         print("First Deconv input shape = " + str(out.shape))
#         print("After last convolution (encoding stage) output shape = " +\
#                   str(indices_list[1].shape))
        out = self.decode_conv1(out)
#         print("First Deconv output shape = " + str(out.shape))
        out = self.decode_relu1(out)
#         print("First ReLU output shape = " + str(out.shape))
        # maxunpooling needs indices

#         out = self.decode_maxunpool1(out,
#                              indices = indices_list[1])
        out = self.decode_maxunpool1(out,
                                     indices = indices_list[1],
                                     output_size = size_list[1])
#         print("2nd MaxUnpool ouput shape = " + str(out.shape))
        
        out = self.decode_conv2(out)
#         print("2nd Deconv output shape = " + str(out.shape))
        out = self.decode_relu2(out)
#         print("2nd ReLU output shape = " + str(out.shape))
        out = self.decode_maxunpool1(out,
                     indices = indices_list[0])
#         out = self.decode_maxunpool2(out,
#                                      indices= indices_list[1],
#                                      output_size = size_list[1])
        
        out = self.decode_conv3(out)
        out = self.decode_relu3(out)
#         out = self.decode_maxunpool1(out,
#                              indices = indices_list[0])
        # there is no last maxunpool in https://github.com/pgtgrly/Convolution-Deconvolution-Network-Pytorch/blob/master/Neural_Network_Class.py
#         out = self.decode_maxunpool2(out,
#                                      indices= indices_list[0],
#                                      output_size = size_list[0])
        
        return out
    

    # Forward Pass
    def forward(self, x):
#         mu, logvar = self.encode(x.view(-1, 784))
        mu, logvar, indices_list, size_list = self.encode(x)
        z = self.reparameterize(mu, logvar)
#         print("z = ")
#         print(z)
        return self.decode(z, indices_list, size_list), mu, logvar
       

In [None]:
model = VAE().to(device)

In [None]:
# for param in model.parameters():
#     print(param.name)
#     print(type(param.data), param.size())

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1)

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    print("--------------------------------------")
    print("Calculating Loss...")
#     print("recon_x shape = " + str(recon_x.shape))
    
#     BCE = F.binary_cross_entropy(recon_x, 
#                                  x.view(-1, 1, 128, 128, 128), 
#                                  reduction='sum')
#     print("BCE Loss = " + str(MSE))

    MSE = F.mse_loss(recon_x, 
                             x.view(-1, 1, 128, 128, 128), 
                             reduction='sum')
    
    print("MSE Loss = " + str(MSE))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    print("KLD Loss = " + str(KLD))

    return MSE
#     return BCE + KLD
#     return MSE + KLD

In [None]:
#list(enumerate(train_loader))[0]

In [None]:
#list(enumerate(train_loader))[0][0]

In [None]:
#list(enumerate(train_loader))[0][1]

In [None]:
#list(enumerate(train_loader))[0][2]

In [None]:
list(enumerate(train_loader))[0][1].shape

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
#     for batch_idx, (data, _) in enumerate(train_loader):
    for batch_idx, data in enumerate(train_loader):
#         print(batch_idx)
#         print(data)
        
        #print("Batch size = " + str(data.shape))
        
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        
#         print("Reconstructed Input = \n " + str(recon_batch))
#         print("Real Input = \n " + str(data))
#         print("Reconstructed Input Shape = \n " + str(recon_batch.shape))
#         print("Real Input Shape = \n " + str(data.shape))
        
        loss = loss_function(recon_batch, data, mu, logvar)
        
        loss.backward()
        train_loss += loss.item()
        
        loss_history.append(loss.item())
        
        optimizer.step()
        # if batch_idx % args.log_interval == 0:
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [None]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
#         for i, (data, _) in enumerate(test_loader):
        for i, data in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n], 
                                        recon_batch.view(batch_size, 1, 128, 128, 128)[:n]])
                                      #recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
#                                         recon_batch.view(batch_size, 1, 28, 28)[:n]])
#                 save_image(comparison.cpu(),
#                          'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
if __name__ == "__main__":
    
    loss_history = []
    #for epoch in range(1, args.epochs + 1):
    for epoch in range(1, epochs + 1):
        print("Epoch = " + str(epoch) + " / " + str(epochs))
        
        train(epoch)
        
        # Plotting Training Losses
        plt.figure(figsize=(16,8))
        plt.plot(loss_history)
        plt.show()
        
        test(epoch)
        
        
#         with torch.no_grad():
#             sample = torch.randn(64, 20).to(device)
#             sample = model.decode(sample).cpu()
#             save_image(sample.view(64, 1, 28, 28),
#                        'results/sample_' + str(epoch) + '.png')