In [1]:
import math

import sentencepiece as spm
import torch.nn as nn

sp = spm.SentencePieceProcessor()
sp.load("sp-bpt-anderson.model")

True

# Model

In [2]:
import torch

from model import TransformerModule

max_seq_len = 128

In [3]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.load("sp-bpt-anderson.model")

True

In [5]:
model_path = (
    "checkpoints/machine-translation-mode=1.0-epoch=94-step=415625-loss=1.9441.ckpt"
)
model = TransformerModule.load_from_checkpoint(model_path)
model.eval()
model.to("cuda")

TransformerModule(
  (src_embedding): Embedding(8000, 256)
  (tgt_embedding): Embedding(8000, 256)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x 

In [32]:
import torch
from sentencepiece import SentencePieceProcessor
from torch import nn

from model import TransformerModule
from typing import Optional, List

class BeamSearch:

    def __init__(self, model: TransformerModule, sp: SentencePieceProcessor, device: Optional[str] = None,
                 max_seq_len: int = 128):
        self.model = model
        self.sp = sp
        self.max_seq_len = max_seq_len
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def greedy_search_from_text(self, text):
        self.model.eval()
        self.model.to(self.device)

        src_tensor = self.create_source_tensor(text)
        src_padding_mask = self.get_padding_mask(src_tensor, pad_idx=self.sp.pad_id())
        return self.greedy_search(src_tensor, src_padding_mask)
        
    def greedy_search(self, src_tensor: torch.Tensor, src_padding_mask: torch.Tensor):
        batch_size = src_tensor.shape[0]
        # Get initial encoder output
        # if we wrap it with `torch.no_grad()` it doesn't work for some reason.
        with torch.enable_grad():
            memory = self.model.encode(src_tensor, src_padding_mask=src_padding_mask)
            memory = memory.to(self.device)
            mask = torch.zeros(batch_size).type(torch.bool).to(self.device)
            
        with torch.no_grad():
            # Create decoder input.
            # it starts with <bos> token.
            y_pred = (
                torch.ones(batch_size, 1)
                .fill_(self.sp.bos_id())
                .type(torch.long)
                .to(self.device)
            )
            
            for i in range(self.max_seq_len - 1):
                tgt_mask = (nn.Transformer.generate_square_subsequent_mask(y_pred.size(1))
                            .type(torch.bool).to(self.device))
                out = self.model.decode(y_pred, memory, tgt_mask)
                prob = self.model.out(out[:, -1])
                _, next_words = torch.max(prob, dim=1)
                
                y_pred = torch.cat(
                    [y_pred, 
                     next_words.masked_fill(mask, self.sp.pad_id()).type_as(src_tensor.data).unsqueeze(1)], dim=1).to(self.device)
                
                mask |= next_words == self.sp.eos_id()
                if mask.all().item():
                    break
                    
        return y_pred, prob

    def convert_output_to_text(self, y_pred: torch.Tensor):
        batch_size = y_pred.shape[0]
        output = [None] * batch_size
        for i in range(batch_size):
            output[i] = self.sp.Decode(y_pred[i].tolist())
        return output

    def create_source_tensor(self, texts: List[str]) -> torch.Tensor:
        # Create src input
        batch_size = len(texts)
        src_tensor = torch.zeros(batch_size, self.max_seq_len, dtype=torch.int32).to(self.model.device)
        
        for i, text in enumerate(texts):
            src_tokenized = self.sp.Encode(text, add_bos=True, add_eos=True)
            src_tokenized = src_tokenized[:self.max_seq_len]
            if src_tokenized[-1] != self.sp.eos_id():
                src_tokenized[-1] = self.sp.eos_id()
            
            src_tensor[i, :len(src_tokenized)] = torch.Tensor(src_tokenized)
        src_tensor = src_tensor.to(self.device)
        return src_tensor

    @staticmethod
    def get_padding_mask(seq, pad_idx: int):
        return torch.tensor(seq == pad_idx).to(seq.device)


search = BeamSearch(model, sp)
y_pred, y_prob = search.greedy_search_from_text(["제가 이번 여름 휴가 보낸 이야기를 할게요.", "도와드릴까요??", "돈 줘", "해리포터 읽고 있어요"])
search.convert_output_to_text(y_pred)

  return torch.tensor(seq == pad_idx).to(seq.device)


['I will tell you the story of my vacation this summer vacation.',
 'Would you like me to help you?',
 'Please give me money.',
 "I'm reading Harry Potter."]

In [ ]:
from beam_search import BeamSearch


search = BeamSearch(model, sp)
search.greedy_search("제가 이번 여름 휴가 보낸 이야기를 할게요.")

In [ ]:
text = "제가 이번 여름 휴가 보낸 이야기를 할게요."
src_tokenized = sp.encode(text, add_bos=True, add_eos=True)

src_tokenized = src_tokenized[:max_seq_len]
if src_tokenized[-1] != sp.eos_id():
    src_tokenized[-1] = sp.eos_id()

src_tensor = torch.zeros(max_seq_len, dtype=torch.int32).to(model.device)
src_tensor[: len(src_tokenized)] = torch.tensor(src_tokenized)
src_tensor = src_tensor.unsqueeze(0)

src_padding_mask = get_padding_mask(src_tensor, pad_idx=sp.pad_id())

print(f"src_tensor:", src_tensor.shape, src_tensor)
print(f"tgt_input :", tgt_input.shape, tgt_input)

In [ ]:
with torch.no_grad():
    # tgt_idx = 1
    # tgt_padding_mask = get_padding_mask(tgt_input[:, :tgt_idx], pad_idx=sp.pad_id())
    # tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_idx, device='cuda')

    memory = model.encode(src=src_tensor, src_padding_mask=src_padding_mask)
    
    # output = model(
    #     src_tensor,
    #     tgt_input[:, :tgt_idx],
    #     src_padding_mask=src_padding_mask,
    #     tgt_padding_mask=tgt_padding_mask,
    #     tgt_mask=tgt_mask
    # )

    print(memory)
    

In [ ]:
torch.max(output[0], dim=1)

In [ ]:
tgt_input

In [ ]:
tgt_input = torch.zeros(max_seq_len, dtype=torch.int32).to(model.device)
tgt_input[0] = sp.bos_id()
tgt_input = tgt_input.unsqueeze(0)

with torch.no_grad():
    tgt_idx = 1
    while True:
        print(tgt_input[:, :tgt_idx][0].tolist())
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_idx, device='cuda').type(torch.bool)
    
        output = model.decode(tgt_input[:, :tgt_idx], memory, tgt_mask=tgt_mask)
        output = model.out(output)
        token_idx = torch.argmax(output[0][-1]).item()
        tgt_input[:, tgt_idx] = token_idx
        print(sp.decode(tgt_input[0].tolist()))
        tgt_idx += 1

        if token_idx == 3:
            break
        
        
    # print(torch.argmax(output[0][-1]))

