In [None]:
import torch
import torch.nn as nn

from dataset import get_dataLoader
from model.vanilla_vae import VanillaVAE

In [None]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

dataloader = get_dataLoader('Data/celeba',
                            partition=-1,
                            batch_size=64,
                            patch_size=64)

len(dataloader)

In [None]:
model = VanillaVAE(in_channels=3,
                   latent_dim=128).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95)



In [None]:

for idx,(image, _) in enumerate(dataloader):
    
    print(idx,'/',len(dataloader))

    image = image.to(device)
    optimizer.zero_grad()
    results=model(image)
    loss = model.loss_function(*results,
                               M_N =0.00025)
    loss['loss'].backward()
    optimizer.step()

scheduler.step()

In [None]:
out = model.sample(4,device)

In [None]:
import matplotlib.pyplot as plt

out = out.detach().cpu().permute(0,2,3,1)
out = out.numpy()





## TODO

1. TensorBoard.
2. weight decay
3. scheduler
4. learning rate
5. kld_weight.


In [None]:
plt.imshow(out[3])


In [None]:
x = next(iter(dataloader))[0][4]
x = x.unsqueeze(0)
x = x.to(device)

In [None]:
forward_result = model(x)


In [None]:
plt.subplot(1,2,1)
plt.imshow(x.squeeze().permute(1,2,0).detach().cpu().numpy())
plt.title('x')

plt.subplot(1,2,2)
plt.imshow(forward_result[0].squeeze().permute(1,2,0).detach().cpu().numpy())
plt.title('recons')

In [None]:
model.loss_function(*forward_result, M_N=0.00025)

In [None]:
recons = forward_result[0]
input = forward_result[1]
mu = forward_result[2]
log_var = forward_result[3]

mu.shape, log_var.shape

In [None]:

kld_loss = torch.mean(-0.5 * torch.sum(1 +
                              log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)


In [None]:
torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1)

In [None]:
kld_loss

In [None]:
1/128

# Experiment

In [None]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from model.vanilla_vae import VanillaVAE
from utils.trainer import train, train_step
from dataset import get_dataLoader

In [None]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

model = VanillaVAE(in_channels=3,
                   latent_dim=128).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95)


dataloader = get_dataLoader('Data/celeba',
                            partition=-1,
                            batch_size=64,
                            patch_size=64)

writer = SummaryWriter()

In [None]:
train(model,
      dataloader,
      optimizer,
      scheduler,
      epochs=20,
      writer=writer,
      device=device)


In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(out[0])

In [None]:
writer.close()

In [None]:
from torchvision.utils import make_grid
writer = SummaryWriter()

x = next(iter(dataloader))[0][0:4]
x = x.to(device)

out = model.sample(4, device)
result = model(x)

image_grid0 = make_grid(x)
image_grid = make_grid(result[0])
image_grid2 = make_grid(out)
writer.add_image('original_image', image_grid0)
writer.add_image('reconstruction', image_grid)
writer.add_image('generation', image_grid2)
writer.close()


In [None]:
x = next(iter(dataloader))[0][0:4]

In [None]:
x.shape

In [None]:
result = model(x)

In [None]:
result[1].shape