# AutoEncoder_images report 0
## based on AutoEncoder_images.py 
I develop in Spyder IDE, which is very convenient because of the workspace concept taken from Matlab - each object is kept in the workspace after execution if not deleted, and can be opened up or called. It also supports execution by cells (or any selected code). However, it does not save results, which is why I use Jupyter.

In [1]:
# -*- coding: utf-8 -*-
"""@author: oz.livneh@gmail.com

* All rights of this project and my code are reserved to me, Oz Livneh.
* Feel free to use - for personal use!
* Use at your own risk ;-)
"""
# to enable interactive plotting in Jupyter notebook:
%matplotlib notebook

In [2]:
#%% main parameters
#--------- general -------------
debugging=True # executes the debugging (short) sections, prints results
# debugging=False
torch_manual_seed=0 # integer or None for no seed; for torch reproducibility, as much as possible
#torch_manual_seed=None

#--------- data -------------
images_folder_path=r'D:\AI Data\DeepFake\ZioNLight Bibi.mp4 frames'

#random_transforms=True # soft data augmentation - color jitter (no random cropping or flipping - to keep all images aligned)
random_transforms=False
#max_dataset_length=100 # if positive: builds a dataset by sampling only max_dataset_length samples from all available data; requires user approval
max_dataset_length=0 # if non-positive: not restricting dataset length - using all available data
seed_for_dataset_downsampling=0 # integer or None for no seed; for sampling max_dataset_length samples from dataset

validation_ratio=0.3 # validation dataset ratio from total dataset length

#batch_size_int_or_ratio_float=1e-2 # if float: batch_size=round(batch_size_over_dataset_length*len(dataset_to_split))
batch_size_int_or_ratio_float=64 # if int: this is the batch size, should be 2**n
data_workers=0 # 0 means no multiprocessing in dataloaders
#data_workers='cpu cores' # sets data_workers=multiprocessing.cpu_count()

shuffle_dataset_indices_for_split=True # dataset indices for dataloaders are shuffled before splitting to train and validation indices
#shuffle_dataset_indices_for_split=False
dataset_shuffle_random_seed=0 # numpy seed for sampling the indices for the dataset, before splitting to train and val dataloaders
#dataset_shuffle_random_seed=None
dataloader_shuffle=True # samples are shuffled inside each dataloader, on each epoch
#dataloader_shuffle=False

#--------- net -------------
#net_architecture='simple auto-encoder'
#net_architecture='bigger auto-encoder'
net_architecture='v3 auto-encoder'

loss_name='MSE'

#--------- training -------------
train_model_else_load_weights=True
#train_model_else_load_weights=False # instead of training, loads a pre-trained model and uses it

epochs=30
learning_rate=1e-1
momentum=0.9

lr_scheduler_step_size=1
lr_scheduler_decay_factor=0.9

best_model_criterion='min val epoch MSE' # criterion for choosing best net weights during training as the final weights
return_to_best_weights_in_the_end=True # when training complets, loads weights of the best net, definied by best_model_criterion
#return_to_best_weights_in_the_end=False

training_progress_ratio_to_log_loss=0.25 # <=1, inter-epoch logging and reporting loss and metrics during training, period_in_batches_to_log_loss=round(training_progress_ratio_to_log_loss*dataset_samples_number['train']/batch_size)
#plot_realtime_stats_on_logging=True # incomplete implementation!
plot_realtime_stats_on_logging=False
#plot_realtime_stats_after_each_epoch=True
plot_realtime_stats_after_each_epoch=False
#plot_loss_in_log_scale=True
plot_loss_in_log_scale=False

#offer_mode_saving=True # offer model weights saving ui after training (only if train_model_else_load_weights=True)
offer_mode_saving=False
models_folder_path='D:\My Documents\Dropbox\Python\DatingAI\Data\Saved Models'

In [3]:
#%% initialization
import logging
logging.basicConfig(format='%(asctime)s %(funcName)s (%(levelname)s): %(message)s',
                   datefmt='%Y-%m-%d %H:%M:%S')
logger=logging.getLogger('data processing logger')
logger.setLevel(logging.INFO)

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
import random
from time import time
import copy
import PIL
import multiprocessing
if data_workers=='cpu cores':
    data_workers=multiprocessing.cpu_count()

import torch
torch.manual_seed(torch_manual_seed)

import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils

def plot_from_image_filenames_list(sample_indices_to_plot,image_filenames_list,
                                       images_folder_path,images_per_row=4):
    assert (isinstance(images_per_row,int) and images_per_row>0), 'images_per_row is invalid, must be a positive integer'
    plt.figure()
    columns_num=math.ceil(len(sample_indices_to_plot)/images_per_row)
    for i,sample_index in enumerate(sample_indices_to_plot):
        image_filename=image_filenames_list[sample_index]
        image_array=plt.imread(os.path.join(images_folder_path,image_filename))    
        
        plt.subplot(columns_num,images_per_row,i+1)
        plt.imshow(image_array)
        plt.title(image_filename)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.show()

class images_torch_dataset(torch.utils.data.Dataset):
    def __init__(self,image_filenames,images_folder_path,transform_func=None):
        self.image_filenames=image_filenames
        self.images_folder_path=images_folder_path
        self.transform_func=transform_func

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self,idx):
        image_filename=self.image_filenames[idx]
        image_path=os.path.join(self.images_folder_path,image_filename)
        image_array=PIL.Image.open(image_path)
        
        if self.transform_func!=None:
            image_array=self.transform_func(image_array)
        
        sample={'image filename':image_filename,'image array':image_array}
        return sample

def plot_from_torch_dataset(sample_indices_to_plot,torch_dataset,
                            images_per_row=4,image_format='PIL->torch'):
    assert (isinstance(images_per_row,int) and images_per_row>0), 'images_per_row is invalid, must be a positive integer'
    plt.figure()
    columns_num=math.ceil(len(sample_indices_to_plot)/images_per_row)
    for i,sample_index in enumerate(sample_indices_to_plot):
        sample=torch_dataset[sample_index]
        image_filename=sample['image filename']
        image_array=sample['image array']
        if image_format=='np->torch': # to return from a torch format that reached from a np format, to a np for plotting. see # Helper function to show a batch from https://pytorch.org/tutorials/beginner/data_loading_tutorial
            image_array=image_array.transpose((1,2,0))
        elif image_format=='PIL->torch': # to return from a torch format that reached from a PIL format, to a np for plotting. see # Helper function to show a batch from https://pytorch.org/tutorials/beginner/data_loading_tutorial
            image_array=image_array.numpy().transpose((1,2,0))
        
        plt.subplot(columns_num,images_per_row,i+1)
        plt.imshow(image_array)
        plt.title(image_filename)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.show()

def training_stats_plot(stats_dict,fig,loss_subplot,MSE_subplot):
    running_stats_df=pd.DataFrame.from_dict(stats_dict['train']['running metrics'],orient='index')
    epoch_train_stats_df=pd.DataFrame.from_dict(stats_dict['train']['epoch metrics'],orient='index')
    epoch_val_stats_df=pd.DataFrame.from_dict(stats_dict['val']['epoch metrics'],orient='index')
    
    loss_subplot.clear() # clearing plot before plotting, to avoid over-plotting
    if len(running_stats_df)>0:
        loss_subplot.plot(running_stats_df['loss per sample'],'-x',label='running train')
    loss_subplot.plot(epoch_train_stats_df['loss per sample'],'k-o',label='epoch train')
    loss_subplot.plot(epoch_val_stats_df['loss per sample'],'r-o',label='epoch val')
    loss_subplot.set_ylabel('loss per sample')
    loss_subplot.grid()
    loss_subplot.legend(loc='best')
    if plot_loss_in_log_scale:
        loss_subplot.set_yscale('log')
        
    MSE_subplot.clear() # clearing plot before plotting, to avoid over-plotting
    if len(running_stats_df)>0:
        MSE_subplot.plot(running_stats_df['MSE']**0.5,'-x',label='running train')
    MSE_subplot.plot(epoch_train_stats_df['MSE']**0.5,'k-o',label='epoch train')
    MSE_subplot.plot(epoch_val_stats_df['MSE']**0.5,'r-o',label='epoch val')
    MSE_subplot.set_ylabel('sqrt(MSE)')
    MSE_subplot.grid()
    MSE_subplot.legend(loc='best')
    fig.canvas.draw()

