## Install & Import

In [None]:
!pip install import_ipynb

Collecting import_ipynb
  Downloading https://files.pythonhosted.org/packages/63/35/495e0021bfdcc924c7cdec4e9fbb87c88dd03b9b9b22419444dc370c8a45/import-ipynb-0.1.3.tar.gz
Building wheels for collected packages: import-ipynb
  Building wheel for import-ipynb (setup.py) ... [?25l[?25hdone
  Created wheel for import-ipynb: filename=import_ipynb-0.1.3-cp36-none-any.whl size=2976 sha256=2ab021b9415b99c6d24c873dbb41746e08d782d81cd955748088f4230845960d
  Stored in directory: /root/.cache/pip/wheels/b4/7b/e9/a3a6e496115dffdb4e3085d0ae39ffe8a814eacc44bbf494b5
Successfully built import-ipynb
Installing collected packages: import-ipynb
Successfully installed import-ipynb-0.1.3


In [None]:
!pip install pytorch-ignite

Collecting pytorch-ignite
[?25l  Downloading https://files.pythonhosted.org/packages/db/4d/49a158b7e7ce3e31d2b921eea274c945e48eea07eb9df3da987b64ee87b6/pytorch_ignite-0.4.3-py3-none-any.whl (193kB)
[K     |█▊                              | 10kB 21.3MB/s eta 0:00:01[K     |███▍                            | 20kB 18.4MB/s eta 0:00:01[K     |█████                           | 30kB 14.9MB/s eta 0:00:01[K     |██████▊                         | 40kB 13.4MB/s eta 0:00:01[K     |████████▌                       | 51kB 9.2MB/s eta 0:00:01[K     |██████████▏                     | 61kB 9.8MB/s eta 0:00:01[K     |███████████▉                    | 71kB 9.9MB/s eta 0:00:01[K     |█████████████▌                  | 81kB 10.0MB/s eta 0:00:01[K     |███████████████▏                | 92kB 10.1MB/s eta 0:00:01[K     |█████████████████               | 102kB 8.7MB/s eta 0:00:01[K     |██████████████████▋             | 112kB 8.7MB/s eta 0:00:01[K     |████████████████████▎           | 1

In [None]:
# this code used in colab. 
%cd /content/drive/MyDrive/Colab Notebooks

/content/drive/MyDrive/Colab Notebooks


In [None]:
from copy import copy
import numpy as np
import torch
from ignite.engine import Engine
from ignite.engine import Events
from ignite.metrics import RunningAverage
from ignite.contrib.handlers.tqdm_logger import ProgressBar

import import_ipynb

In [None]:
import korean_semantic_classification_utils as utils

## Define our trainer

In [None]:
VERBOSE_SILENT = 0
VERBOSE_EPOCH_WISE = 1
VERBOSE_BATCH_WISE = 2

class MyEngine(Engine):
    def __init__(self, func, model, crit, optimizer, config):
      self.model = model
      self.crit = crit
      self.optimizer = optimizer
      self.config = config

      super().__init__(func)

      self.best_loss = np.inf
      self.best_model = None

      self.device = next(model.parameters()).device

    @staticmethod
    def train(engine, mini_batch):
      engine.model.train()
      engine.optimizer.zero_grad()

      x, y = mini_batch.text, mini_batch.label
      x, y = x.to(engine.device), y.to(engine.device)

      x = x[:, :engine.config.max_length]

      # Take feed-forward
      y_hat = engine.model(x)

      loss = engine.crit(y_hat, y)
      loss.backward()

      if isinstance(y, torch.longTensor) or isinstance(y, torch.cuda.LongTensor):
        accuracy = (torch.argmax(y_hat, dim = -1) == y).sum() / float(y.size(0))
      else:
        accuracy = 0

      p_norm = float(utils.get_parameter_norm(engine.model.parameters()))
      g_norm = float(utils.get_grad_norm(engine.model.parameters()))

      engine.optimizer.step()

      return {
          'loss' : float(loss),
          'accuracy' : float(accuracy),
          '|param|' : p_norm,
          '|g_norm|' : g_norm,
      }

    @staticmethod
    def validate(engine, mini_batch):
      engine.model.eval()

      with torch.no_grad():
        x, y = mini_batch.text, mini_batch.label
        x, y = x.to(engine.device), y.to(engine.device)

        x = x[:, :engine.config.max_length]

        y_hat = engine.model(x)
        
        loss = engine.crit(y_hat, y)

        if instance(y, torch.LongTensor) or isinstance(y, torch.cuda.LongTensor):
          accuracy = (torch.argmax(y_hat, dim = -1) == y).sum() / float(y.size(0))
        else:
          accuracy = 0
      
      return {
          'loss' : float(loss),
          'accuracy' : float(accuracy),
      }

    @staticmethod
    def attach(train_engine, validation_engine, verbose = VERBOSE_BATCH_WISE):
      def attach_running_average(engine, metric_name):
        RunningAverage(output_transform = lambda x : x[metric_name]).attach(
            engine,
            metric_name,
        )
      
      training_metric_names = ['loss','accuracy', '|param|', '|g_param|']

      for metric_name in training_metric_names:
        attach_running_average(train_engine, metric_name)

      if verbose >= VERBOSE_BATCH_WISE:
        pbar = ProgressBar(bar_format = None, ncols = 120)
        pbar.attach(train_engine, training_metric_names)

      if verbose >= VERBOSE_EPOCH_WISE:
        @train_engine.on(Events.EPOCH_COMPLETED)
        def print_train_logs(engine):
          print('Epoch {} - |param| = {:.2e} loss = {:.4e} accuracy = {:.4f}'.format(
              engine.state.epoch,
              engine.state.metrics['|param|'],
              engine.state.metrics['|g_param|'],
              engine.state.metrics['loss'],
              engine.state.metrics['accuracy'],
          ))
      validation_metric_names = ['loss', 'accuracy']

      for metric_name in validation_metric_names:
        attach_running_average(validation_engine, metric_name)

      if verbose >= VERBOSE_EPOCH_WISE:
        @valid_engine.on(Events.EPOCH_COMPLETED)
        def print_valid_logs(engine):
          print('Validation - loss={:.4e} accuracy={:.4f} best_loss={:.4e}'.format(
            engine.state.metrics['loss'],
            engine.state.metrics['accuracy'],
            engine.best_loss,
          ))

    @staticmethod
    def check_best(engine):
      loss = float(engine.state.metrics['loss'])
      if loss <= engine.best_loss:
        engine.best_loss = loss
        engine.best_model = deepcopy(engine.model.state_dict())

    @staticmethod
    def save_model(engine, train_engine, config, **kwargs):
      torch.save(
          {
              'model' : engine.best_model,
              'config': config,
              **kwargs
          }, config.model_fn
      )
          
class Trainer():
  def __init__(self, config):
    self.config = config

  def train(self, model, crit, optimizer, train_loader, valid_loader,):
    train_engine = MyEngine(MyEngine.train, model, crit, optimizer, self.config)
    validation_engine = MyEngine(MyEngine.validate, model, crit, optimizer, self.config)
    MyEngine.attach(
        train_engine, validation_engine, verbose = self.config.verbose
    )

    def run_validation(engine, validation_engine, valid_loader):
      validation_engine.run(valid_loader, max_epochs = 1)
    
    train_engine.add_event_handler(
        Events.EPOCH_COMPLETED,
        run_validation,
        validation_engine, valid_loader,
    )

    validation_engine.add_event_handler(
        Events.EPOCH_COMPLETED,
        MyEngine.check_best,
    )

    train_engine.run(train_loader, max_epochs = self.config.n_epochs,)

    model.load_state_dict(validation_engine.best_model)

    return model
          
        