In [None]:
import torch
import os
import torchvision
from torchvision.transforms import v2
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn as nn
from PIL import Image, ImageOps
import torch
import pdb
import numpy as np
import yaml
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
sys.path.append('..')
from template import utils
from torchvision.utils import save_image

import torch.nn.functional as F

In [None]:
# setting config
config = yaml.safe_load(open("config.yaml"))
batch_size = int(config["BATCH_SIZE"])

print(f"Our config: {config}")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    v2.Resize((128, 128))
])

In [None]:
train_dataset = torchvision.datasets.CelebA(root='./data', split='train',
                                        download=True, transform=transform)
valid_dataset = torchvision.datasets.CelebA(root='./data', split='valid',
                                       download=True, transform=transform)
test_dataset = torchvision.datasets.CelebA(root='./data', split='test',
                                       download=True, transform=transform)

#create dataloaders
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

# Display images from the dataset

In [None]:
imgs, labels = next(iter(testloader))
print(f"Image Shapes: {imgs.shape}")
print(f"Label Shapes: {labels.shape}")

In [None]:
N_IMGS = 8
fig, ax = plt.subplots(1,N_IMGS)
fig.set_size_inches(3 * N_IMGS, 3)

ids = np.random.randint(low=0, high=len(train_dataset), size=N_IMGS)

for i, n in enumerate(ids):
    img = train_dataset[n][0].numpy().reshape(3,128,128).transpose(1, 2, 0)
    ax[i].imshow(img)
    #ax[i].set_title(f"Img #{n}  Label: {train_dataset[n][1]}")
    #ax[i].axis("off")
plt.show()

In [None]:
def save_model(model, optimizer, epoch, stats, exp_no = 7):
    """ Saving model checkpoint """
    
    if(not os.path.exists("experiments/experiment_"+str(exp_no)+"/models")):
        os.makedirs("experiments/experiment_"+str(exp_no)+"/models")
    savepath = "experiments/experiment_"+str(exp_no)+f"/models/checkpoint_epoch_{epoch}.pth"

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stats': stats
    }, savepath)
    
    return


def load_model(model, optimizer, savepath):
    """ Loading pretrained checkpoint """
    
    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint["epoch"]
    stats = checkpoint["stats"]
    
    return model, optimizer, epoch, stats

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, epoch, device):
    """ Training a model for one epoch """
    
    loss_list = []
    recons_loss = []
    vae_loss = []
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, _) in progress_bar:
        images = images.to(device)
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
         
        # Forward pass
        recons, quant_loss = model(images)
         
        # Calculate Loss
        loss, mse = criterion(recons, quant_loss, images)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())

        # Getting gradients w.r.t. parameters
        loss.backward()
         
        # Updating parameters
        optimizer.step()
        
        progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
        
    mean_loss = np.mean(loss_list)
    
    return mean_loss, loss_list


@torch.no_grad()
def eval_model(model, eval_loader, criterion, device, epoch=None, savefig=False, savepath="", writer=None):
    """ Evaluating the model for either validation or test """
    loss_list = []
    recons_loss = []
    
    for i, (images, _) in enumerate(eval_loader):
        images = images.to(device)
        
        # Forward pass 
        recons, quant_loss = model(images)
                 
        loss, mse = criterion(recons, quant_loss, images)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())
        
        if(i==0 and savefig):
            save_image(recons[:64].cpu(), os.path.join(savepath, f"recons{epoch}.png"))
            
    # Total correct predictions and loss
    loss = np.mean(loss_list)
    recons_loss = np.mean(recons_loss)

    return loss, recons_loss


def train_model(model, optimizer, scheduler, criterion, train_loader, valid_loader,
                num_epochs, savepath, writer, save_frequency=2):
    """ Training a model for a given number of epochs"""
    
    train_loss = []
    val_loss =  []
    val_loss_recons =  []
    val_loss_kld =  []
    loss_iters = []
    
    for epoch in range(num_epochs):
           
        # validation epoch
        model.eval()  # important for dropout and batch norms
        log_epoch = (epoch % save_frequency == 0 or epoch == num_epochs - 1)
        loss, recons_loss= eval_model(
                model=model, eval_loader=valid_loader, criterion=criterion,
                device=device, epoch=epoch, savefig=log_epoch, savepath=savepath,
                writer=writer
            )
        val_loss.append(loss)
        val_loss_recons.append(recons_loss)
        
        # training epoch
        model.train()  # important for dropout and batch norms
        mean_loss, cur_loss_iters = train_epoch(
                model=model, train_loader=train_loader, optimizer=optimizer,
                criterion=criterion, epoch=epoch, device=device
            )
        
        # PLATEAU SCHEDULER
        scheduler.step(val_loss[-1])
        train_loss.append(mean_loss)
        loss_iters = loss_iters + cur_loss_iters
        
        if(epoch % save_frequency == 0):
            stats = {
                "train_loss": train_loss,
                "valid_loss": val_loss,
                "loss_iters": loss_iters
            }
            save_model(model=model, optimizer=optimizer, epoch=epoch, stats=stats)
        
        if(log_epoch):
            print(f"    Train loss: {round(mean_loss, 5)}")
            print(f"    Valid loss: {round(loss, 5)}")
            print(f"       Valid loss recons: {round(val_loss_recons[-1], 5)}")
    
    print(f"Training completed")
    return train_loss, val_loss, loss_iters, val_loss_recons

Dropouts + additional conv layers led to brwon pictures, but more details -> did not improve. When only using kernel size = 2 we saw very blocky pictures so we opted for kernel size 3 with a smaller stride. The results were as expected. We lost some of the blockieness. Increasing encoder linear layers did not help   (led to brownness)
After fixing this with additional conv operations ath the end, we still had the problem that the vae was overfitting to much to the input vectors

