<a href="https://colab.research.google.com/github/aachen6/deepTC/blob/master/colab/deepTC_net_image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DeepTC - Post-binding Architecture

The objective of *deepTC* can be found on [deepTC github page](https://github.com/aachen6/deepTC/), and the analysis is outlined below.
1. Data Preprocess
 - 1.1 [Satellite images and tracks of TC](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_images_tracks_sync.ipynb)
 - 1.2 [Statistics of satellite images and tracks](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_images_tracks_stats.ipynb)

2. Model for TC Image
 - **2.1 [Post-binding architecture of TC image](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_net_image.ipynb)**
 - 2.2 [CNN model for TC image classification ](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_classification_cnn5.ipynb)
 - 2.3 [Resnet model for TC image classification](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_classification_resnet.ipynb)
 - 2.4 [Resnet model for TC image intensity estimation](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_intensity_resnet.ipynb)
 - 2.5 [Operation of TC image prediction](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_image_prediction.ipynb)

3. Model for TC Track
 - 3.1 [Post-binding architecture of TC track](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_net_track.ipynb)
 - 3.2 [LSTM model for TC track prediction](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_track_lstm.ipynb)
 - 3.3 [LSTM model with attension for TC track prediction](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_track_lstm.ipynb) 
 - 3.4 [LSTM-CNN model for TC track prediction](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_track_lstmcnn.ipynb)

4. Generative Model for TC Image
 - 4.1 [DCGAN model for deepTC](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_generative_dcgan.ipynb)
 - 4.2 [SAGAN model for deepTC](https://github.com/aachen6/deepTC/blob/master/colab/deepTC_generative_sagan.ipynb)
 
Now, the best track and satellite image dataset of the historical TC are ready from the first two notebooks. This notebook will cover architecture of *deepTC*, which is based on *pytorch*. To explore different deep neural network architectures efficiently,  *deepTC* features post-binding deep neutral network architecture from a configuration file without the need of code revision. Let's start with importing the necessary python modules, particularly installing pytorch with GPU support on *Google Colab*.

In [0]:
# basics
import os
import yaml
import pickle
import numpy as np
import pandas as pd
from copy import deepcopy

# intall an early verison of pillow, as the latest pillow cause an error 
# when loading images from zip file
!pip install pillow==4.1.1

# handling the images
from PIL import Image
from zipfile import ZipFile

# install and load pytorch
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

import torch
torch.backends.cudnn.enabled = False  # cudnn doesn't seem to work

##Post-Binding Deep Neutral Network

It's beneficial to test different architectures of deep netural network. To improve the efficiency of the process, let's decouple the model construction from the code by post-binding the deep neutral network from a configuration file. A *YAML* configuration file is used to define the architecture of the deep netural network. Two classes are created to construct the deep neutral network based on the *YAML* configuration file, i.e. a static class mapping method string names to pytorch methods or class instances and a pytorch module subclass to generate the model instance. Currently, sequential model with an extension to have residual block/net is implemented. It is straightforward to extend the idea to include more complex deep neutral network architectures. 

The first class *PyTorchCall* is simply a static class that maps pytorch methods based on their string names with the corresponding arguements. Only the necessary methods for the current application are included at this moment. The implementation is very straightforward utilizing *python getattr* method. 

~~~python
class PyTorchCall:
     def map_torch_call(func_str):
          return getattr(PytorchCall, '_' + func_str)
     # any pytorch calls to be added below
~~~

In [0]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class PyTorchCall:

    # will update to unroll function variables using **kwarg
  
    @staticmethod
    def map_torch_call(func_str): return getattr(PyTorchCall, '_' + func_str)

    # pytorch nn calls
    @staticmethod 
    def _linear(args): return nn.Linear(*args['args'], **args['kwargs'])
    @staticmethod
    def _dropout(args): return nn.Dropout(*args['args'], **args['kwargs'])
    @staticmethod
    def _conv2d(args):
        return nn.Conv2d(*args['args'], **args['kwargs'])
    @staticmethod
    def _deconv2d(args):
        return nn.ConvTranspose2d(*args['args'], **args['kwargs'])
    @staticmethod
    def _upsample(args):
        return nn.Upsample(*args['args'], **args['kwargs'])
    @staticmethod
    def _batchnorm2d(args):
        return nn.BatchNorm2d(*args['args'], **args['kwargs'])
    @staticmethod     
    def _avgpool2d(args):
        return nn.AvgPool2d(*args['args'], **args['kwargs'])
    @staticmethod    
    def _maxpool2d(args):
        return nn.MaxPool2d(*args['args'], **args['kwargs'])
    @staticmethod
    def _lstm(args):
        return nn.LSTM(*args['args'], **args['kwargs'])
 
    @staticmethod
    def _relu(args):
        return nn.ReLU(*args['args'], **args['kwargs'])
    @staticmethod
    def _tanh(args):
        return nn.Tanh(*args['args'], **args['kwargs'])
    @staticmethod
    def _sigmoid(args):
        return nn.Sigmoid(*args['args'], **args['kwargs'])
    @staticmethod 
    def _leakyrelu(args):
        return nn.LeakyReLU(*args['args'], **args['kwargs'])

    # pytorch functional
    #@staticmethod
    #def _relu(args): return F.relu
    @staticmethod 
    def _batchnorm(args): return F.batch_norm
    @staticmethod
    def _softmax(args): return F.softmax

    @staticmethod
    def _view1d(args): return  
    @staticmethod
    def _view2d(args): return
    
    @staticmethod
    def _pad_packed(args): return
    @staticmethod
    def _pack_padded(args): return
      
    # pytorch loss
    @staticmethod
    def _l1loss(): return nn.L1Loss()
    @staticmethod
    def _mseloss(): return nn.MSELoss() 
    @staticmethod
    def _bceloss(): return nn.BCELoss()
    @staticmethod
    def _crossentropy(): return nn.CrossEntropyLoss() 
		
    # pytorch optimiter
    @staticmethod
    def _sgd(): return optim.SGD
    @staticmethod
    def _adam(): return optim.Adam
		

The second class *YML2ModelNet* inherits pytoch *nn.Module* that is designed to generate pytorch model instance based on the *YAML* configuration file. The configuration file defines each layer of deep neutral network according to pytorch *nn.Module*. An example of two-layer convolution network is shown below,

~~~yaml
model: 
  cnn2:
   - layer1-sequential: # groupd into sequential but can be expanded out into each layer
      - conv2d:
          args: [1, 32, 3]
          kwargs: [padding: 1, stride: 1]
      - maxpool2d:
          args: [2]
          kwargs: [padding: 0, stride: 2]
      - relue:
          args: []
          kwargs: {}
  - layer2-sequential:
      - conv2d:
          args: [32, 32, 3]
          kwargs: [padding: 1, stride: 1]
      - maxpool2d:
          args: [2]
          kwargs: [padding: 0, stride: 2]
      - relue:
          args: []
          kwargs: {}        
  - layer3-view:
      args: []
      kwargs: {}
  - layer4-linear:
      args: [8192, 10]
      kwargs: {}
~~~

The pytorch module subclass should inherit *nn.Mudule* and define the following methods:

~~~python
__init__ : # define layer as class variable
__forward__: # forward loop for the network 
~~~
The *init* method construct each layer as its class variable, which are eventually used in the forward method. *View* method is included as a layer which is handled seperately. Alternatively, this method can be wrapped into a separate pytorch *nn.Module* subclass. Separate classes are also created inheriting *nn.Module* for special or non-sequential network components, which are used as building blocks in the model. Currently, such special classes include an extension of RNN layer and a residual block for residual network. It is straightforward to extend the idea to generate more complex deep neutral network architectures, such as embedding, attension etc., which will be covered in later notebooks.

In [0]:
class YML2ModelNet(torch.nn.Module):
    
    def __init__(self, config, model_name):
      
        super(YML2ModelNet, self).__init__()
      
        self.layers = []
        cfg_model = config['models'][model_name] 
        
        for lyr_cfg in cfg_model:
          
            # get current layer name, type, and arguments
            lyr_key = list(lyr_cfg.keys())[0]
            [lyr_name, lyr_type] = lyr_key.split('-')
            
            # get layer argument in place or outside model scope            
            lyr_args = lyr_cfg[lyr_key]
            if lyr_args=='None': lyr_args = config[lyr_key]            
                
            if lyr_type in ['rnn','gru', 'lstm']:
                n_hidden = lyr_args['args'][1] 
                n_layers = lyr_args['kwargs']['num_layers']
                lyr_ts = self.yml2lyr(lyr_type, lyr_args)
                lyr_inst = RnnTS(lyr_type, n_hidden, n_layers, lyr_ts)
              
            elif lyr_type=='resblock':
                key_args = {}
                for _key, _args in lyr_args.items():
                    [_name, _type] = _key.split('-')
                    key_args[_name] = [_type, _args]
                residual = self.yml2lyr(*key_args['residual'])
                identity = self.yml2lyr(*key_args['identity'])
                activate = self.yml2lyr(*key_args['activate'])
                lyr_inst = ResBlock(residual, identity, activate)
              
            else:  # nn module layer
                lyr_inst = self.yml2lyr(lyr_type, lyr_args)
              
            # register layer to the class 
            setattr(self, lyr_name, lyr_inst)
            lyr_ref = getattr(self, lyr_name)
            self.layers.append([lyr_type, lyr_args, lyr_ref]) 
                       
            
    def yml2lyr(self, lyr_type, lyr_args):
            
        if lyr_type=='sequential':
            modules = []
            for row in lyr_args:
                r_func = list(row.keys())[0]
                r_args = row[r_func]                  
                r_module = PyTorchCall.map_torch_call(r_func)(r_args)
                modules.append(r_module)
            lyr_obj = nn.Sequential(*modules)
            
        else: # individual nn.module
            lyr_obj = PyTorchCall.map_torch_call(lyr_type)(lyr_args) 
       
        return lyr_obj             
              

    def forward(self, x, seq_mask=None):      
        # implementation of Module forward method
        for i, (lyr_type, ly_arg, lyr_ref) in enumerate(self.layers):

            if lyr_type=='view1d': 
                n = int(np.prod(x.size()[1:]))
                x = x.view(-1, n)
              
            elif lyr_type=='view2d':
                kwargs = ly_arg['kwargs']
                x = x.view(x.size()[0], kwargs['channel'], kwargs['size'], kwargs['size'])
                
            elif lyr_type in ['rnn', 'gru', 'lstm']:
                x = lyr_ref(x, seq_mask)
                
            else: # nn module and subclass
                x = lyr_ref(x)
                
        return x              
            

      
class RnnTS(torch.nn.Module):              
    def __init__(self, lyr_type, n_hidden, n_layer, lyr_obj):
      
        super(RnnTS, self).__init__()
        
        self.layer = lyr_obj
        self.lyr_type = lyr_type
        self.n_hidden = n_hidden
        self.n_layer  = n_layer
              
    def init_hidden(self, batch_size, device):
      
        if self.lyr_type == 'lstm':  # currently, only LSTM model is implemented
            self.h = torch.zeros(self.n_layer, batch_size, self.n_hidden)
            self.c = torch.zeros(self.n_layer, batch_size, self.n_hidden)
            
        self.h = self.h.to(device)
        self.c = self.c.to(device)
        
        return
      
    def forward(self, x, seq_mask=None):
      
        batch_size = x.size()[0]
        device = x.get_device() if x.is_cuda else 'cpu'  
    
        if seq_mask is not None: # mask for sequences with different length       
            x = nn.utils.rnn.pack_padded_sequence(x, seq_mask, batch_first=True)   
       
        self.init_hidden(batch_size, device)
        x, (self.h, self.c) = self.layer(x, (self.h, self.c))
        
        if seq_mask is not None:
            x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)                

        return x          
      
      
class ResBlock(torch.nn.Module):
   
    def __init__(self, residual, identity, activate):
      
        super(ResBlock, self).__init__()
        
        self.residual = residual  # sequential residual mapping
        self.identity = identity  # identify mapping of input
        self.activate = activate  # activation
        
    def forward(self, x):
 
        residual = self.residual(x)
        x = self.identity(x)
        x += residual
        x = self.activate(x)
                
        return x            
    

## Storm Dataset Class

A lot of effort in solving any machine learning problem goes into data preparation. To simplify the process, the dataset shall be a *Pandas DataFrame*, and the column name for input and target variable(s) shall be specified in the configuration file. Any image data shall be specified with the file path. *ImageDataSet* is created that inherits pytorch *Dataset* class and handles data preparation before feeding into the training. According to pytorch document, the custom dataset class should inherit *Dataset* and override the following methods:

~~~python
__len__ : # so that len(dataset) returns the size of the dataset.
__getitem__: # to support the indexing such that dataset can be used to get ith sample
~~~

The *getitem* method returns training input and target, and the corresponding index, which can be used later during the post-processing, e.g. to idenfity mis-classified images from confusion matrix. A data split method is also implemented to split the data into training, validation, and test set. As shown earlier, the number of samples for each class are not well-balanced, therefore, an option with label aware is included, which can be enabled to preserve the ratio of sample size of each class within training, validation, and testing sets. 




In [0]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch.utils.data.sampler import SubsetRandomSampler

class ImageDataSet(Dataset):
    
    def __init__(self, config, transform=None, cats=None, hotstart=False):
        
        self.model_object = config['model_object']
        self.batch_size   = config['batch_size']
        self.col_input    = config['col_input']
        self.col_target   = config['col_target']
        self.n_workers    = config['n_workers']
        
        # read input files
        f_storm_msg = config['f_storm_msg']
        self.pd_storm = pd.read_msgpack(f_storm_msg)
        
        # load image archive
        if self.col_input=='image':
            f_image_zip = config['f_image_zip']
            self.img_archv = ZipFile(f_image_zip, 'r')
            # define transformation    
            self.transform = transform
        
        # train valid test split
        f_data_yml = config['f_data_yml']
        if hotstart: # read from file
            with open(f_data_yml, 'rb') as fp: 
                data = yaml.load(fp)
                self.data_indices = data['indices']
                if 'class' in self.model_object:
                    self.one_hot_key = data['one_hot_key'] 
                    self.one_hot_rev = data['one_hot_rev'] 
                    
        else: # create a new split
            if 'class' in self.model_object:
                # define one hot key map
                self.one_hot_key = {}
                self.one_hot_rev = {}
                if cats is None:
                    cats = set(self.pd_storm[self.col_target].tolist())
                    cats = sorted(list(cats))
                for i, cat in enumerate(sorted(cats)):
                    self.one_hot_key[cat] = i
                    self.one_hot_rev[i]   = cat
            
            # train validation split
            self.data_indices = self.train_valid_test(config)

            data = {} # save dataset
            data['indices'] = self.data_indices  
            if 'class' in self.model_object:
                data['one_hot_key'] = self.one_hot_key
                data['one_hot_rev'] = self.one_hot_rev            
                      
            with open(f_data_yml, 'w') as fp: 
                yaml.dump(data, fp)
            
        # summary of dataset
        valid_pct = 1.- config['valid_pct']
        test_pct  = 1. - config['test_pct']
        n_train   = len(self.data_indices['train'])
        n_valid   = len(self.data_indices['valid'])
        n_test    = len(self.data_indices['test'])
        batch     = self.batch_size
        
        divider = '-' * 36
        header  = '{:<10s}{:>10s}{:>10s}{:>10s}'
        record1 = '{:<10s}{:>10.2f}{:>10.2f}{:>10.2f}'
        record2 = '{:<10s}{:>10d}{:>10d}{:>10d}'

        print (divider)
        print ('summary of dataset')
        print (divider)
        print (header.format(' ', 'train', 'valid', 'test')) 
        print (record1.format('percent', test_pct*valid_pct, test_pct*(1-valid_pct), 1-test_pct))
        print (record2.format('size', n_train, n_valid, n_test))
        print (record2.format('batch', int(n_train/batch), int(n_valid/batch), int(n_test/batch)))

        return            
        
        
    def __len__(self):
        return self.pd_storm[0].count()
            

    def __getitem__(self, idx):
 
        col_input  = self.col_input
        col_target = self.col_target
        row = self.pd_storm.iloc[idx]

        # input
        if self.col_input=='image':
            image = row[col_input]
            temp = image.split('.')[0].split('_')
            f_image = temp[0] + '_' + temp[1] + '.jpg'
            sample = Image.open(self.img_archv.open(f_image))
            if self.transform is not None: 
                sample = self.transform(sample)
        else:
            sample = torch.FloatTensor(row[col_input].values)
          
        # target
        if 'class' in self.model_object:
            cat = row[col_target]          
            target = self.one_hot_key[cat]   # no need to create one hot for pytorch
        
        if self.model_object=='regression':
            if isinstance(col_target, list):
                target = torch.FloatTensor(row[col_target].values)
            else:
                target = torch.FloatTensor([row[col_target]])
                
        return idx, sample, target 
		
    
    def random_split(self, indices, pct, label_aware=None, shuffle=True, seed=64):
      
        # creating data indices for two splits:
        indices_1 = []  # first half of the indices
        indices_2 = []  # second half of the indices
        if label_aware is not None:
            pd_sub = self.pd_storm.iloc[indices]
            cats = set(self.pd_storm[label_aware].tolist())
            for cat in cats:
                sub_indices = pd_sub[pd_sub[label_aware]==cat].index.tolist()
                if shuffle: 
                    np.random.seed(seed)
                    np.random.shuffle(sub_indices)
                isplit = int(np.floor(len(sub_indices)*pct))            
                indices_1 = indices_1 + sub_indices[:isplit]
                indices_2 = indices_2 + sub_indices[isplit:]
        else:
            if shuffle: 
                np.random.seed(seed)
                np.random.shuffle(indices)
            isplit = int(np.floor(len(indices)*pct))            
            indices_1 = indices_1 + indices[:isplit]
            indices_2 = indices_2 + indices[isplit:]
                
        return indices_1,  indices_2
                       
        
    def train_valid_test(self, config):
        
        data_indices = {} 
        valid_pct = 1. - config['valid_pct']
        test_pct  = 1. - config['test_pct']
        label_aware = config['label_aware']
        shuffle = config['shuffle'] 
        seed = config['seed'] #+100
        
        indices = self.pd_storm.index.tolist()
        
        if config['test_pct'] is None:
            test_indices = None
            train_indices, valid_indices = self.random_split(indices, valid_pct, label_aware, shuffle, seed)
        else:
            _indices, test_indices = self.random_split(indices, test_pct, label_aware, shuffle)
            train_indices, valid_indices = self.random_split(_indices, valid_pct, label_aware, shuffle, seed) 
            
        data_indices['train'] = train_indices
        data_indices['valid'] = valid_indices
        data_indices['test'] = test_indices
            
        return data_indices

      
    def load_data(self):
      
        data_split = {} 
        
        batch_size = self.batch_size
        n_workers = self.n_workers
       
        train_indices = self.data_indices['train']
        valid_indices = self.data_indices['valid']
        test_indices  = self.data_indices['test']
            
        train_sampler = SubsetRandomSampler(train_indices)
        valid_sampler = SubsetRandomSampler(valid_indices)

        data_split['train'] = DataLoader(self, batch_size=batch_size, sampler=train_sampler, num_workers=n_workers)
        data_split['valid'] = DataLoader(self, batch_size=batch_size, sampler=valid_sampler, num_workers=n_workers)        
        
        if test_indices is None:
            data_split['test'] = None 
        else:
            test_sampler = SubsetRandomSampler(test_indices)  
            data_split['test'] = DataLoader(self, batch_size=batch_size, sampler=test_sampler, num_workers=n_workers)
        
        return data_split
      
     
    def normalization_factor(self, sample_a, sample_b):
      
        (n_a, mean_a, std_a) = sample_a
        (n_b, mean_b, std_b) = sample_b
  
        n_c = n_a + n_b
        mean_c = n_a*mean_a + n_b*mean_b
        mean_c = mean_c/n_c
  
        numerator = (n_a-1)*std_a**2. + (n_b-1)*std_b**2. + \
                    n_a*(mean_a-mean_c)**2. + n_b*(mean_b-mean_c)**2.
  
        denorminator = n_c - 1
  
        std_c = np.sqrt(numerator/denorminator)
  
        return np.array([n_c, mean_c, std_c])

##Image Training Class

This is where things start to get interesting. The trainer class links everything together and perform training to optimize the network based on loss objective. During the initialization of the training instance,  parameters like number of epoch, batch size, loss funcation etc. are passed from the configuration file. Some state parameters are also initialized to document the training state for assessment of model training/performance, such as training_batch_loss, training_batch_accuracy etc. The model net is passed to the training class, and the loss function and optimizer are initialized based on the configuration file. Finally, we simply have to loop through our data iterator, and feed the inputs to the network and optimize. 
~~~python
for i_epoch in range(self.max_epochs):
    for i_batch, (_, images, labels) in enumerate(data['train']): 
        self.optimizer.zero_grad()  # set the gradient to zero 
        predicts = self.model(images)  # make prediction
        loss = self.criterion(predicts, labels)  # calculate loss 
        loss.backward() # backpropagation to get the weight update
        self.optimizer.step() # update weight using the optimizer
~~~
During the training, we follow the approach similar to [this tutorial](https://github.com/GokuMohandas/practicalAI/blob/master/notebooks/11_Convolutional_Neural_N`etworks.ipynb) by Goku Mohandas to calculate the running epoch loss and accuracy. 
~~~python
batch_accu = self.accuracy(predicts, labels)
epoch_accu += (batch_accu - epoch_accu) / (i_batch + 1)
~~~              
Depending on the number of parameters, it may take a long time to run. It would be cost-efficient to detect early stopping if the proposed architecture does work well for the problem.  In order to check the process and performance of the training while it's running, two methods are implemented, i.e. a method to show progress bar for each epoch and a method for dynamic visualization of batch and running epoch loss and accuracy. Those methods embed html in the notebook that will be refreshing dynamically during the training.

~~~python
def html_loss_plot(self, image):
    return  HTML("<img src='{0}'/>".format(image))
  
def html_progress(self, var, value, max=100):
    return HTML("""{var}: <progress value='{value}' max='{max}', style='width: 80%'>{value}
                            </progress>""".format(var=var, value=value, max=max))
 ~~~
During the training, the state parameters will be saved into a file for post-process and/or hotstart the training. The model dict state will be saved as well using *torch.save* method.



In [0]:
import io
import base64
import matplotlib.pyplot as plt
from IPython.display import HTML
from IPython.display import display

class ImageTrainer(object):
    def __init__(self, params, model, hotstart=False):
        # CUDA for PyTorch
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
        if self.device!='cpu':
            divider = '-' * 36
            print(divider)
            print('summary of GPU')
            print(divider)
            print(torch.cuda.get_device_name(0))
            print('Memory Usage:')
            print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
            print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')
        else:
            print('training with cpu')   
 
        # model objective
        self.model_object = params['model_object']
        self.col_target = params['col_target']

        # hyper params
        self.max_epochs = params['max_epochs']
        self.batch_size = params['batch_size']
        self.r_learning = params['r_learning']
        self.loss_func  = params['loss_func']
        self.optim_func = params['optimizer']
 
        # path for output
        self.f_state    = params['f_state_yml']
        self.f_model    = params['f_model_pth']
        self.f_test     = params['f_test_yml']
        
        # training state 
        self.state = {'stop_early':   False,
                      'stop_criteria': 99.9,
                      'stop_step':    0,
                      'epoch_index':  0,
                      'best_epoch':   -1,
                      'best_accu' :   -1,
                      'test_loss':    -1,
                      'test_accu':    -1,
                      'train_epoch_loss': [],
                      'train_epoch_accu': [],
                      'train_batch_loss': [],
                      'train_batch_accu': [],
                      'valid_epoch_loss': [],
                      'valid_epoch_accu': [],
                      'valid_batch_loss': [],
                      'valid_batch_accu': []}

        # model
        self.model = model.to(self.device)
        # loss
        self.criterion = PyTorchCall.map_torch_call(self.loss_func)()
        # optimizer
        self.optimizer = PyTorchCall.map_torch_call(self.optim_func)()
        self.optimizer = self.optimizer(model.parameters(), lr=self.r_learning)
        
        if hotstart: # hotstart 
            model_state = torch.load(params['f_model_pth'])
            self.model.load_state_dict(model_state['model_state_dict'])
            for key, value in self.state.items():
                self.state[key] = model_state[key]
    

    def train_loop(self, data):
        
        divider = '-' * 36
        print (divider)
        print ('training')
        print (divider)
        
        loss_plot = display(self.html_loss_plot('PLOT'), display_id=True)
        
        # loop over epochs
        for i_epoch in range(self.max_epochs):
            
            self.state['epoch_index'] = i_epoch
            
            # training
            epoch_loss = 0.
            epoch_accu = 0.
            self.model.train()
            bar_train = display(self.html_progress('Train', 0, 100), display_id=True)
            for i_batch, (_, inputs, targets) in enumerate(data['train']):

                # transfer to GPU
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                # model computations
                self.optimizer.zero_grad()
                predicts = self.model(inputs)
 
                loss = self.criterion(predicts, targets)
                batch_loss = loss.item()
                epoch_loss += (batch_loss - epoch_loss) / (i_batch + 1)
                
                loss.backward()
                self.optimizer.step()
                
                batch_accu = self.accuracy(predicts, targets)
                epoch_accu += (batch_accu - epoch_accu) / (i_batch + 1)
                
                self.state['train_batch_loss'].append(batch_loss)
                self.state['train_batch_accu'].append(batch_accu)

                pct_done = (i_batch+1)/len(data['train'])*100
                bar_train.update(self.html_progress('Train', pct_done, 100))

            self.state['train_epoch_loss'].append(epoch_loss)
            self.state['train_epoch_accu'].append(epoch_accu)
  
            # validation
            epoch_loss = 0.
            epoch_accu = 0.
            self.model.eval()
            bar_valid = display(self.html_progress('Valid', 0, 100), display_id=True)
            for i_batch, (_, inputs, targets) in enumerate(data['valid']):
                # transfer to GPU
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                # model computations
                predicts = self.model(inputs)
 
                loss = self.criterion(predicts, targets)
                batch_loss = loss.item()
                epoch_loss += (batch_loss - epoch_loss) / (i_batch + 1)
                
                batch_accu = self.accuracy(predicts, targets)
                epoch_accu += (batch_accu - epoch_accu) / (i_batch + 1)
                
                self.state['valid_batch_loss'].append(batch_loss)
                self.state['valid_batch_accu'].append(batch_accu)
                
                pct_done = (i_batch+1)/len(data['valid'])*100
                bar_valid.update(self.html_progress('Valid', pct_done, 100))                
                
            self.state['valid_epoch_loss'].append(epoch_loss)
            self.state['valid_epoch_accu'].append(epoch_accu)
                        
            # epoch summary
            header   = '{:<12s}{:>10s}{:>10s}'
            n_target = len(self.col_target) if isinstance(self.col_target, list) else 1
            record1  = '{:<12s}' + "{:>10.3f}"*2
            record2  = '{:<12s}' + "{:>10.3f}"*n_target
            
            if i_epoch%1==0:
                print (divider)
                print ('summary of epoch:', i_epoch)
                print (divider)
                print (header.format('loss - ', 'train', 'valid')) 
                print (record1.format(' ', self.state['train_epoch_loss'][-1], self.state['valid_epoch_loss'][-1]))
                print ('accuracy/error - ')
                if isinstance(self.col_target, list):
                    row = self.state['train_epoch_accu'][-1]
                    if isinstance(row, float): row = [row]
                    print (record2.format('train', *row))
                    row = self.state['valid_epoch_accu'][-1]
                    if isinstance(row, float): row = [row]
                    print (record2.format('valid', *row))
                else:
                    print (record1.format(' ', self.state['train_epoch_accu'][-1], self.state['valid_epoch_accu'][-1]))                
                
                uri = self.update_loss_plot()
                loss_plot.update(self.html_loss_plot(uri))
                print (' ')
            
            self.update_save_state()
            if self.state['stop_early']: break
             
        
    def test_loop(self, data, apply_softmax=False):
        total = 0
        correct = 0.
        self.model.eval()
        with torch.no_grad():
            for i_batch, (idxs, inputs, targets) in enumerate(data['test']):
              
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                predicts = self.model(inputs)
                if apply_softmax: predicts  = F.softmax(predicts, 1) 

                total += int(targets.size(0))
                batch_accu = self.accuracy(predicts, targets)
                correct += (batch_accu - correct) / (i_batch + 1)   

                if 'class' in self.model_object:
                    _, predicts = torch.max(predicts.data, 1)
                    
                if i_batch==0: 
                    test_idxs = idxs.data
                    test_targets = targets.data
                    test_predicts = predicts.data
                else:
                    test_idxs = torch.cat((test_idxs, idxs.data))
                    test_targets = torch.cat((test_targets, targets.data))
                    test_predicts = torch.cat((test_predicts, predicts.data))
                        
        test_idxs = test_idxs.cpu().detach().numpy()
        test_targets = test_targets.cpu().detach().numpy()
        test_predicts = test_predicts.cpu().detach().numpy()
       
        # test summary
        divider = '-' * 36
        print(divider)
        print('summary of test')
        print(divider)
        print('{:<10s}{:>10s}{:>10s}'.format('', 'total', 'accuracy'))
        print('{:<10s}{:>10d}{:>10.3f}'.format('test', total, correct))    
        
        # save test
        test_results = {}
        test_results['idxs'] = test_idxs
        test_results['labels'] = test_targets
        test_results['predicts'] = test_predicts
        test_results['accuracy'] = correct
        
        with open(self.f_test, 'w') as fp:
            yaml.dump(test_results, fp)
            
        return test_results
      
        
    def accuracy(self, predicts, targets):
      
        if 'class' in self.model_object: 
            _, predicts_indices = predicts.max(dim=1)
            n_correct = torch.eq(predicts_indices, targets).sum().item()
            return n_correct / len(predicts_indices) * 100

        if self.model_object=='regression':
            accu = torch.abs(predicts - targets).mean().item()
            return accu
          
   
    def html_loss_plot(self, image):
        
        h = HTML("<img src='{0}'/>".format(image))
    
        return h

    
    def html_progress(self, var, value, max=100):
      
        h = HTML("""{var}: <progress value='{value}' max='{max}', style='width: 80%'>{value}
                           </progress>""".format(var=var, value=value, max=max))
    
        return h
       
      
    def update_loss_plot(self):
        
        train_batch_loss = self.state['train_batch_loss']
        train_batch_accu = self.state['train_batch_accu']
        train_epoch_loss = self.state['train_epoch_loss']
        train_epoch_accu = self.state['train_epoch_accu']

        ntb = len(train_batch_loss)
        nte = len(train_epoch_loss)
        nnn = ntb/nte 
        xtb = np.arange(ntb)/nnn 
        xte = np.arange(nte, dtype=np.int16)

        valid_batch_loss = self.state['valid_batch_loss']
        valid_batch_accu = self.state['valid_batch_accu']
        valid_epoch_loss = self.state['valid_epoch_loss']
        valid_epoch_accu = self.state['valid_epoch_accu']

        nvb = len(valid_batch_loss)
        nve = len(valid_epoch_loss)
        nnn = nvb/nve 
        xvb = np.arange(nvb)/nnn 
        xve = np.arange(nve, dtype=np.int16)

        n = 2;  m = 2  # m features
        if isinstance(self.col_target, list):
            m = max(len(self.col_target), m)
            n = 3
        
        fig, axes = plt.subplots(n, m, figsize=(12,8))
        # loss
        axes[0,0].plot(xtb, train_batch_loss)
        axes[0,0].plot(xvb, valid_batch_loss)
        axes[0,1].plot(xte, train_epoch_loss)        
        axes[0,1].plot(xve, valid_epoch_loss)
        # accuracy
        if isinstance(self.col_target, list):  
            for i in range(m): # feature
                # train
                #axes[1, i].plot(xtb, np.array(train_batch_accu)[:, i])
                axes[1, i].plot(xte, np.array(train_epoch_accu)[:, i])
                # validation
                #axes[2, i].plot(xvb, np.array(valid_batch_accu)[:, i])
                axes[2, i].plot(xve, np.array(valid_epoch_accu)[:, i])

        else: # single feature prediction or classification
            axes[1,0].plot(xtb, train_batch_accu)
            axes[1,0].plot(xvb, valid_batch_accu)
            axes[1,1].plot(xte, train_epoch_accu)
            axes[1,1].plot(xve, valid_epoch_accu)

        bio = io.BytesIO()
        fig.savefig(bio, format='png')
        bio.seek(0)
        uri = 'data:image/png;base64,' + base64.encodebytes(bio.getvalue()).decode()

        plt.close()

        return uri
      
      
    def update_save_state(self):
        
        # save state
        with open(self.f_state, 'w') as fp: yaml.dump(self.state, fp)
            
        # save model
        if self.state['epoch_index']==0:
            if 'class' in self.model_object: self.state['best_accu'] = 0.
            if self.model_object=='regression': self.state['best_accu'] = 999.
            self.state['best_epoch'] = self.state['epoch_index']

        cur_accu = self.state['valid_epoch_accu'][-1]
        
        is_improve = True
        # classification accuray measures classification rate, increases 
        if 'class' in self.model_object and self.state['best_accu']>cur_accu: is_improve = False
        # regresion accuracy measures the abs error difference, decreases
        if self.model_object=='regression' and self.state['best_accu']<cur_accu: is_improve = False

        if is_improve:
            self.state['best_accu'] = cur_accu
            self.state['best_epoch'] = self.state['epoch_index']
            # save the model
            state_cp = deepcopy(self.state)
            state_cp['model_state_dict'] = self.model.state_dict()
            state_cp['optim_state_dict'] = self.optimizer.state_dict()
            torch.save(state_cp, self.f_model)
            

##Storm Inference Class

Once the model is trained and optimized, this class will initiate the model based on the configuration file and the pytorch state dictionary file. The model is then used to make prediction on new samples, which is straightforward.

In [0]:
class ImageInference(object):
  
    def __init__(self, config, model_name):
        # CUDA for PyTorch
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
 
        # Model
        model = YML2ModelNet(config, model_name)
        model_state = torch.load(config['params']['f_model_pth'])
        model.load_state_dict(model_state['model_state_dict'])
        self.model = model.to(self.device)
        
    def inference(self, imgs, apply_softmax=False):
                      
        self.model.eval() 
        with torch.no_grad():
            imgs = imgs.to(self.device)  
            predicts = self.model(imgs)
            if apply_softmax: predicts = F.softmax(predicts, 1)
        
        return predicts.cpu().detach().numpy() 
      