# Import

In [None]:
# -*- coding:utf-8 -*-
import os
import math
import time
import random
import torch
import torch.nn as nn
import torch.optim as opt
from torch.utils.data import DataLoader
from data_helper import create_or_get_voca, LSTMSeq2SeqDataset
from Customize_Seq2SeqWithAttention.model import Encoder, AttentionDecoder, Seq2SeqWithAttention
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import font_manager, rc

# Attention 이미지 만들 때 글씨체 변경
font_name = font_manager.FontProperties(fname="c:/Windows/Fonts/NanumBarunGothic.ttf").get_name()
rc('font', family=font_name)
plt.rcParams.update({'font.size': 7})

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

- 먼저 NanumBarunGothic.ttf 를 다운받아서 설치해야 합니다. (https://vhrms.tistory.com/403 저는 여기서 다운받았습니다.)
- 만약 설치하지 않으면 Attention 이미지에서 한글이 깨지게 됩니다.
- Device는 gpu를 기본으로 하고 없을경우 cpu를 사용하게 되어 있습니다.

# Trainer

In [None]:
class Trainer(object):  # Train
    def __init__(self, args):
        self.args = args
        self.x_train_path = os.path.join(self.args.data_path, self.args.src_train_filename)  # train input 경로 
        self.y_train_path = os.path.join(self.args.data_path, self.args.tar_train_filename)  # train target 경로
        self.x_val_path = os.path.join(self.args.data_path, self.args.src_val_filename)      # validation input 경로 
        self.y_val_path = os.path.join(self.args.data_path, self.args.tar_val_filename)      # validation target 경로
        self.ko_voc, self.en_voc = self.get_voca()      # vocabulary
        self.train_loader = self.get_train_loader()     # train data loader
        self.val_loader = self.get_val_loader()         # validation data loader
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.en_voc['<pad>'])             # cross entropy
        self.writer = SummaryWriter()                   # tensorboard 기록
        self.train()                                    # train 실행

In [None]:
    def train(self):
        start = time.time()
        encoder_parameter = self.encoder_parameter()
        decoder_parameter = self.decoder_parameter()

        encoder = Encoder(**encoder_parameter)
        decoder = AttentionDecoder(**decoder_parameter)
        model = Seq2SeqWithAttention(encoder, decoder, self.args.sequence_size, self.args.get_attention)
        model = nn.DataParallel(model)
        model.cuda()
        model.train()

        encoder_optimizer = opt.Adam(model.parameters(), lr=self.args.learning_rate)
        decoder_optimizer = opt.Adam(model.parameters(), lr=self.args.learning_rate)

        epoch_step = len(self.train_loader) + 1
        total_step = self.args.epochs * epoch_step
        train_ratios = cal_teacher_forcing_ratio(self.args.learning_method, total_step)
        val_ratios = cal_teacher_forcing_ratio('Mixed_Sampling', int(total_step / 100)+1)

        step = 0
        attention = None

        for epoch in range(self.args.epochs):
            for i, data in enumerate(self.train_loader, 0):
                try:
                    src_input, tar_input, tar_output = data
                    if self.args.get_attention:
                        output, attention = model(src_input, tar_input, teacher_forcing_rate=train_ratios[i])
                    else:
                        output = model(src_input, tar_input, teacher_forcing_rate=train_ratios[i])
                    # Get loss & accuracy
                    loss, accuracy, ppl = self.loss_accuracy(output, tar_output)

                    # Training Log
                    if step % self.args.train_step_print == 0:
                        self.writer.add_scalar('train/loss', loss.item(), step)
                        self.writer.add_scalar('train/accuracy', accuracy.item(), step)
                        self.writer.add_scalar('train/PPL', ppl, step)

                        print('[Train] epoch : {0:2d}  iter: {1:4d}/{2:4d}  step : {3:6d}/{4:6d}  '
                              '=>  loss : {5:10f}  accuracy : {6:12f}  PPL : {7:6f}'
                              .format(epoch, i, epoch_step, step, total_step, loss.item(), accuracy.item(), ppl))

                    # Validation Log
                    if step % self.args.val_step_print == 0:
                        with torch.no_grad():
                            model.eval()
                            if step >= 100:
                                steps = int(step / 100)
                            else:
                                steps = step
                            val_loss, val_accuracy, val_ppl = self.val(model,
                                                                       teacher_forcing_rate=val_ratios[steps])
                            self.writer.add_scalar('val/loss', val_loss, step)
                            self.writer.add_scalar('val/accuracy', val_accuracy, step)
                            self.writer.add_scalar('val/PPL', val_ppl, step)

                            print('[Val] epoch : {0:2d}  iter: {1:4d}/{2:4d}  step : {3:6d}/{4:6d}  '
                                  '=>  loss : {5:10f}  accuracy : {6:12f}   PPL : {7:10f}'
                                  .format(epoch, i, epoch_step, step, total_step, val_loss, val_accuracy, val_ppl))
                            model.train()

                    # Save Model Point
                    if step % self.args.step_save == 0:
                        print("time :", time.time() - start)
                        if self.args.get_attention:
                            self.plot_attention(step, src_input, tar_input, attention)
                        self.model_save(model=model, encoder_optimizer=encoder_optimizer,
                                        decoder_optimizer=decoder_optimizer, epoch=epoch, step=step)

                    # optimizer
                    encoder_optimizer.zero_grad()
                    decoder_optimizer.zero_grad()
                    loss.backward()
                    encoder_optimizer.step()
                    decoder_optimizer.step()
                    step += 1

                # If KeyBoard Interrupt Save Model
                except KeyboardInterrupt:
                    self.model_save(model=model, encoder_optimizer=encoder_optimizer,
                                    decoder_optimizer=decoder_optimizer, epoch=epoch, step=step)

