In [1]:
!pip install transformers==4.7 torchinfo rouge

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
!git clone https://github.com/Taeksu-Kim/Transformer.git

fatal: destination path 'Transformer' already exists and is not an empty directory.


In [3]:
cd Transformer/PyTorch

/content/Transformer/PyTorch


In [4]:
!git clone https://github.com/songys/Chatbot_data.git

fatal: destination path 'Chatbot_data' already exists and is not an empty directory.


In [5]:
# common
import math
import random
import os
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from torchinfo import summary
from rouge import Rouge

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from transformers import AutoTokenizer

# custom
from transformer import Transformer

In [6]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore
    torch.cuda.manual_seed_all(seed)

seed = 42

seed_everything(seed)

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

epochs = 300
learning_rate = 1e-4
weight_decay = 1e-2
batch_size = 64

gradient_scaler = True
# use_lr_scheduler = False

early_stopping_patience = 10

save_name = 'tft_model'

In [8]:
df = pd.read_csv('./Chatbot_data/ChatbotData.csv')

In [9]:
model_path = "monologg/kobigbird-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [10]:
def cal_token_len(text, tokenizer):
  return len(tokenizer.encode(text))

In [11]:
df['enc_token_len'] = [ cal_token_len(df.iloc[i]['Q'], tokenizer) for i in tqdm(range(df.shape[0])) ]

100%|██████████| 11823/11823 [00:04<00:00, 2944.68it/s]


In [12]:
df['dec_token_len'] = [ cal_token_len(df.iloc[i]['A'], tokenizer) for i in tqdm(range(df.shape[0])) ]

100%|██████████| 11823/11823 [00:05<00:00, 2286.50it/s]


In [13]:
df.head()

Unnamed: 0,Q,A,label,enc_token_len,dec_token_len
0,12시 땡!,하루가 또 가네요.,0,6,8
1,1지망 학교 떨어졌어,위로해 드립니다.,0,8,6
2,3박4일 놀러가고 싶다,여행은 언제나 좋죠.,0,11,8
3,3박4일 정도 놀러가고 싶다,여행은 언제나 좋죠.,0,12,8
4,PPL 심하네,눈살이 찌푸려지죠.,0,6,9


In [14]:
tar_per_list = [95,98,99,100]
tar_col = df['enc_token_len']

for i in tar_per_list:
    print('{}% length : {}'.format(i, np.percentile(tar_col,i)))

95% length : 16.0
98% length : 19.0
99% length : 20.0
100% length : 30.0


In [15]:
max_enc_len = 20

In [16]:
tar_col = df['dec_token_len']

for i in tar_per_list:
    print('{}% length : {}'.format(i, np.percentile(tar_col,i)))

95% length : 18.0
98% length : 20.0
99% length : 23.0
100% length : 42.0


In [17]:
max_dec_len = 26

In [18]:
df = df[(df['enc_token_len']<=max_enc_len)&(df['dec_token_len']<=max_dec_len)]

In [19]:
df.head()

Unnamed: 0,Q,A,label,enc_token_len,dec_token_len
0,12시 땡!,하루가 또 가네요.,0,6,8
1,1지망 학교 떨어졌어,위로해 드립니다.,0,8,6
2,3박4일 놀러가고 싶다,여행은 언제나 좋죠.,0,11,8
3,3박4일 정도 놀러가고 싶다,여행은 언제나 좋죠.,0,12,8
4,PPL 심하네,눈살이 찌푸려지죠.,0,6,9


In [20]:
train, valid =  train_test_split(df, test_size=0.05, random_state=seed, shuffle=True)

In [21]:
def mk_token_inputs(text, max_seq_len, mode='encoder'):
    input_ids = tokenizer.encode(text,max_length=max_seq_len, padding='max_length')
    
    if mode == 'decoder':
      cls_idx = input_ids.index(tokenizer.cls_token_id)
      sep_idx = input_ids.index(tokenizer.sep_token_id)

      input_ids[cls_idx] = tokenizer.bos_token_id
      input_ids[sep_idx] = tokenizer.eos_token_id

    return torch.tensor(input_ids, dtype=int)

