# Implementing Vanilla VAE on FashionMNIST dataset

In [22]:
import os, sys

import numpy as np

from tqdm import tqdm
import torch
from torch.nn import functional as F
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, TensorDataset

from models import vanilla_vae

In [23]:
device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)

Using device: cuda:0


In [24]:
# Download and prepare dataset
train_ratio = 0.9
batch_size = 100
dataset = FashionMNIST('~/datasets', download=True)

train_size = int(len(dataset.data)*train_ratio)
print('train_size =', train_size)
train_dataset = TensorDataset(dataset.data[:train_size], dataset.data[:train_size])
val_dataset = TensorDataset(dataset.data[train_size:], dataset.data[train_size:])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_size = 54000


In [25]:
# initialize VAE
img_size = 28
latent_size = 10

enc = vanilla_vae.Encoder(input_dim=img_size*img_size, hidden_dim=img_size, z_dim=latent_size).to(device)
dec = vanilla_vae.Decoder(z_dim=latent_size, hidden_dim=img_size, output_dim=img_size*img_size).to(device)
vae = vanilla_vae.VAE(enc, dec).to(device)

optim = torch.optim.Adam(vae.parameters(), lr=0.003)

In [26]:
# training loop
epochs = 50
pbar = tqdm(range(epochs), colour='green')
for epoch in pbar:
    for batch_x, _ in train_dataloader:
        # to device
        batch_x_dev = batch_x.to(device)
        # conversion
        batch_x_dev = batch_x_dev.float() / 255
        # reshape
        batch_x_dev = batch_x_dev.flatten(start_dim=1)
        # forward pass
        predicted, z_mu, z_var = vae(batch_x_dev)
        # reshape
        predicted = predicted.reshape((batch_size, img_size, img_size))
        # reconstruction loss
        rec_loss = F.binary_cross_entropy(predicted, batch_x_dev, size_average=False)
        # KL divergence loss
        kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
        # total loss
        loss = rec_loss + kl_loss
        
        # update parameters
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    

  0%|[32m                                                                                                     [0m| 0/50 [00:00<?, ?it/s][0m


TypeError: linear(): argument 'input' (position 1) must be Tensor, not ReLU

In [None]:
assert False

In [None]:
# training loop
best_test_loss = float('inf')
patience_counter=0
for epochs in range(0,50):
    train_loss = 0
    test_loss  = 0
    print('Started accessing trainloader')
    for img, 
    
    for img, lab in trainloader:
        img=img.view(-1,10000)
        img=F.softmax(img)
        #img=img.to(device)
        optimizer.zero_grad()
        # forward pass
        img_sample, z_mu, z_var = model(img)
#         print(f'before softmax : {img_sample}')
#         img_sample=F.softmax(img_sample)
#         print(f'after softmax : {img_sample}')
        recon_loss = F.binary_cross_entropy(img_sample, img, size_average=False)
        # kl divergence loss
        kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
        # total loss
        loss = recon_loss + kl_loss
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
else:
    with torch.no_grad():
      print('accessing test loader')
      for img,lab in testloader:
        # reshape the data
        img = img.view(-1, 10000)
        img=F.softmax(img)
       # img=img.to(device)
        # forward pass
        img_sample, z_mu, z_var = model(img)
        #img_sample=F.softmax(img_sample)
        # reconstruction loss
        recon_loss = F.binary_cross_entropy(img_sample, img, size_average=False)
        # kl divergence loss
        kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
        # total loss
        loss = recon_loss + kl_loss
        test_loss += loss.item()

  train_loss /= len(trainloader)
  test_loss /= len(testloader)
  print(f'Epoch {epochs}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f}')
  if best_test_loss > test_loss:
    best_test_loss = test_loss
    patience_counter = 1
  else:
    patience_counter += 1
  if patience_counter > 3:
    break