In [None]:
!pip install transformers==4.7 torchinfo rouge

In [None]:
!git clone https://github.com/Taeksu-Kim/Transformer.git

In [None]:
cd Transformer/PyTorch

In [None]:
!git clone https://github.com/songys/Chatbot_data.git

In [5]:
# common
import math
import random
import os
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from torchinfo import summary
from rouge import Rouge

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from transformers import AutoTokenizer

# custom
from transformer import Transformer

In [6]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore
    torch.cuda.manual_seed_all(seed)

seed = 42

seed_everything(seed)

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

epochs = 300
learning_rate = 1e-4
weight_decay = 1e-2
batch_size = 64

early_stopping_patience = 10

save_name = 'chatbot_model'

In [8]:
df = pd.read_csv('./Chatbot_data/ChatbotData.csv')

In [9]:
model_path = "monologg/kobigbird-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)

Downloading:   0%|          | 0.00/870 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/241k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/492k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/169 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/373 [00:00<?, ?B/s]

In [10]:
def cal_token_len(text, tokenizer):
  return len(tokenizer.encode(text))

In [11]:
df['enc_token_len'] = [ cal_token_len(df.iloc[i]['Q'], tokenizer) for i in tqdm(range(df.shape[0])) ]

100%|██████████| 11823/11823 [00:03<00:00, 3702.07it/s]


In [12]:
df['dec_token_len'] = [ cal_token_len(df.iloc[i]['A'], tokenizer) for i in tqdm(range(df.shape[0])) ]

100%|██████████| 11823/11823 [00:03<00:00, 3309.52it/s]


In [13]:
df.head()

Unnamed: 0,Q,A,label,enc_token_len,dec_token_len
0,12시 땡!,하루가 또 가네요.,0,6,8
1,1지망 학교 떨어졌어,위로해 드립니다.,0,8,6
2,3박4일 놀러가고 싶다,여행은 언제나 좋죠.,0,11,8
3,3박4일 정도 놀러가고 싶다,여행은 언제나 좋죠.,0,12,8
4,PPL 심하네,눈살이 찌푸려지죠.,0,6,9


In [14]:
tar_per_list = [95,98,99,100]
tar_col = df['enc_token_len']

for i in tar_per_list:
    print('{}% length : {}'.format(i, np.percentile(tar_col,i)))

95% length : 16.0
98% length : 19.0
99% length : 20.0
100% length : 30.0


In [15]:
max_enc_len = 20

In [16]:
tar_col = df['dec_token_len']

for i in tar_per_list:
    print('{}% length : {}'.format(i, np.percentile(tar_col,i)))

95% length : 18.0
98% length : 20.0
99% length : 23.0
100% length : 42.0


In [17]:
max_dec_len = 26

In [18]:
df = df[(df['enc_token_len']<=max_enc_len)&(df['dec_token_len']<=max_dec_len)]

In [19]:
df.head()

Unnamed: 0,Q,A,label,enc_token_len,dec_token_len
0,12시 땡!,하루가 또 가네요.,0,6,8
1,1지망 학교 떨어졌어,위로해 드립니다.,0,8,6
2,3박4일 놀러가고 싶다,여행은 언제나 좋죠.,0,11,8
3,3박4일 정도 놀러가고 싶다,여행은 언제나 좋죠.,0,12,8
4,PPL 심하네,눈살이 찌푸려지죠.,0,6,9


In [20]:
train, valid =  train_test_split(df, test_size=0.05, random_state=seed, shuffle=True)

In [21]:
def mk_token_inputs(text, max_seq_len, mode='encoder'):
    input_ids = tokenizer.encode(text,max_length=max_seq_len, padding='max_length')
    
    if mode == 'decoder':
      cls_idx = input_ids.index(tokenizer.cls_token_id)
      sep_idx = input_ids.index(tokenizer.sep_token_id)

      input_ids[cls_idx] = tokenizer.bos_token_id
      input_ids[sep_idx] = tokenizer.eos_token_id

    return torch.tensor(input_ids, dtype=int)

In [22]:
class chatbot_dataset(Dataset):

  def __init__(self, df, enc_max_len, dec_max_len):
    self.df = df
    self.enc_max_len = enc_max_len
    self.dec_max_len = dec_max_len

  def __len__(self):
    return self.df.shape[0]

  def __getitem__(self, index):

    return {'enc_inputs' : mk_token_inputs(self.df['Q'].iloc[index], self.enc_max_len),
            'dec_inputs' : mk_token_inputs(self.df['A'].iloc[index], self.dec_max_len, mode='decoder'),
           }

