# Seq2Seq 모델 구현 및 챗봇데이터 학습


In [93]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader


import pandas as pd

from konlpy.tag import Mecab


class LSTMDecoder(nn.Module):
    def __init__(self, num_classes=18):
        super(LSTMDecoder, self).__init__()
        self.net = nn.Conv2d()

    def forward(self, x):
        x = self.net(x)
        return x


class LSTMEncoder(nn.Module):
    def __init__(self, embedding_dim, output_dim):
        super(LSTMEncoder, self).__init__()
        self.lstm = nn.LSTM(input_size=512, hidden_size=1024, num_layers=2 ,batch_first=True)

    def forward(self, embedded_ids):
        last_hidden_state = self.lstm(embedded_ids)[1][0]
        return last_hidden_state



class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim):
        super(Seq2Seq, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = LSTMEncoder(embedding_dim, output_dim)
        self.decoer = LSTMDecoder(embedding_dim, output_dim)

    def forward(self, input_ids, output_ids):
        embedded_ids = self.word_embedding(input_ids)
        
        
        return x


In [3]:
chatbot_data = pd.read_csv('../chatbot_data/ChatbotData.csv')
chatbot_data

#형태소 분석기 호출 -> 형태소 별로 단어를 나누고 이를 set으로 만들어 vocab생성
pos_tagger = Mecab()

question = chatbot_data.Q # Seq2Seq에 encoder_input
answer = chatbot_data.A # Seq2Seq에 decoder_input
total_utterances = pd.concat((question,answer)) #vocab을 만들기위한 전체 발화문장

단어 임베딩 vocab을 만들기위해 아래와같이 단어별로 나눈 뒤 vocab set에 넣어줍니다.

set에 add되기때문에 중복되는 vocab은 존재하지 않습니다.

In [92]:
vocab = []

special_tokens = ['[PAD]', '[MASK]', '[START]', '[END]', '[UNK]']
for special_token in special_tokens:
    vocab.append(special_token)

for utterance in total_utterances:
    for eojeols in pos_tagger.pos(utterance,flatten=False, join=True):
        count = 0
        for token in eojeols:
            if count > 0:
                if token in vocab:
                    continue
                vocab.append('##' + token)
            else:
                if token in vocab:
                    continue
                vocab.append(token)
                count += 1
        
vocab_size = len(vocab)

In [7]:
token2index = {token : index for index, token in enumerate(vocab)}
index2token = {index : token for index, token in enumerate(vocab)}

In [8]:
vocab_size = len(vocab)

In [26]:
class WordHandler:
    def __init__(self, vocab, pos_tagger, token2index, index2token):
        self.vocab = vocab
        self.pos_tagger = pos_tagger
        self.token2index = token2index
        self.index2token = index2token
        
    def encode(self, sentence):
        encoded_vector = [self.token2index[token] if token in self.token2index else self.token2index['[UNK]']
                          for token in self.pos_tagger.pos(sentence, join= True)]
        
        return encoded_vector
    
    def decode(self, indice, join=True):
        decoded_vector = [self.index2token[index] for index in indice]
        
        return decoded_vector
    
    def decode_without_tag(self, indice):
        decoded_vector = ' '.join([self.index2token[index].split('/')[0] for index in indice])
        
        return decoded_vector
    
    @staticmethod
    def return_max_seq_len(sentences):
        max_seq_len = 0
        for sentence in sentences:
            max_seq_len = max(len(sentence), max_seq_len)
        
        return max_seq_len
            
        
        

In [74]:
class ChitChatDataset(Dataset):
    def __init__(self, input_ids, output_ids, index2token, token2index, max_seq_len):
        self.input_ids = input_ids
        self.output_ids = output_ids
        self.index2token = index2token
        self.token2index = token2index
        self.max_seq_len = max_seq_len

    def __getitem__(self, idx):
        
        if len(self.input_ids[idx]) + 2 < self.max_seq_len:
            padding_block = self.max_seq_len - len(self.input_ids[idx]) + 2
            input = torch.LongTensor([self.token2index['[START]']] + 
                                     self.input_ids[idx] + 
                                     [self.token2index['[END]']] + 
                                     [self.token2index['[PAD]']] * padding_block)
        else:
            input = torch.LongTensor([self.token2index['[START]']] + 
                                     self.input_ids[idx] + 
                                     [self.token2index['[END]']])
        
        if len(self.output_ids[idx]) + 2 < self.max_seq_len:
            padding_block = self.max_seq_len - len(self.output_ids[idx]) + 2
            output = torch.LongTensor([self.token2index['[START]']] + 
                                     self.output_ids[idx] + 
                                     [self.token2index['[END]']] + 
                                     [self.token2index['[PAD]']] * padding_block)
        else:
            output = torch.LongTensor([self.token2index['[START]']] + 
                                     self.output_ids[idx] + 
                                     [self.token2index['[END]']])

        return input, output

    def __len__(self):
        return len(self.input_ids)

In [55]:
handler = WordHandler(vocab, pos_tagger, token2index, index2token)

In [56]:
output_ids = question.map(handler.encode)
input_ids = answer.map(handler.encode)

In [57]:
max_seq_len = max(handler.return_max_seq_len(output_ids), handler.return_max_seq_len(input_ids))
max_seq_len

40

In [75]:
chitchat_data = ChitChatDataset(input_ids, output_ids, index2token, token2index, 60)

In [77]:
chichat_dataloader = DataLoader(chitchat_data, batch_size= 10, shuffle = True)

In [107]:
for input, output in chichat_dataloader:
    a = embedding(input)
    break