In [22]:
class chatbot_dataset(Dataset):

  def __init__(self, df, enc_max_len, dec_max_len):
    self.df = df
    self.enc_max_len = enc_max_len
    self.dec_max_len = dec_max_len

  def __len__(self):
    return self.df.shape[0]

  def __getitem__(self, index):

    return {'enc_inputs' : mk_token_inputs(self.df['Q'].iloc[index], self.enc_max_len),
            'dec_inputs' : mk_token_inputs(self.df['A'].iloc[index], self.dec_max_len, mode='decoder'),
           }

In [23]:
train_dataset = chatbot_dataset(train, max_enc_len, max_dec_len+1)
valid_dataset = chatbot_dataset(valid, max_enc_len, max_dec_len+1)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers=0, shuffle=True)

In [24]:
for i, batch in enumerate(train_dataloader):
    break

In [25]:
# Config Class
# dict class를 json으로 바꿔서 confg.arg 와 같이 사용할 수 있게 만드는 class
class Config(dict): 
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

    @classmethod
    def load(cls, file):
        with open(file, 'r') as f:
            config = json.loads(f.read())
            return Config(config)

In [26]:
config_dict = {
    'vocab_size' : tokenizer.vocab_size,
    'd_model' : 256,
    'max_enc_len' : max_enc_len,
    'max_dec_len' : max_dec_len,
    'pad_id' : tokenizer.pad_token_id,
    'bos_id' : tokenizer.bos_token_id,
    'eos_id' : tokenizer.eos_token_id,
    'use_decoder' : True,
    'init_std' : 2e-2,
    'norm_eps' : 1e-12, 
    'drop_out_raito' : 0.1,
    'num_enc_layers' : 3,
    'num_dec_layers' : 3,
    'num_att_heads' : 4,
    'feed_forward_dim' : 1024,
}

config = Config(config_dict)

In [27]:
model = Transformer(config)

In [28]:
enc_inputs = batch['enc_inputs']
dec_inputs = batch['dec_inputs'][:,1:]
summary(model, input_data=[enc_inputs, dec_inputs])

Layer (type:depth-idx)                                            Output Shape              Param #
Transformer                                                       [64, 26, 32500]           --
├─TransformerEncoder: 1-1                                         [64, 20, 256]             --
│    └─Embedding: 2-1                                             [64, 20, 256]             8,320,000
│    └─ModuleList: 2-2                                            --                        --
│    │    └─TransformerEncoderLayer: 3-1                          [64, 20, 256]             789,760
│    │    └─TransformerEncoderLayer: 3-2                          [64, 20, 256]             789,760
│    │    └─TransformerEncoderLayer: 3-3                          [64, 20, 256]             789,760
├─TransformerDecoder: 1-2                                         [64, 26, 32500]           --
│    └─Embedding: 2-3                                             [64, 26, 256]             8,320,000
│    └─ModuleLis

In [29]:
config_dict = {
    'vocab_size' : tokenizer.vocab_size,
    'd_model' : 512,
    'max_enc_len' : max_enc_len,
    'max_dec_len' : max_dec_len,
    'pad_id' : tokenizer.pad_token_id,
    'bos_id' : tokenizer.bos_token_id,
    'eos_id' : tokenizer.eos_token_id,
    'use_decoder' : True,
    'init_std' : 2e-2,
    'norm_eps' : 1e-12, 
    'drop_out_raito' : 0.1,
    'num_enc_layers' : 6,
    'num_dec_layers' : 6,
    'num_att_heads' : 4,
    'feed_forward_dim' : 1024,
}

config = Config(config_dict)

In [30]:
model = Transformer(config)

In [31]:
enc_inputs = batch['enc_inputs']
dec_inputs = batch['dec_inputs'][:,1:]
summary(model, input_data=[enc_inputs, dec_inputs])

Layer (type:depth-idx)                                            Output Shape              Param #
Transformer                                                       [64, 26, 32500]           --
├─TransformerEncoder: 1-1                                         [64, 20, 512]             --
│    └─Embedding: 2-1                                             [64, 20, 512]             16,640,000
│    └─ModuleList: 2-2                                            --                        --
│    │    └─TransformerEncoderLayer: 3-1                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-2                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-3                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-4                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-5                          [64, 20, 512]             2,102,784
│ 

In [32]:
model.to(device)

