## Notebook that trains a VAE for galaxy image generation

Implementation adapted from: https://github.com/pytorch/examples/blob/master/vae/main.py

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import save_image
import torchvision
from PIL import Image
import pandas as pd
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
import torchvision.utils as vutils
import matplotlib.animation as animation
from IPython.display import HTML

# set up accordingly
data_dir = "../../data" # directory with data files
labeled_image_dir = "labeled/"       # folder within data directory with labeled images (0, 1)
scored_image_dir = "scored_128"      # folder within data directory with scored images
device = torch.device("cuda")        # cuda or cpu
collab = False                        # google collab flag

In [2]:
# google collab
# mount drive, copy over data as zip and unzip it
if collab:
  from google.colab import drive
  drive.mount('/content/drive')
  collab_dir = "/content/"

  zip_path = os.path.join(data_dir, 'labeled.zip')
  !cp '{zip_path}' .
  !unzip -q labeled.zip
  !rm labeled.zip

  zip_path = os.path.join(data_dir, 'scored_128.zip')
  !cp '{zip_path}' .
  !unzip -q scored_128.zip
  !rm scored_128.zip
else:
    labeled_image_dir = os.path.join(data_dir, labeled_image_dir)
    scored_image_dir = os.path.join(data_dir, scored_image_dir)

In [3]:
class GalaxyDataset(torch.utils.data.Dataset):
    """
    Galaxy dataset class
    Builds a dataset from the labeled and scored images. 
    Requires a threshold score for scored images. 
    Images with a score below the threshold are not used.
    """

    def __init__(self, csv_file, image_dir, scored_dir=None, scores_file=None, transform=None, train=True, size=(128, 128), train_split=0.8, scored_threshold=3):
        self.labels = pd.read_csv(csv_file, index_col="Id")
        self.labels = self.labels[self.labels['Actual'] == 1.0]
        self.size = size
        self.original = []
        self.scores = None
        if scores_file is not None and train == True:
          self.scores = pd.read_csv(scores_file, index_col="Id")
        self.samples = []
        if train == True:
          self.labels = self.labels[:int(self.labels.shape[0]*train_split)]
        else:
          self.labels = self.labels[int(self.labels.shape[0]*train_split):]
        self.image_dir = image_dir
        self.transform = transform
        self.scored_dir = scored_dir
        self.scored_threshold = scored_threshold
        self.load_dataset()
    def __len__(self):
        return len(self.samples)

    def load_dataset(self):
      print("Loading Dataset...")
      for id, _ in self.labels.iterrows():
        img_name = os.path.join(self.image_dir,
                                  str(id)+'.png')
        self.original.append(Image.open(img_name))
        self.samples.append(Image.open(img_name).resize(self.size))
      
      if self.scores is not None:
        for id, score in self.scores.iterrows():
          if score.item() > self.scored_threshold:

            img_name = os.path.join(self.scored_dir,
                                      str(id)+'.png')
            self.original.append(Image.open(img_name))
            self.samples.append(Image.open(img_name).resize(self.size))
        
      print("Dataset Loaded")

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.samples[idx]
        if self.transform:
            image = self.transform(image)

        return image

In [4]:
batch_size = 64
size = (128, 128)
train_transformation = torchvision.transforms.Compose([
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize(0, 255.0)
])
val_transformation = torchvision.transforms.Compose([
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize(0, 255.0)
])

train_dataset = GalaxyDataset(os.path.join(data_dir, "labeled.csv"), labeled_image_dir, scored_dir=scored_image_dir, scores_file=os.path.join(data_dir, "scored.csv"), transform=train_transformation, train=True, size=size, train_split=0.8, scored_threshold=2.60)
val_dataset = GalaxyDataset(os.path.join(data_dir, "labeled.csv"), labeled_image_dir, transform=val_transformation, train=False, size=size, train_split=0.8)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=1, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=1, pin_memory=True)

print("Loaded {} train images, {} val images".format(len(train_dataset), len(val_dataset)))

