In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch import nn
import torch.nn.functional as F
from PIL import Image, ImageEnhance
import plotly.express as px
import plotly.graph_objects as go
import os
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
class CustomFaceDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None, shift = None):
        self.path = path
        self.transform = transform
        self.size = len(os.listdir(path))
        if shift is not None:
            self.shift = shift
        else:
            self.shift = 0
        
    def __getitem__(self, index): # without labels
        image_path = self.path + f'{index + self.shift}'.rjust(5, '0') + '.jpg'
        x = Image.open(image_path)
        if self.transform is not None:
            x = self.transform(x)
        return x, 0
    
    def __len__(self):
        return self.size

In [4]:
transform = transforms.Compose([
        transforms.Resize(128),
        transforms.ToTensor()
    ])

train_face_dataset = CustomFaceDataset('../celeba_hq_256/', transform=transform, shift=1000)

val_face_dataset = CustomFaceDataset('../data/', transform=transform)

In [5]:
px.imshow(transforms.ToPILImage()(val_face_dataset[228][0]))

In [7]:
train_face_dataloader = DataLoader(train_face_dataset, batch_size=16, shuffle=False)

val_face_dataloader = DataLoader(val_face_dataset, batch_size=16, shuffle=False)

In [8]:
final_conv_size = 83
kernel_size = 16

In [9]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size),
            nn.ReLU(),
        )
        self.flatten = nn.Flatten()

        self.linear = nn.Sequential(
            nn.Linear(32 * final_conv_size ** 2, 128),
            nn.ReLU()
        )

        self.linear_mu = nn.Sequential(
            nn.Linear(128, encoded_space_dim),
            nn.Dropout(0.25)
        )

        self.linear_sigma = nn.Sequential(
            nn.Linear(128, encoded_space_dim),
            nn.Dropout(0.25)
        )
        
        self.N = torch.distributions.Normal(0, 1)

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.linear(x)
        
        mu = self.linear_mu(x)
        sigma = self.linear_sigma(x)

        N = self.N.sample(mu.shape).to(device)

        z = mu + torch.exp(sigma / 2)*N
        return z, mu, sigma

In [10]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32 * final_conv_size**2),
            nn.ReLU()
        )

        self.unflatten = nn.Unflatten(unflattened_size=(32, final_conv_size, final_conv_size), dim=1)

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 8, kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.ConvTranspose2d(8, 3, kernel_size),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [11]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z, mu, sigma = self.encoder(x)
        return self.decoder(z), mu, sigma

In [12]:
from tqdm import tqdm

In [13]:
model = VariationalAutoencoder(36).to(device)

In [14]:
def Loss(logit, data, mu, sigma):
    mse = nn.MSELoss(reduction='sum')
    return mse(logit, data) + -0.5 * (1 + sigma - torch.pow(mu, 2) - torch.exp(sigma)).sum()
criterion = Loss
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)

In [15]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for data, _ in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        reconstruction, mu, sigma = model(data)
        loss = Loss(reconstruction, data, mu, sigma)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [16]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for data, _ in dataloader:
            data = data.to(device)
            reconstruction, mu, sigma = model(data)
            loss = Loss(reconstruction, data, mu, sigma)
            running_loss += loss.item()
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [17]:
train_loss = []
val_loss = []
epochs = 20

for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_face_dataloader)
    val_epoch_loss = validate(model, val_face_dataloader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 20
Train Loss: 4048.6328
Val Loss: 3530.4314
Epoch 2 of 20
Train Loss: 3462.6632
Val Loss: 3361.8752
Epoch 3 of 20
Train Loss: 3403.0095
Val Loss: 3317.0634
Epoch 4 of 20
Train Loss: 3385.5013
Val Loss: 3328.4905
Epoch 5 of 20
Train Loss: 3376.7414
Val Loss: 3322.4435
Epoch 6 of 20
Train Loss: 3371.9736
Val Loss: 3321.0204
Epoch 7 of 20
Train Loss: 3365.1173
Val Loss: 3310.3787
Epoch 8 of 20
Train Loss: 3360.2927
Val Loss: 3310.5229
Epoch 9 of 20
Train Loss: 3355.4733
Val Loss: 3309.9623
Epoch 10 of 20
Train Loss: 3350.6559
Val Loss: 3307.0626
Epoch 11 of 20
Train Loss: 3348.6117
Val Loss: 3301.1228
Epoch 12 of 20
Train Loss: 3344.8681
Val Loss: 3295.7545
Epoch 13 of 20
Train Loss: 3342.1914
Val Loss: 3297.2419
Epoch 14 of 20
Train Loss: 3339.3312
Val Loss: 3299.5755
Epoch 15 of 20
Train Loss: 3337.0367
Val Loss: 3319.7567
Epoch 16 of 20
Train Loss: 3336.2330
Val Loss: 3294.5630
Epoch 17 of 20
Train Loss: 3333.6481
Val Loss: 3301.9657
Epoch 18 of 20
Train Loss: 3332.0032
Val

In [52]:
fig = go.Figure(
    go.Scatter(x=torch.linspace(1, 10, 10), y=train_loss, mode='lines', name='train loss')
)

fig.add_trace(
    go.Scatter(x=torch.linspace(1, 10, 10), y=val_loss, mode='lines', name='validation loss')
)

fig.show()

In [91]:
model = model.to(device)

In [98]:
test = model.encoder(val_face_dataset[228][0][None,:,:,:].to(device))[0]

In [219]:
model.eval()
img = transforms.ToPILImage()(model.decoder(4*torch.rand((1,36)).to(device))[0])

enc = ImageEnhance.Sharpness(img)

img = enc.enhance(2.5)

px.imshow(img)

In [87]:
model = model.cpu()
torch.save(model.state_dict(), './VAE_1_0_0_CPU.pth')