In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import math
import random
import numpy as np

In [None]:
class Tokenizer:

  def __init__(self):
    self.dicionario = {}
    self.dicionario_inverso = {}
    self.token_desconhecido = 1
    self.token_padding = 0
    self.token_sos = 2
    self.token_eos = 3
    self.max_len = 0

  def fit(self, dataset):
    # Armazena todas as palavras únicas em um dataset de texto
    indice = 4
    for texto in dataset:
      palavras = texto.split(' ')
      self.max_len = len(palavras) + 2 if len(palavras) + 2 > self.max_len else self.max_len
      for palavra in palavras:
        if palavra not in self.dicionario:
          self.dicionario[palavra] = indice
          indice += 1

    self.dicionario_inverso = { 
      valor: chave for chave, valor in self.dicionario.items() 
    }

  def encode(self, texto):
    # Converte texto para tokens (inteiros)
    tokens = texto.split(' ')
    texto_tokenizado = [self.token_sos]
    for token in tokens:
      if token not in self.dicionario:
        texto_tokenizado.append(self.token_desconhecido)
      else:
        texto_tokenizado.append(self.dicionario[token])
        
    texto_tokenizado.append(self.token_eos)

    for _ in range(self.max_len - len(texto_tokenizado)):
      texto_tokenizado.append(self.token_padding)
  
    return texto_tokenizado

  def decode(self, tokens):
    # Converte tokens para texto
    texto = []
    for token in tokens:
      if token not in self.dicionario_inverso:
        texto.append(self.token_desconhecido)
      else:
        texto.append(self.dicionario_inverso[token])
    return ' '.join(texto)

  def save(self):
    # armazena o dicionario usando json
    with open('dicionario.json', 'w') as f:
      json.dump(self.dicionario, f)
    with open('dicionario_inverso.json', 'w') as f:
      json.dump(self.dicionario_inverso, f)

  def load(self):
    # carrega o dicionario usando json
    with open('dicionario.json', 'r') as f:
      self.dicionario = json.load(f)
    with open('dicionario_inverso.json', 'r') as f:
      self.dicionario_inverso = json.load(f)

In [None]:
dataset = [
  "bom celular",
  "sol, choveu durante o dia",
  "bom dia, não gostei disso"
]

tokenizer = Tokenizer()
tokenizer.fit(dataset)
tokenized_dataset = [tokenizer.encode(x) for x in dataset]
tokenizer.encode(dataset[1])

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout_p)

        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)

        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding", pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_tokens,
        dim_model,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        dropout_p,
    ):
      super().__init__()
      self.dim_model = dim_model

      self.positional_encoder = PositionalEncoding(
          dim_model=dim_model,
          dropout_p=dropout_p,
          max_len=5000
      )

      self.embedding = nn.Embedding(num_tokens, dim_model)

      self.transformer = nn.Transformer(
          d_model=dim_model,
          nhead=num_heads,
          num_encoder_layers=num_encoder_layers,
          num_decoder_layers=num_decoder_layers,
          dropout=dropout_p
      )

      self.output = nn.Linear(dim_model, num_tokens)
        
    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
      src = self.embedding(src) * math.sqrt(self.dim_model)
      tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
      src = self.positional_encoder(src)
      tgt = self.positional_encoder(tgt)

      src = src.permute(1, 0, 2)
      tgt = tgt.permute(1, 0, 2)

      output_transformer = self.transformer(
          src,
          tgt,
          tgt_mask=tgt_mask,
          src_key_padding_mask=src_pad_mask,
          tgt_key_padding_mask=tgt_pad_mask
      )
      output = self.output(output_transformer)
      return output
      
    def get_tgt_mask(self, size) -> torch.tensor:
      mask = torch.tril(torch.ones(size, size))
      mask = mask.masked_fill(mask == 0, float('-inf'))
      mask = mask.masked_fill(mask == 1, float(0.0))
      return mask

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(
    num_tokens=len(tokenizer.dicionario.keys()) + 4,
    dim_model=32, 
    num_heads=2, 
    num_encoder_layers=3, 
    num_decoder_layers=3, 
    dropout_p=0.1
).to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

In [None]:
model.train()
for _ in range(500):
  for sentence in tokenized_dataset:
      X, y = sentence, sentence
      X, y = torch.tensor(X).to(device)[None], torch.tensor(y).to(device)[None]

      y_input = y[:,:-1]
      y_expected = y[:,1:]
      
      sequence_length = y_input.size(1)
      tgt_mask = model.get_tgt_mask(sequence_length).to(device)

      pred = model(X, y_input, tgt_mask)

      pred = pred.permute(1, 2, 0)      
      loss = loss_fn(pred, y_expected)

      opt.zero_grad()
      loss.backward()
      opt.step()

In [None]:
model.eval()

x = torch.tensor([[2, 6]], device=device)
y_input = torch.tensor([[2]], dtype=torch.long, device=device)

for _ in range(7):
    tgt_mask = model.get_tgt_mask(y_input.size(1)).to(device)
    
    pred = model(x, y_input, tgt_mask)
    
    next_item = torch.argmax(pred, 2)[-1].item()
    next_item = torch.tensor([[next_item]], device=device)

    y_input = torch.cat((y_input, next_item), dim=1)

    if next_item.view(-1).item() == 3:
        break

tokenizer.decode(y_input[0, 1:-1].cpu().numpy())