In [ ]:

def greedy_decode(model, src, src_padding_mask, max_len, start_symbol, device='cuda'):
    """
    src: 인코더의 입력 문장, [batch_size, src_len]
    src_mask: 소스 문장의 마스크, [batch_size, 1, src_len]
    max_len: 생성할 최대 문장 길이
    start_symbol: 문장 시작을 나타내는 심볼 (BOS 토큰)
    """
    src = src.to(device)
    src_padding_mask = src_padding_mask.to(device)

    memory = model.encode(src, src_padding_mask=src_padding_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)

    for i in range(max_len-1):
        memory = memory.to(device)
        
        tgt_mask = (nn.Transformer.generate_square_subsequent_mask(ys.size(1))
                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.out(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        
        next_word = next_word[-1].item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)

        # print(sp.decode(ys[0].tolist()))
        if next_word == 3:
            break

    return ys

text = "오랫동안 일부 사람들은 생명이 없는 물질로부터 생명이 발생할 수 있다고 믿었습니다."
src_tokenized = sp.encode(text, add_bos=True, add_eos=True)

src_tokenized = src_tokenized[:max_seq_len]
if src_tokenized[-1] != sp.eos_id():
    src_tokenized[-1] = sp.eos_id()

src_tensor = torch.zeros(max_seq_len, dtype=torch.int32).to(model.device)
src_tensor[: len(src_tokenized)] = torch.tensor(src_tokenized)
src_tensor = src_tensor.unsqueeze(0)
src_padding_mask = get_padding_mask(src_tensor, pad_idx=sp.pad_id())

output = greedy_decode(model, src_tensor, src_padding_mask, max_len=128, start_symbol=sp.bos_id())

print(sp.decode(output[0].tolist()))

In [178]:
model

TransformerModule(
  (src_embedding): Embedding(8000, 256)
  (tgt_embedding): Embedding(8000, 256)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x 

In [71]:
print(output.shape)
tgt_input[:, 1] = torch.argmax(output[-1], dim=1)

torch.Size([1, 1, 8000])


In [72]:
tgt_input

tensor([[ 2, 17,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]], device='cuda:0', dtype=torch.int32)

In [56]:
def greedy_decode(model, src, max_len, start_symbol, device='cuda'):
    src = src.to(device)

    memory = model(src
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


def translate(model, text, sp):
    model.eval()
    pad_idx = sp.pad_idx()
    
    src_tokenized = sp.encode(text, add_bos=True, add_eos=True)
    if src_tokenized[-1] != sp.eos_id():
        src_tokenized[-1] = sp.eos_id()
        
    src_tensor = torch.zeros(max_seq_len, dtype=torch.int32).to(model.device)
    src_tensor[:len(src_tokenized)] = torch.tensor(src_tokenized)
    src_tensor = src_tensor.unsqueeze(0)

    src_padding_mask = get_padding_mask(src_tensor, pad_idx)

    madel(src_tensor, src_padding_mask=src_padding_mask)

    
    
    greedy_decode(model, src_tensor, max_len=128, start_symbol=sp.bos_idx())

def get_padding_mask(seq, pad_idx:int):
    torch.tensor(seq == pad_idx).to(seq.device)
                  
translate(model, '안녕! 내일 학교에서 보자! 내일은 돈 꼭 갚아!', sp=sp)

SyntaxError: invalid syntax (3820642031.py, line 6)

In [49]:
model._make_padding_mask(4, 0)

AttributeError: 'int' object has no attribute 'device'