In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

In [2]:
import os

In [3]:
os.sys.path.append('/content/gdrive/path/to/module_dir')

In [None]:
os.sys.path

In [5]:
import math
import sys
import pickle
import time
import numpy as np

from docopt import docopt
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from nmt import Hypothesis, NMT
import numpy as np
from typing import List, Tuple, Dict, Set, Union
from tqdm import tqdm
from utils import read_corpus, batch_iter
from vocab import Vocab, VocabEntry

import torch
import torch.nn.utils
import torch.nn as nn

In [6]:
def evaluate_ppl(model, dev_data, batch_size=32):
    """ Evaluate perplexity on dev sentences
    @param model (NMT): NMT Model
    @param dev_data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (batch size)
    @returns ppl (perplixty on dev sentences)
    """
    was_training = model.training
    model.eval()
    
    cum_loss = 0.
    cum_tgt_words = 0.
    
    # no_grad() signals backend to throw away all gradients
    with torch.no_grad():
        for src_sents, tgt_sents in batch_iter(dev_data, batch_size):
            loss = -model(src_sents, tgt_sents).sum()
            
            cum_loss += loss.item()
            tgt_word_num_to_predict = sum(len(s[1:]) for s in tgt_sents) # omitting leading '<s>'
            cum_tgt_words += tgt_word_num_to_predict
            
        ppl = np.exp(cum_loss / cum_tgt_words)
        
    if was_training:
        model.train()
        
    return ppl

In [7]:
def compute_corpus_level_bleu_score(references: List[List[str]], hypotheses: List[Hypothesis]) -> float:
    """ Given decoding results and reference sentences, compute corpus-level BLEU score.
    @param references (List[List[str]]): a list of gold-standard reference target sentences
    @param hypotheses (List[Hypothesis]): a list of hypotheses, one for each reference
    @returns bleu_score: corpus-level BLEU score
    """
    if references[0][0] == '<s>':
        references = [ref[1:-1] for ref in references]
    bleu_score = corpus_bleu([[ref] for ref in references],
                             [hyp.value for hyp in hypotheses])
    return bleu_score

