Code taken from: https://github.com/pytorch/examples/tree/master/vae

In [12]:
from __future__ import print_function
import argparse
import os
import h5py
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

In [13]:
# modify to accept hard coded arguments
batch_size = 4
epochs = 2
no_cuda = False
log_interval = 10

cuda = not no_cuda and torch.cuda.is_available()

seed = 1
torch.manual_seed(seed)


# device = torch.device("cuda" if args.cuda else "cpu")
device = torch.device("cuda" if cuda else "cpu")

# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [14]:
class HydrogenDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, h5_file, root_dir):
        """
        Args:
            h5_file (string): name of the h5 file with 32 sampled cubes.
            root_dir (string): Directory with the .h5 file.
        """
        file_size = os.path.getsize(root_dir + h5_file) / 1e6 # in MBs
        print("The file size is " + str(int(file_size)) + " MBs")
        
        # self.subcubes = h5py.File('../data/sample_32.h5', 'r')
        self.subcubes = h5py.File(root_dir + h5_file, 'r')['sample32']
        self.h5_file = h5_file
        self.root_dir = root_dir

    def __len__(self):
        # Function called when len(self) is executed
        
        #print(len(self.subcubes))
        return len(self.subcubes)

    def __getitem__(self, idx):
        """
        This can be implemented in such a way that the whole h5 file read 
        using h5py.File() and get_sample() function is called to return
        a random subcube. This won't increase memory usage because the
        subcubes will be read in the same way and only the batch will
        be read into memory.
        
        Here we have implemented it so that it can be used with data
        generated by get_sample() function.
        
        The output of this function is one subcube with the dimensions
        specified by get_sample() implementation.
        """
        
        sample = self.subcubes[idx]

        return sample

In [15]:
sampled_subcubes = HydrogenDataset(h5_file="sample_32.h5",
                                    root_dir = "../data/")
sampled_subcubes

The file size is 268 MBs


<__main__.HydrogenDataset at 0x11c64a588>

In [16]:
# Data Loaders
train_loader = DataLoader(
        dataset=sampled_subcubes,
        #batch_size=args.batch_size, 
        batch_size=batch_size,
        shuffle=True, **kwargs)

test_loader = DataLoader(
        dataset=sampled_subcubes,
        #batch_size=args.batch_size, 
        batch_size=batch_size,
        shuffle=True, **kwargs)

In [17]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

#         self.fc1 = nn.Linear(784, 400)
#         self.fc21 = nn.Linear(400, 20)
#         self.fc22 = nn.Linear(400, 20)
#         self.fc3 = nn.Linear(20, 400)
#         self.fc4 = nn.Linear(400, 784)

        """
        out_channels is the number of different filters we convolute 
        over the whole sampled subcube.
        
        So the first convolutional layer's in_channel should be 0 (?)
        
        In addition, the next layer's in_channel should be equal to
        the previous layer's out_channels (all examples show that
        this is the case)
        """
        self.encode_group1 = nn.Sequential(
            nn.Conv3d(in_channels=128, 
                      out_channels=64, 
                      kernel_size=(4,4,4),
                      stride = (2,2,2),
                      padding=(1,1,1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), 
                         stride=(2, 2, 2)))
        #init.xavier_normal(self.group1.state_dict()['weight'])
        
        self.encode_group2 = nn.Sequential(
            nn.Conv3d(in_channels = 64, 
                      out_channels = 16, 
                      kernel_size=(4,4,4),
                      stride = (2,2,2),
                      padding=(1,1,1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), 
                         stride=(2, 2, 2)))
        
#         self.encode_group3 = nn.Sequential(
#             nn.Conv3d(in_channels = 8, 
#                       out_channels = 4, 
#                       kernel_size=(4,4,4),
#                       stride = (2,2,2),
#                       padding=(1,1,1)),
#             nn.ReLU(),
#             nn.MaxPool3d(kernel_size=(2, 2, 2), 
#                          stride=(2, 2, 2)))
        
        
        """
        Fully Connected Layers after 3D Convolutional Layers
        First FC layer's input should be equal to 
        last convolutional layer's output
        512 = 8^3 (output of 2nd convolutional layer)
        """
        
        self.encode_fc1 = nn.Sequential(
            nn.Linear(in_features=512, 
                      out_features=2048), 
            nn.ReLU(),
            nn.Dropout(0.5))
        #init.xavier_normal(self.fc1.state_dict()['weight'])
        
        self.encode_fc2 = nn.Sequential(
            nn.Linear(in_features = 2048,
                      out_features = 2048),
            nn.ReLU(),
            nn.Dropout(0.5))
        #init.xavier_normal(self.fc2.state_dict()['weight'])
        
        """
        The last fully connected layer's output is the dimensions
        of the embeddings?
        """
        self.encode_fc3 = nn.Sequential(
            nn.Linear(in_features=2048,
                      out_features=32))
        
        
        # Grouping convolutional and fully-connected layers
        self._encode_conv = nn.Sequential(
            self.encode_group1,
            self.encode_group2
        )

        self._encode_fc = nn.Sequential(
            self.encode_fc1,
            self.encode_fc2
        )
        
        
    # Encoding part of VAE
    def encode(self, x):
#         h1 = F.relu(self.fc1(x))
#         return self.fc21(h1), self.fc22(h1)

        out = self._encode_conv(x)
        out = out.view(out.size(0), -1)
        out = self._encode_fc(out)
        return self.encode_fc3(out)
    
    

    # Reparametrization Trick
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
    
    
    

    # Decoding part of VAE
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    
    

    # Forward Pass
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, 
                                 x.view(-1, 784), 
                                 reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        # if batch_idx % args.log_interval == 0:
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [None]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      #recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                                        recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
if __name__ == "__main__":
    #for epoch in range(1, args.epochs + 1):
    for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')