class remainder_time:
    def __init__(self,time_seconds):
        self.time_seconds=time_seconds
        self.hours=int(time_seconds/3600)
        self.remainder_minutes=int((time_seconds-self.hours*3600)/60)
        self.remainder_seconds=time_seconds-self.hours*3600-self.remainder_minutes*60

logger.info('script initialized')

2019-06-10 20:03:27 <module> (INFO): script initialized


In [4]:
#%% building a torch dataset of PIL images with torchvision transforms
"""torchvision.transforms accept PIL images, and not np images that are 
    created when using skimage as presented in 
    https://pytorch.org/tutorials/beginner/data_loading_tutorial

torchvision transforms: https://pytorch.org/docs/stable/torchvision/transforms.html
"""

if random_transforms:
    transform_func=torchvision.transforms.Compose([
    #            torchvision.transforms.Resize(400),
    #            torchvision.transforms.RandomCrop(390),
                torchvision.transforms.CenterCrop(200),
                torchvision.transforms.ColorJitter(brightness=0.1,contrast=0.1,saturation=0,hue=0),
    #            torchvision.transforms.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.ToTensor(),
                ])
else:
    transform_func=torchvision.transforms.Compose([
                torchvision.transforms.CenterCrop(200),
                torchvision.transforms.ToTensor(),
                ])
crop_approval=input('ATTENTION: using torchvision.transforms.CenterCrop(200), approve? y/[n] ')
if crop_approval!='y':
    raise RuntimeError('user did not approve torchvision.transforms.CenterCrop(200)!')
"""torchvision.transforms.ToTensor() Converts a PIL Image or numpy.ndarray (H x W x C) in the 
    range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the 
    range [0.0, 1.0] if the PIL Image belongs to one of the modes 
    (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has 
    dtype = np.uint8
In the other cases, tensors are returned without scaling
source: https://pytorch.org/docs/stable/torchvision/transforms.html
"""
image_filenames=os.listdir(images_folder_path)

if max_dataset_length>0:
    user_data_approval=input('ATTENTION: downsampling is chosen - building a dataset by sampling only max_dataset_length=%d samples from all available data! approve? y/[n] '%(round(max_dataset_length)))
    if user_data_approval!='y':
        raise RuntimeError('user did not approve dataset max_dataset_length sampling!')
    random.seed(seed_for_dataset_downsampling)
    image_filenames=random.sample(image_filenames,max_dataset_length)

images_dataset=images_torch_dataset(image_filenames,images_folder_path,transform_func=transform_func)
sample_size=images_dataset[0]['image array'].size()
sample_pixels_per_channel=sample_size[1]*sample_size[2]
sample_pixels_all_channels=sample_size[0]*sample_pixels_per_channel
logger.info('set a PyTorch dataset of length %.2e, input size (assuming it is constant): (%d,%d,%d)'%(
        len(image_filenames),sample_size[0],sample_size[1],sample_size[2]))

ATTENTION: using torchvision.transforms.CenterCrop(200), approve? y/[n] y


2019-06-10 20:03:29 <module> (INFO): set a PyTorch dataset of length 8.31e+03, input size (assuming it is constant): (3,200,200)


In [5]:
#%% (debugging) verifying dataset by plotting
samples_to_plot=20
#sampling_for_sample_verification='none' # plotting first samples_to_plot samples
sampling_for_sample_verification='random' # plotting randomly selected samples_to_plot samples, using seed_for_sample_verification seed
seed_for_sample_verification=0
images_per_row=4
# end of inputs ---------------------------------------------------------------
if sampling_for_sample_verification=='none':
    sample_indices_to_plot=range(samples_to_plot)
elif sampling_for_sample_verification=='random':
    random.seed(seed_for_sample_verification)
    sample_indices_to_plot=random.sample(range(len(image_filenames)),samples_to_plot)
else:
    raise RuntimeError('unsupported sampling_for_sample_verification input!')

if debugging:
    plot_from_image_filenames_list(sample_indices_to_plot,image_filenames,
                                   images_folder_path,images_per_row)
    plt.suptitle('plotting images directly from disk')
    
    plot_from_torch_dataset(sample_indices_to_plot,images_dataset,
                            images_per_row,image_format='PIL->torch')
    plt.suptitle('plotting images from PyTorch dataset')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [6]:
#%% splitting to train and val datsets and dataloaders
dataset_to_split=images_dataset

if isinstance(batch_size_int_or_ratio_float,int):
    batch_size=batch_size_int_or_ratio_float
elif isinstance(batch_size_int_or_ratio_float,float):
    batch_size=round(batch_size_int_or_ratio_float*len(dataset_to_split))
else:
    raise RuntimeError('unsupported batch_size input!')
if batch_size<1:
    batch_size=1
    logger.warning('batch_size=round(batch_size_over_dataset_length*len(dataset_to_split))<1 so batch_size=1 was set')
if batch_size==1:
    user_batch_size=input('batch_size=1 should cause errors since batch_size>1 is generally assumed! enter a new batch size equal or larger than 1, or smaller than 1 to abort: ')
    if user_batch_size<1:
        raise RuntimeError('aborted by user batch size decision')
    else:
        batch_size=round(user_batch_size)

dataset_length=len(dataset_to_split)
dataset_indices=list(range(dataset_length))
split_index=int((1-validation_ratio)*dataset_length)
if shuffle_dataset_indices_for_split:
    np.random.seed(dataset_shuffle_random_seed)
    np.random.shuffle(dataset_indices)
train_indices=dataset_indices[:split_index]
val_indices=dataset_indices[split_index:]

# splitting the dataset to train and val
train_dataset=torch.utils.data.Subset(dataset_to_split,train_indices)
val_dataset=torch.utils.data.Subset(dataset_to_split,val_indices)

# creating the train and val dataloaders
train_dataloader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,
                        num_workers=data_workers,shuffle=dataloader_shuffle)
val_dataloader=torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,
                        num_workers=data_workers,shuffle=dataloader_shuffle)

# structuring
dataset_indices={'train':train_indices,'val':val_indices}
datasets={'train':train_dataset,'val':val_dataset}
dataset_samples_number={'train':len(train_dataset),'val':len(val_dataset)}

dataloaders={'train':train_dataloader,'val':val_dataloader}
dataloader_batches_number={'train':len(train_dataloader),'val':len(val_dataloader)}

logger.info('dataset split to training and validation datasets and dataloaders with validation_ratio=%.1f, lengths: (train,val)=(%d,%d)'%(
        validation_ratio,dataset_samples_number['train'],dataset_samples_number['val']))

2019-06-10 20:03:30 <module> (INFO): dataset split to training and validation datasets and dataloaders with validation_ratio=0.3, lengths: (train,val)=(5814,2492)


In [7]:
#%% (debugging) verifying dataloaders
images_per_row=4
# end of inputs ---------------------------------------------------------------

if debugging:
    if __name__=='__main__' or data_workers==0: # required in Windows for multi-processing
        samples_batches={}
        for phase in ['train','val']:
            samples_batch=next(iter(dataloaders[phase]))
            samples_batches.update({phase:samples_batch})
    else:
        raise RuntimeError('cannot use multiprocessing (data_workers>0 in dataloaders) in Windows when executed not as main!')
        
    columns_num=math.ceil(batch_size/images_per_row)
    for phase in ['train','val']:
        plt.figure()
        for i in range(batch_size):
            samples_batch=samples_batches[phase]
            image_array=samples_batch['image array'][i].numpy().transpose((1,2,0))
            image_filename=samples_batch['image filename'][i]
            
            plt.subplot(columns_num,images_per_row,i+1)
            plt.imshow(image_array)
            plt.title(image_filename)
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
        plt.suptitle('plotting a batch from the %s dataloader'%phase)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [8]:
