In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T
from torchvision.utils import make_grid, save_image

import time
from PIL import Image

from tqdm import tqdm
from matplotlib import pyplot as plt
%matplotlib inline

is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


In [2]:
train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=T.Compose(
        [T.Resize(32), T.ToTensor()]
    )
)

In [3]:
to_pil_image = T.ToPILImage()

# VAE

In [80]:
class VAE(nn.Module):
    def __init__(self, latent_size=128):
        super(VAE, self).__init__()
        
        # encoder
        self.e_conv1 = nn.Conv2d(1, 6, 5)
        self.e_pool = nn.MaxPool2d(2, 2)
        self.e_conv2 = nn.Conv2d(6, 16, 5)
        self.e_fc = nn.Linear(16 * 5 * 5, 300)
        self.e_fc_mu = nn.Linear(300, 128)
        self.e_fc_log_var = nn.Linear(300, 128)

        # decoder
        self.d_conv1 = nn.ConvTranspose2d(latent_size, 64, kernel_size=4, stride=1, padding=0, bias=False)
        self.d_norm1 = nn.BatchNorm2d(64)

        self.d_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False)
        self.d_norm2 = nn.BatchNorm2d(32)

        self.d_conv3 = nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False)
        self.d_norm3 = nn.BatchNorm2d(16)

        self.d_conv4 = nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1, bias=False)
        
        self.kl = 0
        
    def encoder(self, x):
        x = self.e_pool(F.relu(self.e_conv1(x)))
        x = self.e_pool(F.relu(self.e_conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.e_fc(x))
        return self.e_fc_mu(x), self.e_fc_log_var(x)
    

    def decoder(self, z):
        z = F.relu(self.d_norm1(self.d_conv1(z)))
        z = F.relu(self.d_norm2(self.d_conv2(z)))
        z = F.relu(self.d_norm3(self.d_conv3(z)))
        z = F.sigmoid(self.d_conv4(z))
        return z
    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        
        # Reparameterize
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + (eps * std)
                
        # Add dimensions for conv
        z = z.view(z.shape[0], z.shape[1], 1, 1)

        return self.decoder(z), mu, log_var

In [81]:

def train(net, train_data, batch_size=10, learning_rate=0.0001, epochs=10):
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    criterion = nn.BCELoss(reduction='sum')
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    running_loss = []

    for epoch in range(epochs):
        print("Epoch:" , epoch+1)

        for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
            inputs, _ = data
            inputs = inputs.to(device)
            
            # Forward
            outputs, mu, log_var = net(inputs)
            
            # Backward
            kld = -0.5 * (1 + log_var - mu **2 - log_var.exp()).sum() # Kullback–Leibler divergence
            loss = criterion(outputs, inputs) + kld
            loss.backward()
            optimizer.step()
            
            running_loss.append(loss.item())
        
        latent_inputs = torch.randn(64, 128, 1, 1, device=device)
        generated_img = net.decoder(latent_inputs)
        generated_img = make_grid(generated_img)

        #SAVE IMAGE
        im = Image.fromarray(np.array(to_pil_image(generated_img)))
        im.save(f"Images_output/vae_cnn/epoch_{epoch}.jpeg")

        print(f'Loss: {np.mean(running_loss[-len(train_data):])}')

    return loss




In [82]:
net = VAE()
net.to(device)

VAE(
  (e_conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (e_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (e_conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (e_fc): Linear(in_features=400, out_features=300, bias=True)
  (e_fc_mu): Linear(in_features=300, out_features=128, bias=True)
  (e_fc_log_var): Linear(in_features=300, out_features=128, bias=True)
  (d_conv1): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (d_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (d_conv2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (d_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (d_conv3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (d_norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (d_conv4): ConvTran

In [83]:
loss = train(net, train_data, epochs=20)

  0%|          | 3/6000 [00:00<03:24, 29.29it/s]

Epoch: 1


100%|██████████| 6000/6000 [01:10<00:00, 84.92it/s]
  0%|          | 9/6000 [00:00<01:11, 84.09it/s]

Loss: 2806.670363688151
Epoch: 2


100%|██████████| 6000/6000 [01:10<00:00, 85.68it/s]
  0%|          | 9/6000 [00:00<01:11, 83.44it/s]

Loss: 2668.88618862915
Epoch: 3


100%|██████████| 6000/6000 [01:13<00:00, 81.88it/s]
  0%|          | 9/6000 [00:00<01:11, 83.52it/s]

Loss: 2613.44677210829
Epoch: 4


100%|██████████| 6000/6000 [01:11<00:00, 84.12it/s]
  0%|          | 9/6000 [00:00<01:12, 82.91it/s]

Loss: 2578.476755101522
Epoch: 5


100%|██████████| 6000/6000 [01:11<00:00, 83.89it/s]
  0%|          | 9/6000 [00:00<01:13, 81.08it/s]

Loss: 2553.5606520996093
Epoch: 6


100%|██████████| 6000/6000 [01:11<00:00, 84.40it/s]
  0%|          | 9/6000 [00:00<01:11, 84.38it/s]

Loss: 2536.1216024780274
Epoch: 7


100%|██████████| 6000/6000 [01:13<00:00, 81.84it/s]
  0%|          | 8/6000 [00:00<01:14, 79.94it/s]

Loss: 2523.83304956636
Epoch: 8


100%|██████████| 6000/6000 [01:12<00:00, 82.32it/s]
  0%|          | 9/6000 [00:00<01:12, 82.99it/s]

Loss: 2513.952667681376
Epoch: 9


100%|██████████| 6000/6000 [01:10<00:00, 84.79it/s]
  0%|          | 9/6000 [00:00<01:10, 84.57it/s]

Loss: 2505.58878833686
Epoch: 10


100%|██████████| 6000/6000 [01:12<00:00, 82.30it/s]
  0%|          | 8/6000 [00:00<01:17, 76.93it/s]

Loss: 2499.1515712565106
Epoch: 11


100%|██████████| 6000/6000 [01:13<00:00, 82.08it/s]
  0%|          | 9/6000 [00:00<01:11, 83.60it/s]

Loss: 2463.090596417236
Epoch: 12


100%|██████████| 6000/6000 [01:11<00:00, 84.14it/s]
  0%|          | 7/6000 [00:00<01:29, 67.02it/s]

Loss: 2454.7701266092936
Epoch: 13


100%|██████████| 6000/6000 [01:12<00:00, 82.75it/s]
  0%|          | 8/6000 [00:00<01:15, 79.29it/s]

Loss: 2450.0829407084148
Epoch: 14


100%|██████████| 6000/6000 [01:11<00:00, 83.79it/s]
  0%|          | 9/6000 [00:00<01:12, 82.20it/s]

Loss: 2449.2976134195965
Epoch: 15


100%|██████████| 6000/6000 [01:11<00:00, 84.15it/s]
  0%|          | 9/6000 [00:00<01:11, 83.87it/s]

Loss: 2451.382392014567
Epoch: 16


100%|██████████| 6000/6000 [01:10<00:00, 85.01it/s]
  0%|          | 9/6000 [00:00<01:11, 84.18it/s]

Loss: 2454.6033083251955
Epoch: 17


100%|██████████| 6000/6000 [01:10<00:00, 85.30it/s]
  0%|          | 9/6000 [00:00<01:12, 82.41it/s]

Loss: 2460.323127829997
Epoch: 18


100%|██████████| 6000/6000 [01:11<00:00, 83.46it/s]
  0%|          | 8/6000 [00:00<01:22, 72.62it/s]

Loss: 2465.7297141866047
Epoch: 19


100%|██████████| 6000/6000 [01:10<00:00, 84.68it/s]
  0%|          | 9/6000 [00:00<01:11, 83.59it/s]

Loss: 2469.6023672241213
Epoch: 20


100%|██████████| 6000/6000 [01:10<00:00, 85.09it/s]

Loss: 2472.3979351908365



