In [1]:
import os
import numpy as np
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

from tqdm.notebook import tqdm

import time

import matplotlib.pyplot as plt

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device", device)

Using device mps


In [3]:
batch_size = 64  # 256
epochs = 15
fig_folder = "results"
backup_folder = "backup"

for f in fig_folder, backup_folder:
    os.makedirs(f, exist_ok=True)

In [4]:
train_loader = DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=False)

In [5]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
example_data.shape

torch.Size([64, 1, 28, 28])

In [6]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc1 = nn.Conv2d(1, 32, 3, padding="same")
        self.enc2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.enc3 = nn.Conv2d(64, 64, 3, padding="same")
        self.enc4 = nn.Conv2d(64, 64, 3, padding="same")
        self.enc5 = nn.Linear(12544, 32)

        self.dec1 = nn.Linear(2, 12544)
        self.dec2 = nn.ConvTranspose2d(64, 32, 3, stride=(2, 2), padding=1, output_padding=1)
        self.dec3 = nn.Conv2d(32, 1, 3, padding="same")

        self.mu = nn.Linear(32, 2)
        self.logvar = nn.Linear(32, 2)

    def encode(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc4(x))
        x = x.flatten(start_dim=1)
        x = F.relu(self.enc5(x))
        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):

        z = F.relu(self.dec1(z))
        z = z.view(-1, 64, 14, 14)
        z = F.relu(self.dec2(z))
        x = torch.sigmoid(self.dec3(z))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [7]:
# Reconstruction + KL divergence losses summed over all elements and batch
def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [8]:
def train(model, loss_function, epoch, log_interval=100):
    model.train()
    train_loss = 0
    t = time.time()
    for batch_idx, (data, _) in tqdm(enumerate(train_loader), position=1, total=len(train_loader), desc=f"Epoch {epoch}"):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t {:.2f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader),
        #         loss.item() / len(data),
        #         time.time() - t))
        #     t = time.time()

    print('====> Epoch: {}\tAverage loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


# def test(model, loss_function, epoch, batch_size, fig_folder):
#     model.eval()
#     test_loss = 0
#     with torch.no_grad():
#         for i, (data, _) in enumerate(test_loader):
#             data = data.to(device)
#             recon_batch, mu, logvar = model(data)
#             test_loss += loss_function(recon_batch, data, mu, logvar).item()
#             if i == 0:
#                 n = min(data.size(0), 8)
#                 comparison = torch.cat([data[:n],
#                                       recon_batch.view(batch_size, 1, 28, 28)[:n]])
#                 save_image(comparison.cpu(),
#                          f'{fig_folder}/reconstruction_{epoch}.png', nrow=n)

#     test_loss /= len(test_loader.dataset)
#     print('====> Test set loss: {:.4f}'.format(test_loss))


In [None]:
vae_model = VAE().to(device)

optimizer = optim.Adam(vae_model.parameters(), lr=1e-3)

for epoch in tqdm(range(1, epochs + 1)):
    train(model=vae_model, loss_function=vae_loss, epoch=epoch)
    # test(model=vae_model, loss_function=vae_loss, epoch=epoch, batch_size=batch_size,
    #      fig_folder=fig_folder)
    # with torch.no_grad():
    #     sample = torch.randn(64, 2).to(device)
    #     sample = vae_model.decode(sample).cpu()
    #     save_image(sample.view(64, 1, 28, 28),
    #                'results/sample_' + str(epoch) + '.png')

backup_file_vae = f"{backup_folder}/vae_model_no_predictor.p"
torch.save(vae_model.state_dict(), backup_file_vae)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s]

====> Epoch: 1	Average loss: 217.5319


Epoch 2:   0%|          | 0/938 [00:00<?, ?it/s]

In [None]:
class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.pred1 = nn.Linear(2, 128)
        self.pred2 = nn.Linear(128, 128)
        self.pred3 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.pred1(x))
        x = F.relu(self.pred2(x))
        x = self.pred3(x)
        return x

In [None]:
def predictor_loss(y, logits):
    loss = nn.CrossEntropyLoss()
    output = loss(logits, y)
    return output

In [None]:
def combined_loss(batch, models, weight_pred_loss=20.):
    vae_model, predictor_model = models
    x, y = batch
    recon_x, mu, logvar = vae_model(x)
    logits = predictor_model(mu)
    vae_l = vae_loss(recon_x=recon_x, x=x, mu=mu, logvar=logvar)
    pred_l = predictor_loss(y=y, logits=logits)
    return vae_l + weight_pred_loss*pred_l

In [None]:
def train_combined(models, epoch, log_interval=100):
    for model in models:
      model.train()
    train_loss = 0
    t = time.time()
    for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader), position=1):
        batch = [b.to(device) for b in batch]
        optimizer.zero_grad()
        loss = combined_loss(batch, models, )
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t {:.2f}'.format(
        #         epoch, batch_idx * len(batch[0]), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader),
        #         loss.item() / len(batch[0]),
        #         time.time() - t))
        #     t = time.time()

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [None]:
vae_model = VAE().to(device)
pred_model = Predictor().to(device)

optimizer = optim.Adam(vae_model.parameters(), lr=1e-3)

for epoch in tqdm(range(1, epochs + 1)):
    train_combined(models=(vae_model, pred_model), epoch=epoch)
    with torch.no_grad():
        sample = torch.randn(64, 2).to(device)
        sample = vae_model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   'results/sample_' + str(epoch) + '.png')

backup_file_vae = f"{backup_folder}/vae_model_predictor.p"
torch.save(vae_model.state_dict(), backup_file_vae)

In [None]:
test_loader = DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=10000, shuffle=False)

x_test = None
y_test = None
for _, (x, y) in enumerate(test_loader):
    x_test = x
    y_test = y
    break

In [None]:
model_predictor_off = VAE()
model_predictor_off.load_state_dict(torch.load("backup/vae_model_no_predictor.p"))

model_predictor_on = VAE()
model_predictor_on.load_state_dict(torch.load("backup/vae_model_predictor.p"))

In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
models = {'VAE latent space without predictor': model_predictor_off, 
          'VAE latent space with predictor': model_predictor_on}

ts = []
for model in models.values():
    model.eval()
    with torch.no_grad():
        recon_batch, mu, logvar = model(x_test)
        ts.append(mu.numpy())

if not isinstance(y_test, np.ndarray):
    y_test = y_test.numpy()

titles = list(models.keys())

fig, axes = plt.subplots(ncols=len(ts), figsize=(5*len(ts), 4))
cmap = plt.get_cmap('viridis', 10)

for i, t in enumerate(ts):
    try:
        ax = axes[i]
    except TypeError:
        ax = axes

    im = ax.scatter(
        t[:, 0], t[:, 1], c=y_test,
        cmap=cmap,
        vmin=-0.5, vmax=9.5,
        marker='o', s=0.4)
    ax.set_xlim(-4, 4)
    ax.set_ylim(-4, 4)
    ax.set_title(titles[i])

fig.subplots_adjust(right=0.8)
plt.colorbar(im, fig.add_axes([0.82, 0.13, 0.02, 0.74]), ticks=range(10));