In [1]:
import re
import torch

from src.model import Tformer, save
from src.dataset import Preprocessing, MakeDataset

In [27]:
class E2E_dialog:
    def __init__(self, dataset, model_path):
        self.vocab = dataset.transformers_tokenizer
        self.vocab_size = dataset.transformers_tokenizer.vocab_size()
        
        self.model = Tformer(num_tokens = self.vocab_size, dim_model = 256, num_heads = 8, dff = 512, num_layers = 2, dropout_p = 0.1)
        device = torch.device('cpu')
        self.model.load_state_dict(torch.load(model_path, map_location = device))
        self.model.eval()
        self.MAX_LENGTH = 50
        
        
    def preprocess_sentence(self, sentence):
        sentence = re.sub(r"([?.!,])", r"\1", sentence)
        sentence = sentence.strip()
        return sentence
    
    def evaluate(self, sentence):
        sentence = self.preprocess_sentence(sentence)
        input = torch.tensor([[2] + self.vocab.encode_as_ids(sentence) + [3]])
        output = torch.tensor([[2]])
        
        # decode predict start\
        ps = []
        for i in range(self.MAX_LENGTH):
            src_mask = self.model.generate_square_subsequent_mask(input.shape[1])
            tgt_mask = self.model.generate_square_subsequent_mask(output.shape[1])

            src_padding_mask = self.model.gen_attention_mask(input)
            tgt_padding_mask = self.model.gen_attention_mask(output)

            predictions = self.model(input, output, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask).transpose(0,1)
            
            # get current point predict word
            predictions = predictions[:, -1:, :]
            predictions = torch.softmax(predictions.view(-1).cpu(), dim=0)
            predictions = torch.max(predictions, axis = -1)
            predicted_p = predictions.values
            ps.append(predicted_p)
            predicted_id =predictions.indices.view(1,1)

            
            # if last point token is end token prediction end
            if torch.equal(predicted_id[0][0], torch.tensor(3)):
                break
                
            # output fuction conect the last point precition word 
            # this will use to decode input data
            output = torch.cat([output, predicted_id], axis=1)

        return torch.squeeze(output, axis=0).cpu().numpy(), (sum(ps)/len(ps)).detach().numpy()

    def predict(self, sentence):
        prediction, predicted_sentence_p = self.evaluate(sentence)
        predicted_sentence = self.vocab.Decode(list(map(int,[i for i in prediction if i < self.vocab_size])))

        print('Input: {}'.format(sentence))
        print('Output: {}'.format(predicted_sentence))

        return predicted_sentence, predicted_sentence_p

In [28]:
chitchat_pretrain_path = './chatbot_data/pretraining/4_chitchat_trasnformer_model/chitchat_transformer_0.035401_steps_69.pt'

In [29]:
dataset = MakeDataset()
e2e = E2E_dialog(dataset, chitchat_pretrain_path)

In [33]:
%%time
s, p = e2e.predict('오')

Input: 오
Output: 제가 들어드릴게요.
CPU times: total: 62.5 ms
Wall time: 24.9 ms


In [31]:
float(p)

0.4220072329044342