In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken

# Transformer components

In [76]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim: int, n_heads: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        assert (
            self.head_dim * n_heads == embed_dim
        ), "embed_dim needs to be divisible by n_heads"

        self.queries = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.keys = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.values = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.fc_out = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(
        self,
        pre_queries: torch.Tensor,
        pre_keys: torch.Tensor,
        pre_values: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:  # shapes of pre_query, pre_key, pre_value are [N, seq_len, embed_dim]. They have the 'pre' prefix because they're the inputs (word vectors) to get the queries, keys, and values
        N = pre_queries.shape[0]  # batch size
        queries_seq_len, keys_seq_len, values_seq_len = pre_queries.shape[1], pre_keys.shape[1], pre_values.shape[1]  # sequence (sentence) lengths which may be different in encoder/decoder

        queries = self.queries(pre_queries)
        keys = self.keys(pre_keys)
        values = self.values(pre_values)

        queries = queries.reshape(N, queries_seq_len, self.n_heads, self.head_dim)
        keys = keys.reshape(N, keys_seq_len, self.n_heads, self.head_dim)
        values = values.reshape(N, values_seq_len, self.n_heads, self.head_dim)

        attention_grid = torch.einsum("nqhd, nkhd -> nhqk", queries, keys)  # shape is [N, self.n_heads, queries_seq_len, keys_seq_len], we sum over self.head_dim

        if mask is not None:
            attention_grid = attention_grid.masked_fill(mask == 0, float("-1e20"))

        attention_weights = torch.softmax(attention_grid / self.head_dim ** (1 / 2), dim=3)  # dim=3 because they need to sum up to one along the key dimension (cuz query asks keys intuitively)

        multi_head = torch.einsum("nhql, nlhd -> nqhd", attention_weights, values)  # this computes all the heads which now need to be concatenated. Shape is [N, query_seq_len, self.n_heads, self.head_dim]
        multi_head_attention = multi_head.reshape(N, queries_seq_len, self.embed_dim)  # concatenation, notice that it's the attention vectors for each corresponding query
        multi_head_attention = self.fc_out(multi_head_attention)  # final weight matrix multiplication

        return multi_head_attention  # shape is [N, queries_seq_len, embed_dim], i.e. the attention vector for each corresponding query vector


class TransformerBlock(nn.Module):
    def __init__(
        self, embed_dim: int, n_heads: int, dropout: float, forward_expansion: int
    ):
        super().__init__()
        self.attention = SelfAttention(embed_dim, n_heads)
        self.norm = nn.LayerNorm(embed_dim)
        self.linear_1 = nn.Linear(embed_dim, forward_expansion * embed_dim)
        self.linear_2 = nn.Linear(forward_expansion * embed_dim, embed_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self,
        pre_queries: torch.Tensor,
        pre_keys: torch.Tensor,
        pre_values: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        attention = self.attention(pre_queries, pre_keys, pre_values, mask)
        x = self.dropout(self.norm(attention + pre_queries))  # initially, the pre_queries, pre_keys, pre_values are either the same or different if in the decoder, they're just the word vectors
        x_fc = self.linear_2(F.gelu(self.linear_1(x)))
        out = self.dropout(self.norm(x + x_fc))
        return out  # shape is [N, queries_seq_len, embed_dim]


class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        embed_dim: int,
        n_layers: int,
        n_heads: int,
        device: torch.device,
        forward_expansion: int,
        dropout: float,
        max_seq_length: int,
    ):
        super().__init__()
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_seq_length, embed_dim)
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_dim, n_heads, dropout, forward_expansion)
                for _ in range(n_layers)
            ]
        )
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        N, seq_length = x.shape  # seq_length == max_seq_length?
        positions = torch.arange(0, seq_length).expand(N, -1).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.positional_embedding(positions))  # these are the word vectors
        for layer in self.layers:
            out = layer(out, out, out, mask)  # calling the forward method of the TransformerBlock
        return out  # shape is [N, x_seq_len, embed_dim]


class DecoderBlock(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        n_heads: int,
        forward_expansion: int,
        dropout: float,
        device: torch.device,
    ):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.attention = SelfAttention(embed_dim, n_heads)
        self.transformer_block = TransformerBlock(embed_dim, n_heads, dropout, forward_expansion)
        self.device = device
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self,
        x: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        src_mask: torch.Tensor,
        trg_mask: torch.Tensor,
    ) -> torch.Tensor:  # src_mask is optional, it's so that we don't do computations on padded values. trg_mask is not optional, you MUST have it
        attention = self.attention(x, x, x, trg_mask)
        queries = self.dropout(self.norm(x + attention))  # the queries in the decoder which get passed into the second multiheaded attention
        out = self.transformer_block(queries, keys, values, src_mask)
        return out  # shape is [N, x_seq_len, embed_dim]


