In [1]:
from datasets import load_dataset
import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
dataset = load_dataset("wmt14", "fr-en", split="train[:1000]")

In [3]:
print(dataset)

Dataset({
    features: ['translation'],
    num_rows: 1000
})


In [6]:
tokenizer = tiktoken.get_encoding("gpt2")

In [7]:
eos_token_id = 50256
pad_token_id = 50257

In [8]:
def tokenize_function(examples):
    french_sentences = [item['fr'] for item in examples['translation']]
    english_sentences = [item['en'] for item in examples['translation']]
    
    # Tokenize with tiktoken without adding <EOS> token
    french_tokens = [torch.tensor(tokenizer.encode(sentence)) for sentence in french_sentences]
    english_tokens = [torch.tensor(tokenizer.encode(sentence)) for sentence in english_sentences]
    
    return french_tokens, english_tokens

In [9]:
french_sentences, english_sentences = tokenize_function(dataset)

In [10]:
encoder_max_context_length = 30
decoder_max_context_length = 20 # english tokenization more efficient

In [11]:
def pad_or_truncate(sequences, max_length, pad_token_id):
    padded_sequences = []
    for seq in sequences:
        seq_length = seq.shape[0]
        if seq_length > max_length:
            seq = seq[:max_length]
        else:
            padding_length = max_length - seq_length
            padding = torch.full((padding_length,), pad_token_id, dtype=torch.long)
            seq = torch.cat([seq, padding], dim=0)
        padded_sequences.append(seq)
    return torch.stack(padded_sequences)

encoder_inputs = pad_or_truncate(
    french_sentences, encoder_max_context_length, pad_token_id
)

In [12]:

decoder_inputs = []
decoder_labels = []

for seq in english_sentences:
    seq_length = seq.shape[0]

    seq = seq[:decoder_max_context_length - 1]

    # Prepare decoder inputs by prepending <EOS>
    input_seq = torch.cat([torch.tensor([eos_token_id]), seq], dim=0)

    # Prepare decoder labels by appending <EOS>
    label_seq = torch.cat([seq, torch.tensor([eos_token_id])], dim=0)

    # Pad decoder inputs to decoder_max_context_length if needed
    if input_seq.shape[0] < decoder_max_context_length:
        padding_length = decoder_max_context_length - input_seq.shape[0]
        padding = torch.full((padding_length,), pad_token_id, dtype=torch.long)
        input_seq = torch.cat([input_seq, padding], dim=0)

    # Pad decoder labels to decoder_max_context_length if needed
    if label_seq.shape[0] < decoder_max_context_length:
        padding_length = decoder_max_context_length - label_seq.shape[0]
        padding = torch.full((padding_length,), pad_token_id, dtype=torch.long)
        label_seq = torch.cat([label_seq, padding], dim=0)

    decoder_inputs.append(input_seq)
    decoder_labels.append(label_seq)

# Convert lists to tensors
decoder_inputs = torch.stack(decoder_inputs)
decoder_labels = torch.stack(decoder_labels)

In [95]:
vocab_size = 50258
encoder_block_size = 30
decoder_block_size = 20
d_model = 32
n_head = 4
n_encoder_layers = 3
n_decoder_layers = 3
dropout = 0.2
batch_size = 32


In [96]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