In [None]:
def cal_teacher_forcing_ratio(learning_method, total_step):
    if learning_method == 'Teacher_Forcing':
        teacher_forcing_ratios = [1.0 for _ in range(total_step)]  # 교사강요
    elif learning_method == 'Scheduled_Sampling':
        import numpy as np
        teacher_forcing_ratios = np.linspace(0.0, 1.0, num=total_step)[::-1]  # 스케줄 샘플링
        # np.linspace : 시작점과 끝점을 균일하게 toptal_step수 만큼 나눈 점을 생성
    elif learning_method == 'Mixed_Sampling':
        import numpy as np
        teacher_forcing_ratios = [1.0 for _ in range(int(total_step/2))]  # 교사강요
        b = np.linspace(0.0, 1.0, num=int(total_step/2))[::-1]  # 스케줄 샘플링
        teacher_forcing_ratios.extend(b)
    else:
        raise NotImplementedError('learning method must choice [Teacher_Forcing, Scheduled_Sampling]')
    return teacher_forcing_ratios

In [None]:
    def get_voca(self):
        try:
            ko_voc, en_voc = create_or_get_voca(save_path=self.args.dictionary_path)
        except OSError:
            ko_voc, en_voc = create_or_get_voca(save_path=self.args.dictionary_path,
                                                ko_corpus_path=self.x_train_path,
                                                en_corpus_path=self.y_train_path)
        return ko_voc, en_voc

In [None]:
    def get_train_loader(self):
        # 재현을 위해 랜덤시드 고정
        # seed_val = 42
        # torch.manual_seed(seed_val)
        # path를 불러와서 train_loader를 만드는 함수
        train_dataset = LSTMSeq2SeqDataset(self.x_train_path, self.y_train_path, self.ko_voc, self.en_voc,
                                           self.args.sequence_size)
        point_sampler = torch.utils.data.RandomSampler(train_dataset)
        train_loader = DataLoader(train_dataset, batch_size=self.args.batch_size, sampler=point_sampler)
        return train_loader
    def get_val_loader(self):
        # 재현을 위해 랜덤시드 고정
        # seed_val = 42
        # torch.manual_seed(seed_val)
        # path를 불러와서 train_loader를 만드는 함수
        val_dataset = LSTMSeq2SeqDataset(self.x_val_path, self.y_val_path, self.ko_voc, self.en_voc,
                                         self.args.sequence_size)
        point_sampler = torch.utils.data.RandomSampler(val_dataset)
        val_loader = DataLoader(val_dataset, batch_size=self.args.batch_size, sampler=point_sampler)
        return val_loader

In [None]:
   def encoder_parameter(self):
        param = {
            'embedding_size': 5000,
            'embedding_dim': self.args.embedding_dim,
            'pad_id': self.ko_voc['<pad>'],
            'rnn_dim': self.args.encoder_rnn_dim,
            'rnn_bias': True,
            'n_layers': self.args.encoder_n_layers,
            'embedding_dropout': self.args.encoder_embedding_dropout,
            'rnn_dropout': self.args.encoder_rnn_dropout,
            'dropout': self.args.encoder_dropout,
            'residual_used': self.args.encoder_residual_used,
            'bidirectional': self.args.encoder_bidirectional_used,
            'encoder_output_transformer': self.args.encoder_output_transformer,
            'encoder_output_transformer_bias': self.args.encoder_output_transformer_bias,
            'encoder_hidden_transformer': self.args.encoder_hidden_transformer,
            'encoder_hidden_transformer_bias': self.args.encoder_hidden_transformer_bias
        }
        return param

    def decoder_parameter(self):
        param = {
            'embedding_size': 5000,
            'embedding_dim': self.args.embedding_dim,
            'pad_id': self.en_voc['<pad>'],
            'rnn_dim': self.args.decoder_rnn_dim,
            'rnn_bias': True,
            'n_layers': self.args.decoder_n_layers,
            'embedding_dropout': self.args.decoder_embedding_dropout,
            'rnn_dropout': self.args.decoder_rnn_dropout,
            'dropout': self.args.decoder_dropout,
            'residual_used': self.args.decoder_residual_used,
            'attention_score_func': self.args.attention_score
        }
        return param