class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size: int,
        embed_dim: int,
        n_layers: int,
        n_heads: int,
        device: torch.device,
        forward_expansion: int,
        dropout: float,
        max_seq_length: int,
    ):
        super().__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_seq_length, embed_dim)
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_dim, n_heads, forward_expansion, dropout, device)
                for _ in range(n_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_dim, trg_vocab_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self,
        x: torch.Tensor,
        encoder_out: torch.Tensor,
        src_mask: torch.Tensor,
        trg_mask: torch.Tensor,
    ) -> torch.Tensor:
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, -1).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.positional_embedding(positions))
        for layer in self.layers:
            x = layer(x, encoder_out, encoder_out, src_mask, trg_mask)  # encoder output is the attention vector which went through the feedforward nets & normalization + skip connection
        out = self.fc_out(x)  # i think the shape is [N, x_seq_len, trg_vocab_size]. How, then, do i use this output to make predictions ? 
        return out

# Transformer itself (putting it all together)

In [None]:
class ChatGPT(nn.Module): 
    def __init__(self): 
        super().__init__()   
        
    def forward(self):
        pass 


class Translator(nn.Module):   
    def __init__(
        self, 
        src_vocab_size: int, 
        trg_vocab_size: int, 
        src_pad_idx: int, 
        trg_pad_idx: int, 
        embed_dim: int, 
        n_layers: int, 
        n_heads: int, 
        device: torch.device, 
        forward_expansion: int, 
        dropout: float, 
        max_seq_length: int
    ):
        super().__init__() 
        
        self.encoder = Encoder(
            src_vocab_size, 
            embed_dim, 
            n_layers, 
            n_heads, 
            device, 
            forward_expansion, 
            dropout, 
            max_seq_length
        )
        
        self.decoder = Decoder(
            trg_vocab_size, 
            embed_dim, 
            n_layers, 
            n_heads, 
            device, 
            forward_expansion, 
            dropout, 
            max_seq_length
        )
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device 
        
    def make_src_mask(self, src) -> torch.Tensor:  # src is of shape [N, src sentence length]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)  # shape is [N, 1, 1, src sentece length], makes this broadcastable 
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg) -> torch.Tensor:
        N, trg_length = trg.shape
        trg_mask = torch.tril(torch.ones(trg_length, trg_length)).expand(N, 1, trg_length, trg_length) 
        return trg_mask.to(self.device)
              
    def forward(
        self, 
        src: torch.Tensor,  
        trg: torch.Tensor,
    ) -> torch.Tensor:
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_src_mask(trg)
        encoder_src = self.encoder(src, src_mask)
        out = self.decoder(trg, encoder_src, src_mask, trg_mask)
        return out 

# Training 

In [1]:
# to be done 
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
text = "Replace me by any text you'd like."
encoded_bert = tokenizer('a person walked into a bar на русском бар пиов ', return_tensors='pt')
encoded2_bert = tokenizer('anotehr sentence', return_tensors='pt')
print('encoded_bert', encoded_bert)
print('encoded_bert input_ids', encoded_bert['input_ids'], encoded_bert['input_ids'].shape, len(encoded_bert['input_ids']))
encoded_bert_ids = encoded_bert['input_ids'][0, :]
encoded2_bert_ids = encoded2_bert['input_ids'][0, :]
padded_seq = tokenizer(['a person walked into a bar на русском бар пиов ', 'anotehr sentence'], padding=True, return_tensors='pt')
print('padded', padded_seq)
print(encoded_bert_ids)
print(encoded2_bert_ids)
decoded_bert = tokenizer.decode(encoded_bert_ids)
decoded2_bert = tokenizer.decode(encoded2_bert_ids)
print(decoded_bert)
print(decoded2_bert)
print(tokenizer.decode([0]))

  from .autonotebook import tqdm as notebook_tqdm


encoded_bert {'input_ids': tensor([[  101,   169, 19867, 11401, 10157, 33734, 10336, 10708,   169, 18121,
         10122, 64644, 14748,   556, 21607, 10541,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
encoded_bert input_ids tensor([[  101,   169, 19867, 11401, 10157, 33734, 10336, 10708,   169, 18121,
         10122, 64644, 14748,   556, 21607, 10541,   102]]) torch.Size([1, 17]) 1
padded {'input_ids': tensor([[  101,   169, 19867, 11401, 10157, 33734, 10336, 10708,   169, 18121,
         10122, 64644, 14748,   556, 21607, 10541,   102],
        [  101, 12797, 10216, 16757, 49219,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1