In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from pytorch_msssim import ssim

import matplotlib.pyplot as plt
from matplotlib import animation 
from PIL import Image

from tqdm import tqdm

import os

from IPython.display import HTML

# Config

In [None]:
learning_rate = 1e-4
epochs = 250
batch_size = 128

experiment_name = f"convae_v1"

num_layers = 2
max_filters = 1024
image_size = 32

data_prefix = "data"
train_data_folder = data_prefix + "/train/"
val_data_folder = data_prefix + "/val/"
test_data_folder = data_prefix + "/test/"

output_prefix = f"outputs/{experiment_name}"
output_dir = output_prefix + "/generated"
model_output_path = output_prefix + "/model.pt"
animation_output_path = output_prefix + "/animation.mp4"
loss_output_path = output_prefix + "/loss.jpg"

In [None]:
gpu = torch.cuda.is_available()
device = torch.device("cuda" if gpu else "cpu")

In [None]:
print(gpu, device)

In [None]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [None]:
seed = 42
np.random.seed(seed)
_ = torch.manual_seed(seed)

# Load Dataset

In [None]:
def load_images_from_folder(folder):
    dataset = {}
    for file in os.listdir(folder):
        image = Image.open(os.path.join(folder, file))
        dataset[file] = np.array(image)
    return dataset

In [None]:
train = load_images_from_folder(train_data_folder)

In [None]:
val = load_images_from_folder(val_data_folder)

In [None]:
test = load_images_from_folder(test_data_folder)

In [None]:
len(train), len(val), len(test)

# Visualize Some Examples

In [None]:
data = {
    "train": list(train.values()),
    "test": list(test.values()),
    "val": list(val.values())
}

In [None]:
for i, dataset in enumerate(data.keys()):
    for j in range(2):
        index = np.random.randint(0, len(data[dataset]))
        axes = plt.subplot(2, 3, i+j*3+1)
        plt.imshow(data[dataset][index])
        axes.set_title(f"{dataset} ({index})")
plt.tight_layout()

# Preprocess Data

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
])

In [None]:
train_data = [transform(x) for name, x in train.items()]
val_data = [transform(x) for name, x in val.items()]
test_data = [(name, transform(x)) for name, x in test.items()]

# Setup Dataloaders

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=gpu
)
val_dataloader = torch.utils.data.DataLoader(
    val_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=gpu
)
test_dataloader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=gpu
)

# Visualize Again

In [None]:
def make_grid(images, height, width, axis):
    i, j = 0, 0
    text, images = images
    for num, image in enumerate(images):
        if num == height * width:
            break
        axis[i,j].imshow(np.asarray(image.permute(1, 2, 0)))
        if j == width - 1:
            j = 0
            i += 1
        else:
            j += 1
    if type(text) == int:
        text = f"Epoch: {text}"
    fig.suptitle(text, va="baseline")
    plt.tight_layout()
    return axis

In [None]:
def get_samples_from_data(data, sample_size, test=False):
    sample = []
    for i in np.random.choice(len(data), size=sample_size, replace=False):
        if test:
            sample.append(np.asarray(data[i][1]))
        else:
            sample.append(np.asarray(data[i]))
    return torch.as_tensor(sample)

In [None]:
# Creating a sample set that we visualize every epoch to show the model's training
sample = get_samples_from_data(val_data, 16)
test_sample = get_samples_from_data(test_data, 16, test=True)

In [None]:
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Sample", sample), 4, 4, axis)

# Model Time

