In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch import nn
import torch.nn.functional as F
from PIL import Image, ImageEnhance
import plotly.express as px
import plotly.graph_objects as go
import os
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
class CustomFaceDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None, shift = None):
        self.path = path
        self.transform = transform
        self.size = len(os.listdir(path))
        if shift is not None:
            self.shift = shift
        else:
            self.shift = 0
        
    def __getitem__(self, index): # without labels
        image_path = self.path + f'{index + self.shift}'.rjust(5, '0') + '.jpg'
        x = Image.open(image_path)
        if self.transform is not None:
            x = self.transform(x)
        return x, 0
    
    def __len__(self):
        return self.size

In [4]:
transform = transforms.Compose([
        transforms.Resize(128),
        transforms.ToTensor()
    ])

train_face_dataset = CustomFaceDataset('../celeba_hq_256/', transform=transform, shift=1000)

val_face_dataset = CustomFaceDataset('../data/', transform=transform)

In [5]:
px.imshow(transforms.ToPILImage()(val_face_dataset[228][0]))

In [6]:
conv = nn.Sequential(
            nn.Conv2d(3, 8, 16),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 8),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 4),
            nn.ReLU(True)
        )

conv(val_face_dataset[228][0][None,:,:,:]).shape

torch.Size([1, 32, 103, 103])

In [7]:
train_face_dataloader = DataLoader(train_face_dataset, batch_size=16, shuffle=False)

val_face_dataloader = DataLoader(val_face_dataset, batch_size=16, shuffle=False)

In [8]:
final_conv_size = 103

In [9]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(3, 8, 16),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 8),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 4),
            nn.ReLU(True)
        )
        self.flatten = nn.Flatten()

        self.linear = nn.Sequential(
            nn.Linear(32 * final_conv_size ** 2, 128),
            nn.ReLU(True)
        )

        self.linear_mu = nn.Sequential(
            nn.Linear(128, encoded_space_dim)
        )

        self.linear_sigma = nn.Sequential(
            nn.Linear(128, encoded_space_dim),
            nn.ReLU()
        )
        
        self.N = torch.distributions.Normal(0, 1)
        self.kl = 0 # kld

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.linear(x)
        
        mu = self.linear_mu(x)
        sigma = self.linear_sigma(x)

        N = self.N.sample(mu.shape).to(device)

        z = mu + torch.exp(sigma / 2)*N

        self.kl = -0.5 * (1 + sigma - torch.pow(mu, 2) - torch.exp(sigma)).sum()
        return z

In [10]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32 * final_conv_size**2),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(unflattened_size=(32, final_conv_size, final_conv_size), dim=1)

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 4),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 8),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 16)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [11]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [12]:
from tqdm import tqdm

In [20]:
def train_vae(autoencoder, data, epochs=20):
    train_loss = np.zeros((epochs, 2))
    autoencoder.train()
    crit = nn.MSELoss(reduction='sum')
    opt = torch.optim.Adam(autoencoder.parameters())
    for epoch in tqdm(range(epochs)):
        dyn_loss = 0.
        for x, _ in data:
            x = x.to(device) # GPU
            opt.zero_grad()
            x_hat = autoencoder(x)
            bce_loss = crit(x_hat, x)
            loss = bce_loss + autoencoder.encoder.kl
            loss.backward()
            dyn_loss += loss.item()
            opt.step()
        train_loss[epoch][0] = epoch + 1
        train_loss[epoch][1] = dyn_loss / len(val_face_dataset)
        
    fig = go.Figure(data=go.Scatter(x=train_loss[:, 0], y=train_loss[:,1]))
    fig.show()
    return autoencoder

In [79]:
model = VariationalAutoencoder(36).to(device)

In [80]:
model = train_vae(model, val_face_dataloader)

100%|██████████| 20/20 [01:04<00:00,  3.24s/it]


In [151]:
model.eval()
img = transforms.ToPILImage()(model.decoder(2*torch.rand((1, 36)).to(device))[0])

enc = ImageEnhance.Sharpness(img)

img = enc.enhance(2.5)

px.imshow(img)