In [None]:
    def loss_accuracy(self, out, tar):
        # out => [embedding_size, sequence_len, vocab_size]
        # tar => [embedding_size, sequence_len]
        out = out.view(-1, out.size(-1))
        tar = tar.view(-1).to(device)
        # out => [embedding_size * sequence_len, vocab_size]
        # tar => [embedding_size * sequence_len]
        loss = self.criterion(out, tar)
        ppl = math.exp(loss.item())

        _, indices = out.max(-1)
        invalid_targets = tar.eq(self.en_voc['<pad>'])
        equal = indices.eq(tar)
        total = 1
        for i in equal.size():
            total *= i
        accuracy = torch.div(equal.masked_fill_(invalid_targets, 0).long().sum().to(dtype=torch.float32), total)
        return loss, accuracy, ppl

In [None]:
    def val(self, model, teacher_forcing_rate):
        total_loss = 0
        total_accuracy = 0
        total_ppl = 0
        with torch.no_grad():
            count = 0
            for data in self.val_loader:
                src_input, tar_input, tar_output = data
                output = model(src_input, tar_input, teacher_forcing_rate=teacher_forcing_rate)

                if isinstance(output, tuple):
                    output = output[0]
                loss, accuracy, ppl = self.loss_accuracy(output, tar_output)
                total_loss += loss.item()
                total_accuracy += accuracy.item()
                total_ppl += ppl
                count += 1
            _, indices = output.view(-1, output.size(-1)).max(-1)
            indices = indices[:self.args.sequence_size].tolist()
            a = src_input[0].tolist()
            b = tar_output[0].tolist()
            print(self.tensor2sentence_ko(a))
            print(self.tensor2sentence_en(indices))
            print(self.tensor2sentence_en(b))
            avg_loss = total_loss / count
            avg_accuracy = total_accuracy / count
            avg_ppl = total_ppl / count
            return avg_loss, avg_accuracy, avg_ppl

In [None]:
    def model_save(self, model, encoder_optimizer, decoder_optimizer, epoch, step):
        model_name = '{0:06d}_model_1.pth'.format(step)
        model_path = os.path.join(self.args.model_path, model_name)
        torch.save({
            'epoch': epoch,
            'steps': step,
            'seq_len': self.args.sequence_size,
            'encoder_parameter': self.encoder_parameter(),
            'decoder_parameter': self.decoder_parameter(),
            'model_state_dict': model.state_dict(),
            'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
            'decoder_optimizer_state_dict': decoder_optimizer.state_dict()

        }, model_path)

In [None]:
    def tensor2sentence_en(self, indices: torch.Tensor) -> list:
        result = []
        translation_sentence = []
        for idx in indices:
            word = self.en_voc.IdToPiece(idx)
            if word == '</s>':
                break
            translation_sentence.append(word)
        translation_sentence = ''.join(translation_sentence).replace('▁', ' ').strip()
        result.append(translation_sentence)
        return result

    def tensor2sentence_ko(self, indices: torch.Tensor) -> list:
        result = []
        translation_sentence = []
        for idx in indices:
            word = self.ko_voc.IdToPiece(idx)
            if word == '<pad>':
                break
            translation_sentence.append(word)
        translation_sentence = ''.join(translation_sentence).replace('▁', ' ').strip()
        result.append(translation_sentence)
        return result

In [None]:
    def plot_attention(self, step, src_input, trg_input, attention):
        filename = '{0:06d}_step'.format(step)
        filepath = os.path.join(self.args.img_path, filename)
        try:
            os.mkdir(filepath)
        except FileExistsError:
            pass

        def replace_pad(words):
            return [word if word != '<pad>' else '' for word in words]

        with torch.no_grad():
            src_input = src_input.to('cpu')
            trg_input = trg_input.to('cpu')
            attention = attention.to('cpu')

            sample = [i for i in range(src_input.shape[0] - 1)]
            sample = random.sample(sample, self.args.plot_count)

            for num, i in enumerate(sample):
                src, trg = src_input[i], trg_input[i]
                src_word = replace_pad([self.ko_voc.IdToPiece(word.item()) for word in src])
                trg_word = replace_pad([self.en_voc.IdToPiece(word.item()) for word in trg])

                fig = plt.figure()
                ax = fig.add_subplot(111)
                cax = ax.matshow(attention[i].data, cmap='bone')
                fig.colorbar(cax)

                ax.set_xticklabels(trg_word, rotation=90)
                ax.set_yticklabels(src_word)
                ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
                ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
                fig.savefig(fname=os.path.join(filepath, 'attention-{}.png'.format(num)))