#%% setting the NN
if training_progress_ratio_to_log_loss>1:
    raise RuntimeError('invalid training_progress_ratio_to_log_loss=%.2f, must be <=1'%training_progress_ratio_to_log_loss)
period_in_batches_to_log_loss=round(training_progress_ratio_to_log_loss*dataset_samples_number['train']/batch_size) # logging only during training

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if plot_realtime_stats_on_logging or plot_realtime_stats_after_each_epoch:
    logger.warning('plotting from inside the net loop is not working, should be debugged...')

if net_architecture=='simple auto-encoder':
    # inspired by https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
    class autoencoder(nn.Module):
        def __init__(self):
            super(autoencoder, self).__init__()
            self.encoder = nn.Sequential(
                nn.Conv2d(3,16,20,stride=4,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(16),
                
                nn.Conv2d(16,3,8,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(3),
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose2d(3,16,8,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(16),
                
                nn.ConvTranspose2d(16,3,20,stride=4,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(3),
            )
    
        def forward(self,x):
            x=self.encoder(x)
            x=self.decoder(x)
            return x
    model=autoencoder()
elif net_architecture=='bigger auto-encoder':
    class autoencoder(nn.Module):
        def __init__(self):
            super(autoencoder, self).__init__()
            self.encoder = nn.Sequential(
                nn.Conv2d(3,16,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(16),
                
                nn.Conv2d(16,8,4,stride=1,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(8),
                
                nn.Conv2d(8,2,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(2),
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose2d(2,8,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(8),
                
                nn.ConvTranspose2d(8,16,4,stride=1,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(16),
                
                nn.ConvTranspose2d(16,3,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(3),
            )
    
        def forward(self,x):
            x=self.encoder(x)
            x=self.decoder(x)
            return x
    model=autoencoder()
elif net_architecture=='v3 auto-encoder':
    class autoencoder(nn.Module):
        def __init__(self):
            super(autoencoder, self).__init__()
            self.encoder = nn.Sequential(
                nn.Conv2d(3,16,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(16),
                
                nn.Conv2d(16,8,4,stride=1,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(8),
                
                nn.Conv2d(8,4,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(4),
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose2d(4,8,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(8),
                
                nn.ConvTranspose2d(8,16,4,stride=1,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(16),
                
                nn.ConvTranspose2d(16,3,4,stride=2,bias=False),
                nn.ReLU(True),
                nn.BatchNorm2d(3),
            )
    
        def forward(self,x):
            x=self.encoder(x)
            x=self.decoder(x)
            return x
    model=autoencoder()
else:
    raise RuntimeError('untreated net_architecture!')

model=model.to(device)
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate,momentum=momentum)

if loss_name=='MSE':
    loss_fn=nn.MSELoss(reduction='mean').to(device)
else:
    raise RuntimeError('untreated loss_name input')
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,
    step_size=lr_scheduler_step_size,gamma=lr_scheduler_decay_factor)

In [9]:
#%% (debugging) verifying model outputs and loss
if debugging:
    if __name__=='__main__' or data_workers==0:
        batch=next(iter(dataloaders['train']))
        input_images=batch['image array']
        input_images=input_images.to(device)
        print('input shape:',input_images.shape)
        model.eval()
        output_images=model(input_images)
        print('output shape:',output_images.shape)
        print('nn.MSELoss(input_images,output_images):',
              nn.MSELoss(reduction='mean').to(device)(input_images,output_images))
        print('((input_images-output_images)**2).mean():',
              ((input_images-output_images)**2).mean())
    else:
        raise RuntimeError('cannot use multiprocessing (data_workers>0 in dataloaders) in Windows when executed not as main!')

input shape: torch.Size([64, 3, 200, 200])
output shape: torch.Size([64, 3, 200, 200])
nn.MSELoss(input_images,output_images): tensor(0.2694, device='cuda:0', grad_fn=<MeanBackward1>)
((input_images-output_images)**2).mean(): tensor(0.2694, device='cuda:0', grad_fn=<MeanBackward1>)


In [10]:
#%% training the net
if train_model_else_load_weights and (__name__=='__main__' or data_workers==0):
    stats_dict={'train':{'epoch metrics':{},
                     'running metrics':{}}, # running = measuerd on samples only since the last log
                     'val':{'epoch metrics':{}}}
    
    total_batches=epochs*(dataloader_batches_number['train']+dataloader_batches_number['val'])
    logger.info("started training '%s' net on %s"%(net_architecture,device))
    tic=time()
    for epoch in range(epochs):
        for phase in ['train','val']:
            if phase == 'train':
                scheduler.step()
                model.train() # set model to training mode
            else:
                model.eval() # set model to evaluate mode
            
            epoch_loss=0.0 # must be a float
            epoch_squared_error=0.0
            samples_processed_since_last_log=0
            loss_since_last_log=0.0 # must be a float
            squared_error_since_last_log=0.0
            
            for i_batch,batch in enumerate(dataloaders[phase]):
                input_images=batch['image array'].to(device)
                
                optimizer.zero_grad() # zero the parameter gradients
                
                # forward
                with torch.set_grad_enabled(phase=='train'): # if phase=='train' it tracks tensor history for grad calc
                    output_images=model(input_images)
                    loss=loss_fn(output_images,input_images)
                    if torch.isnan(loss):
                        raise RuntimeError('reached NaN loss - aborting training!')
                    # backward + optimize if training
                    if phase=='train':
                        loss.backward()
                        optimizer.step()
                
                # accumulating stats
                samples_number=len(input_images)
                samples_processed_since_last_log+=samples_number
                
                current_loss=loss.item()*samples_number*sample_pixels_all_channels # the loss is averaged across samples and pixels in each minibatch, so it is multiplied to return to a total
                epoch_loss+=current_loss
                loss_since_last_log+=current_loss
                
                with torch.set_grad_enabled(False):
                    batch_squared_error=((output_images-input_images)**2).sum()
                    batch_squared_error=batch_squared_error.item()
                epoch_squared_error+=batch_squared_error
                squared_error_since_last_log+=batch_squared_error
                
                if phase=='train' and i_batch%period_in_batches_to_log_loss==(period_in_batches_to_log_loss-1):
                    loss_since_last_log_per_sample=loss_since_last_log/samples_processed_since_last_log
                    MSE_since_last_log=squared_error_since_last_log/samples_processed_since_last_log
                    stats_dict[phase]['running metrics'].update({pd.Timestamp.now():
                        {'epoch':epoch+1,'batch':i_batch+1,'loss per sample':loss_since_last_log_per_sample,
                         'MSE':MSE_since_last_log}})
            
                    completed_batches=epoch*(dataloader_batches_number['train']+dataloader_batches_number['val'])+(i_batch+1)
                    completed_batches_progress=completed_batches/total_batches
                    passed_seconds=time()-tic
                    expected_seconds=passed_seconds/completed_batches_progress*(1-completed_batches_progress)
                    expected_remainder_time=remainder_time(expected_seconds)
                    
                    logger.info('(epoch %d/%d, batch %d/%d, %s) running loss per sample (since last log): %.3e, running sqrt(MSE) (since last log): sqrt(%.3e)=%.3e\n\tETA: %dh:%dm:%.0fs'%(
                                epoch+1,epochs,i_batch+1,dataloader_batches_number[phase],phase,
                                loss_since_last_log_per_sample,
                                MSE_since_last_log,MSE_since_last_log**0.5,
                                expected_remainder_time.hours,expected_remainder_time.remainder_minutes,expected_remainder_time.remainder_seconds))
                    
                    loss_since_last_log=0.0 # must be a float
                    squared_error_since_last_log=0.0
                    samples_processed_since_last_log=0
            
            # epoch stats
            epoch_loss_per_sample=epoch_loss/dataset_samples_number[phase]
            epoch_MSE=epoch_squared_error/dataset_samples_number[phase]
            
            stats_dict[phase]['epoch metrics'].update({pd.Timestamp.now():
                        {'epoch':epoch,
                         'loss per sample':epoch_loss_per_sample,
                         'MSE':epoch_MSE}})
            if phase=='val':
                if best_model_criterion=='min val epoch MSE':
                    best_criterion_current_value=epoch_MSE
                    if epoch==0:
                        best_criterion_best_value=best_criterion_current_value
                        best_model_wts=copy.deepcopy(model.state_dict())
                        best_epoch=epoch
                    else:
                        if best_criterion_current_value<best_criterion_best_value:
                            best_criterion_best_value=best_criterion_current_value
                            best_model_wts=copy.deepcopy(model.state_dict())
                            best_epoch=epoch
                
                completed_epochs_progress=(epoch+1)/epochs
                passed_seconds=time()-tic
                expected_seconds=passed_seconds/completed_epochs_progress*(1-completed_epochs_progress)
                expected_remainder_time=remainder_time(expected_seconds)
                
                # not printing epoch stats for training, since in this phase they are being measured while the weights are being updated, unlike in validation where stats are measured with no update
                logger.info('(epoch %d, %s) epoch loss per sample: %.3e, epoch sqrt(MSE): sqrt(%.3e)=%.3e\n\tETA: %dh:%dm:%.0fs'%(
                                    epoch+1,phase,
                                    epoch_loss_per_sample,
                                    epoch_MSE,epoch_MSE**0.5,
                                    expected_remainder_time.hours,
                                    expected_remainder_time.remainder_minutes,
                                    expected_remainder_time.remainder_seconds))
                print('-'*10)
    toc=time()
    elapsed_sec=toc-tic
    pytorch_total_wts=sum(p.numel() for p in model.parameters())
    pytorch_trainable_wts=sum(p.numel() for p in model.parameters() if p.requires_grad)

    logger.info('finished training %d epochs in %dm:%.1fs, trainable/total weigths: %d/%d'%(
            epochs,elapsed_sec//60,elapsed_sec%60,pytorch_trainable_wts,pytorch_total_wts))
    if return_to_best_weights_in_the_end:
        model.load_state_dict(best_model_wts)
        logger.info("loaded weights of best model according to '%s' criterion: best value %.3f achieved in epoch %d"%(
                best_model_criterion,best_criterion_best_value,best_epoch+1))
    if not (plot_realtime_stats_on_logging or plot_realtime_stats_after_each_epoch):
        fig=plt.figure()
        plt.suptitle('model stats')
        loss_subplot=plt.subplot(1,2,1)
        MSE_subplot=plt.subplot(1,2,2)
    training_stats_plot(stats_dict,fig,loss_subplot,MSE_subplot)
else: # train_model_else_load_weights==False
    ui_model_name=input('model weights file name to load: ')
    model_weights_file_path=os.path.join(models_folder_path,ui_model_name)
    if not os.path.isfile(model_weights_file_path):
        raise RuntimeError('%model_weights_path does not exist!')
    model_weights=torch.load(model_weights_file_path)
    model.load_state_dict(model_weights)
    logger.info('model weights from %s were loaded'%model_weights_file_path)

2019-06-10 20:03:37 <module> (INFO): started training 'v3 auto-encoder' net on cuda:0
2019-06-10 20:03:47 <module> (INFO): (epoch 1/30, batch 23/91, train) running loss per sample (since last log): 2.543e+04, running sqrt(MSE) (since last log): sqrt(2.543e+04)=1.595e+02
	ETA: 0h:28m:21s
2019-06-10 20:03:57 <module> (INFO): (epoch 1/30, batch 46/91, train) running loss per sample (since last log): 6.469e+03, running sqrt(MSE) (since last log): sqrt(6.469e+03)=8.043e+01
	ETA: 0h:28m:7s
2019-06-10 20:04:07 <module> (INFO): (epoch 1/30, batch 69/91, train) running loss per sample (since last log): 4.348e+03, running sqrt(MSE) (since last log): sqrt(4.348e+03)=6.594e+01
	ETA: 0h:27m:52s
2019-06-10 20:04:29 <module> (INFO): (epoch 1, val) epoch loss per sample: 3.334e+03, epoch sqrt(MSE): sqrt(3.334e+03)=5.774e+01
	ETA: 0h:25m:17s


----------


2019-06-10 20:04:40 <module> (INFO): (epoch 2/30, batch 23/91, train) running loss per sample (since last log): 3.172e+03, running sqrt(MSE) (since last log): sqrt(3.172e+03)=5.632e+01
	ETA: 0h:25m:36s
2019-06-10 20:04:50 <module> (INFO): (epoch 2/30, batch 46/91, train) running loss per sample (since last log): 2.913e+03, running sqrt(MSE) (since last log): sqrt(2.913e+03)=5.397e+01
	ETA: 0h:25m:47s
2019-06-10 20:05:01 <module> (INFO): (epoch 2/30, batch 69/91, train) running loss per sample (since last log): 2.723e+03, running sqrt(MSE) (since last log): sqrt(2.723e+03)=5.218e+01
	ETA: 0h:26m:1s
2019-06-10 20:05:24 <module> (INFO): (epoch 2, val) epoch loss per sample: 2.558e+03, epoch sqrt(MSE): sqrt(2.558e+03)=5.058e+01
	ETA: 0h:25m:2s


----------


2019-06-10 20:05:35 <module> (INFO): (epoch 3/30, batch 23/91, train) running loss per sample (since last log): 2.523e+03, running sqrt(MSE) (since last log): sqrt(2.523e+03)=5.023e+01
	ETA: 0h:25m:5s
2019-06-10 20:05:45 <module> (INFO): (epoch 3/30, batch 46/91, train) running loss per sample (since last log): 2.467e+03, running sqrt(MSE) (since last log): sqrt(2.467e+03)=4.967e+01
	ETA: 0h:25m:8s
2019-06-10 20:05:56 <module> (INFO): (epoch 3/30, batch 69/91, train) running loss per sample (since last log): 2.423e+03, running sqrt(MSE) (since last log): sqrt(2.423e+03)=4.923e+01
	ETA: 0h:25m:8s
2019-06-10 20:06:19 <module> (INFO): (epoch 3, val) epoch loss per sample: 2.341e+03, epoch sqrt(MSE): sqrt(2.341e+03)=4.838e+01
	ETA: 0h:24m:21s


----------


2019-06-10 20:06:31 <module> (INFO): (epoch 4/30, batch 23/91, train) running loss per sample (since last log): 2.315e+03, running sqrt(MSE) (since last log): sqrt(2.315e+03)=4.811e+01
	ETA: 0h:24m:26s
2019-06-10 20:06:42 <module> (INFO): (epoch 4/30, batch 46/91, train) running loss per sample (since last log): 2.226e+03, running sqrt(MSE) (since last log): sqrt(2.226e+03)=4.718e+01
	ETA: 0h:24m:26s
2019-06-10 20:06:52 <module> (INFO): (epoch 4/30, batch 69/91, train) running loss per sample (since last log): 2.093e+03, running sqrt(MSE) (since last log): sqrt(2.093e+03)=4.575e+01
	ETA: 0h:24m:25s
2019-06-10 20:07:15 <module> (INFO): (epoch 4, val) epoch loss per sample: 1.890e+03, epoch sqrt(MSE): sqrt(1.890e+03)=4.348e+01
	ETA: 0h:23m:40s


----------


2019-06-10 20:07:26 <module> (INFO): (epoch 5/30, batch 23/91, train) running loss per sample (since last log): 1.835e+03, running sqrt(MSE) (since last log): sqrt(1.835e+03)=4.283e+01
	ETA: 0h:23m:38s
2019-06-10 20:07:37 <module> (INFO): (epoch 5/30, batch 46/91, train) running loss per sample (since last log): 1.751e+03, running sqrt(MSE) (since last log): sqrt(1.751e+03)=4.185e+01
	ETA: 0h:23m:35s
2019-06-10 20:07:48 <module> (INFO): (epoch 5/30, batch 69/91, train) running loss per sample (since last log): 1.705e+03, running sqrt(MSE) (since last log): sqrt(1.705e+03)=4.130e+01
	ETA: 0h:23m:31s
2019-06-10 20:08:11 <module> (INFO): (epoch 5, val) epoch loss per sample: 1.653e+03, epoch sqrt(MSE): sqrt(1.653e+03)=4.065e+01
	ETA: 0h:22m:49s


----------


2019-06-10 20:08:21 <module> (INFO): (epoch 6/30, batch 23/91, train) running loss per sample (since last log): 1.643e+03, running sqrt(MSE) (since last log): sqrt(1.643e+03)=4.054e+01
	ETA: 0h:22m:42s
2019-06-10 20:08:32 <module> (INFO): (epoch 6/30, batch 46/91, train) running loss per sample (since last log): 1.627e+03, running sqrt(MSE) (since last log): sqrt(1.627e+03)=4.034e+01
	ETA: 0h:22m:36s
2019-06-10 20:08:42 <module> (INFO): (epoch 6/30, batch 69/91, train) running loss per sample (since last log): 1.615e+03, running sqrt(MSE) (since last log): sqrt(1.615e+03)=4.019e+01
	ETA: 0h:22m:30s
2019-06-10 20:09:05 <module> (INFO): (epoch 6, val) epoch loss per sample: 1.583e+03, epoch sqrt(MSE): sqrt(1.583e+03)=3.978e+01
	ETA: 0h:21m:53s


----------


2019-06-10 20:09:16 <module> (INFO): (epoch 7/30, batch 23/91, train) running loss per sample (since last log): 1.582e+03, running sqrt(MSE) (since last log): sqrt(1.582e+03)=3.977e+01
	ETA: 0h:21m:47s
2019-06-10 20:09:27 <module> (INFO): (epoch 7/30, batch 46/91, train) running loss per sample (since last log): 1.571e+03, running sqrt(MSE) (since last log): sqrt(1.571e+03)=3.963e+01
	ETA: 0h:21m:41s
2019-06-10 20:09:37 <module> (INFO): (epoch 7/30, batch 69/91, train) running loss per sample (since last log): 1.550e+03, running sqrt(MSE) (since last log): sqrt(1.550e+03)=3.937e+01
	ETA: 0h:21m:34s
2019-06-10 20:10:01 <module> (INFO): (epoch 7, val) epoch loss per sample: 1.527e+03, epoch sqrt(MSE): sqrt(1.527e+03)=3.908e+01
	ETA: 0h:21m:1s


----------


2019-06-10 20:10:11 <module> (INFO): (epoch 8/30, batch 23/91, train) running loss per sample (since last log): 1.526e+03, running sqrt(MSE) (since last log): sqrt(1.526e+03)=3.907e+01
	ETA: 0h:20m:54s
2019-06-10 20:10:22 <module> (INFO): (epoch 8/30, batch 46/91, train) running loss per sample (since last log): 1.511e+03, running sqrt(MSE) (since last log): sqrt(1.511e+03)=3.888e+01
	ETA: 0h:20m:48s
2019-06-10 20:10:33 <module> (INFO): (epoch 8/30, batch 69/91, train) running loss per sample (since last log): 1.504e+03, running sqrt(MSE) (since last log): sqrt(1.504e+03)=3.878e+01
	ETA: 0h:20m:40s
2019-06-10 20:10:56 <module> (INFO): (epoch 8, val) epoch loss per sample: 1.479e+03, epoch sqrt(MSE): sqrt(1.479e+03)=3.846e+01
	ETA: 0h:20m:8s


----------


2019-06-10 20:11:07 <module> (INFO): (epoch 9/30, batch 23/91, train) running loss per sample (since last log): 1.476e+03, running sqrt(MSE) (since last log): sqrt(1.476e+03)=3.842e+01
	ETA: 0h:20m:0s
2019-06-10 20:11:17 <module> (INFO): (epoch 9/30, batch 46/91, train) running loss per sample (since last log): 1.461e+03, running sqrt(MSE) (since last log): sqrt(1.461e+03)=3.822e+01
	ETA: 0h:19m:53s
2019-06-10 20:11:28 <module> (INFO): (epoch 9/30, batch 69/91, train) running loss per sample (since last log): 1.459e+03, running sqrt(MSE) (since last log): sqrt(1.459e+03)=3.820e+01
	ETA: 0h:19m:45s
2019-06-10 20:11:51 <module> (INFO): (epoch 9, val) epoch loss per sample: 1.432e+03, epoch sqrt(MSE): sqrt(1.432e+03)=3.784e+01
	ETA: 0h:19m:12s


----------


2019-06-10 20:12:01 <module> (INFO): (epoch 10/30, batch 23/91, train) running loss per sample (since last log): 1.433e+03, running sqrt(MSE) (since last log): sqrt(1.433e+03)=3.786e+01
	ETA: 0h:19m:4s
2019-06-10 20:12:11 <module> (INFO): (epoch 10/30, batch 46/91, train) running loss per sample (since last log): 1.423e+03, running sqrt(MSE) (since last log): sqrt(1.423e+03)=3.772e+01
	ETA: 0h:18m:55s
2019-06-10 20:12:22 <module> (INFO): (epoch 10/30, batch 69/91, train) running loss per sample (since last log): 1.405e+03, running sqrt(MSE) (since last log): sqrt(1.405e+03)=3.748e+01
	ETA: 0h:18m:47s
2019-06-10 20:12:46 <module> (INFO): (epoch 10, val) epoch loss per sample: 1.392e+03, epoch sqrt(MSE): sqrt(1.392e+03)=3.731e+01
	ETA: 0h:18m:17s


----------


2019-06-10 20:12:56 <module> (INFO): (epoch 11/30, batch 23/91, train) running loss per sample (since last log): 1.393e+03, running sqrt(MSE) (since last log): sqrt(1.393e+03)=3.733e+01
	ETA: 0h:18m:9s
2019-06-10 20:13:07 <module> (INFO): (epoch 11/30, batch 46/91, train) running loss per sample (since last log): 1.383e+03, running sqrt(MSE) (since last log): sqrt(1.383e+03)=3.719e+01
	ETA: 0h:18m:1s
2019-06-10 20:13:18 <module> (INFO): (epoch 11/30, batch 69/91, train) running loss per sample (since last log): 1.379e+03, running sqrt(MSE) (since last log): sqrt(1.379e+03)=3.714e+01
	ETA: 0h:17m:54s
2019-06-10 20:13:41 <module> (INFO): (epoch 11, val) epoch loss per sample: 1.359e+03, epoch sqrt(MSE): sqrt(1.359e+03)=3.687e+01
	ETA: 0h:17m:24s


----------


2019-06-10 20:13:52 <module> (INFO): (epoch 12/30, batch 23/91, train) running loss per sample (since last log): 1.361e+03, running sqrt(MSE) (since last log): sqrt(1.361e+03)=3.689e+01
	ETA: 0h:17m:16s
2019-06-10 20:14:03 <module> (INFO): (epoch 12/30, batch 46/91, train) running loss per sample (since last log): 1.355e+03, running sqrt(MSE) (since last log): sqrt(1.355e+03)=3.681e+01
	ETA: 0h:17m:8s
2019-06-10 20:14:14 <module> (INFO): (epoch 12/30, batch 69/91, train) running loss per sample (since last log): 1.344e+03, running sqrt(MSE) (since last log): sqrt(1.344e+03)=3.666e+01
	ETA: 0h:16m:60s
2019-06-10 20:14:37 <module> (INFO): (epoch 12, val) epoch loss per sample: 1.331e+03, epoch sqrt(MSE): sqrt(1.331e+03)=3.648e+01
	ETA: 0h:16m:30s


----------


2019-06-10 20:14:48 <module> (INFO): (epoch 13/30, batch 23/91, train) running loss per sample (since last log): 1.335e+03, running sqrt(MSE) (since last log): sqrt(1.335e+03)=3.654e+01
	ETA: 0h:16m:22s
2019-06-10 20:14:58 <module> (INFO): (epoch 13/30, batch 46/91, train) running loss per sample (since last log): 1.327e+03, running sqrt(MSE) (since last log): sqrt(1.327e+03)=3.643e+01
	ETA: 0h:16m:13s
2019-06-10 20:15:09 <module> (INFO): (epoch 13/30, batch 69/91, train) running loss per sample (since last log): 1.318e+03, running sqrt(MSE) (since last log): sqrt(1.318e+03)=3.630e+01
	ETA: 0h:16m:4s
2019-06-10 20:15:32 <module> (INFO): (epoch 13, val) epoch loss per sample: 1.308e+03, epoch sqrt(MSE): sqrt(1.308e+03)=3.617e+01
	ETA: 0h:15m:34s


----------


2019-06-10 20:15:42 <module> (INFO): (epoch 14/30, batch 23/91, train) running loss per sample (since last log): 1.312e+03, running sqrt(MSE) (since last log): sqrt(1.312e+03)=3.622e+01
	ETA: 0h:15m:25s
2019-06-10 20:15:52 <module> (INFO): (epoch 14/30, batch 46/91, train) running loss per sample (since last log): 1.306e+03, running sqrt(MSE) (since last log): sqrt(1.306e+03)=3.614e+01
	ETA: 0h:15m:17s
2019-06-10 20:16:03 <module> (INFO): (epoch 14/30, batch 69/91, train) running loss per sample (since last log): 1.297e+03, running sqrt(MSE) (since last log): sqrt(1.297e+03)=3.602e+01
	ETA: 0h:15m:8s
2019-06-10 20:16:25 <module> (INFO): (epoch 14, val) epoch loss per sample: 1.287e+03, epoch sqrt(MSE): sqrt(1.287e+03)=3.588e+01
	ETA: 0h:14m:38s


----------


2019-06-10 20:16:36 <module> (INFO): (epoch 15/30, batch 23/91, train) running loss per sample (since last log): 1.290e+03, running sqrt(MSE) (since last log): sqrt(1.290e+03)=3.592e+01
	ETA: 0h:14m:29s
2019-06-10 20:16:46 <module> (INFO): (epoch 15/30, batch 46/91, train) running loss per sample (since last log): 1.288e+03, running sqrt(MSE) (since last log): sqrt(1.288e+03)=3.588e+01
	ETA: 0h:14m:20s
2019-06-10 20:16:57 <module> (INFO): (epoch 15/30, batch 69/91, train) running loss per sample (since last log): 1.279e+03, running sqrt(MSE) (since last log): sqrt(1.279e+03)=3.577e+01
	ETA: 0h:14m:11s
2019-06-10 20:17:21 <module> (INFO): (epoch 15, val) epoch loss per sample: 1.270e+03, epoch sqrt(MSE): sqrt(1.270e+03)=3.563e+01
	ETA: 0h:13m:44s


----------


2019-06-10 20:17:31 <module> (INFO): (epoch 16/30, batch 23/91, train) running loss per sample (since last log): 1.277e+03, running sqrt(MSE) (since last log): sqrt(1.277e+03)=3.573e+01
	ETA: 0h:13m:35s
2019-06-10 20:17:42 <module> (INFO): (epoch 16/30, batch 46/91, train) running loss per sample (since last log): 1.265e+03, running sqrt(MSE) (since last log): sqrt(1.265e+03)=3.557e+01
	ETA: 0h:13m:26s
2019-06-10 20:17:52 <module> (INFO): (epoch 16/30, batch 69/91, train) running loss per sample (since last log): 1.262e+03, running sqrt(MSE) (since last log): sqrt(1.262e+03)=3.552e+01
	ETA: 0h:13m:17s
2019-06-10 20:18:15 <module> (INFO): (epoch 16, val) epoch loss per sample: 1.255e+03, epoch sqrt(MSE): sqrt(1.255e+03)=3.543e+01
	ETA: 0h:12m:48s


----------


2019-06-10 20:18:26 <module> (INFO): (epoch 17/30, batch 23/91, train) running loss per sample (since last log): 1.260e+03, running sqrt(MSE) (since last log): sqrt(1.260e+03)=3.550e+01
	ETA: 0h:12m:39s
2019-06-10 20:18:36 <module> (INFO): (epoch 17/30, batch 46/91, train) running loss per sample (since last log): 1.255e+03, running sqrt(MSE) (since last log): sqrt(1.255e+03)=3.543e+01
	ETA: 0h:12m:30s
2019-06-10 20:18:47 <module> (INFO): (epoch 17/30, batch 69/91, train) running loss per sample (since last log): 1.254e+03, running sqrt(MSE) (since last log): sqrt(1.254e+03)=3.541e+01
	ETA: 0h:12m:21s
2019-06-10 20:19:10 <module> (INFO): (epoch 17, val) epoch loss per sample: 1.242e+03, epoch sqrt(MSE): sqrt(1.242e+03)=3.524e+01
	ETA: 0h:11m:53s


----------


2019-06-10 20:19:20 <module> (INFO): (epoch 18/30, batch 23/91, train) running loss per sample (since last log): 1.242e+03, running sqrt(MSE) (since last log): sqrt(1.242e+03)=3.524e+01
	ETA: 0h:11m:44s
2019-06-10 20:19:31 <module> (INFO): (epoch 18/30, batch 46/91, train) running loss per sample (since last log): 1.242e+03, running sqrt(MSE) (since last log): sqrt(1.242e+03)=3.524e+01
	ETA: 0h:11m:35s
2019-06-10 20:19:41 <module> (INFO): (epoch 18/30, batch 69/91, train) running loss per sample (since last log): 1.241e+03, running sqrt(MSE) (since last log): sqrt(1.241e+03)=3.523e+01
	ETA: 0h:11m:26s
2019-06-10 20:20:04 <module> (INFO): (epoch 18, val) epoch loss per sample: 1.231e+03, epoch sqrt(MSE): sqrt(1.231e+03)=3.509e+01
	ETA: 0h:10m:58s


----------


2019-06-10 20:20:14 <module> (INFO): (epoch 19/30, batch 23/91, train) running loss per sample (since last log): 1.230e+03, running sqrt(MSE) (since last log): sqrt(1.230e+03)=3.507e+01
	ETA: 0h:10m:49s
2019-06-10 20:20:25 <module> (INFO): (epoch 19/30, batch 46/91, train) running loss per sample (since last log): 1.232e+03, running sqrt(MSE) (since last log): sqrt(1.232e+03)=3.510e+01
	ETA: 0h:10m:39s
2019-06-10 20:20:35 <module> (INFO): (epoch 19/30, batch 69/91, train) running loss per sample (since last log): 1.229e+03, running sqrt(MSE) (since last log): sqrt(1.229e+03)=3.506e+01
	ETA: 0h:10m:30s
2019-06-10 20:20:58 <module> (INFO): (epoch 19, val) epoch loss per sample: 1.222e+03, epoch sqrt(MSE): sqrt(1.222e+03)=3.496e+01
	ETA: 0h:10m:3s


----------


2019-06-10 20:21:08 <module> (INFO): (epoch 20/30, batch 23/91, train) running loss per sample (since last log): 1.222e+03, running sqrt(MSE) (since last log): sqrt(1.222e+03)=3.496e+01
	ETA: 0h:9m:53s
2019-06-10 20:21:19 <module> (INFO): (epoch 20/30, batch 46/91, train) running loss per sample (since last log): 1.222e+03, running sqrt(MSE) (since last log): sqrt(1.222e+03)=3.495e+01
	ETA: 0h:9m:44s
2019-06-10 20:21:29 <module> (INFO): (epoch 20/30, batch 69/91, train) running loss per sample (since last log): 1.222e+03, running sqrt(MSE) (since last log): sqrt(1.222e+03)=3.495e+01
	ETA: 0h:9m:35s
2019-06-10 20:21:52 <module> (INFO): (epoch 20, val) epoch loss per sample: 1.214e+03, epoch sqrt(MSE): sqrt(1.214e+03)=3.484e+01
	ETA: 0h:9m:8s


----------


2019-06-10 20:22:02 <module> (INFO): (epoch 21/30, batch 23/91, train) running loss per sample (since last log): 1.216e+03, running sqrt(MSE) (since last log): sqrt(1.216e+03)=3.487e+01
	ETA: 0h:8m:58s
2019-06-10 20:22:13 <module> (INFO): (epoch 21/30, batch 46/91, train) running loss per sample (since last log): 1.214e+03, running sqrt(MSE) (since last log): sqrt(1.214e+03)=3.484e+01
	ETA: 0h:8m:49s
2019-06-10 20:22:23 <module> (INFO): (epoch 21/30, batch 69/91, train) running loss per sample (since last log): 1.217e+03, running sqrt(MSE) (since last log): sqrt(1.217e+03)=3.489e+01
	ETA: 0h:8m:39s
2019-06-10 20:22:46 <module> (INFO): (epoch 21, val) epoch loss per sample: 1.205e+03, epoch sqrt(MSE): sqrt(1.205e+03)=3.472e+01
	ETA: 0h:8m:13s


----------


2019-06-10 20:22:57 <module> (INFO): (epoch 22/30, batch 23/91, train) running loss per sample (since last log): 1.209e+03, running sqrt(MSE) (since last log): sqrt(1.209e+03)=3.477e+01
	ETA: 0h:8m:3s
2019-06-10 20:23:07 <module> (INFO): (epoch 22/30, batch 46/91, train) running loss per sample (since last log): 1.204e+03, running sqrt(MSE) (since last log): sqrt(1.204e+03)=3.469e+01
	ETA: 0h:7m:54s
2019-06-10 20:23:18 <module> (INFO): (epoch 22/30, batch 69/91, train) running loss per sample (since last log): 1.208e+03, running sqrt(MSE) (since last log): sqrt(1.208e+03)=3.476e+01
	ETA: 0h:7m:44s
2019-06-10 20:23:41 <module> (INFO): (epoch 22, val) epoch loss per sample: 1.199e+03, epoch sqrt(MSE): sqrt(1.199e+03)=3.463e+01
	ETA: 0h:7m:18s


----------


2019-06-10 20:23:51 <module> (INFO): (epoch 23/30, batch 23/91, train) running loss per sample (since last log): 1.204e+03, running sqrt(MSE) (since last log): sqrt(1.204e+03)=3.470e+01
	ETA: 0h:7m:8s
2019-06-10 20:24:02 <module> (INFO): (epoch 23/30, batch 46/91, train) running loss per sample (since last log): 1.199e+03, running sqrt(MSE) (since last log): sqrt(1.199e+03)=3.463e+01
	ETA: 0h:6m:59s
2019-06-10 20:24:12 <module> (INFO): (epoch 23/30, batch 69/91, train) running loss per sample (since last log): 1.197e+03, running sqrt(MSE) (since last log): sqrt(1.197e+03)=3.460e+01
	ETA: 0h:6m:49s
2019-06-10 20:24:35 <module> (INFO): (epoch 23, val) epoch loss per sample: 1.193e+03, epoch sqrt(MSE): sqrt(1.193e+03)=3.454e+01
	ETA: 0h:6m:23s


----------


2019-06-10 20:24:46 <module> (INFO): (epoch 24/30, batch 23/91, train) running loss per sample (since last log): 1.199e+03, running sqrt(MSE) (since last log): sqrt(1.199e+03)=3.463e+01
	ETA: 0h:6m:14s
2019-06-10 20:24:57 <module> (INFO): (epoch 24/30, batch 46/91, train) running loss per sample (since last log): 1.193e+03, running sqrt(MSE) (since last log): sqrt(1.193e+03)=3.454e+01
	ETA: 0h:6m:4s
2019-06-10 20:25:07 <module> (INFO): (epoch 24/30, batch 69/91, train) running loss per sample (since last log): 1.193e+03, running sqrt(MSE) (since last log): sqrt(1.193e+03)=3.454e+01
	ETA: 0h:5m:55s
2019-06-10 20:25:30 <module> (INFO): (epoch 24, val) epoch loss per sample: 1.188e+03, epoch sqrt(MSE): sqrt(1.188e+03)=3.447e+01
	ETA: 0h:5m:28s


----------


2019-06-10 20:25:40 <module> (INFO): (epoch 25/30, batch 23/91, train) running loss per sample (since last log): 1.193e+03, running sqrt(MSE) (since last log): sqrt(1.193e+03)=3.454e+01
	ETA: 0h:5m:19s
2019-06-10 20:25:51 <module> (INFO): (epoch 25/30, batch 46/91, train) running loss per sample (since last log): 1.192e+03, running sqrt(MSE) (since last log): sqrt(1.192e+03)=3.452e+01
	ETA: 0h:5m:9s
2019-06-10 20:26:01 <module> (INFO): (epoch 25/30, batch 69/91, train) running loss per sample (since last log): 1.188e+03, running sqrt(MSE) (since last log): sqrt(1.188e+03)=3.447e+01
	ETA: 0h:4m:60s
2019-06-10 20:26:24 <module> (INFO): (epoch 25, val) epoch loss per sample: 1.183e+03, epoch sqrt(MSE): sqrt(1.183e+03)=3.440e+01
	ETA: 0h:4m:33s


----------


2019-06-10 20:26:35 <module> (INFO): (epoch 26/30, batch 23/91, train) running loss per sample (since last log): 1.184e+03, running sqrt(MSE) (since last log): sqrt(1.184e+03)=3.441e+01
	ETA: 0h:4m:24s
2019-06-10 20:26:45 <module> (INFO): (epoch 26/30, batch 46/91, train) running loss per sample (since last log): 1.186e+03, running sqrt(MSE) (since last log): sqrt(1.186e+03)=3.444e+01
	ETA: 0h:4m:14s
2019-06-10 20:26:56 <module> (INFO): (epoch 26/30, batch 69/91, train) running loss per sample (since last log): 1.185e+03, running sqrt(MSE) (since last log): sqrt(1.185e+03)=3.442e+01
	ETA: 0h:4m:5s
2019-06-10 20:27:18 <module> (INFO): (epoch 26, val) epoch loss per sample: 1.180e+03, epoch sqrt(MSE): sqrt(1.180e+03)=3.434e+01
	ETA: 0h:3m:39s


----------


2019-06-10 20:27:29 <module> (INFO): (epoch 27/30, batch 23/91, train) running loss per sample (since last log): 1.181e+03, running sqrt(MSE) (since last log): sqrt(1.181e+03)=3.437e+01
	ETA: 0h:3m:29s
2019-06-10 20:27:39 <module> (INFO): (epoch 27/30, batch 46/91, train) running loss per sample (since last log): 1.187e+03, running sqrt(MSE) (since last log): sqrt(1.187e+03)=3.446e+01
	ETA: 0h:3m:20s
2019-06-10 20:27:50 <module> (INFO): (epoch 27/30, batch 69/91, train) running loss per sample (since last log): 1.175e+03, running sqrt(MSE) (since last log): sqrt(1.175e+03)=3.428e+01
	ETA: 0h:3m:10s
2019-06-10 20:28:12 <module> (INFO): (epoch 27, val) epoch loss per sample: 1.175e+03, epoch sqrt(MSE): sqrt(1.175e+03)=3.429e+01
	ETA: 0h:2m:44s


----------


2019-06-10 20:28:23 <module> (INFO): (epoch 28/30, batch 23/91, train) running loss per sample (since last log): 1.177e+03, running sqrt(MSE) (since last log): sqrt(1.177e+03)=3.431e+01
	ETA: 0h:2m:34s
2019-06-10 20:28:33 <module> (INFO): (epoch 28/30, batch 46/91, train) running loss per sample (since last log): 1.181e+03, running sqrt(MSE) (since last log): sqrt(1.181e+03)=3.437e+01
	ETA: 0h:2m:25s
2019-06-10 20:28:43 <module> (INFO): (epoch 28/30, batch 69/91, train) running loss per sample (since last log): 1.177e+03, running sqrt(MSE) (since last log): sqrt(1.177e+03)=3.431e+01
	ETA: 0h:2m:15s
2019-06-10 20:29:06 <module> (INFO): (epoch 28, val) epoch loss per sample: 1.172e+03, epoch sqrt(MSE): sqrt(1.172e+03)=3.423e+01
	ETA: 0h:1m:49s


----------


2019-06-10 20:29:17 <module> (INFO): (epoch 29/30, batch 23/91, train) running loss per sample (since last log): 1.174e+03, running sqrt(MSE) (since last log): sqrt(1.174e+03)=3.427e+01
	ETA: 0h:1m:40s
2019-06-10 20:29:27 <module> (INFO): (epoch 29/30, batch 46/91, train) running loss per sample (since last log): 1.173e+03, running sqrt(MSE) (since last log): sqrt(1.173e+03)=3.425e+01
	ETA: 0h:1m:30s
2019-06-10 20:29:38 <module> (INFO): (epoch 29/30, batch 69/91, train) running loss per sample (since last log): 1.172e+03, running sqrt(MSE) (since last log): sqrt(1.172e+03)=3.424e+01
	ETA: 0h:1m:20s
2019-06-10 20:30:01 <module> (INFO): (epoch 29, val) epoch loss per sample: 1.169e+03, epoch sqrt(MSE): sqrt(1.169e+03)=3.420e+01
	ETA: 0h:0m:55s


----------


2019-06-10 20:30:11 <module> (INFO): (epoch 30/30, batch 23/91, train) running loss per sample (since last log): 1.172e+03, running sqrt(MSE) (since last log): sqrt(1.172e+03)=3.424e+01
	ETA: 0h:0m:45s
2019-06-10 20:30:21 <module> (INFO): (epoch 30/30, batch 46/91, train) running loss per sample (since last log): 1.171e+03, running sqrt(MSE) (since last log): sqrt(1.171e+03)=3.422e+01
	ETA: 0h:0m:35s
2019-06-10 20:30:32 <module> (INFO): (epoch 30/30, batch 69/91, train) running loss per sample (since last log): 1.173e+03, running sqrt(MSE) (since last log): sqrt(1.173e+03)=3.425e+01
	ETA: 0h:0m:26s
2019-06-10 20:30:54 <module> (INFO): (epoch 30, val) epoch loss per sample: 1.166e+03, epoch sqrt(MSE): sqrt(1.166e+03)=3.415e+01
	ETA: 0h:0m:0s


----------


2019-06-10 20:30:54 <module> (INFO): finished training 30 epochs in 27m:17.2s, trainable/total weigths: 6766/6766
2019-06-10 20:30:54 <module> (INFO): loaded weights of best model according to 'min val epoch MSE' criterion: best value 1166.347 achieved in epoch 30


<IPython.core.display.Javascript object>

In [11]:
#%% post-training model evaluation
"""the validation class_metrics_df measured here in the model evaluation must be identical to those measured during the
    last/best epoch, UNLIKE the training metrics - since the train phase metrics measured during training were being 
    measured while the weights were being updated in batches (!), not after the train phase epoch completed (which 
    would require another iteration on the train dataloader to measure metrics, as is done here without training)
"""
logger.info('started model evaluation')
model.eval() # set model to evaluate mode
for phase in ['train','val']:
    epoch_loss=0.0 # must be a float
    epoch_squared_error=0.0
    
    for i_batch,batch in enumerate(dataloaders[phase]):
        input_images=batch['image array'].to(device)
                    
        # forward
        with torch.set_grad_enabled(False): # if phase=='train' it tracks tensor history for grad calc
            output_images=model(input_images)
            loss=loss_fn(output_images,input_images)
        
        # accumulating stats
        samples_number=len(input_images)            
        current_loss=loss.item()*sample_pixels_all_channels # the loss is averaged across samples and pixels in each minibatch, so it is multiplied to return to a total
        epoch_loss+=current_loss
        
        with torch.set_grad_enabled(False):
            batch_squared_error=((output_images-input_images)**2).sum()
            batch_squared_error=batch_squared_error.item()
        epoch_squared_error+=batch_squared_error
        
    # epoch stats
    epoch_loss_per_sample=epoch_loss/dataset_samples_number[phase]
    epoch_MSE=epoch_squared_error/dataset_samples_number[phase]
    
    logger.info('(post-training, %s) loss per sample: %.3e, sqrt(MSE): sqrt(%.3e)=%.3e'%(
                    phase,epoch_loss_per_sample,epoch_MSE,epoch_MSE**0.5))
          
logger.info('completed model evaluation')

2019-06-10 20:30:54 <module> (INFO): started model evaluation
2019-06-10 20:31:24 <module> (INFO): (post-training, train) loss per sample: 1.830e+01, sqrt(MSE): sqrt(1.169e+03)=3.419e+01
2019-06-10 20:31:37 <module> (INFO): (post-training, val) loss per sample: 1.825e+01, sqrt(MSE): sqrt(1.166e+03)=3.415e+01
2019-06-10 20:31:37 <module> (INFO): completed model evaluation


In [12]:
#%% inspecting auto-encoding by plotting (building the batches from samples)
samples_to_plot=20
#sampling_for_sample_verification='none' # plotting first samples_to_plot samples from datasets (train, val)
sampling_for_sample_verification='random' # plotting randomly selected samples_to_plot samples from datasets (train, val), using seed_for_sample_verification seed
seed_for_sample_verification=0
images_per_row=4
# end of inputs ---------------------------------------------------------------

# concatenating the samples to inspect into batches
image_batches={}
for phase in ['train','val']:
    if sampling_for_sample_verification=='none':
        sample_indices_to_plot=range(samples_to_plot)
    elif sampling_for_sample_verification=='random':
        random.seed(seed_for_sample_verification)
        sample_indices_to_plot=random.sample(range(len(datasets[phase])),samples_to_plot)
    else:
        raise RuntimeError('unsupported sampling_for_sample_verification input!')
    
    image_tensors_list=[]
    for i,i_sample in enumerate(sample_indices_to_plot):
        image_array=datasets[phase][i_sample]['image array'].unsqueeze(0)
        image_tensors_list.append(image_array)
    image_batches.update({phase:torch.cat(image_tensors_list,0)})

# applying the model, plotting results
model.eval() # set model to evaluate mode
for phase in ['train','val']:
    input_images=image_batches[phase].to(device)
    with torch.set_grad_enabled(False):
        output_images=model(input_images)
    
    # see https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
    input_images_grid=np.transpose(vutils.make_grid(
            input_images,nrow=images_per_row,padding=5,normalize=True).cpu(),(1,2,0))
    output_images_grid=np.transpose(vutils.make_grid(
            output_images,nrow=images_per_row,padding=5,normalize=True).cpu(),(1,2,0))
    
    plt.figure()
    plt.subplot(1,2,1)
    plt.axis('off')
    plt.title('original images')
    plt.imshow(input_images_grid)
    
    plt.subplot(1,2,2)
    plt.axis('off')
    plt.title('reconstructed images')
    plt.imshow(output_images_grid)

    plt.suptitle('%s batch'%phase)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [13]:
#%% saving the model
if offer_mode_saving and train_model_else_load_weights:
    try: os.mkdir(models_folder_path)
    except FileExistsError: pass # if the folder exists already - do nothing
    
    saving_decision=input('save model weights? [y]/n ')
    if saving_decision!='n':
        ui_model_name=input('name model weights file: ')
        model_weights_file_path=os.path.join(models_folder_path,ui_model_name+'.ptweights')
        if os.path.isfile(model_weights_file_path):
            alternative_filename=input('%s already exists, give a different file name to save, the same file name to over-write, or hit enter to abort: '%model_weights_file_path)
            if alternative_filename=='':
                raise RuntimeError('aborted by user')
            else:
                model_weights_file_path=os.path.join(models_folder_path,alternative_filename+'.ptweights')
        torch.save(model.state_dict(),model_weights_file_path)       
        logger.info('%s saved'%model_weights_file_path)

In [14]:
#%%
logger.info('script completed')

2019-06-10 20:31:37 <module> (INFO): script completed