In [23]:
train_dataset = chatbot_dataset(train, max_enc_len, max_dec_len+1)
valid_dataset = chatbot_dataset(valid, max_enc_len, max_dec_len+1)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers=0, shuffle=True)

In [24]:
for i, batch in enumerate(train_dataloader):
    break

In [25]:
# Config Class
# dict class를 json으로 바꿔서 confg.arg 와 같이 사용할 수 있게 만드는 class
class Config(dict): 
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

    @classmethod
    def load(cls, file):
        with open(file, 'r') as f:
            config = json.loads(f.read())
            return Config(config)

In [26]:
config_dict = {
    'vocab_size' : tokenizer.vocab_size,
    'd_model' : 256,
    'max_enc_len' : max_enc_len,
    'max_dec_len' : max_dec_len,
    'pad_id' : tokenizer.pad_token_id,
    'bos_id' : tokenizer.bos_token_id,
    'eos_id' : tokenizer.eos_token_id,
    'use_decoder' : True,
    'init_std' : 2e-2,
    'norm_eps' : 1e-12, 
    'drop_out_raito' : 0.1,
    'num_enc_layers' : 3,
    'num_dec_layers' : 3,
    'num_att_heads' : 4,
    'feed_forward_dim' : 1024,
}

config = Config(config_dict)

In [27]:
model = Transformer(config)

In [28]:
enc_inputs = batch['enc_inputs']
dec_inputs = batch['dec_inputs'][:,1:]
summary(model, input_data=[enc_inputs, dec_inputs])

Layer (type:depth-idx)                                            Output Shape              Param #
Transformer                                                       [64, 26, 32500]           --
├─TransformerEncoder: 1-1                                         [64, 20, 256]             --
│    └─Embedding: 2-1                                             [64, 20, 256]             8,320,000
│    └─ModuleList: 2-2                                            --                        --
│    │    └─TransformerEncoderLayer: 3-1                          [64, 20, 256]             789,760
│    │    └─TransformerEncoderLayer: 3-2                          [64, 20, 256]             789,760
│    │    └─TransformerEncoderLayer: 3-3                          [64, 20, 256]             789,760
├─TransformerDecoder: 1-2                                         [64, 26, 32500]           --
│    └─Embedding: 2-3                                             [64, 26, 256]             8,320,000
│    └─ModuleLis

In [29]:
config_dict = {
    'vocab_size' : tokenizer.vocab_size,
    'd_model' : 512,
    'max_enc_len' : max_enc_len,
    'max_dec_len' : max_dec_len,
    'pad_id' : tokenizer.pad_token_id,
    'bos_id' : tokenizer.bos_token_id,
    'eos_id' : tokenizer.eos_token_id,
    'use_decoder' : True,
    'init_std' : 2e-2,
    'norm_eps' : 1e-12, 
    'drop_out_raito' : 0.1,
    'num_enc_layers' : 6,
    'num_dec_layers' : 6,
    'num_att_heads' : 4,
    'feed_forward_dim' : 1024,
}

config = Config(config_dict)

In [30]:
model = Transformer(config)

In [31]:
enc_inputs = batch['enc_inputs']
dec_inputs = batch['dec_inputs'][:,1:]
summary(model, input_data=[enc_inputs, dec_inputs])

Layer (type:depth-idx)                                            Output Shape              Param #
Transformer                                                       [64, 26, 32500]           --
├─TransformerEncoder: 1-1                                         [64, 20, 512]             --
│    └─Embedding: 2-1                                             [64, 20, 512]             16,640,000
│    └─ModuleList: 2-2                                            --                        --
│    │    └─TransformerEncoderLayer: 3-1                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-2                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-3                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-4                          [64, 20, 512]             2,102,784
│    │    └─TransformerEncoderLayer: 3-5                          [64, 20, 512]             2,102,784
│ 

In [32]:
model.to(device)

