# Variational AutoEncoder

In [15]:
import os
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import pytorch_model_summary
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

if not os.path.exists('./VAE_img'):
    os.mkdir('./VAE_img')

In [16]:
def normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor

def value_round(tensor):
    return torch.round(tensor)

def to_img(x):
    x = x.view(x.size(0), 1, 28, 28)
    return x

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda tensor:normalization(tensor, 0, 1)),
    transforms.Lambda(lambda tensor:value_round(tensor))
])
batch_size = 1024

dataset = MNIST('./MNIST_dataset', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [17]:
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = nn.Sequential( # VAE의 encoder는 잠재 벡터의 확률 분포를 나타내는 평균과 로그분산을 포함한 벡터를 출력하는 역할
            nn.Linear(28 * 28, 400),
            nn.ReLU(True),
            nn.Linear(400, 40))
        self.decoder = nn.Sequential(
            nn.Linear(20, 400),
            nn.ReLU(True),
            nn.Linear(400, 28 * 28),
            nn.Sigmoid())

    def reparametrization(self, mu, logvar):
        var = logvar.exp()
        std = var.sqrt()
        eps = Variable(torch.cuda.FloatTensor(std.size()).normal_())
        return eps.mul(std).add(mu) # eps * std + mu

    def forward(self, x):
        h = self.encoder(x) # h: 잠재 벡터의 평균과 로그분산을 포함한 벡터
        mu = h[:, :20]
        logvar = h[:, 20:]
        z = self.reparametrization(mu, logvar)
        x_gen = self.decoder(z)
        return x_gen, mu, logvar

    def interpolation(self, x_1, x_2, alpha):
        traverse_1 = self.encoder(x_1)
        traverse_2 = self.encoder(x_2)
        mu_1, mu_2 = traverse_1[:, :20], traverse_2[:, :20]
        logvar_1, logvar_2 = traverse_1[:, 20:], traverse_2[:, 20:]
        traverse_m = (1 - alpha) * mu_1 + alpha * mu_2
        traverse_logvar = (1 - alpha) * logvar_1 + alpha * logvar_2
        z = self.reparametrization(traverse_m, traverse_logvar)
        generated_image = self.decoder(z)
        return generated_image

In [18]:
model = VariationalAutoencoder().cuda()
print(pytorch_model_summary.summary(model, torch.zeros(1,784).cuda(), show_input = True))

-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Linear-1            [1, 784]         314,000         314,000
            ReLU-2            [1, 400]               0               0
          Linear-3            [1, 400]          16,040          16,040
          Linear-4             [1, 20]           8,400           8,400
            ReLU-5            [1, 400]               0               0
          Linear-6            [1, 400]         314,384         314,384
         Sigmoid-7            [1, 784]               0               0
Total params: 652,824
Trainable params: 652,824
Non-trainable params: 0
-----------------------------------------------------------------------


In [19]:
def visualize_latent_space(vae, dataloader, method='PCA', dimension=2, epoch=0):
    '''
    To visualize the latent space of VAE.
    method: 'PCA' or 't-SNE'
    dimension: 2 or 3 (dimensionality of visualization)
    '''
    latent_vectors = []
    labels = []

    # 잠재 공간 벡터 추출
    with torch.no_grad():
        for data in dataloader:
            img, label = data
            img = img.view(img.size(0), -1).cuda()
            h = vae.encoder(img)
            mu = h[:, :20]  # 평균 벡터만 사용
            latent_vectors.append(mu.cpu())
            labels.append(label)
    
    latent_vectors = torch.cat(latent_vectors).numpy()
    labels = torch.cat(labels).numpy()

    # 차원 축소
    if method == 'PCA':
        reducer = PCA(n_components=dimension)
    elif method == 't-SNE':
        reducer = TSNE(n_components=dimension, perplexity=30, random_state=777)
    else:
        raise ValueError("Invalid method. Use 'PCA' or 't-SNE'.")
    
    latent_reduced = reducer.fit_transform(latent_vectors)

    # 시각화
    if dimension == 2:
        # 2D 시각화
        plt.figure(figsize=(8, 6))
        scatter = plt.scatter(latent_reduced[:, 0], latent_reduced[:, 1], c=labels, cmap='viridis', s=5)
        plt.colorbar(scatter)
        plt.title(f'Latent Space Visualization ({method}) - Epoch {epoch}')
        plt.xlabel('Component 1')
        plt.ylabel('Component 2')
        plt.savefig(f'./VAE_img/latent_space_{method}_2D_epoch_{epoch}.png')
        plt.close()
    elif dimension == 3:
        # 3D 시각화
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        scatter = ax.scatter(latent_reduced[:, 0], latent_reduced[:, 1], latent_reduced[:, 2], c=labels, cmap='viridis', s=5)
        fig.colorbar(scatter, ax=ax, pad=0.1)
        ax.set_title(f'Latent Space Visualization ({method}) - Epoch {epoch}')
        ax.set_xlabel('Component 1')
        ax.set_ylabel('Component 2')
        ax.set_zlabel('Component 3')
        plt.savefig(f'./VAE_img/latent_space_{method}_3D_epoch_{epoch}.png')
        plt.close()
    else:
        raise ValueError("Invalid dimension. Use 2 or 3.")

In [None]:
BCE = nn.BCELoss()
num_epochs, learning_rate = 50, 3e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def evaluate_reconstruction(vae, dataloader):
    total_loss = 0
    criterion = nn.MSELoss()

    with torch.no_grad():
        for data in dataloader:
            img, _ = data
            img = img.view(img.size(0), -1).cuda()
            x_gen, _, _ = vae(img)
            loss = criterion(x_gen, img)
            total_loss += loss.item() * img.size(0)

    return total_loss / len(dataloader.dataset)

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img).cuda()
        x_gen, mu, logvar = model(img)

        # KL Divergence 계산
        NKLD = mu.pow(2).add(logvar.exp()).mul(-1).add(logvar.add(1))
        KLD = torch.sum(NKLD).mul(-0.5)
        KLD /= batch_size * 784

        # 손실 계산
        loss = BCE(x_gen, img) + KLD

        # 최적화
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if epoch % 10 == 0 or (epoch + 1) == num_epochs:
        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))

        # 재구성 이미지 저장
        x_gt = to_img(img.cpu().data)
        x_gen = to_img(x_gen.cpu().data)
        save_image(x_gt, './VAE_img/ground_truth_{}.png'.format(epoch))
        save_image(x_gen, './VAE_img/generated_x{}.png'.format(epoch))

        # 잠재 공간 시각화
        visualize_latent_space(model, dataloader, method='t-SNE', dimension=2, epoch=epoch)

        # 재구성 성능 평가
        recon_loss = evaluate_reconstruction(model, dataloader)
        print(f'Epoch {epoch + 1}/{num_epochs}, Reconstruction Loss: {recon_loss:.4f}')

        # 보간 이미지 저장
        batch = next(iter(dataloader))
        batch = batch[0].clone().detach()
        batch = batch.view(batch.size(0), -1)
        batch = Variable(batch).cuda()
        x_1 = batch[0:1]
        x_2 = batch[1:2]
        generated_images = []
        for alpha in torch.arange(0.0, 1.0, 0.1):
            generated_images.append(model.interpolation(x_1, x_2, alpha))
        generated_images = torch.cat(generated_images, 0).cpu().data
        save_image(generated_images.view(-1, 1, 28, 28),
                   './VAE_img/interpolation_{}.png'.format(epoch),
                   nrow=1)

# 모델 저장
torch.save(model.state_dict(), './variational_autoencoder.pth')