In [None]:
# Code was taken from: https://colab.research.google.com/github/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
# Note, that the commitment_cost is the beta parameter from the paper

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings


embedding_dim = 64
num_embeddings = 512

commitment_cost = 0.25

VQ = VectorQuantizer(embedding_dim=embedding_dim, num_embeddings=num_embeddings, commitment_cost=commitment_cost)

inp = torch.zeros(64, 64, 8, 8)

_, out, _, _= VQ(inp)
print(out.shape)

In [None]:
'''# Code was heavily inspired by Explaining AI: https://www.youtube.com/watch?v=1ZHzAOutcnw

class VQVAE(nn.Module):
    def __init__(self):
        super(VQVAE, self).__init__()
        
        # Define Convolutional Encoders
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
        )

        VQ = VectorQuantizer(embedding_dim=64, num_embeddings=512, commitment_cost=0.25)

        # Define decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(), 
            nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = 3, stride = 1, padding = 1),
            nn.Tanh(),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 3, kernel_size = 3, stride = 1, padding = 1),
        )
    
    def forward(self, x):
        encoded_output = self.encoder(x)
        quant_input = self.pre_quant_conv(encoded_output)

        ## Quantization

        B, C, H, W = quant_input.shape
        quant_input = quant_input.permute(0, 2, 3, 1)
        # Shape: B, Pixels, C
        quant_input = quant_input.reshape(quant_input.size(0), -1, quant_input.size(-1))

        # Reshape for torch.cdist
        weights = self.embedding.weight[None, :].repeat(quant_input.size(0), 1, 1).permute(0, 2, 1)
        # Compute pairwise distances
        # print(quant_input.shape)
        # dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat(quant_input.size(0), 1, 1))
        dist = torch.cdist(quant_input, weights)

        # print(dist.shape)
        # Find index of nearest embedding
        min_encoding_indices = torch.argmin(dist, dim=-1)

        # Select the embedding weights
        print(quant_input.shape)
        print(min_encoding_indices.shape)
        print(self.embedding.weight.shape)
        quant_out = torch.index_select(self.embedding.weight.repeat(quant_input.size(0), 1, 1), 2, min_encoding_indices.view(-1))
        print(quant_out.shape)
        quant_input = quant_input.reshape((-1, quant_input.size(-1)))

        # Compute losses (mean or some does not change the gradient only the gradient length)
        commitment_loss = torch.mean((quant_out.detach() - quant_input)**2)
        codebook_loss = torch.mean((quant_out - quant_input.detach()**2))
        quantize_losses = codebook_loss + self.beta*commitment_loss

        # Reshape the output
        quant_out = quant_input + (quant_out - quant_input).detach()
        quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
        min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-1), quant_out))

        # Decoding

        decoder_input = self.post_quant_conv(quant_out)
        output = self.decoder(decoder_input)

        return output, quantize_losses
'''

In [None]:
class VQVAE(nn.Module):
    def __init__(self):
        super(VQVAE, self).__init__()
        
        # Define Convolutional Encoders
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
        )

        self.VQ = VectorQuantizer(embedding_dim=64, num_embeddings=512, commitment_cost=0.25)

        # Define decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(), 
            nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = 4, stride = 2, padding = 1),
            nn.Tanh(),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 3, kernel_size = 4, stride = 2, padding = 1),
        )
    
    def forward(self, x):
        x = self.encoder(x)
        q_loss, x, ppl, encodings = self.VQ(x)
        out = self.decoder(x) 

        return out, q_loss



In [None]:
def vqvae_loss_function(recons, quantize_losses, target):
    recons_loss = F.mse_loss(recons, target)
    
    loss = recons_loss + quantize_losses

    return loss, quantize_losses

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vqvae = VQVAE()
criterion = vqvae_loss_function
optimizer = torch.optim.Adam(vqvae.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 3, factor = 0.5, verbose = True)
vqvae = vqvae.to(device)

In [None]:
savepath = "/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_7"


In [None]:
train_loss, val_loss, loss_iters, val_loss_recons = train_model(
        model=vqvae, optimizer=optimizer, scheduler=scheduler, criterion=vqvae_loss_function,
        train_loader=trainloader, valid_loader=validloader, num_epochs=20, savepath=savepath, writer=None)

In [None]:
with torch.no_grad():
    for i in range(5):
        z = torch.randn(64, 200).to(device)
        sample = cvae.decoder(z)
    

recons = sample.view(64, 3, 128, 128).cpu()

recons.shape

In [None]:
recons.shape

In [None]:
fig, axes = plt.subplots(1, 10, figsize=(128, 128))  # Adjust figsize as needed

for i in range(10):
    img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
    axes[i].imshow(img)
    axes[i].axis('off')  # Turn off axis labels for clarity

plt.tight_layout()
plt.show()

In [None]:
test_data[i][0].shape

In [None]:
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)

with torch.no_grad():
    sample, _ = cvae(test_data)

    recons = sample.view(batch_size, 3, 128, 128).cpu()

    fig, axes = plt.subplots(2, 10, figsize=(128, 128))  # Adjust figsize as needed

    for i in range(10):
        img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
        # test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
        test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
        axes[0][i].imshow(test_img)
        axes[0][i].axis('off')
        axes[1][i].imshow(img)
        axes[1][i].axis('off')  # Turn off axis labels for clarity

    plt.tight_layout()
    plt.show()

# Frechet Inception Distance