In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
import time
import os

from layers import *

In [2]:
# Model params
image_size=28
kernel_size=5
num_kernels=32
patch_size=4
maxpooling_size=[2] 
maxpooling_stride=[2]
maxpooling_pad='same'
num_layers=7
num_classes=10 
embed_dim=64
num_heads=4 
mlp_dim=64
channels=1
drop_prob=0.3
batch_size=50 
epochs=2
lr=0.001
lr_end=0.00001
kl_factor=0.001
Training=True 
continue_training=False
saved_model_epochs=800

In [3]:
def main_function(image_size=28, kernel_size=5, num_kernels=32, patch_size=4, maxpooling_size=[2], 
                  maxpooling_stride=[2], maxpooling_pad='same', num_layers=7, num_classes=10, 
                  embed_dim=64, num_heads=4, mlp_dim=64, channels=1, drop_prob=0.3, batch_size=50, 
                  epochs=2, lr=0.001, lr_end=0.00001, kl_factor=0.001, Training=True, 
                  continue_training=False, saved_model_epochs=800):
    
    PATH = f'./saved_models/VDP_cnn_epoch_{epochs}_kl_{kl_factor}_lr_{lr}/'
    
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = DensityPropCNN(kernel_size=kernel_size, num_kernel=num_kernels, 
                                  pooling_size=maxpooling_size, pooling_stride=maxpooling_stride, 
                                  pooling_pad=maxpooling_pad, units=num_classes)
    if continue_training:
        saved_model_path = f'./saved_models_new/VDP_cnn_epoch_{saved_model_epochs}_kl_{kl_factor}_lr_latest/'
        model.load_state_dict(torch.load(os.path.join(saved_model_path, 'vdp_trans_model.pth')))

    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # num_train_steps = epochs * int(len(train_dataset) / batch_size)
    # def polynomial_decay(step):
    #     step = min(step, num_train_steps)
    #     return (lr - lr_end) * (1 - step / num_train_steps) ** 0.5 + lr_end
    #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=polynomial_decay)


    # -------------------------  Training loop -------------------------

    train_acc = np.zeros(epochs)
    valid_acc = np.zeros(epochs)
    train_err = np.zeros(epochs)
    valid_error = np.zeros(epochs)
    start = time.time()

    for epoch in range(epochs):
        print(f'Epoch: {epoch+1}/{epochs}')
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        ll = []
        # --------------------  Training phase  --------------------
        for batch_idx, (x, y) in enumerate(train_loader):
            optimizer.zero_grad()
            mu_out, sigma, kl = model(x)
            loss_final = nll_gaussian(y, mu_out, torch.clamp(sigma, min=1e-10, max=1e+6)) * 0.001
            loss = 0.5 * (loss_final + kl_factor * kl)

            loss.backward()
            optimizer.step()
            #scheduler.step()
            total_loss += loss.item()
            _, predicted = torch.max(mu_out, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
            grads = [param.grad for _, param in model.named_parameters()]
            if batch_idx % 50 == 0:
                print(f'Step: {batch_idx}, Loss: {loss_final.item()}, kl {kl.item()}, Training accuracy so far: {correct / total:.3f}, sigma norm {np.mean(sigma.detach().numpy())}')
            ll.append(loss.item())
        train_acc[epoch] = correct / total
        train_err[epoch] = total_loss / len(train_loader)
        print(f'Training Acc: {train_acc[epoch]}, Training error: {train_err[epoch]}')

        # --------------------  Validation phase  --------------------
        model.eval()
        correct_valid = 0
        total_valid = 0
        val_loss = 0
        
        for batch_idx, (x, y) in enumerate(test_loader):
            mu_out, sigma, kl = model(x)
            total_valid += y.size(0)
            loss_v = nll_gaussian(y, mu_out, sigma) * 0.001
            val_loss += loss_v.item()
            _, predicted = torch.max(mu_out, 1)
            correct_valid += (predicted == y).sum().item()

        valid_acc[epoch] = correct_valid / total_valid
        valid_error[epoch] = val_loss / len(test_loader)
        stop = time.time()
        os.makedirs(PATH, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(PATH, 'vdp_trans_model.pth'))

        print(f'Total Training Time: {stop - start:.2f}s')
        print(f'Validation Acc: {valid_acc[epoch]}, Validation error: {valid_error[epoch]}')


In [4]:
main_function()

Epoch: 1/2
Step: 0, Loss: 583.7427368164062, kl 7.222908020019531, Training accuracy so far: 0.200, sigma norm 2.406434873591934e-07
Step: 50, Loss: 16.555526733398438, kl 7.166213035583496, Training accuracy so far: 0.683, sigma norm 1.3451494851324242e-07
Step: 100, Loss: 33.47002410888672, kl 7.154552459716797, Training accuracy so far: 0.754, sigma norm 1.5209624848466774e-07
Step: 150, Loss: 23.23363494873047, kl 7.147308349609375, Training accuracy so far: 0.790, sigma norm 1.1624350548800066e-07
Step: 200, Loss: 27.083805084228516, kl 7.13693380355835, Training accuracy so far: 0.806, sigma norm 9.491463259791999e-08
Step: 250, Loss: 38.787254333496094, kl 7.12849235534668, Training accuracy so far: 0.815, sigma norm 1.0768489744350518e-07
Step: 300, Loss: 23.597301483154297, kl 7.114316940307617, Training accuracy so far: 0.819, sigma norm 9.5577327385854e-08
Step: 350, Loss: 35.03846740722656, kl 7.102791786193848, Training accuracy so far: 0.826, sigma norm 1.1729179760777697