## PyTorch CEA model

In [1]:
# general imports
import time
import types
import os
from shutil import copyfile, copy

# ML stuff
import numpy as np
from matplotlib.image import imread

# PyTorch stuff
import torch
import torch.nn as nn
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchsummary import summary

# custom utils
from logger import TBLogger
from extract_patches import *
from pytorch_utils import *

%load_ext autoreload
%autoreload 2

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
current_notebook_path = 'ConvAAE.ipynb'
debug_vals = types.SimpleNamespace()

In [4]:
device = get_current_device()
device

device(type='cuda')

In [5]:
TINY = 1e-15

patch_size = 64 # image size = 64 x 64 = 4096
batch_size = 256
emb_size = 3200
conv_emb_size = 50 * 8 * 8

## Model

In [6]:
# just a namespace
class ConvAAE():
    class Encoder(nn.Module):
        def __init__(self, image_channels=1):
            super(ConvAAE.Encoder, self).__init__()
            
            self.encoder = nn.Sequential(
                nn.Conv2d(image_channels, 16, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(),
                
                nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(),
                
                nn.MaxPool2d(2, stride=2),
            
                nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                
                nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                
                nn.MaxPool2d(2, stride=2),
                
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(),
                
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(),
                
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(),
                
                nn.MaxPool2d(2, stride=2),
                
                nn.Conv2d(64, 50, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(50),
                nn.LeakyReLU(),
                
                nn.Conv2d(50, 50, kernel_size=3, stride=1, padding=1),
#                 nn.BatchNorm2d(50),
#                 nn.LeakyReLU(),
                
#                 nn.Conv2d(50, 50, kernel_size=3, stride=1, padding=1),
#                 nn.BatchNorm2d(50),
#                 nn.LeakyReLU(),
                
#                 nn.MaxPool2d(2, stride=2),
            
                Flatten()
            )

        def forward(self, input_data):
            # generate embeddings for our images
            embeddings = self.encoder(input_data)

            return embeddings

        
    class Decoder(nn.Module):
        def __init__(self, image_channels=1):
            super(ConvAAE.Decoder, self).__init__()
            
            self.decoder = nn.Sequential(
                Unflatten(C=50, H=8, W=8),

                # upsampling + conv
                nn.ConvTranspose2d(50, 64, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(),
                
                nn.Conv2d(64, 64, 3, 1, 1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(),
                
                nn.Conv2d(64, 64, 3, 1, 1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(),

                # upsampling + conv
                nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                
                nn.Conv2d(32, 32, 3, 1, 1),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                
                nn.Conv2d(32, 32, 3, 1, 1),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),

                # upsampling + conv
                nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(),
                
                nn.Conv2d(16, 16, 3, 1, 1),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(),
                
#                 # upsampling + conv
#                 nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1),
#                 nn.BatchNorm2d(16),
#                 nn.LeakyReLU(),
                
#                 nn.Conv2d(16, 16, 3, 1, 1),
#                 nn.BatchNorm2d(16),
#                 nn.LeakyReLU(),

                nn.Conv2d(16, image_channels, 3, 1, 1),
            )


        def forward(self, input_embeddings):
            data = self.decoder(input_embeddings)
            
            return data

    class Discriminator(nn.Module):
        def __init__(self, image_channels=1):
            super(ConvAAE.Discriminator, self).__init__()

            self.discriminator = nn.Sequential(
                nn.Linear(emb_size, 30),
                nn.LeakyReLU(),
                
                nn.Linear(30, 10),
                nn.LeakyReLU(),

                # linear layer with 1 element
                nn.Linear(10, 1),

                nn.Sigmoid()
            )

        def forward(self, input_embeddings):
            # process embeddings and guess if it's real or fake
            discriminator_output = self.discriminator(input_embeddings)

            return discriminator_output

print('    ENCODER')
summary(ConvAAE.Encoder().to(device), (1, patch_size, patch_size))

print('\n\n    DECODER')
summary(ConvAAE.Decoder().to(device), (1, emb_size))

print('\n\n    DISCRIMINATOR')
summary(ConvAAE.Discriminator().to(device), (1, emb_size))

    ENCODER
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 64, 64]             160
       BatchNorm2d-2           [-1, 16, 64, 64]              32
         LeakyReLU-3           [-1, 16, 64, 64]               0
            Conv2d-4           [-1, 16, 64, 64]           2,320
       BatchNorm2d-5           [-1, 16, 64, 64]              32
         LeakyReLU-6           [-1, 16, 64, 64]               0
         MaxPool2d-7           [-1, 16, 32, 32]               0
            Conv2d-8           [-1, 32, 32, 32]           4,640
       BatchNorm2d-9           [-1, 32, 32, 32]              64
        LeakyReLU-10           [-1, 32, 32, 32]               0
           Conv2d-11           [-1, 32, 32, 32]           9,248
      BatchNorm2d-12           [-1, 32, 32, 32]              64
        LeakyReLU-13           [-1, 32, 32, 32]               0
        MaxPool2d-14       

In [7]:
learning_rate = 1e-5


def reset_model():
    global get_reconstruction_loss, encoder_net, decoder_net, discrim_net
    global decoder_optimizer, encoder_optimizer, encoder_generator_optimizer, discriminator_optimizer

    get_reconstruction_loss = nn.MSELoss()

    # creating sub-models
    encoder_net = ConvAAE.Encoder().to(device)
    decoder_net = ConvAAE.Decoder().to(device)
    discrim_net = ConvAAE.Discriminator().to(device)

    # optimizes decoder
    decoder_optimizer = torch.optim.Adam(decoder_net.parameters(), weight_decay=0, lr=learning_rate)

    # optimizes encoder/generator
    encoder_optimizer = torch.optim.Adam(encoder_net.parameters(), weight_decay=0, lr=learning_rate)

    # optimizes encoder/generator
    encoder_generator_optimizer = torch.optim.Adam(encoder_net.parameters(), 
        weight_decay=0, lr=learning_rate / 100)

    # optimizes discriminator
    discriminator_optimizer = torch.optim.Adam(discrim_net.parameters(), 
        weight_decay=0, lr=learning_rate / 50)


reset_model()

### Data Loader

In [None]:
from os import listdir
from os.path import isfile, join

In [9]:
data_path = './data'
files_per_epoch = 5

# get all red files
red_img_files = [f for f in listdir(data_path)
                 if isfile(join(data_path, f)) and 'RED' in f]

# random shuffle
np.random.seed(42)

In [10]:
len(red_img_files)

58

In [11]:
from sklearn.utils import shuffle

def get_new_epoch_patches():
    np.random.shuffle(red_img_files)
    # list of .IMG files names
    epoch_files = red_img_files[:files_per_epoch]

    # extract patches and concatenate all of them into one list
    images, ids = extract_patches_from_img(epoch_files, patch_size=patch_size)
    
    images, ids = shuffle(images, ids)
    
    # numpy to tensor
    tensor_images = numpy_images_to_tensor_dataset(images)
    
    tensor_images = tensor_images / 255
    
    # create loader for PyTorch
    dataset = torch.utils.data.TensorDataset(tensor_images)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    
    return loader

### Training procedure

In [12]:
def train_model(log_name, print_interval=20):
    global log_images
    
    # images used for logging and testing
    log_images = None
    
    log_dir = './logs/' + log_name
    
    # make sure we are not training in existing directory
    if os.path.isdir(log_dir):
        raise ValueError('Please, use a new logging directory for training')
        
    if not os.path.exists(log_dir + '/models'):
        os.makedirs(log_dir + '/models')

    # make a new logger for new directory
    tb_logger = TBLogger(log_dir)
    
    # make a copy so you we can recall what/how exactly we trained this time
    copy(current_notebook_path, log_dir)
    
    # now we can start the training
    encoder_net.train(), decoder_net.train() #, discrim_net.train()
    
    for epoch in range(num_epochs):
        reconstruction_losses = []
        generator_losses = []
        discriminator_losses = []
        
        # get random group of images
        # split them into patches, combine to one list
        # and convert this list into PyTorch dataset
        loader = get_new_epoch_patches()
        
        # loop over mini-batches sampled from patches in current dataset
        for i, (data,) in enumerate(loader):
            
            # store some images to track the training progress
            if log_images is None:
                log_images = data.clone().to(device)
            
            images = data.to(device)
            
            ### FORWARD PASS THROUGH TWO OUTPUTS
            
            ## Autoencoder part
            # run through encoder net and get embeddings
            embeddings = encoder_net(images)

            # run embeddings through dencoder net and get reconstruction
            reconstructed_images = decoder_net(embeddings)

            reconstruction_loss = get_reconstruction_loss(images, reconstructed_images)
            
            ## Generator + Discriminator part
            
            # they are fake, but it's "normal" :)
            batch_size = images.size()[0]
            fake_embeddings = Variable(torch.randn(batch_size, emb_size) * 5.).to(device)

            real_embeddings_discrimination = discrim_net(embeddings)

            # discriminator loss
            discriminator_loss = -torch.mean(torch.log(discrim_net(fake_embeddings) + TINY) + 
                                             torch.log(1 - real_embeddings_discrimination + TINY))

            # we are trying to maximize generator's loss on real embeddings
            generator_loss = -torch.mean(torch.log(real_embeddings_discrimination + TINY))

            ### BACKWARD PASS

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            
            # update gradients of encoder + decoder
            reconstruction_loss.backward(retain_graph=True)

            # update encoder and decoder models
            encoder_optimizer.step()
            decoder_optimizer.step()
            
            # update gradients of encoder as generator
            
            encoder_generator_optimizer.zero_grad()
            generator_loss.backward(retain_graph=True)
            encoder_generator_optimizer.step()

            # update gradients of discriminator only
            # reset negative generator gradients
            discriminator_optimizer.zero_grad()
            discriminator_loss.backward()
            discriminator_optimizer.step()

            reconstruction_losses.append(reconstruction_loss.item())
            generator_losses.append(generator_loss.item())
            discriminator_losses.append(discriminator_loss.item())
            
        
        # get mean averaged per mini-batch
        mean_r_loss = np.mean(reconstruction_losses)
        mean_g_loss = np.mean(generator_losses)
        mean_d_loss = np.mean(discriminator_losses)
        
        # log to file on every epoch
        tb_logger.add_loss(mean_r_loss)
        tb_logger.log_scalar('generator/loss', mean_g_loss)
        tb_logger.log_scalar('discriminator/loss', mean_d_loss)
        
        # calculate and log val loss
        reconstructed_images = decoder_net(encoder_net(log_images))
        reconstruction_val_loss = get_reconstruction_loss(log_images, reconstructed_images)
        tb_logger.log_scalar('validation/loss', reconstruction_val_loss.item())
        
        if epoch == 0:
            # print original raw image
            save_img_grid(f'{log_dir}/_src_image.jpg', encoder_net, decoder_net, log_images, print_input=True)
        
        # log to console from time to time
        if epoch % print_interval == 0:
            print(time.strftime('%X') + ' - epoch [{}/{}], r loss: {:.8f}, g loss: {:.4f}, d loss: {:.4f}'
                  .format(epoch+1, num_epochs, mean_r_loss, mean_g_loss, mean_d_loss))
            
            # save the model for myself and model state for sharing with others
            torch.save(encoder_net.state_dict(), f'{log_dir}/models/encoder_net_e_{epoch}.pth')
            torch.save(decoder_net.state_dict(), f'{log_dir}/models/decoder_net_e_{epoch}.pth')
            torch.save(discrim_net.state_dict(), f'{log_dir}/models/discrim_net_e_{epoch}.pth')
            
            save_img_grid(f'{log_dir}/e_{epoch}_image.jpg', encoder_net, decoder_net, log_images)
            
            # save the histogram. time consuming operation
            tb_logger.log_histogram('histograms/encoder', embeddings.clone().detach().cpu())
    
    # saving the final model
    torch.save(encoder_net.state_dict(), f'{log_dir}/models/encoder_net_e_{epoch}.pth')
    torch.save(decoder_net.state_dict(), f'{log_dir}/models/decoder_net_e_{epoch}.pth')
    torch.save(discrim_net.state_dict(), f'{log_dir}/models/discrim_net_e_{epoch}.pth')
    
    print('DONE')

def save_test_img(fname, encoder_net, decoder_net, tensor_images, print_input=False):
    test_data = tensor_images[:1]

    # forward
    if print_input:
        np_output = test_data.data.cpu().numpy()
    else:
        test_output = encoder_net(test_data)
        test_output = decoder_net(test_output)
        np_output = test_output.data.cpu().numpy()
    
    plt.imsave(fname, np.squeeze(np_output), vmin=test_data.min(), vmax=test_data.max())
    
def save_img_grid(fname, encoder_net, decoder_net, tensor_images, print_input=False):
    grid_size = 8
    
    test_data = tensor_images[:grid_size ** 2]
    
    # forward
    if print_input:
        np_output = test_data.data.cpu().numpy()
    else:
        test_output = encoder_net(test_data)
        test_output = decoder_net(test_output)
        np_output = test_output.data.cpu().numpy()
    
    fig, img_plots = plt.subplots(grid_size, grid_size, figsize=(8, 8), gridspec_kw = {'wspace':0.01, 'hspace':0.01})

    fig.patch.set_facecolor('black')

    for i in range(0, grid_size):
        for j in range(0, grid_size):
            img_plt = img_plots[i, j]

            img = np_output[i * grid_size + j]
            img_plt.imshow(np.squeeze(img), vmin=test_data.min(), vmax=test_data.max())
            img_plt.axis('off')

    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    
    fig.savefig(fname)
    plt.close(fig)

In [13]:
# log_images.min(), log_images.max(), log_images.mean()

## Training
### Please, use new logging directory name  before each training!

### Running for the first time

In [None]:
reset_model()

num_epochs = 1500

# specify a new name for log directory!
train_model('added adversarial component ' + time.strftime('%H-%M-%S'), print_interval=5)

20:22:15 - epoch [1/1500], r loss: 0.32265964, g loss: 0.5614, d loss: 1.3526
20:36:32 - epoch [6/1500], r loss: 0.07376406, g loss: 0.5760, d loss: 1.3288
20:57:18 - epoch [11/1500], r loss: 0.00337090, g loss: 0.5896, d loss: 1.3053
21:11:51 - epoch [16/1500], r loss: 0.00112747, g loss: 0.5979, d loss: 1.2902
21:29:17 - epoch [21/1500], r loss: 0.00072740, g loss: 0.6107, d loss: 1.2701
21:45:29 - epoch [26/1500], r loss: 0.00058217, g loss: 0.6191, d loss: 1.2557
22:03:44 - epoch [31/1500], r loss: 0.00056214, g loss: 0.6376, d loss: 1.2289
22:20:39 - epoch [36/1500], r loss: 0.00058928, g loss: 0.6389, d loss: 1.2226
22:39:31 - epoch [41/1500], r loss: 0.00015694, g loss: 0.6457, d loss: 1.2102
22:56:12 - epoch [46/1500], r loss: 0.00010201, g loss: 0.6555, d loss: 1.1965
23:12:15 - epoch [51/1500], r loss: 0.00016270, g loss: 0.6685, d loss: 1.1778
23:31:12 - epoch [56/1500], r loss: 0.00038326, g loss: 0.6752, d loss: 1.1671
23:49:50 - epoch [61/1500], r loss: 0.00031130, g loss

### Training same model

In [31]:
num_epochs = 1500

# specify a new name for log directory!
train_model('added adversarial component ' + time.strftime('%H-%M-%S'), print_interval=5)

14:08:47 - epoch [1/1500], r loss: 0.00005139, g loss: 0.6922, d loss: 1.4305


KeyboardInterrupt: 

### Running from scratch

In [102]:
reset_model()

num_epochs = 2000

train_model('cae_v5_lrelu_17')

01:07:00 - epoch [1/2000], loss:7170.91958926
01:08:04 - epoch [6/2000], loss:88.76065694
01:09:07 - epoch [11/2000], loss:35.16238949
01:10:11 - epoch [16/2000], loss:9.74772633
01:11:14 - epoch [21/2000], loss:19.80191330
01:12:18 - epoch [26/2000], loss:5.94075181
01:13:20 - epoch [31/2000], loss:17.81406898
01:14:24 - epoch [36/2000], loss:5.08296149
01:15:27 - epoch [41/2000], loss:17.50005694
01:16:31 - epoch [46/2000], loss:4.69269189
01:17:34 - epoch [51/2000], loss:17.45890252
01:18:38 - epoch [56/2000], loss:4.49350875
01:19:40 - epoch [61/2000], loss:16.39188738
01:20:44 - epoch [66/2000], loss:4.42343875
01:21:46 - epoch [71/2000], loss:16.17326309
01:22:50 - epoch [76/2000], loss:4.28442504
01:23:52 - epoch [81/2000], loss:15.79591764
01:24:56 - epoch [86/2000], loss:4.19751638
01:25:59 - epoch [91/2000], loss:15.42403963
01:27:03 - epoch [96/2000], loss:4.28876181
01:28:05 - epoch [101/2000], loss:15.70192102
01:29:09 - epoch [106/2000], loss:4.16870721
01:30:12 - epoch [