In [1]:
## VAE digits

import torch.nn as nn

import numpy as np
import json
import os
import pickle

import torch
import torch.optim as optim

import collections

from torchsummary import summary as torch_network_summary

import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.stats import norm

from tqdm import tqdm_notebook as tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Unzip Image Files

In [3]:
!pip install py7zlib

[31mERROR: Could not find a version that satisfies the requirement py7zlib (from versions: none)[0m
[31mERROR: No matching distribution found for py7zlib[0m


In [2]:
import subprocess
import py7zlib 

ModuleNotFoundError: No module named 'py7zlib'

In [None]:


archiveman = r'c:\Program Files\7-zip\7z' # 7z.exe comes with 7-zip
archivepath = r'C:\Path\to\archive.7z'

with open(archivepath,'rb') as f:
    archive = py7zlib.Archive7z(f)
    names = archive.filenames
    for name in names:
        _ = subprocess.check_output([archiveman, 'e', archivepath, '-o{}'.format(r'C:\Destination\of\copy'), name])

In [None]:


class Encoder_part(nn.Module):
    def __init__(self
        , input_dim
        , encoder_conv_filters
        , encoder_conv_kernel_size
        , encoder_conv_strides
        , z_dim
        , use_batch_norm = False
        , use_dropout= False
        , use_VAE = False
        ):
        super(Encoder_part, self).__init__()

        self.name = 'variational_autoencoder'

        self.input_dim = input_dim
        self.input_filter = self.input_dim[0]
        self.encoder_conv_filters_input = [self.input_filter] + encoder_conv_filters[:-1]
        self.encoder_conv_filters_output = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.z_dim = z_dim
        
        self.use_VAE = use_VAE

        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout

        self.n_layers_encoder = len(encoder_conv_filters)
        
        self.conv_layers = nn.Sequential()
        self.output_shape = collections.OrderedDict()
        
        self.output_shape['Input'] = self.input_dim
        current_layer_name = 'Input'
        current_input_shape = self.output_shape[current_layer_name]
        
        for i in range(self.n_layers_encoder):
            
            current_filter, input_H, input_W = current_input_shape
            current_layer_name = 'Layer {} Conv2d'.format(i)
            
            kernel = self.encoder_conv_kernel_size[i]
            stride = self.encoder_conv_strides[i]
            
            _output_H = int((input_H-kernel)/stride) + 1
            _output_W = int((input_W-kernel)/stride) + 1
            
            if int(input_H/stride) != _output_H:
                padding = max(0, int((kernel-(_output_H%stride))/2))
            else:
                padding = 0
            
            output_H = int((input_H-kernel+2*padding)/stride) + 1
            output_W = int((input_W-kernel+2*padding)/stride) + 1
            
            self.output_shape[current_layer_name] = [self.encoder_conv_filters_output[i], output_H, output_W]
            current_input_shape = self.output_shape[current_layer_name]
            
            self.conv_layers.add_module(current_layer_name,
                            nn.Conv2d(self.encoder_conv_filters_input[i]
                                    , self.encoder_conv_filters_output[i]
                                    , kernel_size = self.encoder_conv_kernel_size[i]
                                    , stride = self.encoder_conv_strides[i]
                                    , padding = padding
                                    ))
            if self.use_batch_norm:
                
                self.conv_layers.add_module('Layer {} BatchNorm2d'.format(i), 
                                            nn.BatchNorm2d(self.encoder_conv_filters_output[i]))

            self.conv_layers.add_module('Layer {} LeakyReLU'.format(i), nn.LeakyReLU())

            if self.use_dropout:
                self.conv_layers.add_module('Layer {} Dropout2d'.format(i), nn.Dropout2d(p=0.5))
        
        self.encoder_cnn_output_shape = current_input_shape
        self.Flattened_shape = np.prod(self.encoder_cnn_output_shape)
        self.Flattened = nn.Linear(self.Flattened_shape, self.z_dim)
        
        if self.use_VAE:
            self.mu = nn.Linear(self.Flattened_shape, self.z_dim)
            self.log_var = nn.Linear(self.Flattened_shape, self.z_dim)
            self.kl_loss = None
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.size()[0],-1)
        flattened = self.Flattened(x)
        
        if self.use_VAE:
            epsilon = torch.randn(flattened.size()[1:]).to(device)
            mu = self.mu(x)
            log_var = self.log_var(x)
            encoder_output = mu + torch.exp(log_var/2) * epsilon
            
            self.kl_loss =  -0.5 * (1 + log_var - mu**2 - torch.exp(log_var)).sum()
            
        else:
            encoder_output = flattened
            
        return encoder_output
        

class Decoder_part(nn.Module):
    def __init__(self
        , encoder_cnn_output_shape
        , decoder_conv_t_filters
        , decoder_conv_t_kernel_size
        , decoder_conv_t_strides
        , z_dim
        , use_batch_norm = False
        , use_dropout= False
        ):
        super(Decoder_part, self).__init__()
    
        self.encoder_cnn_output_shape = encoder_cnn_output_shape
        self.decoder_conv_t_filters_input = [encoder_cnn_output_shape[0]] + decoder_conv_t_filters[:-1]
        self.decoder_conv_t_filters_output = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.z_dim = z_dim

        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout

        self.n_layers_decoder = len(decoder_conv_t_filters)
        
        
        self.conv_layers = nn.Sequential()
        self.output_shape = collections.OrderedDict()
        
        self.output_shape['Input'] = self.encoder_cnn_output_shape
        current_layer_name = 'Input'
        current_input_shape = self.output_shape[current_layer_name]
        
        self.Reshaping = nn.Linear(self.z_dim, np.prod(self.encoder_cnn_output_shape))
        
        for i in range(self.n_layers_decoder):
            
            current_filter, input_H, input_W = current_input_shape
            current_layer_name = 'Layer {} Conv2d'.format(i)
            
            kernel = self.decoder_conv_t_kernel_size[i]
            stride = self.decoder_conv_t_strides[i]
            
            _output_H = (input_H-1)*stride + kernel
            _output_W = (input_W-1)*stride + kernel
            
            if int(input_H*stride) != _output_H:
                padding = (_output_H-input_H*stride)/2
                if padding%1 == 0.5:
                    padding = max(0, int(padding)) + 1
                    output_padding = 1
                else:
                    padding = int(padding)
                    output_padding = 0
            else:
                padding = 0
                output_padding = 0
            
            output_H = (input_H-1)*stride + kernel - 2*padding + output_padding
            output_W = (input_W-1)*stride + kernel - 2*padding + output_padding
            
            self.output_shape[current_layer_name] = [self.decoder_conv_t_filters_output[i], output_H, output_W]
            current_input_shape = self.output_shape[current_layer_name]
            
            self.conv_layers.add_module(current_layer_name,
                            nn.ConvTranspose2d(self.decoder_conv_t_filters_input[i]
                                    , self.decoder_conv_t_filters_output[i]
                                    , kernel_size = self.decoder_conv_t_kernel_size[i]
                                    , stride = self.decoder_conv_t_strides[i]
                                    , padding = padding
                                    , output_padding = output_padding
                                    ))
            if i < self.n_layers_decoder - 1:
                if self.use_batch_norm:
                    self.conv_layers.add_module('Layer {} BatchNorm2d'.format(i), 
                                                nn.BatchNorm2d(self.decoder_conv_t_filters_output[i]))

                self.conv_layers.add_module('Layer {} LeakyReLU'.format(i), nn.LeakyReLU())

                if self.use_dropout:
                    self.conv_layers.add_module('Layer {} Dropout2d'.format(i), nn.Dropout2d(p=0.5))
            else:
                self.conv_layers.add_module('Layer {} Sigmoid'.format(i), nn.Sigmoid())
        
    def forward(self, x):
        x = self.Reshaping(x)
        x = x.reshape([x.size()[0]]+self.encoder_cnn_output_shape)
        x = self.conv_layers(x)
        
        return x

class Autoenocoder_pt(nn.Module):
    def __init__(self,
                encoder, decoder):
        super(Autoenocoder_pt, self).__init__()
        self.encoder_part = encoder_part
        self.decoder_part = decoder_part
        
    def forward(self, x):
        x = self.encoder_part(x)
        x = self.decoder_part(x)
        
        return x

## Training

In [None]:
channels, H, W = 1, 28, 28

sample_input = torch.rand(1,channels, H, W).to(device)
input_dim = sample_input.cpu().numpy().shape[1:]

encoder_conv_filters = [32,64,64, 64]
encoder_conv_kernel_size = [3,3,3,3]
encoder_conv_strides = [1,2,2,1]
decoder_conv_t_filters = [64,64,32,1]
decoder_conv_t_kernel_size = [3,3,3,3]
decoder_conv_t_strides = [1,2,2,1]
z_dim = 2

use_VAE_mode = False

encoder_part = Encoder_part(
            input_dim = input_dim
        , encoder_conv_filters = encoder_conv_filters
        , encoder_conv_kernel_size = encoder_conv_kernel_size
        , encoder_conv_strides = encoder_conv_strides
        , z_dim = z_dim
#         , use_batch_norm = True
#         , use_dropout= True
        , use_VAE = use_VAE_mode
).to(device)

decoder_part = Decoder_part(
        encoder_cnn_output_shape = encoder_part.encoder_cnn_output_shape
        , decoder_conv_t_filters = decoder_conv_t_filters
        , decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        , decoder_conv_t_strides = decoder_conv_t_strides
        , z_dim = z_dim
#         , use_batch_norm = True
#         , use_dropout= True
).to(device)

AE = Autoenocoder_pt(encoder_part, decoder_part)

In [None]:
torch_network_summary(AE, encoder_part.input_dim)

In [None]:
from utils.loaders import load_mnist, load_model
from torch.utils.data import DataLoader

In [None]:
(_x_train, y_train), (_x_test, y_test) = load_mnist()
_x_train_t = np.transpose(_x_train, (0,3,1,2))
_x_test_t = np.transpose(_x_test, (0,3,1,2))

x_train = torch.from_numpy(_x_train_t).to(device)
x_test = torch.from_numpy(_x_test_t).to(device)

input_data = [[x_item, y_item] for x_item, y_item in zip(_x_train_t, y_train)]

LEARNING_RATE = 0.0005
R_LOSS_FACTOR = 1000

optimizer = optim.Adam(AE.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss().to(device)

EPOCHS = 100

AE.train()

best_loss = 0
early_stopping_count = 0
total_loss_list = []
total_r_loss_list = []
total_kl_loss_list = []

for epoch in range(EPOCHS):
    total_loss = 0
    total_r_loss = 0
    total_kl_loss = 0
    for x_batch, y_batch in tqdm(DataLoader(input_data, 32, True)):
        
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        optimizer.zero_grad()
        
        decoder_output = AE(x_batch)
        
        r_loss = criterion(x_batch, decoder_output)
        total_r_loss += r_loss.item()
        if AE.encoder_part.use_VAE:
            kl_loss = AE.encoder_part.kl_loss
            loss = R_LOSS_FACTOR * r_loss + kl_loss
            
            total_kl_loss += kl_loss.item()
        else:
            loss = r_loss
        
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    
    if epoch == 0:
        best_loss = total_loss
    else:
        if total_loss < best_loss:
            best_loss = total_loss
            early_stopping_count += 1
    total_loss_list.append(total_loss)
    total_r_loss_list.append(total_r_loss)
    
    if AE.encoder_part.use_VAE:
        total_kl_loss_list.append(total_kl_loss)
        
        info2print = '{}  {}  {}  {}'.format(epoch, total_loss, total_r_loss, total_kl_loss)
    else:
        info2print = '{}  {}  {}'.format(epoch, total_loss, total_r_loss)
    
    print(info2print)
    if early_stopping_count>=20:
        break