Loading Dataset...
Dataset Loaded
Loading Dataset...
Dataset Loaded
Loaded 2554 train images, 200 val images


Setup weight initialization

In [0]:
def kaiming_init(m):
      if isinstance(m, (nn.Linear, nn.Conv2d)):
          init.kaiming_normal(m.weight)
          if m.bias is not None:
              m.bias.data.fill_(0)
      elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
          m.weight.data.fill_(1)
          if m.bias is not None:
              m.bias.data.fill_(0)

In [0]:
class Encoder(nn.Module):
    def __init__(self, config):
        self.config = config
        super(Encoder, self).__init__()
        # 1 * 128 * 128
        self.conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(32)
        # 32 * 64 * 64
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(64)
        # 64 * 32 * 32
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.batch_norm3 = nn.BatchNorm2d(128)
        # 128 * 16 * 16
        self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.batch_norm4 = nn.BatchNorm2d(128)
        # 128 * 8 * 8
        self.conv5 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.batch_norm5 = nn.BatchNorm2d(128)
        # 128 * 4 * 4

        self.mu_linear = nn.Linear(128 * 4 * 4, self.config['z_dim'])
        self.log_sigma_squared_linear = nn.Linear(128 * 4 * 4, self.config['z_dim'])
        self.weight_init()

    def weight_init(self):
        for m in self._modules:
            kaiming_init(m)
            
    def forward(self, x):
        x = self.batch_norm1(F.relu(self.conv1(x)))
        x = self.batch_norm2(F.relu(self.conv2(x)))
        x = self.batch_norm3(F.relu(self.conv3(x)))
        x = self.batch_norm4(F.relu(self.conv4(x)))
        x = self.batch_norm5(F.relu(self.conv4(x)))
        x = x.view(-1, 128 * 4 * 4)
        return self.mu_linear(x), self.log_sigma_squared_linear(x)

In [0]:
class Decoder(nn.Module):
    def __init__(self, config):
        self.config = config
        super(Decoder, self).__init__()
        self.upsample = nn.Upsample()
        self.linear1 = nn.Linear(self.config['z_dim'], 128 * 4 * 4)
        self.linear2 = nn.Linear(128 * 4 * 4, 128 * 4 * 4)
        self.conv1 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1)
        self.batch_norm1 = nn.BatchNorm2d(128)
        self.conv2 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1)
        self.batch_norm2 = nn.BatchNorm2d(128)
        self.conv3 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.batch_norm3 = nn.BatchNorm2d(64)
        self.conv4 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.batch_norm4 = nn.BatchNorm2d(32)
        self.conv5 = nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1)
        self.batch_norm5 = nn.BatchNorm2d(1)
        self.conv6 = nn.Conv2d(1, 1, 1, stride=1)
    
        self.weight_init()

    def weight_init(self):
        for m in self._modules:
            kaiming_init(m)
            
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = x.view(-1, 128, 4, 4)
        x = self.batch_norm1(F.relu(self.conv1(x)))
        x = self.batch_norm2(F.relu(self.conv2(x)))
        x = self.batch_norm3(F.relu(self.conv3(x)))
        x = self.batch_norm4(F.relu(self.conv4(x)))
        x = self.batch_norm5(F.relu(self.conv5(x)))
        x = self.conv6(x)
        return x

In [0]:
class VAE(nn.Module):
    def __init__(self, config):
        super(VAE, self).__init__()
        self.config = config
        self.Encoder = Encoder(config)
        self.Decoder = Decoder(config)
        
    def forward(self, x):
        mu, log_sigma_sq = self.Encoder(x)
        # sample
        z = self.reparametrize(mu, log_sigma_sq)
        return self.Decoder(z), mu, log_sigma_sq

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.randn_like(std)
        return eps * std + mu

    def encode(self, x):
        return self.Encoder(x)
        
    def decode(self, x):
        return self.Decoder(x)

Set up training parameters

In [0]:
config = {'z_dim': 64}
vae = VAE(config).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
Beta = 1
Lambda = 1
epochs = 500

kl_start = 40
kl_anneal_time = 20
annealing_factor = 0