In [97]:
class FeedForward(nn.Module):
    """linear layer followed by non-linearity"""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),  # [B, T, n_embd]
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model),  # transition matrix to prepare for going back into residual pathway via addition
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [98]:
class EncoderMultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        B, T, C = x.shape
        qkv = self.attn(x)  # [B, T, 3*C]
        q, k, v = qkv.split(d_model, 2) # all [B, T, C]
        q = q.view(B, T, n_head, C // n_head).transpose(1, 2)  # [B, nh, T, hs]
        k = k.view(B, T, n_head, C // n_head).transpose(1, 2)  # [B, nh, T, hs]
        v = v.view(B, T, n_head, C // n_head).transpose(1, 2)  # [B, nh, T, hs]
        full_att = q @ k.transpose(-2, -1) / (k.shape[-1]) ** 0.5  # [B, nh, T, T]
        mask = mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, T]
        full_att = full_att.masked_fill(mask == 0, float('-inf'))
        attention_scores = F.softmax(full_att, dim=-1)
        context_vectors = attention_scores @ v  # [B, nh, T, hs]
        context_vectors = context_vectors.transpose(1, 2).contiguous().view(B, T, C) # concat heads -- [B, T, C]
        out = self.dropout(self.proj(context_vectors))
        return out

In [99]:
class Encoder_Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = EncoderMultiHeadAttention()
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        x = x + self.sa(self.ln1(x), mask) # departure from Attention is All You Need -- we apply LN before transformation
        x = x + self.ffwd(self.ln2(x))
        return x 

In [100]:
class Encoder(nn.Module):
    def __init__(self, shared_embedding):
        super().__init__()
        self.token_embedding_table = shared_embedding
        self.position_embedding_table = nn.Embedding(encoder_block_size, d_model)
        self.blocks = nn.ModuleList(
            [Encoder_Block() for _ in range(n_encoder_layers)]
        )
        self.ln_f = nn.LayerNorm(d_model)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) 
        pos_emb = self.position_embedding_table(torch.arange(T, device = device)) 
        x = tok_emb + pos_emb  
        mask = (idx != pad_token_id).to(idx.device) # [B, T]
        for block in self.blocks:
            x = block(x, mask)
        x = self.ln_f(x)
        return x

In [101]:
class DecoderMultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.register_buffer('tril', torch.tril(torch.ones(decoder_block_size, decoder_block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.attn(x)  # [B, T, 3*C]
        q, k, v = qkv.split(d_model, 2)  # all [B, T, C]
        q = q.view(B, T, n_head, C // n_head).transpose(1, 2)  # [B, nh, T, hs]
        k = k.view(B, T, n_head, C // n_head).transpose(1, 2)  # [B, nh, T, hs]
        v = v.view(B, T, n_head, C // n_head).transpose(1, 2)  # [B, nh, T, hs]
        full_att = q @ k.transpose(-2, -1) / (k.shape[-1]) ** 0.5  # [B, nh, T, T]
        left_att = full_att.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        attention_scores = F.softmax(left_att, dim = -1)
        context_vectors = attention_scores @ v  # [B, nh, T, hs]
        context_vectors = (context_vectors.transpose(1, 2).contiguous().view(B, T, C))  # concat heads -- [B, T, C]
        out = self.dropout(self.proj(context_vectors))
        return out

In [102]:
class DecoderEncoderMultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.kv = nn.Linear(d_model, 2 * d_model)
        self.q = nn.Linear(d_model, d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_outputs, encoder_outputs_mask):
        B, T_decoder, C = x.shape
        B, T_encoder, C = encoder_outputs.shape
        kv = self.kv(encoder_outputs) # [B, T_encoder, 2 * d_model]
        k, v = kv.split(d_model, 2) # each [B, T_encoder, d_model]
        q = self.q(x) # [B, T_decoder, d_model]
        q = q.view(B, T_decoder, n_head, C // n_head).transpose(1, 2)  # [B, nh, T_decoder, hs]
        k = k.view(B, T_encoder, n_head, C // n_head).transpose(1, 2)  # [B, nh, T_encoder, hs]
        v = v.view(B, T_encoder, n_head, C // n_head).transpose(1, 2)  # [B, nh, T_encoder, hs]
        full_att = q @ k.transpose(-2, -1) / (k.shape[-1]) ** 0.5  # [B, nh, T_decoder, T_encoder]
        mask = encoder_outputs_mask.unsqueeze(1).unsqueeze(2)
        full_att = full_att.masked_fill(mask == 0, float('-inf'))
        attention_scores = F.softmax(full_att, dim=-1)
        context_vectors = attention_scores @ v  # [B, nh, T_decoder, hs]
        context_vectors = (context_vectors.transpose(1, 2).contiguous().view(B, T_decoder, C))  # concat heads -- [B, T_decoder, d_model]
        out = self.dropout(self.proj(context_vectors))
        return out

In [103]:
class Decoder_Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa1 = DecoderMultiHeadAttention()
        self.sa2 = DecoderEncoderMultiHeadAttention()
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)

    def forward(self, x, encoder_outputs, encoder_outputs_mask):
        x = x + self.sa1(self.ln1(x))  # departure from Attention is All You Need -- we apply LN before transformation
        x = x + self.ffwd(self.ln2(x))
        x = x + self.sa2(self.ln3(x), encoder_outputs, encoder_outputs_mask)
        return x

In [104]:
class Decoder(nn.Module):
    def __init__(self, shared_embedding):
        super().__init__()
        self.token_embedding_table = shared_embedding
        self.position_embedding_table = nn.Embedding(decoder_block_size, d_model)
        self.blocks = nn.ModuleList([Decoder_Block() for _ in range(n_decoder_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size) # need to tie to the input embedding table (transpose of it)
        self.lm_head.weight = self.token_embedding_table.weight

    def forward(self, idx, encoder_outputs, encoder_outputs_mask, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)  # (batch_size, block_size, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device))  # (block_size, n_embd)
        x = tok_emb + pos_emb  #  (batch_size, block_size, n_embd)
        for block in self.blocks:
            x = block(x, encoder_outputs, encoder_outputs_mask)
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (batch_size, block_size, vocab_size)
        if targets is None:
            loss = None
        else:
            # idx and targets are both (B,T) tensor of integers
            B, T, _ = logits.shape
            logits = logits.view(B * T, -1)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets, ignore_index = 50257)
        return logits, loss

In [105]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_embedding = nn.Embedding(vocab_size, d_model) 
        with torch.no_grad(): 
            self.shared_embedding.weight /= (d_model ** 0.5)
        self.encoder = Encoder(self.shared_embedding)
        self.decoder = Decoder(self.shared_embedding)

    def forward(self, encoder_inputs, decoder_inputs, targets=None):
        encoder_outputs_mask = (encoder_inputs != pad_token_id).to(device)
        encoder_outputs = self.encoder(encoder_inputs)
        logits, loss = self.decoder(decoder_inputs, encoder_outputs, encoder_outputs_mask, targets)
        return logits, loss

In [106]:
def get_batch():
    ix = torch.randint(encoder_inputs.shape[0], (batch_size,))
    enc_inputs = encoder_inputs[ix] # [batch_size, encoder_context_length]
    dec_inputs = decoder_inputs[ix] # [B, decoder_context_length]
    dec_labels = decoder_labels[ix]  # [B, decoder_context_length]
    return enc_inputs, dec_inputs, dec_labels

In [107]:
encoder_inputs = encoder_inputs.to(device)
decoder_inputs = decoder_inputs.to(device)
decoder_labels = decoder_labels.to(device)

transformer = Transformer().to(device)

In [None]:
optimizer = torch.optim.AdamW(transformer.parameters(), lr=6e-4)
max_iters = 5000
eval_iter = 500

for iter in max_iters:
    enc_inputs, dec_inputs, dec_labels = get_batch()
    logits, loss = transformer(enc_inputs, dec_inputs, dec_labels)
    if iter % eval_iter == 0:
        print(f"Loss: {loss.item()}")
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


Loss: 11.368083000183105


In [None]:
transformer.eval()

In [93]:
class Hypothesis:
    def __init__(self, tokens, log_prob, length):
        self.tokens = tokens
        self.log_prob = log_prob
        self.length = length

    def get_score(self):
        return self.log_prob / self.length 

In [None]:
sentence = "Bonjour tout le monde!"
encoder_inputs = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0).to(device)
encoder_outputs_mask = (encoder_inputs != pad_token_id).to(device)
encoder_outputs = transformer.encoder(encoder_inputs)

k = 2
max_completed_hypotheses = 1

initial_hypothesis = Hypothesis([eos_token_id], 0.0, 1)
beam = [initial_hypothesis]
completed_hypotheses = []

while len(completed_hypotheses) < max_completed_hypotheses:
    new_beam = []
    for hypothesis in beam:
        decoder_input = torch.tensor([hypothesis.tokens], dtype=torch.long, device=device)
        logits, _ = transformer.decoder(decoder_input, encoder_outputs, encoder_outputs_mask, targets=None)
        next_token_logits = logits[:, -1, :].view(-1)
        next_token_probs = F.softmax(next_token_logits, dim = -1)
        _, top_k_indices = torch.topk(next_token_logits, k) # [k]
        next_tokens = [index.item() for index in top_k_indices]
        next_token_log_probs = [torch.log(next_token_probs[token]).item() for token in next_tokens]
        new_hypotheses = [Hypothesis(hypothesis.tokens + [token], hypothesis.log_prob + log_prob, hypothesis.length + 1) 
                          for token, log_prob in zip(next_tokens, next_token_log_probs)]
        for new_hypothesis in new_hypotheses:
            if new_hypothesis.tokens[-1] == 50256 or len(new_hypothesis.tokens) > decoder_max_context_length:
                completed_hypotheses.append(new_hypothesis)
        new_beam += new_hypotheses
    beam = new_beam

# can just use logits
# normalize by scores
# keep top k overall, not just for each current hypothesis -- explore more candidates each turn?    
# infinite loop if beam becomes empty
# use lambda functions