In [110]:
from transformers import T5Tokenizer
from pprint import pprint

input = ["Translate the following English text into French: Hello, how are you? target:",
         "Translate the following English text into French: Good morning, everyone. target",
          "Translate the following English text into French: Can you help me with this? target" ]

target = ["Bonjour, comment ça va ?",
          "Bonjour à tous.",
          "Pouvez-vous m'aider avec ceci ?"]
tokenizer = T5Tokenizer.from_pretrained("t5-small")


encoder_tokenize = tokenizer(input, return_tensors="pt", padding=True)
decoder_tokenize = tokenizer(target, return_tensors="pt", padding=True)

encoder_input_ids = encoder_tokenize.input_ids
encoder_attention_mask = encoder_tokenize.attention_mask
decoder_input_ids = decoder_tokenize.input_ids
decoder_attention_mask = decoder_tokenize.attention_mask


loading file spiece.model from cache at C:\Users\yuhei/.cache\huggingface\hub\models--t5-small\snapshots\df1b051c49625cf57a3d0d8d3863ed4d13564fe4\spiece.model
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at C:\Users\yuhei/.cache\huggingface\hub\models--t5-small\snapshots\df1b051c49625cf57a3d0d8d3863ed4d13564fe4\tokenizer_config.json


tensor([[30355,    15,     8,   826,  1566,  1499,   139,  2379,    10,  8774,
             6,   149,    33,    25,    58,  2387,    10,     1],
        [30355,    15,     8,   826,  1566,  1499,   139,  2379,    10,  1804,
          1379,     6,   921,     5,  2387,     1,     0,     0],
        [30355,    15,     8,   826,  1566,  1499,   139,  2379,    10,  1072,
            25,   199,   140,    28,    48,    58,  2387,     1]])
['Bonjour, comment ça va ?',
 'Bonjour à tous.',
 "Pouvez-vous m'aider avec ceci ?"]


In [111]:
import numpy as np
import torch
import torch.nn as nn
from pprint import pprint

batch_size = encoder_input_ids.size(0)
encoder_seq_length = encoder_input_ids.size(1)
decoder_seq_length = decoder_input_ids.size(1)

kwargs = {
    "batch_size": batch_size,
    "num_embedding": max_ids + 1,
    "seq_length": (encoder_seq_length, decoder_seq_length),
    "hidden_size": 768,
    "num_layer": 12,
    "num_heads": 12,
    "ffn_hidden_size": 3072,

}

