In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T
from torchvision.utils import make_grid, save_image

import time
from PIL import Image

from tqdm import tqdm
from matplotlib import pyplot as plt
%matplotlib inline

is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


In [2]:
train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=T.Compose(
        [T.Resize(32), T.ToTensor()]
    )
)

In [3]:
to_pil_image = T.ToPILImage()

# VAE

In [80]:
class VAE(nn.Module):
    def __init__(self, latent_size=128):
        super(VAE, self).__init__()
        
        # encoder
        self.e_conv1 = nn.Conv2d(1, 6, 5)
        self.e_pool = nn.MaxPool2d(2, 2)
        self.e_conv2 = nn.Conv2d(6, 16, 5)
        self.e_fc = nn.Linear(16 * 5 * 5, 300)
        self.e_fc_mu = nn.Linear(300, 128)
        self.e_fc_log_var = nn.Linear(300, 128)

        # decoder
        self.d_conv1 = nn.ConvTranspose2d(latent_size, 64, kernel_size=4, stride=1, padding=0, bias=False)
        self.d_norm1 = nn.BatchNorm2d(64)

        self.d_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False)
        self.d_norm2 = nn.BatchNorm2d(32)

        self.d_conv3 = nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False)
        self.d_norm3 = nn.BatchNorm2d(16)

        self.d_conv4 = nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1, bias=False)
        
        self.kl = 0
        
    def encoder(self, x):
        x = self.e_pool(F.relu(self.e_conv1(x)))
        x = self.e_pool(F.relu(self.e_conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.e_fc(x))
        return self.e_fc_mu(x), self.e_fc_log_var(x)
    

    def decoder(self, z):
        z = F.relu(self.d_norm1(self.d_conv1(z)))
        z = F.relu(self.d_norm2(self.d_conv2(z)))
        z = F.relu(self.d_norm3(self.d_conv3(z)))
        z = F.sigmoid(self.d_conv4(z))
        return z
    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        
        # Reparameterize
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + (eps * std)
                
        # Add dimensions for conv
        z = z.view(z.shape[0], z.shape[1], 1, 1)

        return self.decoder(z), mu, log_var

In [124]:

def train(net, train_data, batch_size=10, learning_rate=0.0001, epochs=10):
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    criterion = nn.BCELoss(reduction='sum')
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    running_loss = []
    example_inputs, _ = next(iter(train_loader))

    for epoch in range(epochs):
        print("Epoch:" , epoch+1)

        for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
            inputs, _ = data
            inputs = inputs.to(device)
            
            # Forward
            outputs, mu, log_var = net(inputs)
            
            # Backward
            kld = -0.5 * (1 + log_var - mu **2 - log_var.exp()).sum() # Kullback–Leibler divergence
            loss = criterion(outputs, inputs) + kld
            loss.backward()
            optimizer.step()
            
            running_loss.append(loss.item())
        
#         latent_inputs = torch.randn(64, 128, 1, 1, device=device)
#         generated_img = net.decoder(latent_inputs)
        generated_img, _, _ = net(example_inputs)
        generated_img = make_grid(generated_img)
    
        #SAVE IMAGE
        im = Image.fromarray(np.array(to_pil_image(generated_img)))
        im.save(f"Images_output/vae_cnn/epoch_{epoch}.jpeg")

        print(f'Loss: {np.mean(running_loss[-len(train_data):])}')

    return loss




In [125]:
net = VAE()
net.to(device)

VAE(
  (e_conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (e_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (e_conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (e_fc): Linear(in_features=400, out_features=300, bias=True)
  (e_fc_mu): Linear(in_features=300, out_features=128, bias=True)
  (e_fc_log_var): Linear(in_features=300, out_features=128, bias=True)
  (d_conv1): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (d_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (d_conv2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (d_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (d_conv3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (d_norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (d_conv4): ConvTran

In [126]:
loss = train(net, train_data, epochs=20)

  0%|          | 6/6000 [00:00<01:47, 55.57it/s]

Epoch: 1


100%|██████████| 6000/6000 [01:10<00:00, 85.31it/s]
  0%|          | 9/6000 [00:00<01:11, 83.68it/s]

Loss: 2828.880521952311
Epoch: 2


100%|██████████| 6000/6000 [01:09<00:00, 85.72it/s]
  0%|          | 9/6000 [00:00<01:11, 83.51it/s]

Loss: 2719.9200387268065
Epoch: 3


100%|██████████| 6000/6000 [01:09<00:00, 85.91it/s]
  0%|          | 9/6000 [00:00<01:11, 84.21it/s]

Loss: 2694.5236110772025
Epoch: 4


100%|██████████| 6000/6000 [01:09<00:00, 85.94it/s]
  0%|          | 9/6000 [00:00<01:11, 83.83it/s]

Loss: 2683.9204586079913
Epoch: 5


100%|██████████| 6000/6000 [01:09<00:00, 85.92it/s]
  0%|          | 9/6000 [00:00<01:11, 83.88it/s]

Loss: 2676.0614129313153
Epoch: 6


100%|██████████| 6000/6000 [01:09<00:00, 85.79it/s]
  0%|          | 9/6000 [00:00<01:11, 83.72it/s]

Loss: 2667.4917278984917
Epoch: 7


100%|██████████| 6000/6000 [01:09<00:00, 85.91it/s]
  0%|          | 9/6000 [00:00<01:10, 84.69it/s]

Loss: 2664.4672402111237
Epoch: 8


100%|██████████| 6000/6000 [01:09<00:00, 86.21it/s]
  0%|          | 9/6000 [00:00<01:11, 83.71it/s]

Loss: 2661.091461680094
Epoch: 9


100%|██████████| 6000/6000 [01:09<00:00, 86.17it/s]
  0%|          | 9/6000 [00:00<01:10, 85.40it/s]

Loss: 2659.23280617721
Epoch: 10


100%|██████████| 6000/6000 [01:10<00:00, 85.19it/s]
  0%|          | 8/6000 [00:00<01:22, 72.86it/s]

Loss: 2658.8563928853355
Epoch: 11


100%|██████████| 6000/6000 [01:10<00:00, 85.12it/s]
  0%|          | 9/6000 [00:00<01:11, 83.98it/s]

Loss: 2642.0232948506673
Epoch: 12


100%|██████████| 6000/6000 [01:10<00:00, 85.54it/s]
  0%|          | 9/6000 [00:00<01:11, 83.95it/s]

Loss: 2645.9984269348142
Epoch: 13


100%|██████████| 6000/6000 [01:09<00:00, 85.91it/s]
  0%|          | 9/6000 [00:00<01:10, 84.39it/s]

Loss: 2647.454697906494
Epoch: 14


100%|██████████| 6000/6000 [01:09<00:00, 85.97it/s]
  0%|          | 9/6000 [00:00<01:11, 83.58it/s]

Loss: 2646.296274975586
Epoch: 15


100%|██████████| 6000/6000 [01:09<00:00, 86.17it/s]
  0%|          | 9/6000 [00:00<01:10, 84.51it/s]

Loss: 2646.0597701293946
Epoch: 16


100%|██████████| 6000/6000 [01:09<00:00, 86.13it/s]
  0%|          | 9/6000 [00:00<01:10, 84.64it/s]

Loss: 2647.447543355306
Epoch: 17


100%|██████████| 6000/6000 [01:09<00:00, 86.10it/s]
  0%|          | 9/6000 [00:00<01:10, 84.55it/s]

Loss: 2645.7887810587563
Epoch: 18


100%|██████████| 6000/6000 [01:09<00:00, 86.18it/s]
  0%|          | 9/6000 [00:00<01:11, 84.31it/s]

Loss: 2645.2983165690102
Epoch: 19


100%|██████████| 6000/6000 [01:09<00:00, 86.06it/s]
  0%|          | 9/6000 [00:00<01:10, 84.79it/s]

Loss: 2644.0261201802573
Epoch: 20


100%|██████████| 6000/6000 [01:09<00:00, 86.79it/s]

Loss: 2641.198073492432



