In [1]:
from model import VariationalAutoencoder
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from PIL import Image, ImageEnhance
import plotly.express as px
import plotly.graph_objects as go
import os
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
class CustomFaceDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None, shift = None):
        self.path = path
        self.transform = transform
        self.size = len(os.listdir(path))
        if shift is not None:
            self.shift = shift
        else:
            self.shift = 0
        
    def __getitem__(self, index): # without labels
        image_path = self.path + f'{index + self.shift}'.rjust(5, '0') + '.jpg'
        x = Image.open(image_path)
        if self.transform is not None:
            x = self.transform(x)
        return x, 0
    
    def __len__(self):
        return self.size

In [4]:
transform = transforms.Compose([
        transforms.Resize(128),
        transforms.ToTensor()
    ])

train_face_dataset = CustomFaceDataset('../celeba_hq_256/', transform=transform, shift=1000)

val_face_dataset = CustomFaceDataset('../data/', transform=transform)

In [5]:
px.imshow(transforms.ToPILImage()(val_face_dataset[228][0]))

In [6]:
train_face_dataloader = DataLoader(train_face_dataset, batch_size=16, shuffle=False)

val_face_dataloader = DataLoader(val_face_dataset, batch_size=16, shuffle=False)

In [7]:
from tqdm import tqdm

In [8]:
model = VariationalAutoencoder(36, device).to(device)

In [9]:
def Loss(logit, data, mu, sigma):
    mse = nn.MSELoss(reduction='sum')
    return mse(logit, data) + -0.5 * (1 + sigma - torch.pow(mu, 2) - torch.exp(sigma)).sum()
criterion = Loss
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)

In [10]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for data, _ in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        reconstruction, mu, sigma = model(data)
        loss = Loss(reconstruction, data, mu, sigma)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [11]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for data, _ in dataloader:
            data = data.to(device)
            reconstruction, mu, sigma = model(data)
            loss = Loss(reconstruction, data, mu, sigma)
            running_loss += loss.item()
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [12]:
train_loss = []
val_loss = []
epochs = 20

for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_face_dataloader)
    val_epoch_loss = validate(model, val_face_dataloader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 20
Train Loss: 3610.4102
Val Loss: 3375.8364
Epoch 2 of 20
Train Loss: 3387.6952
Val Loss: 3304.4331
Epoch 3 of 20
Train Loss: 3356.4900
Val Loss: 3274.1413
Epoch 4 of 20
Train Loss: 3339.8884
Val Loss: 3263.1116
Epoch 5 of 20
Train Loss: 3328.5735
Val Loss: 3257.9079
Epoch 6 of 20
Train Loss: 3320.5782
Val Loss: 3246.3232
Epoch 7 of 20
Train Loss: 3312.9389
Val Loss: 3275.2090
Epoch 8 of 20
Train Loss: 3306.9220
Val Loss: 3235.0253
Epoch 9 of 20
Train Loss: 3303.4774
Val Loss: 3239.7931
Epoch 10 of 20
Train Loss: 3299.5290
Val Loss: 3244.2790
Epoch 11 of 20
Train Loss: 3295.4949
Val Loss: 3227.0699
Epoch 12 of 20
Train Loss: 3292.9659
Val Loss: 3265.2540
Epoch 13 of 20
Train Loss: 3289.4824
Val Loss: 3246.2957
Epoch 14 of 20
Train Loss: 3287.1811
Val Loss: 3227.6395
Epoch 15 of 20
Train Loss: 3285.9641
Val Loss: 3212.8715
Epoch 16 of 20
Train Loss: 3281.8027
Val Loss: 3220.6387
Epoch 17 of 20
Train Loss: 3282.3406
Val Loss: 3237.0313
Epoch 18 of 20
Train Loss: 3278.2988
Val

In [19]:
fig = go.Figure(
    go.Scatter(x=torch.linspace(1, 20, 10), y=train_loss, mode='lines', name='train loss')
)

fig.add_trace(
    go.Scatter(x=torch.linspace(1, 20, 10), y=val_loss, mode='lines', name='validation loss')
)

fig.show()

In [22]:
model = model.to(device)

In [15]:
test = model.encoder(val_face_dataset[228][0][None,:,:,:].to(device))[0]

In [17]:
model = model.cpu()
torch.save(model.state_dict(), './VAE_1_1_0_CPU.pth')