class T5(nn.Module):
    def __init__(self, batch_size, num_embedding, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.num_heads = num_heads
        self.ffn_hidden_size = ffn_hidden_size

        self.encoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
        self.decoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
        self.encoder = Encoder(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
        self.decoder = Decoder(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)


    def forward(self, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask):
        encoder_embedding = self.encoder_embedding(encoder_input_ids)
        output_encoder = self.encoder(encoder_embedding, encoder_attention_mask)
        decoder_embedding = self.decoder_embedding(decoder_input_ids)
        decoder_output = self.decoder(decoder_embedding, output_encoder, encoder_attention_mask, decoder_attention_mask)

        return decoder_output

class Encoder(nn.Module):
    def __init__(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()
        self._setupEncoderLayer(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)

    def _setupEncoderLayer(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        encoder_layer_list = []
        for _ in range(num_layer):
            encoder_layer = EncoderLayer(batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size)
            encoder_layer_list.append(encoder_layer)
        self.encoder_module = nn.ModuleList(encoder_layer_list)

    def forward(self, encoder_embedding, encoder_attention_mask):
        tokens = encoder_embedding
        for encoder_layer in self.encoder_module:
            tokens = encoder_layer(tokens, encoder_attention_mask)
        output_encoder = tokens

        return output_encoder
    
class EncoderLayer(nn.Module):
    def __init__(self, batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=False)
        self.add_norm1 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=True)
        self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
        self.add_norm2 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=True)

    def forward(self, tokens, encoder_attention_mask):
        skip1 = tokens
        multi_head_attention = self.multi_head_attention(tokens, tokens, tokens, encoder_attention_mask, encoder_attention_mask)
        add_norm1 = self.add_norm1(multi_head_attention, skip1)
        skip2 = add_norm1
        feed_forward = self.feed_forward(add_norm1)
        add_norm2 = self.add_norm2(feed_forward, skip2)
        tokens = add_norm2

        return tokens
    
    
class Decoder(nn.Module):
    def __init__(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()
        self._setupDecoderLayer(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)

    def _setupDecoderLayer(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        decoder_list = []
        for _ in range(num_layer):
            decoder_layer = DecoderLayer(batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size)
            decoder_list.append(decoder_layer)
        self.decoder_module = nn.ModuleList(decoder_list)

    def forward(self, decoder_embedding, output_encoder, encoder_attention_mask, decoder_attention_mask):
        tokens = decoder_embedding
        for decoder_layer in self.decoder_module:
            tokens = decoder_layer(tokens, output_encoder, encoder_attention_mask, decoder_attention_mask)
        output_decoder = tokens

        return output_decoder
    
class DecoderLayer(nn.Module):
    def __init__(self, batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size):
        super().__init__()
        self.masked_multi_head_attention =  MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=True)
        self.add_norm1 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
        self.cross_multi_head_attention = MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=False, check_mask=False)
        self.add_norm2 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
        self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
        self.add_norm3 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)

    def forward(self, tokens, output_encoder, encoder_attention_mask, decoder_attention_mask):
        skip1 = tokens
        masked_multi_head_attention = self.masked_multi_head_attention(tokens, tokens, tokens, decoder_attention_mask, decoder_attention_mask)
        add_norm1 = self.add_norm1(masked_multi_head_attention, skip1)
        skip2 = add_norm1
        cross_multi_head_attention = self.cross_multi_head_attention(tokens, output_encoder, output_encoder, decoder_attention_mask, encoder_attention_mask)
        add_norm2 = self.add_norm2(cross_multi_head_attention, skip2)
        skip3 = add_norm2
        feed_forward = self.feed_forward(tokens)
        add_norm3 = self.add_norm3(feed_forward, skip3)
        tokens = add_norm3

        return tokens
    
class MultiHeadAttention(nn.Module):
    def __init__(self, batch_size, num_heads, seq_length, hidden_size, check_positional_embedding, check_mask):
        super().__init__()
        self._setupHeadQKV(num_heads, hidden_size)
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.check_positional_embedding = check_positional_embedding
        self.check_mask = check_mask
        self.softmax = nn.Softmax(dim=-1)

    def _setupHeadQKV(self, num_heads, hidden_size):
        query_module = []
        key_module = []
        value_module = []
        head_hidden_size = int(hidden_size / num_heads)

        for _ in range(num_heads):
            query_module.append(nn.Linear(hidden_size, head_hidden_size))
            key_module.append(nn.Linear(hidden_size, head_hidden_size))
            value_module.append(nn.Linear(hidden_size, head_hidden_size))

        self.query_module = nn.ModuleList(query_module)
        self.key_module = nn.ModuleList(key_module)
        self.value_module = nn.ModuleList(value_module)

    def _outputRelativePositionalEmbeddingScalar(self, query, batch_size, seq_length, hidden_size, num_heads):
        if self.check_mask: seq_length = self.seq_length[1]
        else: seq_length = self.seq_length[0]

        embed_Module = []
        head_hidden_size = int(hidden_size / num_heads)
        position_ids = torch.tensor(list(range(seq_length)), dtype=torch.long).reshape(1, seq_length).expand(batch_size, seq_length)
        for id in range(num_heads): embed_Module.append(nn.Embedding(seq_length + 1, head_hidden_size))
        self.embed_module = nn.ModuleList(embed_Module)
        for id in range(num_heads):
            head_query = self.query_module[id](query)
            tmp_relative_position_embedding_scalar = (head_query@(self.embed_module[id](position_ids).transpose(1, 2))).reshape(1, batch_size, seq_length, seq_length)
            if id == 0: relative_position_embedding_scalar = tmp_relative_position_embedding_scalar
            else: relative_position_embedding_scalar = torch.concat([relative_position_embedding_scalar, tmp_relative_position_embedding_scalar], dim=0)

        return relative_position_embedding_scalar
    
    def _outputAttention(self, query, key, value, batch_size, seq_length, hidden_size, num_heads, check_positional_embedding, check_mask, encoder_attention_mask, decoder_attention_mask):

        if check_positional_embedding:
            if check_mask:
                seq_length1 = seq_length2 = self.seq_length[1]
            else:
                seq_length1 = seq_length2 = self.seq_length[0]

        else:
            seq_length1 = self.seq_length[1]
            seq_length2 = self.seq_length[0]

        head_hidden_size = int(hidden_size / num_heads)
        encoder_attention_mask = encoder_attention_mask.reshape(batch_size, -1, 1).expand(batch_size, seq_length1, seq_length2)
        decoder_attention_mask = decoder_attention_mask.reshape(batch_size, 1, -1).expand(batch_size, seq_length1, seq_length2)

        padding_map = encoder_attention_mask * decoder_attention_mask
        mask_map = torch.tensor(np.tril(np.ones((seq_length1, seq_length2))), dtype=torch.long)

        if check_positional_embedding:relative_position_embedding_scalar = self._outputRelativePositionalEmbeddingScalar(query, batch_size, seq_length, hidden_size, num_heads)
        else: relative_position_embedding_scalar = torch.zeros_like(padding_map, dtype=torch.float).expand(num_heads, batch_size, seq_length1, seq_length2)

        for id in range(num_heads):
            head_query = self.query_module[id](query)
            head_key = self.key_module[id](key)
            head_value = self.value_module[id](value)

            if check_mask: tmp_head_attention = self.softmax(padding_map * (mask_map * (head_query@head_key.transpose(1, 2)) / (head_hidden_size) + relative_position_embedding_scalar[id]))@head_value
            else:tmp_head_attention = self.softmax(padding_map * ((head_query@head_key.transpose(1, 2)) / (head_hidden_size) + relative_position_embedding_scalar[id]))@head_value
            if id == 0: head_attention = tmp_head_attention
            else: head_attention = torch.concat([head_attention, tmp_head_attention], dim=-1)
        output_attention = head_attention

        return output_attention

    def forward(self, query, key, value, encoder_attention_mask, decoder_attention_mask):
        output_attention = self._outputAttention(query, key, value, self.batch_size, self.seq_length, self.hidden_size, self.num_heads,
                                                  self.check_positional_embedding, self.check_mask, encoder_attention_mask, decoder_attention_mask)

        return output_attention
    
class AddNorm(nn.Module):
    def __init__(self, batch_size, seq_length, hidden_size, check_encoder):
        super().__init__()
        self._setupAddNormModule(batch_size, seq_length, hidden_size, check_encoder)

    def _setupAddNormModule(self, batch_size, seq_length, hidden_size, check_encoder):
        if check_encoder: seq_length = seq_length[0]
        else: seq_length = seq_length[1]
        self.layer_norm = nn.LayerNorm((batch_size, seq_length, hidden_size))

    def forward(self, tokens, skipped_tokens):
        tokens += skipped_tokens
        tokens = self.layer_norm(tokens)

        return tokens
    
class FeedForward(nn.Module):
    def __init__(self, hidden_size, ffn_hidden_size):
        super().__init__()
        self._setupFeedForwardModule(hidden_size, ffn_hidden_size)

    def _setupFeedForwardModule(self, hidden_size, ffn_hidden_size):
        dense1 = nn.Linear(hidden_size, ffn_hidden_size)
        relu1 = nn.ReLU()
        dense2 = nn.Linear(ffn_hidden_size, hidden_size)
        relu2 = nn.ReLU()
        self.feed_foward_module = nn.ModuleList([dense1, relu1, dense2, relu2])

    def forward(self, tokens):
        for module in self.feed_foward_module:
            tokens = module(tokens)
        return tokens
    
model = T5(**kwargs)
outputs = model(encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask)
print(model)


T5(
  (encoder_embedding): Embedding(30356, 768, padding_idx=0)
  (decoder_embedding): Embedding(30356, 768, padding_idx=0)
  (encoder): Encoder(
    (encoder_module): ModuleList(
      (0-11): 12 x EncoderLayer(
        (multi_head_attention): MultiHeadAttention(
          (query_module): ModuleList(
            (0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
          )
          (key_module): ModuleList(
            (0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
          )
          (value_module): ModuleList(
            (0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
          )
          (softmax): Softmax(dim=-1)
          (embed_module): ModuleList(
            (0-11): 12 x Embedding(19, 64)
          )
        )
        (add_norm1): AddNorm(
          (layer_norm): LayerNorm((3, 18, 768), eps=1e-05, elementwise_affine=True)
        )
        (feed_forward): FeedForward(
          (feed_foward_module): ModuleList(
           