In [None]:
# Ref: https://github.com/sksq96/pytorch-vae/blob/master/vae-cnn.ipynb
class ConvolutionalVAE(nn.Module):
    def __init__(self, image_channels=3, max_filters=512, num_layers=4, kernel_size=2, stride=2, 
                 padding=0, latent_dim=128, input_image_dimensions=96):
        super(ConvolutionalVAE, self).__init__()
        channel_sizes = self.calculate_channel_sizes(image_channels, max_filters, num_layers)
        # Encoder
        encoder_layers = nn.ModuleList()
        for i, channel_size in enumerate(channel_sizes):
            in_channels = channel_size[0]
            out_channels = channel_size[1]
            # Convolutional Layer
            encoder_layers.append(
                nn.Conv2d(
                    in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 
                    stride=stride, padding=padding, bias=False
                )
            )
            # Batch Norm
            encoder_layers.append(nn.BatchNorm2d(out_channels))
            # ReLU
            encoder_layers.append(nn.ReLU())
        # Flatten Encoder Output
        encoder_layers.append(nn.Flatten())
        self.encoder = nn.Sequential(*encoder_layers)
        
        # Calculate shape of the flattened image
        hidden_dim, image_size = self.get_flattened_size(kernel_size, stride, max_filters, input_image_dimensions)
        
        # Latent Space
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        decoder_layers = nn.ModuleList()
        # Feedforward/Dense Layer to expand our latent dimensions
        decoder_layers.append(nn.Linear(latent_dim, hidden_dim))
        # Unflatten to a shape of (Channels, Height, Width)
        decoder_layers.append(nn.Unflatten(1, (max_filters, image_size, image_size)))
        for i, channel_size in enumerate(channel_sizes[::-1]):
            in_channels = channel_size[1]
            out_channels = channel_size[0]
            # Add Transposed Convolutional Layer
            decoder_layers.append(
                nn.ConvTranspose2d(
                    in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 
                    stride=stride, padding=padding, bias=False
                )
            )
            # Batch Norm
            decoder_layers.append(nn.BatchNorm2d(out_channels))
            # ReLU if not final layer
            if i != num_layers - 1:
                decoder_layers.append(nn.ReLU())
            # Sigmoid if final layer
            else:
                decoder_layers.append(nn.Sigmoid())
        self.decoder = nn.Sequential(*decoder_layers) 
        
    def get_flattened_size(self, kernel_size, stride, filters, input_image_dimensions):
        x = input_image_dimensions
        for layer in self.encoder:
            if "Conv2d" in str(layer):
                x = ((x - kernel_size) // stride) + 1
        return filters * x * x, x
    
    def calculate_channel_sizes(self, image_channels, max_filters, num_layers):
        channel_sizes = [(image_channels, max_filters // np.power(2, num_layers - 1))]
        for i in range(1, num_layers):
            prev = channel_sizes[-1][-1]
            new = prev * 2
            channel_sizes.append((prev, new))
        return channel_sizes
        
    def forward(self, x):
        # Encode
        hidden_state = self.encoder(x)
        # Reparameterize
        mu = self.fc_mu(hidden_state) 
        log_var = self.fc_log_var(hidden_state)
        z = self.reparameterize(mu, log_var)
        # Decode
        reconstructed = self.decoder(z)
        return reconstructed, mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var.mul(0.5)) # log sqrt(x) = log x^0.5 = 0.5 log x
        epsilon = torch.randn_like(mu)
        z = mu + (epsilon * std)
        return z

# Loss Function

Since this is a VAE, we also want to minimize the KL-Divergence between the latent vector Z and our input distribution.
So we add the reconstruction loss + KL-Divergence to get our total loss.

In [None]:
def VAELoss(x, reconstructed_x, mu, log_var):
#     reconstruction_loss = nn.functional.mse_loss(reconstructed_x, x, reduction='sum')
#     KL_d = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
    reconstruction_loss = nn.functional.binary_cross_entropy(reconstructed_x, x, reduction='sum')
    KL_d = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return reconstruction_loss + KL_d, reconstruction_loss, KL_d

# Training Time

In [None]:
model = ConvolutionalVAE(max_filters=max_filters, num_layers=num_layers, input_image_dimensions=image_size)
model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
all_samples = []
all_train_loss = []
all_val_loss = []

for epoch in range(epochs):
    train_loss = 0
    train_recon_loss = 0
    train_kl_d = 0
    val_loss = 0
    val_recon_loss = 0
    val_kl_d = 0
    
    # Training Loop
    for iteration, batch in enumerate(tqdm(train_dataloader)):
        # Reset gradients back to zero for this iteration
        optimizer.zero_grad()

        # Move batch to device
        batch = batch.to(device)

        # Run our model & get outputs
        reconstructed, mu, log_var = model(batch)

        # Calculate reconstruction loss
        batch_loss, batch_recon_loss, batch_kl_d = VAELoss(batch, reconstructed, mu, log_var)
                  
        # Backprop
        batch_loss.backward()
        
        # Update our optimizer parameters
        optimizer.step()
        
        # Add the batch's loss to the total loss for the epoch
        train_loss += batch_loss.item()
        train_recon_loss += batch_recon_loss.item()
        train_kl_d += batch_kl_d.item()
        
    # Validation Loop
    with torch.no_grad():
        for iteration, batch in enumerate(tqdm(val_dataloader)):
            # Move batch to device
            batch = batch.to(device)

            # Run our model & get outputs
            reconstructed, mu, log_var = model(batch)

            # Calculate reconstruction loss
            batch_loss, batch_recon_loss, batch_kl_d = VAELoss(batch, reconstructed, mu, log_var)

            # Add the batch's loss to the total loss for the epoch
            val_loss += batch_loss.item()
            val_recon_loss += batch_recon_loss.item()
            val_kl_d += batch_kl_d.item()

        # Get reconstruction of our sample
        epoch_sample, _, _ = model(sample.to(device))

    # Add sample reconstruction to our list
    all_samples.append(epoch_sample.detach().cpu())
    
    # Compute the average losses for this epoch
    train_loss = train_loss / len(train_dataloader)
    train_recon_loss = train_recon_loss / len(train_dataloader)
    train_kl_d = train_kl_d / len(train_dataloader)
    all_train_loss.append((train_loss, train_recon_loss, train_kl_d))
    
    val_loss = val_loss / len(val_dataloader)
    val_recon_loss = val_recon_loss / len(val_dataloader)
    val_kl_d = val_kl_d / len(val_dataloader)
    all_val_loss.append((val_loss, val_recon_loss, val_kl_d))
    
    # Print Metrics
    print(
        f"Epoch: {epoch+1}/{epochs}:\
        \nTrain Loss = {train_loss}, Train Reconstruction Loss = {train_recon_loss}, Train KL Divergence = {train_kl_d}\
        \nVal Loss = {val_loss}, Val Reconstruction Loss = {val_recon_loss}, Val KL Divergence = {val_kl_d}"
    )

# Visualize Training

In [None]:
# Plot Original Image
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Sample", sample), 4, 4, axis)

In [None]:
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
plt.close()
anim = animation.FuncAnimation(fig=fig, func=make_grid, frames=list(enumerate(all_samples)), 
                               fargs=(4, 4, axis), interval=100, repeat=False)

In [None]:
HTML(anim.to_html5_video())

# Evaluation

In [None]:
model.eval()

In [None]:
all_inputs = []
all_outputs = []
file_names = []

# Testing Loop
with torch.no_grad():
    for iteration, batch in enumerate(tqdm(test_dataloader)):
        # Move batch to device
        filename, image = batch
        image = image.to(device)

        # Run our model & get outputs
        reconstructed, _, _ = model(image)

        all_inputs.extend(image.detach().cpu().numpy())
        all_outputs.extend(reconstructed.detach().cpu().numpy())
        file_names.extend(filename)
        
all_inputs = torch.as_tensor(all_inputs)
all_outputs = torch.as_tensor(all_outputs)

mse = nn.functional.mse_loss(all_outputs, all_inputs)
ssim_score = ssim(all_outputs, all_inputs, data_range=1.0, win_size=11, win_sigma=1.5, K=(0.01, 0.03))

# Print Metrics
print(
    f"MSE = {mse}, SSIM = {ssim_score}"
)

In [None]:
# Plot A Set of Test Images
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Test Sample", test_sample), 4, 4, axis)

In [None]:
with torch.no_grad():
    reconstructed = model(test_sample)[0].detach()

In [None]:
# Plot A Set of Test Images
fig, axis = plt.subplots(4, 4, figsize=(8, 6), dpi=80)
plt.tight_layout()
_ = make_grid(("Reconstructed Test", reconstructed), 4, 4, axis)

In [None]:
plt.figure(figsize=(8, 6), dpi=100)
ax = plt.subplot()
plt.plot([x[1] for x in all_train_loss], label="Train Reconstruction Loss")
plt.plot([x[2] for x in all_train_loss], label="Train KL-Divergence")
plt.plot([x[1] for x in all_val_loss], label="Validation Reconstruction Loss")
plt.plot([x[2] for x in all_val_loss], label="Validation KL-Divergence")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.savefig(loss_output_path)

# Save Model

In [None]:
torch.save(model.state_dict(), model_output_path)

# Save Generated Images

In [None]:
all_outputs = all_outputs.permute(0, 2, 3, 1).numpy()
for image, name in zip(all_outputs, file_names):
    plt.imsave(os.path.join(output_dir, name), image)

# Save Animation

In [None]:
Writer = animation.writers['ffmpeg']
writer = Writer()
anim.save(animation_output_path, writer=writer)