In [29]:
def train(args: Dict):
    """ Train the NMT Model.
    @param args (Dict): args from cmd line
    """
    train_data_src = read_corpus(args['train_src'], source='src')
    train_data_tgt = read_corpus(args['train_tgt'], source='tgt')
    
    dev_data_src = read_corpus(args['dev_src'], source='src')
    dev_data_tgt = read_corpus(args['dev_tgt'], source='tgt')
    
    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))
    
    train_batch_size = int(args['batch_size'])
    clip_grad = float(args['clip_grad'])
    valid_niter = int(args['valid_niter'])
    log_every = int(args['log_every'])
    model_save_path = args['save_to']
    
    if not os.path.exists(model_save_path):
      os.makedirs(model_save_path)

    vocab = Vocab.load(args['vocab'])
    
    #print("load model from {}".format(args['model_path']), file=sys.stderr)
    if args['continue_training']:

      # checkpoint 를 저장한 directory 의 최근 checkpoint 를 불러옴
      # e.g., .bin, .optim
      ckpt_file_time = []
      optim_file_time = []
      for f_name in os.listdir(model_save_path):
        if '.bin' not in f_name:
          continue

        written_time = os.path.getctime(f"{model_save_path}/{f_name}")
   
        if f_name.split('.')[-1] == 'optim':
          optim_file_time.append((f_name, written_time))
        else:
          ckpt_file_time.append((f_name, written_time))

      sorted_ckpt = sorted(ckpt_file_time, key=lambda x: x[1], reverse=True)
      sorted_optim = sorted(optim_file_time, key=lambda x: x[1], reverse=True)
      
      recent_ckpt = sorted_ckpt[0][0]
      recent_optim = sorted_optim[0][0]
      
      model = NMT.load(os.path.join(model_save_path, recent_ckpt))

      last_train_prefixes = recent_ckpt.split('/')[-1].split('_')
      last_train_epoch = int(last_train_prefixes[0])
      last_train_iter = int(last_train_prefixes[1])
      
    else:
      model = NMT(embed_size=int(args['embed_size']),
                hidden_size=int(args['hidden_size']),
                dropout_rate=float(args['dropout']),
                vocab=vocab)
    
    model.train()
    
    if not args['continue_training']:
      uniform_init = float(args['uniform_init'])
      if np.abs(uniform_init) > 0.:
          print('uniformly initialize parameters [-%f, +%f]' % (uniform_init, uniform_init), file=sys.stderr)
          for _, p in model.named_parameters():
              nn.init.uniform_(p, -uniform_init, uniform_init)

    
    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    #device = torch.device("cuda:0" if args['--cuda'] else "cpu")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print('use device: %s' % device, file=sys.stderr)
    
    model = model.to(device)
    
    
    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['lr']))
    
    if args['continue_training']:
      print('restore parameters of the optimizers {}'.format(os.path.join(model_save_path, recent_optim)), file=sys.stderr)
      optimizer.load_state_dict(torch.load(os.path.join(model_save_path, recent_optim)))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('begin Maximum Likelihood training')
    
    if args['continue_training']:
      train_iter = last_train_iter
      epoch = last_train_epoch - 1

    while True:

        epoch += 1
        
        for idx, (src_sents, tgt_sents) in enumerate(batch_iter(train_data, batch_size=train_batch_size, shuffle=True)):   
            # src_sents shape: (batch_size,)
            # tgt_sents shape: (batch_size,)
            
            if args['continue_training'] and idx < last_train_iter:
              continue
            
            train_iter += 1
            
            optimizer.zero_grad()
            
            batch_size = len(src_sents)
            
            example_losses = -model(src_sents, tgt_sents)
            
            
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size
            
            loss.backward()
            
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            
            optimizer.step()
            
            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val
            
            tgt_words_num_to_predict = sum(len(s[1:]) for s in tgt_sents) # omitting leading '<s>'
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size
            
            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                      'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter, 
                                                                                         report_loss / report_examples,
                                                                                         math.exp(report_loss / report_tgt_words),
                                                                                         cum_examples,
                                                                                         report_tgt_words / (time.time() - train_time),
                                                                                         time.time() - begin_time),
                      file=sys.stderr)
                
                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.
            
            # perform validation
            if train_iter % valid_niter == 0:
                print('epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d' % (epoch, train_iter,
                                                                                         cum_loss / cum_examples,
                                                                                         np.exp(cum_loss / cum_tgt_words),
                                                                                         cum_examples), file=sys.stderr)

                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1
                
                print('begin validation ...', file=sys.stderr)
                
                # compute dev. ppl and bleu
                dev_ppl = evaluate_ppl(model, dev_data, batch_size=128) # dev batch size can be a bit larger
                valid_metric = -dev_ppl
                
                print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl), file=sys.stderr)
                
                is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)
                
                if is_better:
                    patience = 0
                    model_save_dir = os.path.join(model_save_path, f'{epoch}_{train_iter}_model.bin')
                    print('save currently the best model to [%s]' % model_save_dir, file=sys.stderr)
                    model.save(model_save_dir)
                    
                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(), model_save_dir + '.optim')
                elif patience < int(args['patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)
                    
                    if patience == int(args['patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['max_num_trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)
                        
                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(args['lr_decay'])
                        print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)
                        
                        # load model
                        params = torch.load(model_save_path, map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)
                        
                        print('restore parameters of the optimizers', file=sys.stderr)
                        optimizer.load_state_dict(torch.load(model_save_dir + '.optim'))
                        
                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                        
                        # reset patience
                        patience = 0
                        
                    if epoch == int(args['max_epoch']):
                        print('reached maximum number of epochs!', file=sys.stderr)
                        exit(0)

In [30]:
args = dict()

args['train_src'] = "/content/gdrive/path/to/data/train.de-en.de.wmixerprep"
args['train_tgt'] = "/content/gdrive/path/to/data/train.de-en.en.wmixerprep"
args['dev_src'] = "/content/gdrive/path/to/data/valid.de-en.de"
args['dev_tgt'] = "/content/gdrive/path/to/data/valid.de-en.en"
args['vocab'] = "/content/gdrive/path/to/data/vocab.json"

args['seed'] = 0
args['batch_size'] = 32
args['embed_size'] = 256
args['hidden_size'] = 256
args['clip_grad'] = 5.0
args['log_every'] = 10
args['max_epoch'] = 30
args['patience'] = 5
args['max_num_trial'] = 5
args['lr_decay'] = 0.5
args['beam_size'] = 5
args['lr'] = 0.001
args['uniform_init'] = 0.1
args['save_to'] = '/content/gdrive/path/to/checkpoint'
args['valid_niter'] = 100
args['dropout'] = 0.3
args['max_decoding_time_step'] = 70
args['cuda'] = True
args['continue_training'] = False

In [31]:
# seed the random number generators
seed = int(args['seed'])
torch.manual_seed(seed)
if args['cuda']:
    torch.cuda.manual_seed(seed)
np.random.seed(seed * 13 // 7)

In [None]:
train(args)