In [2]:
import sys 
sys.path.append('../')
import numpy as np
import torch
import torch.utils.data
from torch import optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from models import *
from loss_functions import *
import os 
os.chdir("/ContinuousBernoulliVAE/notebooks")

IMAGE_PATH = "../images/cb_mean/"
MODEL_PATH = "../trained_models/"


DIM = 20
EPOCHS = 100
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=False)


model = CBVAE_Mean().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
    model.train()
    train_loss = 0
    train_loss_vals = []
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        train_loss_vals.append(sum([temp.item() for temp in loss]))
        loss = sum(loss)
        loss.backward()
        train_loss += loss.item()

        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    train_loss /= len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss))
    return np.array(train_loss_vals) / len(train_loader.dataset) * len(train_loader)


def test(epoch):
    model.eval()
    test_loss = 0
    test_loss_vals = []
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = vae_loss(recon_batch, data, mu, logvar)
            test_loss_vals.append(sum([temp.item() for temp in loss]))
            loss = sum(loss)
            test_loss += loss

            if i == 0:
                n = min(data.size(0), 8)
                recon_batch = recon_batch.view(128, 1, 28, 28)
                comparison = torch.cat([data[:n],
                                        recon_batch[:n]])

                save_image(comparison.cpu(),
                            f'{IMAGE_PATH}/reconstruction_' + str(epoch) + '.png', nrow=n)
                # plt.figure(figsize=(10, 4))
                # for i in range(1, 2*n+1):
                #     ax = plt.subplot(2, n, i)
                #     plt.imshow(comparison.cpu().detach().numpy()
                #                [i-1, 0, :, :], cmap="gray")
                #     ax.get_xaxis().set_visible(False)
                #     ax.get_yaxis().set_visible(False)
                #     ax.margins(0, 0)
                # plt.savefig('cbresults/reconstruction_' + str(epoch) + '.png')
                # plt.close()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return np.array(test_loss_vals) / len(test_loader.dataset) * len(test_loader)


train_loss_vals_total = np.array([])
test_loss_vals_total = np.array([])
for epoch in range(1, EPOCHS + 1):
    train_loss = train(epoch)
    train_loss_vals_total = np.append(train_loss_vals_total, train_loss)
    test_loss = test(epoch)
    test_loss_vals_total = np.append(test_loss_vals_total, test_loss)
    with torch.no_grad():
        sample = torch.randn(64, DIM).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   f'{IMAGE_PATH}/sample_' + str(epoch) + '.png')
    

torch.save(model, f'{MODEL_PATH}/cbvae_mean.pt')
# np.save("tmp/cbvae_train_loss_vals_total.npy", train_loss_vals_total)
# np.save("tmp/cbvae_test_loss_vals_total.npy", test_loss_vals_total)


====> Epoch: 1 Average loss: 918.4989
====> Test set loss: 906.3009
====> Epoch: 2 Average loss: 901.0791
====> Test set loss: 897.2676
====> Epoch: 3 Average loss: 895.9760
====> Test set loss: 894.4448
====> Epoch: 4 Average loss: 893.9942
====> Test set loss: 892.9327
====> Epoch: 5 Average loss: 892.7968
====> Test set loss: 892.0292
====> Epoch: 6 Average loss: 891.9752
====> Test set loss: 891.3128
====> Epoch: 7 Average loss: 891.3824
====> Test set loss: 890.7731
====> Epoch: 8 Average loss: 890.9635
====> Test set loss: 890.6860
====> Epoch: 9 Average loss: 890.5882
====> Test set loss: 890.1424
====> Epoch: 10 Average loss: 890.2858
====> Test set loss: 889.8862
====> Epoch: 11 Average loss: 890.0462
====> Test set loss: 889.7014
====> Epoch: 12 Average loss: 889.8537
====> Test set loss: 889.6605
====> Epoch: 13 Average loss: 889.6803
====> Test set loss: 889.4214
====> Epoch: 14 Average loss: 889.5145
====> Test set loss: 889.4498
====> Epoch: 15 Average loss: 889.3757
====