In [1]:
import sys 
sys.path.append('../')
import numpy as np
import torch
import torch.utils.data
from torch import optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from models import * 
from loss_functions import *
import os 
os.chdir("/ContinuousBernoulliVAE/notebooks")

IMAGE_PATH = "../images/b2/"
MODEL_PATH = "../trained_models/"
DIM = 2
EPOCHS = 100
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=False)

model = VAE2().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(epoch):
    model.train()
    train_loss = 0
    train_loss_vals = []
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        train_loss_vals.append(sum([temp.item() for temp in loss]))
        loss = sum(loss[:1])
        loss.backward()
        train_loss += loss.item()

        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
    train_loss /= len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss))
    return np.array(train_loss_vals) / len(train_loader.dataset) * len(train_loader)


def test(epoch):
    model.eval()
    test_loss = 0
    test_loss_vals = []
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)

            loss = vae_loss(recon_batch, data, mu, logvar)
            test_loss_vals.append(sum([temp.item() for temp in loss]))
            loss = sum(loss[:1])
            test_loss += loss

            if i == 0:
                n = min(data.size(0), 8)
                recon_batch = recon_batch.view(128, 1, 28, 28)
                comparison = torch.cat([data[:n],
                                        recon_batch[:n]])

                save_image(comparison.cpu(),
                           f'{IMAGE_PATH}/reconstruction_' + str(epoch) + '.png', nrow=n)
                # plt.figure(figsize=(10, 4))
                # for i in range(1, 2*n+1):
                #     ax = plt.subplot(2, n, i)
                #     plt.imshow(comparison.cpu().detach().numpy()
                #                [i-1, 0, :, :], cmap="gray")
                #     ax.get_xaxis().set_visible(False)
                #     ax.get_yaxis().set_visible(False)
                #     ax.margins(0, 0)
                # plt.savefig('results/reconstruction_' + str(epoch) + '.png')
                # plt.close()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return np.array(test_loss_vals) / len(test_loader.dataset) * len(test_loader)


train_loss_vals_total = np.array([])
test_loss_vals_total = np.array([])
for epoch in range(1, EPOCHS + 1):
    train_loss = train(epoch)
    train_loss_vals_total = np.append(train_loss_vals_total, train_loss)
    test_loss = test(epoch)
    test_loss_vals_total = np.append(test_loss_vals_total, test_loss)
    with torch.no_grad():
        sample = torch.randn(64, DIM).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   f'{IMAGE_PATH}/sample_' + str(epoch) + '.png')

torch.save(model, f'{MODEL_PATH}/vae2.pt')



====> Epoch: 1 Average loss: 177.7060
====> Test set loss: 154.7232
====> Epoch: 2 Average loss: 150.9735
====> Test set loss: 148.1803
====> Epoch: 3 Average loss: 146.5368
====> Test set loss: 145.1202
====> Epoch: 4 Average loss: 143.9201
====> Test set loss: 143.2080
====> Epoch: 5 Average loss: 142.1817
====> Test set loss: 142.5774
====> Epoch: 6 Average loss: 141.0830
====> Test set loss: 141.1926
====> Epoch: 7 Average loss: 139.9673
====> Test set loss: 140.3855
====> Epoch: 8 Average loss: 139.1946
====> Test set loss: 139.8720
====> Epoch: 9 Average loss: 138.4176
====> Test set loss: 138.5857
====> Epoch: 10 Average loss: 137.8754
====> Test set loss: 138.5185
====> Epoch: 11 Average loss: 137.3226
====> Test set loss: 138.4481
====> Epoch: 12 Average loss: 136.8011
====> Test set loss: 138.2589
====> Epoch: 13 Average loss: 136.4349
====> Test set loss: 137.2948
====> Epoch: 14 Average loss: 136.0271
====> Test set loss: 136.7928
====> Epoch: 15 Average loss: 135.5409
====