Transformer(
  (encoder): TransformerEncoder(
    (word_embedding): Embedding(32500, 512, padding_idx=0)
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attention): AddNorm(
          (layer): MultiHeadAttention(
            (query_proj): Linear(in_features=512, out_features=512, bias=True)
            (key_proj): Linear(in_features=512, out_features=512, bias=True)
            (value_proj): Linear(in_features=512, out_features=512, bias=True)
            (scaled_dot_attn): ScaledDotProductAttention()
            (linear): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
        )
        (feed_forward): AddNorm(
          (layer): PoswiseFeedForward(
            (feed_forward): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): Dropout(p=0.1, inplace=False)
              (2): ReLU()
              (3): Linear(in_f

In [33]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [34]:
rouge = Rouge()

In [35]:
def cal_rouge_n(y_pred, y_true):
    y_pred = torch.argmax(y_pred, dim=-1)
   
    scores = []

    for i in range(y_true.shape[0]):
        score = 0
        reference = tokenizer.decode(y_true[i])
        hypothesis = tokenizer.decode(y_pred[i])
        if ' </s>' in hypothesis:
            
            reference = reference.split(' </s>')[0].replace('.', ' +002E')
            hypothesis = hypothesis.split(' </s>')[0].replace('.', ' +002E')

            if len(hypothesis) != 0:
              score = rouge.get_scores(hypothesis, reference)[0]['rouge-1']['f']
        scores.append(round(score,4))

    return sum(scores) / len(scores)

    # reference = []
    # hypothesis = []

    # for i in range(y_true.shape[0]):
    #     reference.append(tokenizer.decode(y_true[i]).split(' </s>')[0])
    #     hypothesis.append(tokenizer.decode(y_pred[i]).split(' </s>')[0])

    #     score = rouge.get_scores(hypothesis, reference)[0]['rouge-2']['f']
    #     scores.append(round(score,4))



In [36]:
def cal_lm_acc(y_pred, y_true, pad_id):
    """
    acc 계산 함수
    :param y_true: 정답 (bs, n_seq)
    :param y_pred: 예측 값 (bs, n_seq, n_vocab)
    """
    # 정답 여부 확인
    y_pred = torch.argmax(y_pred, dim=-1).int()
    matches = torch.eq(y_true, y_pred).int()
    
    # pad(0) 인 부분 mask
    mask = y_true.ne(pad_id).int()
    matches *= mask
    
    # 정확도 계산
    accuracy = torch.sum(matches) / torch.maximum(torch.sum(mask), torch.tensor(1, dtype=int))
    return accuracy

In [37]:
def train_step(batch, epoch, training):
    for batch_key in batch.keys():
        batch[batch_key] = batch[batch_key].to(device)

    if training is True:
        model.train()
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():

            logits = model(enc_inputs=batch['enc_inputs'],
                           dec_inputs=batch['dec_inputs'][:,:-1])[0]
            
            CCE = nn.CrossEntropyLoss(ignore_index=config.pad_id)
            loss = loss = CCE(logits.view(-1, config.vocab_size), batch['dec_inputs'][:,1:].contiguous().view(-1))
            # lm_acc = cal_lm_acc(logits, batch['dec_inputs'][:,1:], config.pad_id)
            rouge_1_f1 = cal_rouge_n(logits, batch['dec_inputs'][:,1:])

        loss.backward()
        optimizer.step()
            
        lr = optimizer.param_groups[0]["lr"]

        return loss, rouge_1_f1, round(lr, 10)
        # return loss, lm_acc, round(lr, 10)

    else:
        model.eval()
        with torch.no_grad():
            logits = model(enc_inputs=batch['enc_inputs'],
                           dec_inputs=batch['dec_inputs'][:,:-1])[0]

            CCE = nn.CrossEntropyLoss(ignore_index=config.pad_id)
            loss = loss = CCE(logits.view(-1, config.vocab_size), batch['dec_inputs'][:,1:].contiguous().view(-1))
            # lm_acc = cal_lm_acc(logits, batch['dec_inputs'][:,1:], config.pad_id)
            rouge_1_f1 = cal_rouge_n(logits, batch['dec_inputs'][:,1:])

        return loss, rouge_1_f1
        # return loss, lm_acc

In [38]:
# class color:
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'

In [39]:
%%time
# train

loss_plot, val_loss_plot = [], []
lrs = []

check_list = []

best_val_rouge_1_f1 = 0
best_val_loss = 100

best_epoch = 0
patience = 0

for epoch in range(epochs):
    gc.collect()
    total_loss, total_val_loss = 0, 0
    total_rouge_1_f1, total_val_rouge_1_f1 = 0, 0
    
    tqdm_dataset = tqdm(enumerate(train_dataloader), total=train_dataloader.__len__())
    training = True
    for batch_idx, batch in tqdm_dataset:
        batch_loss, batch_rouge_1_f1, lr = train_step(batch, epoch, training)
        total_loss += batch_loss
        total_rouge_1_f1 += batch_rouge_1_f1
        
        tqdm_dataset.set_postfix({
            '%+10s' % 'Epoch': epoch + 1,
            '%10s' % GREEN + 'Loss' : '{:.4f}'.format(total_loss/(batch_idx+1)) + END,
            '%10s' % YELLOW + 'Rouge_1_F1' : '{:.4f}'.format(total_rouge_1_f1/(batch_idx+1)) + END,
            '%5s' % 'LR' : lr,
        })
            
    loss_plot.append(total_loss/(batch_idx+1))
    
    tqdm_dataset = tqdm(enumerate(valid_dataloader), total=valid_dataloader.__len__())
    training = False
    for batch_idx, batch in tqdm_dataset:
        batch_loss, batch_rouge_1_f1 = train_step(batch, epoch, training)
        total_val_loss += batch_loss
        total_val_rouge_1_f1 += batch_rouge_1_f1

        tqdm_dataset.set_postfix({
            '%+12s' % 'Epoch': epoch + 1,
            '%6s' % GREEN + 'Val Loss' : '{:.4f}'.format(total_val_loss/(batch_idx+1)) + END,
            '%6s' % YELLOW + 'Val Rouge_1_F1' : '{:.4f}'.format(total_val_rouge_1_f1/(batch_idx+1)) + END,
        })
    val_loss_plot.append(total_val_loss/(batch_idx+1)) 

    cur_val_loss = round(float((total_val_loss/(batch_idx+1)).detach().cpu()), 3)
    cur_val_rouge_1_f1 = round(float((total_val_rouge_1_f1/(batch_idx+1))), 3)

    # cur_val_loss = round(total_val_loss/(batch_idx+1), 4)
    # cur_val_rouge_1_f1 = round(total_val_rouge_1_f1/(batch_idx+1), 4)

    # if cur_val_loss < best_val_loss:
    if cur_val_rouge_1_f1 > best_val_rouge_1_f1:
        print(YELLOW + 'Best_Val_Rouge_1_F1 is updated from {:>5} to {:>5} on epoch {}'.format(best_val_rouge_1_f1, cur_val_rouge_1_f1, epoch+1) + END)
        best_val_rouge_1_f1 = cur_val_rouge_1_f1
        best_epoch = epoch+1
        torch.save(model.state_dict(), './'+save_name+'.ckpt')
        if best_epoch > 15:
            torch.save(model.state_dict(), './'+save_name+'_loss_{}_val_rouge_1_f1_{}.ckpt'.format(cur_val_loss, cur_val_rouge_1_f1))
            patience = 0
    else:
        patience += 1
    
    lrs.append(lr)
    
    if patience == early_stopping_patience:
        break

100%|██████████| 174/174 [00:37<00:00,  4.60it/s,      Epoch=1,      [92mLoss=5.3271[0m,      [93mRouge_1_F1=0.2889[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 11.29it/s,        Epoch=1,  [92mVal Loss=3.9880[0m,  [93mVal Rouge_1_F1=0.2934[0m]


[93mBest_Val_Rouge_1_F1 is updated from     0 to 0.293 on epoch 1[0m


100%|██████████| 174/174 [00:35<00:00,  4.97it/s,      Epoch=2,      [92mLoss=3.6532[0m,      [93mRouge_1_F1=0.2843[0m,    LR=0.0001]
100%|██████████| 10/10 [00:01<00:00,  9.60it/s,        Epoch=2,  [92mVal Loss=3.4263[0m,  [93mVal Rouge_1_F1=0.2937[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.293 to 0.294 on epoch 2[0m


100%|██████████| 174/174 [00:43<00:00,  4.05it/s,      Epoch=3,      [92mLoss=3.1457[0m,      [93mRouge_1_F1=0.3013[0m,    LR=0.0001]
100%|██████████| 10/10 [00:01<00:00,  8.60it/s,        Epoch=3,  [92mVal Loss=3.0891[0m,  [93mVal Rouge_1_F1=0.3053[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.294 to 0.305 on epoch 3[0m


100%|██████████| 174/174 [00:35<00:00,  4.91it/s,      Epoch=4,      [92mLoss=2.7259[0m,      [93mRouge_1_F1=0.3330[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.78it/s,        Epoch=4,  [92mVal Loss=2.9124[0m,  [93mVal Rouge_1_F1=0.3347[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.305 to 0.335 on epoch 4[0m


100%|██████████| 174/174 [00:36<00:00,  4.75it/s,      Epoch=5,      [92mLoss=2.3085[0m,      [93mRouge_1_F1=0.3795[0m,    LR=0.0001]
100%|██████████| 10/10 [00:01<00:00,  8.48it/s,        Epoch=5,  [92mVal Loss=2.7362[0m,  [93mVal Rouge_1_F1=0.3492[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.335 to 0.349 on epoch 5[0m


100%|██████████| 174/174 [00:38<00:00,  4.50it/s,      Epoch=6,      [92mLoss=1.9031[0m,      [93mRouge_1_F1=0.4466[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.01it/s,        Epoch=6,  [92mVal Loss=2.6365[0m,  [93mVal Rouge_1_F1=0.3632[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.349 to 0.363 on epoch 6[0m


100%|██████████| 174/174 [00:34<00:00,  5.06it/s,      Epoch=7,      [92mLoss=1.5212[0m,      [93mRouge_1_F1=0.5317[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 11.01it/s,        Epoch=7,  [92mVal Loss=2.5240[0m,  [93mVal Rouge_1_F1=0.3864[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.363 to 0.386 on epoch 7[0m


100%|██████████| 174/174 [00:34<00:00,  5.01it/s,      Epoch=8,      [92mLoss=1.1784[0m,      [93mRouge_1_F1=0.6236[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 11.00it/s,        Epoch=8,  [92mVal Loss=2.5095[0m,  [93mVal Rouge_1_F1=0.3928[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.386 to 0.393 on epoch 8[0m


100%|██████████| 174/174 [00:34<00:00,  5.02it/s,      Epoch=9,      [92mLoss=0.8866[0m,      [93mRouge_1_F1=0.7194[0m,    LR=0.0001]
100%|██████████| 10/10 [00:01<00:00,  9.42it/s,        Epoch=9,  [92mVal Loss=2.4431[0m,  [93mVal Rouge_1_F1=0.4012[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.393 to 0.401 on epoch 9[0m


100%|██████████| 174/174 [00:34<00:00,  5.00it/s,      Epoch=10,      [92mLoss=0.6487[0m,      [93mRouge_1_F1=0.7971[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 11.15it/s,        Epoch=10,  [92mVal Loss=2.3773[0m,  [93mVal Rouge_1_F1=0.4522[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.401 to 0.452 on epoch 10[0m


100%|██████████| 174/174 [00:33<00:00,  5.12it/s,      Epoch=11,      [92mLoss=0.4623[0m,      [93mRouge_1_F1=0.8630[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 11.10it/s,        Epoch=11,  [92mVal Loss=2.4682[0m,  [93mVal Rouge_1_F1=0.4325[0m]
100%|██████████| 174/174 [00:34<00:00,  4.99it/s,      Epoch=12,      [92mLoss=0.3245[0m,      [93mRouge_1_F1=0.9157[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.49it/s,        Epoch=12,  [92mVal Loss=2.4451[0m,  [93mVal Rouge_1_F1=0.4459[0m]
100%|██████████| 174/174 [00:35<00:00,  4.96it/s,      Epoch=13,      [92mLoss=0.2205[0m,      [93mRouge_1_F1=0.9514[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.74it/s,        Epoch=13,  [92mVal Loss=2.4810[0m,  [93mVal Rouge_1_F1=0.4486[0m]
100%|██████████| 174/174 [00:34<00:00,  5.01it/s,      Epoch=14,      [92mLoss=0.1485[0m,      [93mRouge_1_F1=0.9749[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.29it/s,        Epoch=14,  [9

[93mBest_Val_Rouge_1_F1 is updated from 0.452 to 0.469 on epoch 14[0m


100%|██████████| 174/174 [00:35<00:00,  4.95it/s,      Epoch=15,      [92mLoss=0.1033[0m,      [93mRouge_1_F1=0.9850[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.59it/s,        Epoch=15,  [92mVal Loss=2.3730[0m,  [93mVal Rouge_1_F1=0.4893[0m]


[93mBest_Val_Rouge_1_F1 is updated from 0.469 to 0.489 on epoch 15[0m


100%|██████████| 174/174 [00:35<00:00,  4.96it/s,      Epoch=16,      [92mLoss=0.0740[0m,      [93mRouge_1_F1=0.9884[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.22it/s,        Epoch=16,  [92mVal Loss=2.5645[0m,  [93mVal Rouge_1_F1=0.4591[0m]
100%|██████████| 174/174 [00:35<00:00,  4.95it/s,      Epoch=17,      [92mLoss=0.0539[0m,      [93mRouge_1_F1=0.9911[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 11.03it/s,        Epoch=17,  [92mVal Loss=2.6483[0m,  [93mVal Rouge_1_F1=0.4570[0m]
100%|██████████| 174/174 [00:35<00:00,  4.91it/s,      Epoch=18,      [92mLoss=0.0447[0m,      [93mRouge_1_F1=0.9920[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 10.38it/s,        Epoch=18,  [92mVal Loss=2.5947[0m,  [93mVal Rouge_1_F1=0.4596[0m]
100%|██████████| 174/174 [00:35<00:00,  4.97it/s,      Epoch=19,      [92mLoss=0.0440[0m,      [93mRouge_1_F1=0.9909[0m,    LR=0.0001]
100%|██████████| 10/10 [00:01<00:00,  9.72it/s,        Epoch=19,  [9

CPU times: user 12min 13s, sys: 35.1 s, total: 12min 48s
Wall time: 13min 42s





In [40]:
model.load_state_dict(torch.load('/content/Transformer/PyTorch/tft_model.ckpt'))

<All keys matched successfully>

In [55]:
def inference(text, enc_max_len):
    enc_inputs = mk_token_inputs(text, enc_max_len).unsqueeze(0).to(device)

    logits = model(enc_inputs)[0]
    
    outputs = torch.argmax(logits, dim=-1).to('cpu')[0]
    outputs = tokenizer.decode(outputs).split(' </s>')[0]

    return outputs

In [56]:
# text = '오늘 진짜 좋은 일 있었어!'
text = '아 슬슬 피곤하네'
# text = '잘까? 어떡하지?'
# text = '학습 잘 된 걸까?'

In [57]:
inference(text, config.max_enc_len)

'사랑하는 시간만큼 또 다시 들으세요.'

In [53]:
def inference(text, max_enc_len):
    # enc_token 생성: <string tokens>, [PAD] tokens
    enc_token = tokenizer.encode(text, max_length=max_enc_len, padding='max_length')
    # dec_token 생성: [BOS], [PAD] tokens
    dec_token = [config.bos_id]
    dec_token += [0] * (config.max_dec_len - len(dec_token))
    dec_token = dec_token[:config.max_dec_len]

    response = []
    for i in range(config.max_dec_len - 1):
        output = model(torch.tensor([enc_token]).to(device), torch.tensor([dec_token]).to(device))[0].detach().cpu().numpy()
        word_id = int(np.argmax(output, axis=2)[0][i])

        # [EOS] 토큰이 생성되면 종료
        if word_id == config.eos_id:
            break
        # 예측된 token을 응답에 저장
        response.append(word_id)
        # 예측된 token을 decoder의 다음 입력으로 저장
        dec_token[i + 1] = word_id
    
    return tokenizer.decode(response)

In [54]:
inference(text, config.max_enc_len)

'사랑하는 시간만큼 또 다시 들으세요.'

In [59]:
print('종료를 원하실 시에는 exit를 입력해주세요.')
while True:
    print("input > ", end="")
    string = str(input())
    if string == 'exit':
        break
    print(f"output > {inference(string, config.max_enc_len)}")

종료를 원하실 시에는 exit를 입력해주세요.
input > 안녕하세요
output > 안녕하세요.
input > 챗봇이에요?
output > 저는 사람으로 태어나고 싶어요.
input > exit
