In [None]:
import einops as ein
from einops.layers.torch import Rearrange

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np

In [None]:
class CONV_VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super(CONV_VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc21 = nn.Linear(128*3*3, latent_dim)
        self.fc22 = nn.Linear(128*3*3, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128*3*3), 
            nn.ReLU(),
            Rearrange ("batch (a b c) -> batch a b c", a=128, b=3, c=3),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, output_padding = 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        x = self.encoder(x)
        return self.fc21(x), self.fc22(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

In [None]:
class VAE(nn.Module):
    def __init__(self, lantent_dim=2):
        super(VAE, self).__init__()
        self.fc1 = nn.Conv2d(1, 32, 4, 2, 1)
        self.fc1 = nn.Linear(784, 400)
        self.fc1a = nn.Linear(400, 100)
        self.fc21 = nn.Linear(100, lantent_dim) # Latent space of 2D
        self.fc22 = nn.Linear(100, lantent_dim) # Latent space of 2D
        self.fc3 = nn.Linear(lantent_dim, 100) # Latent space of 2D
        self.fc3a = nn.Linear(100, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc1a(h1))
        return self.fc21(h2), self.fc22(h2)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        h4 = F.relu(self.fc3a(h3))
        return torch.sigmoid(self.fc4(h4))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = CONV_VAE(2).to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

batch_size2 = 256
log_interval2 = 10
epochs2 = 10

#torch.manual_seed(1) # args.seed

kwargs = {'num_workers': 4, 'pin_memory': True} if device == "cuda" else {} # args.cuda

# Get train and train data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size2, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size2, shuffle=True, **kwargs)


In [None]:
train_losses = []
test_losses = []

def train(epoch):
    model.train() # so that everything has gradients and we can do backprop and so on...
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad() # "reset" gradients to 0 for text iteration
        recon_batch, mu, logvar = model(data)
        loss = model.loss(recon_batch, data, mu, logvar)
        loss.backward() # calc gradients
        train_loss += loss.item()
        optimizer.step() # backpropagation
    train_losses.append(train_loss / len(train_loader.dataset))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad(): # no_grad turns of gradients...
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('====> Test set loss: {:.4f}'.format(test_loss))


In [None]:
num_rows = 20
a = torch.linspace(-8., 8.,  num_rows)
x_t = a.repeat(num_rows)
x_t = x_t.view(num_rows,num_rows)
y_t = x_t.t().flip(0)
art_nums = torch.stack((x_t, y_t)).view(2,-1).t().to(device)
print (art_nums.size())

In [None]:
for epoch in range(1, epochs2 + 1):
    train(epoch)
    test(epoch)
    if model.latent_dim == 2:
        with torch.no_grad():
            sample = model.decode(art_nums).cpu()
            save_image(sample.view(-1, 1, 28, 28),
                       'results-lin-lin/sample_' + str(epoch) + '.png', nrow=num_rows)

In [None]:
def display_loss(latent_dims):
  global model
  global optimizer
  global train_losses
  global test_losses
  train_losses = []
  test_losses = []
  train_res = []
  test_res = []

  for l in latent_dims:
    model = VAE(l).to(device)
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    for epoch in range(1, 101):
        train(epoch)
        test(epoch)
    train_res.append(train_losses)
    test_res.append(test_losses)
    train_losses = []
    test_losses = []

  for i, train_loss in enumerate(train_res):
    f1 = open(f"{latent_dims[i]}test.txt", "w")
    f1.write(f"{test_res[i]}")
    f2 = open(f"{latent_dims[i]}train.txt", "w")
    f2.write(f"{train_loss}")

display_loss([2,4,8,16,32])

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def read_array(file_name):
    with open(file_name, 'r') as f:
        array = f.read()
    array = array.replace('[', '')
    array = array.replace(']', '')
    array = array.replace(' ', '')
    array = array.split(',')
    array = [float(i) for i in array]

    # Remove first element for better visualization
    array = array[1:]

    return array

folder1 = 'results-lin/'
folder2 = 'results-conv/'

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.plot(read_array(folder1 + '2test.txt'), label='linear 2')
ax1.plot(read_array(folder2 + '2test.txt'), label='conv 2')
ax1.plot(read_array(folder1 + '4test.txt'), label='linear 4')
ax1.plot(read_array(folder2 + '4test.txt'), label='conv 4')
ax1.set_title('Latent dimension 2 and 4')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Test loss')
ax1.legend()

ax2.plot(read_array(folder1 + '8test.txt'), label='linear 8')
ax2.plot(read_array(folder2 + '8test.txt'), label='conv 8')
ax2.plot(read_array(folder1 + '16test.txt'), label='linear 16')
ax2.plot(read_array(folder2 + '16test.txt'), label='conv 16')
ax2.set_title('Latent dimension 8 and 16')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Test loss')
ax2.legend()

plt.show()

In [None]:
model_conv = torch.load('models/conv_vae16.pth', map_location=device)
model_lin = torch.load('models/lin_vae16.pth', map_location=device)

# Visualize reconstructions of five images from the test set
with torch.no_grad():
    sample = next(iter(test_loader))[0].to(device)
    recon_lin, _, _ = model_lin(sample.view(-1, 784))
    recon_conv, _, _ = model_conv(sample)

    fig, axs = plt.subplots(3, 5, figsize=(10, 6))
    for i in range(5):
        axs[0, i].imshow(sample[i].cpu().view(28, 28), cmap='gray')
        axs[1, i].imshow(recon_lin[i].cpu().view(28, 28), cmap='gray')
        axs[2, i].imshow(recon_conv[i].cpu().view(28, 28), cmap='gray')

    axs[0, 0].set_ylabel('Original', fontsize=16)
    axs[1, 0].set_ylabel('Linear', fontsize=16)
    axs[2, 0].set_ylabel('Conv', fontsize=16)

    for ax in axs.flat:
        ax.set(xticks=[], yticks=[])

    plt.tight_layout()

    plt.show()


In [None]:
torch.save (model, "conv_vae.pth")