# Reconstruction + KL divergence losses summed over all elements and batch, annealed
def loss_function(recon_x, x, mu, logvar, epoch):
    recon_x = recon_x.view(-1, size[0]*size[1])
    x = x.view(-1, size[0]*size[1])

    BCE = torch.nn.BCEWithLogitsLoss(reduction='none')
    recon_loss = BCE(recon_x, x).sum(1)/(size[0]*size[1])
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    # add normalization by image size
    KLD = (-0.5*(1 + logvar - mu.pow(2) - logvar.exp()).sum(1))/config['z_dim']
    
    annealing_factor = min(kl_anneal_time, max(epoch - kl_start, 0))/kl_anneal_time
    return (recon_loss + Beta*KLD).mean() #+ Lambda*l1_loss

In [0]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar, epoch)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 5 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / (len(train_loader.dataset) // batch_size)))

In [0]:
def val(epoch):
    vae.eval()
    val_loss = 0
    with torch.no_grad():
        for i, (data)  in enumerate(val_loader):
            data = data.to(device)
            recon_batch, mu, logvar = vae(data)
            recon_sigmoided = torch.sigmoid(recon_batch)
            val_loss += loss_function(recon_batch, data, mu, logvar, epoch).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_sigmoided.view(batch_size, 1, size[0], size[1])[:n]])
                save_image(comparison.cpu(),
                         os.path.join(results_dir, 'reconstruction',str(epoch) + '.png'), nrow=n, padding=2, pad_value=1)

    val_loss /=  (len(val_loader.dataset) // batch_size)
    print('====> Validation loss: {:.4f}'.format(val_loss))

Train and periodically save output from generator

In [25]:
results_dir = "results_vae"
if not os.path.exists(os.path.join(outf, results_dir)):
  os.mkdir(os.path.join(outf, results_dir))
  os.mkdir(os.path.join(outf, results_dir, "reconstruction"))
  os.mkdir(os.path.join(outf, results_dir, "sample"))
for epoch in range(1, epochs):
  train(epoch)
  val(epoch)
  with torch.no_grad():
    sample = torch.randn(16, config['z_dim']).to(device)
    sample = torch.sigmoid(vae.decode(sample).cpu())
    save_image(sample.view(16, 1, size[0], size[1]),
                os.path.join(outf, results_dir, 'sample', str(epoch) + '.png'), padding=2, pad_value=1)

====> Epoch: 1 Average loss: 0.3872
====> Validation loss: 0.4398
====> Epoch: 2 Average loss: 0.3147
====> Validation loss: 0.4074
====> Epoch: 3 Average loss: 0.2982
====> Validation loss: 0.3789
====> Epoch: 4 Average loss: 0.2820
====> Validation loss: 0.3794
====> Epoch: 5 Average loss: 0.2688
====> Validation loss: 0.5629
====> Epoch: 6 Average loss: 0.2490
====> Validation loss: 0.3170
====> Epoch: 7 Average loss: 0.2505
====> Validation loss: 6.5031
====> Epoch: 8 Average loss: 0.2189
====> Validation loss: 0.2746
====> Epoch: 9 Average loss: 0.1989
====> Validation loss: 0.2493
====> Epoch: 10 Average loss: 0.1862
====> Validation loss: 0.2678
====> Epoch: 11 Average loss: 0.1707
====> Validation loss: 0.2157
====> Epoch: 12 Average loss: 0.1548
====> Validation loss: 0.2368
====> Epoch: 13 Average loss: 0.1417
====> Validation loss: 0.1811
====> Epoch: 14 Average loss: 0.1294
====> Validation loss: 0.1624
====> Epoch: 15 Average loss: 0.1181
====> Validation loss: 0.1485
====

KeyboardInterrupt: ignored

In [0]:
torch.save(vae.state_dict(), os.path.join(outf, results_dir, "vae.model"))
if collab:
  !cp {results_dir}' vae.model '{data_dir}'/'{results_dir}'
  !cp -rf '{results_dir}' '{data_dir}'