Transformer(
  (encoder): TransformerEncoder(
    (word_embedding): Embedding(32500, 512, padding_idx=0)
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attention): AddNorm(
          (layer): MultiHeadAttention(
            (query_proj): Linear(in_features=512, out_features=512, bias=True)
            (key_proj): Linear(in_features=512, out_features=512, bias=True)
            (value_proj): Linear(in_features=512, out_features=512, bias=True)
            (scaled_dot_attn): ScaledDotProductAttention()
            (linear): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
        )
        (feed_forward): AddNorm(
          (layer): PoswiseFeedForward(
            (feed_forward): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): Dropout(p=0.1, inplace=False)
              (2): ReLU()
              (3): Linear(in_f

In [33]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [34]:
rouge = Rouge()

In [35]:
def cal_rouge_n(y_pred, y_true):
    y_pred = torch.argmax(y_pred, dim=-1)
   
    scores = []

    for i in range(y_true.shape[0]):
        score = 0
        reference = tokenizer.decode(y_true[i])
        hypothesis = tokenizer.decode(y_pred[i])
        if ' </s>' in hypothesis:
            
            reference = reference.split(' </s>')[0].replace('.', ' +002E')
            hypothesis = hypothesis.split(' </s>')[0].replace('.', ' +002E')

            if len(hypothesis) != 0:
              score = rouge.get_scores(hypothesis, reference)[0]['rouge-1']['f']
        scores.append(round(score,4))

    return sum(scores) / len(scores)

In [36]:
def cal_lm_acc(y_pred, y_true, pad_id):
    """
    acc 계산 함수
    :param y_true: 정답 (bs, n_seq)
    :param y_pred: 예측 값 (bs, n_seq, n_vocab)
    """
    # 정답 여부 확인
    y_pred = torch.argmax(y_pred, dim=-1).int()
    matches = torch.eq(y_true, y_pred).int()
    
    # pad(0) 인 부분 mask
    mask = y_true.ne(pad_id).int()
    matches *= mask
    
    # 정확도 계산
    accuracy = torch.sum(matches) / torch.maximum(torch.sum(mask), torch.tensor(1, dtype=int))
    return accuracy

In [37]:
def train_step(batch, epoch, training):
    batch = {key: value.to(device) for key, value in batch.items()}

    if training is True:
        model.train()
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():

            logits = model(enc_inputs=batch['enc_inputs'],
                           dec_inputs=batch['dec_inputs'][:,:-1])[0]
            
            CCE = nn.CrossEntropyLoss(ignore_index=config.pad_id)
            loss = loss = CCE(logits.view(-1, config.vocab_size), batch['dec_inputs'][:,1:].contiguous().view(-1))
            # lm_acc = cal_lm_acc(logits, batch['dec_inputs'][:,1:], config.pad_id)
            rouge_1_f1 = cal_rouge_n(logits, batch['dec_inputs'][:,1:])

        loss.backward()
        optimizer.step()
            
        lr = optimizer.param_groups[0]["lr"]

        return loss, rouge_1_f1, round(lr, 10)
        # return loss, lm_acc, round(lr, 10)

    else:
        model.eval()
        with torch.no_grad():
            logits = model(enc_inputs=batch['enc_inputs'],
                           dec_inputs=batch['dec_inputs'][:,:-1])[0]

            CCE = nn.CrossEntropyLoss(ignore_index=config.pad_id)
            loss = loss = CCE(logits.view(-1, config.vocab_size), batch['dec_inputs'][:,1:].contiguous().view(-1))
            # lm_acc = cal_lm_acc(logits, batch['dec_inputs'][:,1:], config.pad_id)
            rouge_1_f1 = cal_rouge_n(logits, batch['dec_inputs'][:,1:])

        return loss, rouge_1_f1
        # return loss, lm_acc

In [38]:
# class color:
GREEN = '\033[92m'
YELLOW = '\033[93m'
END = '\033[0m'

In [39]:
%%time
# train

loss_plot, val_loss_plot = [], []
lrs = []

best_val_rouge_1_f1 = 0
best_val_loss = np.inf

best_epoch = 0
patience = 0

for epoch in range(epochs):
    gc.collect()
    total_loss, total_val_loss = 0, 0
    total_rouge_1_f1, total_val_rouge_1_f1 = 0, 0
    
    tqdm_dataset = tqdm(enumerate(train_dataloader), total=train_dataloader.__len__())
    training = True
    for batch_idx, batch in tqdm_dataset:
        batch_loss, batch_rouge_1_f1, lr = train_step(batch, epoch, training)
        total_loss += batch_loss
        total_rouge_1_f1 += batch_rouge_1_f1
        
        tqdm_dataset.set_postfix({
            '%+10s' % 'Epoch': epoch + 1,
            '%10s' % GREEN + 'Loss' : '{:.4f}'.format(total_loss/(batch_idx+1)) + END,
            '%10s' % YELLOW + 'Rouge_1_F1' : '{:.4f}'.format(total_rouge_1_f1/(batch_idx+1)) + END,
            '%5s' % 'LR' : lr,
        })

    train_epoch_loss = round(float((total_loss/(batch_idx+1)).detach().cpu()), 4)
    loss_plot.append(train_epoch_loss)
    
    tqdm_dataset = tqdm(enumerate(valid_dataloader), total=valid_dataloader.__len__())
    training = False
    for batch_idx, batch in tqdm_dataset:
        batch_loss, batch_rouge_1_f1 = train_step(batch, epoch, training)
        total_val_loss += batch_loss
        total_val_rouge_1_f1 += batch_rouge_1_f1

        tqdm_dataset.set_postfix({
            '%+12s' % 'Epoch': epoch + 1,
            '%6s' % GREEN + 'Val Loss' : '{:.4f}'.format(total_val_loss/(batch_idx+1)) + END,
            '%6s' % YELLOW + 'Val Rouge_1_F1' : '{:.4f}'.format(total_val_rouge_1_f1/(batch_idx+1)) + END,
        })

    valid_epoch_loss = round(float((total_val_loss/(batch_idx+1)).detach().cpu()), 4)
    val_loss_plot.append(valid_epoch_loss) 

    if valid_epoch_loss < best_val_loss:
        print(YELLOW + 'Best_Val_Loss is updated from {:>5} to {:>5} on epoch {}'.format(best_val_loss, valid_epoch_loss, epoch+1) + END)
        best_val_loss = valid_epoch_loss
        best_epoch = epoch+1
        torch.save(model.state_dict(), './'+save_name+'.ckpt')
        patience = 0
    else:
        patience += 1
    
    lrs.append(lr)
    
    if patience == early_stopping_patience:
        break

100%|██████████| 174/174 [00:26<00:00,  6.62it/s,      Epoch=1,      [92mLoss=5.3277[0m,      [93mRouge_1_F1=0.2886[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.75it/s,        Epoch=1,  [92mVal Loss=3.9877[0m,  [93mVal Rouge_1_F1=0.2986[0m]


[93mBest_Val_Loss is updated from   inf to 3.9877 on epoch 1[0m


100%|██████████| 174/174 [00:24<00:00,  6.99it/s,      Epoch=2,      [92mLoss=3.6535[0m,      [93mRouge_1_F1=0.2850[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 14.15it/s,        Epoch=2,  [92mVal Loss=3.4309[0m,  [93mVal Rouge_1_F1=0.2924[0m]


[93mBest_Val_Loss is updated from 3.9877 to 3.4309 on epoch 2[0m


100%|██████████| 174/174 [00:22<00:00,  7.83it/s,      Epoch=3,      [92mLoss=3.1448[0m,      [93mRouge_1_F1=0.3019[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.95it/s,        Epoch=3,  [92mVal Loss=3.0863[0m,  [93mVal Rouge_1_F1=0.3049[0m]


[93mBest_Val_Loss is updated from 3.4309 to 3.0863 on epoch 3[0m


100%|██████████| 174/174 [00:22<00:00,  7.61it/s,      Epoch=4,      [92mLoss=2.7256[0m,      [93mRouge_1_F1=0.3329[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.65it/s,        Epoch=4,  [92mVal Loss=2.9097[0m,  [93mVal Rouge_1_F1=0.3348[0m]


[93mBest_Val_Loss is updated from 3.0863 to 2.9097 on epoch 4[0m


100%|██████████| 174/174 [00:22<00:00,  7.71it/s,      Epoch=5,      [92mLoss=2.3074[0m,      [93mRouge_1_F1=0.3805[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.54it/s,        Epoch=5,  [92mVal Loss=2.7410[0m,  [93mVal Rouge_1_F1=0.3491[0m]


[93mBest_Val_Loss is updated from 2.9097 to 2.741 on epoch 5[0m


100%|██████████| 174/174 [00:22<00:00,  7.70it/s,      Epoch=6,      [92mLoss=1.9024[0m,      [93mRouge_1_F1=0.4469[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.61it/s,        Epoch=6,  [92mVal Loss=2.6504[0m,  [93mVal Rouge_1_F1=0.3648[0m]


[93mBest_Val_Loss is updated from 2.741 to 2.6504 on epoch 6[0m


100%|██████████| 174/174 [00:22<00:00,  7.70it/s,      Epoch=7,      [92mLoss=1.5208[0m,      [93mRouge_1_F1=0.5299[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.68it/s,        Epoch=7,  [92mVal Loss=2.5323[0m,  [93mVal Rouge_1_F1=0.3903[0m]


[93mBest_Val_Loss is updated from 2.6504 to 2.5323 on epoch 7[0m


100%|██████████| 174/174 [00:22<00:00,  7.71it/s,      Epoch=8,      [92mLoss=1.1783[0m,      [93mRouge_1_F1=0.6255[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.76it/s,        Epoch=8,  [92mVal Loss=2.5209[0m,  [93mVal Rouge_1_F1=0.3940[0m]


[93mBest_Val_Loss is updated from 2.5323 to 2.5209 on epoch 8[0m


100%|██████████| 174/174 [00:22<00:00,  7.68it/s,      Epoch=9,      [92mLoss=0.8864[0m,      [93mRouge_1_F1=0.7174[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.78it/s,        Epoch=9,  [92mVal Loss=2.4424[0m,  [93mVal Rouge_1_F1=0.4024[0m]


[93mBest_Val_Loss is updated from 2.5209 to 2.4424 on epoch 9[0m


100%|██████████| 174/174 [00:22<00:00,  7.74it/s,      Epoch=10,      [92mLoss=0.6472[0m,      [93mRouge_1_F1=0.7989[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.84it/s,        Epoch=10,  [92mVal Loss=2.3564[0m,  [93mVal Rouge_1_F1=0.4462[0m]


[93mBest_Val_Loss is updated from 2.4424 to 2.3564 on epoch 10[0m


100%|██████████| 174/174 [00:22<00:00,  7.62it/s,      Epoch=11,      [92mLoss=0.4604[0m,      [93mRouge_1_F1=0.8654[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.70it/s,        Epoch=11,  [92mVal Loss=2.4739[0m,  [93mVal Rouge_1_F1=0.4330[0m]
100%|██████████| 174/174 [00:22<00:00,  7.75it/s,      Epoch=12,      [92mLoss=0.3228[0m,      [93mRouge_1_F1=0.9153[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.86it/s,        Epoch=12,  [92mVal Loss=2.4295[0m,  [93mVal Rouge_1_F1=0.4474[0m]
100%|██████████| 174/174 [00:22<00:00,  7.65it/s,      Epoch=13,      [92mLoss=0.2186[0m,      [93mRouge_1_F1=0.9528[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.80it/s,        Epoch=13,  [92mVal Loss=2.4883[0m,  [93mVal Rouge_1_F1=0.4499[0m]
100%|██████████| 174/174 [00:22<00:00,  7.75it/s,      Epoch=14,      [92mLoss=0.1488[0m,      [93mRouge_1_F1=0.9747[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.52it/s,        Epoch=14,  [9

[93mBest_Val_Loss is updated from 2.3564 to 2.3523 on epoch 14[0m


100%|██████████| 174/174 [00:22<00:00,  7.64it/s,      Epoch=15,      [92mLoss=0.1046[0m,      [93mRouge_1_F1=0.9839[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.67it/s,        Epoch=15,  [92mVal Loss=2.3416[0m,  [93mVal Rouge_1_F1=0.4879[0m]


[93mBest_Val_Loss is updated from 2.3523 to 2.3416 on epoch 15[0m


100%|██████████| 174/174 [00:22<00:00,  7.70it/s,      Epoch=16,      [92mLoss=0.0770[0m,      [93mRouge_1_F1=0.9880[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.38it/s,        Epoch=16,  [92mVal Loss=2.5229[0m,  [93mVal Rouge_1_F1=0.4630[0m]
100%|██████████| 174/174 [00:22<00:00,  7.71it/s,      Epoch=17,      [92mLoss=0.0575[0m,      [93mRouge_1_F1=0.9902[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.91it/s,        Epoch=17,  [92mVal Loss=2.6043[0m,  [93mVal Rouge_1_F1=0.4639[0m]
100%|██████████| 174/174 [00:22<00:00,  7.71it/s,      Epoch=18,      [92mLoss=0.0462[0m,      [93mRouge_1_F1=0.9908[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.63it/s,        Epoch=18,  [92mVal Loss=2.5596[0m,  [93mVal Rouge_1_F1=0.4633[0m]
100%|██████████| 174/174 [00:22<00:00,  7.72it/s,      Epoch=19,      [92mLoss=0.0413[0m,      [93mRouge_1_F1=0.9914[0m,    LR=0.0001]
100%|██████████| 10/10 [00:00<00:00, 13.72it/s,        Epoch=19,  [9

CPU times: user 9min 43s, sys: 19.1 s, total: 10min 3s
Wall time: 10min 1s





In [40]:
model.load_state_dict(torch.load('./'+save_name+'.ckpt'))
model.eval()

Transformer(
  (encoder): TransformerEncoder(
    (word_embedding): Embedding(32500, 512, padding_idx=0)
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attention): AddNorm(
          (layer): MultiHeadAttention(
            (query_proj): Linear(in_features=512, out_features=512, bias=True)
            (key_proj): Linear(in_features=512, out_features=512, bias=True)
            (value_proj): Linear(in_features=512, out_features=512, bias=True)
            (scaled_dot_attn): ScaledDotProductAttention()
            (linear): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
        )
        (feed_forward): AddNorm(
          (layer): PoswiseFeedForward(
            (feed_forward): Sequential(
              (0): Linear(in_features=512, out_features=1024, bias=True)
              (1): Dropout(p=0.1, inplace=False)
              (2): ReLU()
              (3): Linear(in_f

In [41]:
def inference(text, enc_max_len):
    enc_inputs = mk_token_inputs(text, enc_max_len).unsqueeze(0).to(device)

    logits = model(enc_inputs)[0]
    
    outputs = torch.argmax(logits, dim=-1).to('cpu')[0]
    outputs = tokenizer.decode(outputs).split(' </s>')[0]

    return outputs

In [42]:
# text = '오늘 진짜 좋은 일 있었어!'
text = '아 슬슬 피곤하네'
# text = '잘까? 어떡하지?'
# text = '학습 잘 된 걸까?'

In [43]:
inference(text, config.max_enc_len)

'사랑하는 것보다 낫죠.'

In [44]:
def inference(text, max_enc_len):
    # enc_token 생성: <string tokens>, [PAD] tokens
    enc_token = tokenizer.encode(text, max_length=max_enc_len, padding='max_length')
    # dec_token 생성: [BOS], [PAD] tokens
    dec_token = [config.bos_id]
    dec_token += [0] * (config.max_dec_len - len(dec_token))
    dec_token = dec_token[:config.max_dec_len]

    response = []
    for i in range(config.max_dec_len - 1):
        output = model(torch.tensor([enc_token]).to(device), torch.tensor([dec_token]).to(device))[0].detach().cpu().numpy()
        word_id = int(np.argmax(output, axis=2)[0][i])

        # [EOS] 토큰이 생성되면 종료
        if word_id == config.eos_id:
            break
        # 예측된 token을 응답에 저장
        response.append(word_id)
        # 예측된 token을 decoder의 다음 입력으로 저장
        dec_token[i + 1] = word_id
    
    return tokenizer.decode(response)

In [45]:
inference(text, config.max_enc_len)

'사랑하는 것보다 낫죠.'

In [46]:
print('종료를 원하실 시에는 exit를 입력해주세요.')
while True:
    print("input > ", end="")
    string = str(input())
    if string == 'exit':
        break
    print(f"output > {inference(string, config.max_enc_len)}")

종료를 원하실 시에는 exit를 입력해주세요.
input > 안녕하세요
output > 안녕하세요.
input > 너는 누구니?
output > 저는 위로봇입니다.
input > 오늘 너무 피곤해
output > 푹 쉬세요.
input > 공부 더 안하고 쉬어도 될까?
output > 지금처럼 잘하고 있어요.
input > 그래도 미래가 불안해
output > 서로 예의가 없다고 생각해보세요.
input > 그건 무슨 말이야?
output > 저도 좋아해요.
input > 킹받네
output > 직접 물어보세요.
input > exit
