# VAE Training

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch, wandb, cv2
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
from vae import VAE, loss_fn
from camvid import *

In [None]:
# data
SHAPE = (66, 200)
batch_size = 16
IMAGE_PATH = "../CamVid/train/*.png"
VAL_PATH = "../CamVid/val/*.png"

train_dataset = CamVidDataset(SHAPE, IMAGE_PATH)
trainloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)

val_dataset = CamVidDataset(SHAPE, VAL_PATH)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=2)

In [None]:
img = cv2.imread("../CamVid/train/0001TP_009240.png")
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
plt.imshow(img)

In [None]:
plt.imshow(np.moveaxis(train_dataset[0][0].detach().numpy(), 0, -1))
plt.show()

In [None]:
# config
torch.cuda.empty_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using dev {device}")
model = VAE(latent=25).to(device)
encoder_loss = nn.KLDivLoss()
decoder_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
print(summary(model, (3, 66, 200)))

# init wandb
wandb.init(project="deepdriving-autoencoder")
config = wandb.config

# hyperparamters
config.batch_size = batch_size
config.epochs = 30
config.architecture = "vae"
config.activation = "relu"
config.input_size = SHAPE

In [None]:
# training + logging loop
wandb.watch(model, log="all")

for epoch in range(config.epochs): 
    with tqdm(trainloader, unit="batch") as tepoch:

        for data, target in tepoch:
            tepoch.set_description(f"Epoch {epoch+1}")
            
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            # encode -> kl loss -> decode -> mse loss
            x_hat, z, p, q = model(data)
            x_hat, z, p, q = x_hat.to(device), z.to(device), p.to(device), q.to(device)
            loss, kl_loss, mse_loss = loss_fn(x_hat, target, p, q)
            loss.backward()
            optimizer.step()

            # print statistics
            tepoch.set_postfix(kl_loss=kl_loss, mse_loss=mse_loss)

        # run inference

        # log
        wandb.log({
            "epoch": epoch+1,
            "kl_loss": kl_loss,
            "mse_loss": mse_loss
        })

wandb.finish()
print('Finished Training')