In [1]:
from transformers import AutoTokenizer
from datasets import load_dataset
import torch

  from .autonotebook import tqdm as notebook_tqdm




In [196]:
# tokenizer_src = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-de-en')  # German tokenizer
# tokenizer_trg = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')  # English tokenizer
tokenizer_src = AutoTokenizer.from_pretrained('dbmdz/german-gpt2')  # German tokenizer
tokenizer_trg = AutoTokenizer.from_pretrained('openai-community/gpt2')  # English tokenizer



In [85]:
def preprocess_function(examples, tokenizer_src, src_language, trg_language, max_length):
    inputs = [example[src_language] for example in examples["translation"]]
    targets = [example[trg_language] for example in examples["translation"]]
    model_inputs = tokenizer_src(inputs, text_target=targets, max_length=max_length, truncation=True, padding="max_length")
    labels = tokenizer_trg(targets, max_length=max_length, truncation=True, padding="max_length")["input_ids"]
    model_inputs["labels"] = labels
    return model_inputs

def prepare_data(tokenizer_src, tokenizer_trg, batch_size=4, num_workers=2, test_fraction=0.2, max_length=512):
    # Load dataset; ignore validation set (tst2013) and use test set only (tst2014)
    src_language, trg_language = 'de', 'en'
    dataset = load_dataset("ted_talks_iwslt", language_pair=(src_language, trg_language), year="2014")
    dataset = dataset.train_test_split(test_size=test_fraction, shuffle=True)
    trainset, testset = dataset['train'], dataset['test']
    # Preprocess datasets
    tokenized_trainset = trainset.map(lambda examples: preprocess_function(examples, tokenizer_src, src_language, tokenizer_trg, trg_language, max_length), batched=True)
    tokenized_testset = testset.map(lambda examples: preprocess_function(examples, tokenizer_src, src_language, tokenizer_trg, trg_language, max_length), batched=True)
    # Create dataloaders
    trainloader = torch.utils.data.DataLoader(tokenized_trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    testloader = torch.utils.data.DataLoader(tokenized_testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return trainloader, testloader

In [6]:
dataset = load_dataset("ted_talks_iwslt", language_pair=("de", "en"), year="2014")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Downloading data: 100%|██████████| 1.67G/1.67G [04:25<00:00, 6.27MB/s] 
Generating train split: 2972 examples [00:06, 461.77 examples/s]


In [11]:
dataset = dataset["train"].train_test_split(test_size=0.1, shuffle=True)

In [20]:
testset = dataset["test"]

In [86]:
tokenized_testset = testset.map(lambda examples: preprocess_function(examples, tokenizer_src, "de", "en", 128), batched=True)

Map: 100%|██████████| 298/298 [00:00<00:00, 2442.14 examples/s]


In [92]:
testloader = torch.utils.data.DataLoader(tokenized_testset, batch_size=2, shuffle=False, num_workers=1)

In [93]:
batch_size = 2
ix = torch.randint(len(testloader), (batch_size,))
x = torch.stack([data.dataset[i][0] for i in ix])
y = torch.tensor([data.dataset[i][1] for i in ix])

<torch.utils.data.dataloader.DataLoader at 0x184fbe3d0>

In [229]:
import math
import torch
from torch import nn
from torch.nn import functional as F


# https://github.com/tintn/vision-transformer-from-scratch/blob/main/vit.py
class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415

    Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
    """

    def forward(self, input):
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))


class FNSAttentionHead(nn.Module):
    """
    A single attention head.
    This module is used in the FNSMultiHeadAttention module.

    """
    def __init__(self, beta, bandwidth, sphere_radius, hidden_size, attention_head_size, dropout, bias=True, is_cross_attention=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size
        # Create the query, key, and value projection layers
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)

        self.dropout = nn.Dropout(dropout)
    
        self.beta, self.bandwidth = beta, bandwidth
        self.sphere_radius = sphere_radius
        
        self.is_cross_attention = is_cross_attention

    def forward(self, x, encoder_output_states=None):
        if encoder_output_states is not None:
            assert self.is_cross_attention, "Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
            query = F.normalize(self.query(x), p=2, dim=-1)
            key = F.normalize(self.key(encoder_output_states), p=2, dim=-1)
            value = F.normalize(self.value(encoder_output_states), p=2, dim=-1)   
        else:
            query = F.normalize(self.query(x), p=2, dim=-1)
            key = F.normalize(self.key(x), p=2, dim=-1)
            value = F.normalize(self.value(x), p=2, dim=-1)                
        # print(f'query shape: {query.shape}')
        # print(f'key shape: {key.shape}')
        # print(f'value shape: {value.shape}')

        beta, bandwidth = self.beta, self.bandwidth
        sphere_radius = self.sphere_radius
        d_intrinsic = self.attention_head_size

        # geodesic distance on sphere
        eps = 1e-7  # for limiting the divergence from acos
        g_dist = torch.acos(torch.clamp(query @ key.transpose(-2, -1), -1+eps, 1-eps)) * sphere_radius
        
        # Calculate the attention scores
        if beta < 2:
            attn_score = (1 + g_dist/bandwidth**0.5)**(-d_intrinsic-beta)
        else:
            attn_score = torch.exp((-g_dist/bandwidth**0.5)**(beta/(beta-1)))
        attn_score_shape = attn_score.shape
        D_inv = torch.diag_embed(attn_score.sum(-1)**(-1))  # inverse of degree matrix of attn_score
        K_tilde = D_inv @ attn_score @ D_inv
        attention_probs = F.normalize(K_tilde,p=1,dim=3)  # can do this as the attn weights are always positive
        attention_probs = self.attn_dropout(attention_probs)

        # Calculate the attention output
        attention_output = attention_probs @ value

        return (attention_output, attention_probs)


class FNSMultiHeadAttention(nn.Module):
    """
    Multi-head attention module.
    This module is used in the TransformerEncoder module.
    """

    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.num_attention_heads = config["num_attention_heads"]
        # The attention head size is the hidden size divided by the number of attention heads
        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # Whether or not to use bias in the query, key, and value projection layers
        self.qkv_bias = config["qkv_bias"]
        # Create a list of attention heads
        self.heads = nn.ModuleList([])
        # Whether it is cross attention
        self.is_cross_attention = is_cross_attention

        self.beta = config['beta']
        self.bandwidth = config['bandwidth']
        self.sphere_radius = config['sphere_radius']     

        for _ in range(self.num_attention_heads):
            head = FNSAttentionHead(
                self.beta,
                self.bandwidth,
                self.sphere_radius,
                self.hidden_size,
                self.attention_head_size,
                config["attention_probs_dropout_prob"],
                self.qkv_bias,
                self.is_cross_attention
            )
            self.heads.append(head)
        # Create a linear layer to project the attention output back to the hidden size
        # In most cases, all_head_size and hidden_size are the same
        self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)
        self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x, output_attentions=False, encoder_output_states=None):
        # Calculate the attention output for each attention head
        attention_outputs = [head(x, encoder_output_states) for head in self.heads]
        # Concatenate the attention outputs from each attention head
        attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
        # Project the concatenated attention output back to the hidden size
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        # Return the attention output and the attention probabilities (optional)
        if not output_attentions:
            return (attention_output, None)
        else:
            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
            return (attention_output, attention_probs)


class FasterFNSMultiHeadAttention(nn.Module):
    """
    Multi-head attention module with some optimizations.
    All the heads are processed simultaneously with merged query, key, and value projections.
    """

    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.is_cross_attention = is_cross_attention
        
        self.hidden_size = config["hidden_size"]
        self.num_attention_heads = config["num_attention_heads"]
        # The attention head size is the hidden size divided by the number of attention heads
        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # Whether or not to use bias in the query, key, and value projection layers
        self.qkv_bias = config["qkv_bias"]
        # Create a linear layer to project the query, key, and value
        if self.is_cross_attention:
            self.kv_projection = nn.Linear(self.hidden_size, self.all_head_size * 2, bias=self.qkv_bias)
            self.q_projection = nn.Linear(self.hidden_size, self.all_head_size, bias=self.qkv_bias)
        else:
            self.qkv_projection = nn.Linear(self.hidden_size, self.all_head_size * 3, bias=self.qkv_bias)
        self.attn_dropout = nn.Dropout(config["attention_probs_dropout_prob"])
        # Create a linear layer to project the attention output back to the hidden size
        # In most cases, all_head_size and hidden_size are the same
        self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)
        self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

        self.beta = config['beta']
        self.bandwidth = config['bandwidth']
        self.sphere_radius = config['sphere_radius']

    def forward(self, x, attention_mask=None, output_attentions=False, encoder_hidden_states=None):
        # Project the query, key, and value
        if encoder_hidden_states is not None:
            assert hasattr(
                self, "q_projection"
            ), "If class is used as cross attention, the weights `q_projection` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
            query = self.q_projection(x)
            kv = self.kv_projection(encoder_hidden_states)
            key, value = torch.chunk(kv, 2, dim=-1)
        else:
            # (batch_size, sequence_length, hidden_size) -> (batch_size, sequence_length, all_head_size * 3)
            qkv = self.qkv_projection(x)
            # Split the projected query, key, and value into query, key, and value
            # (batch_size, sequence_length, all_head_size * 3) -> (batch_size, sequence_length, all_head_size)
            query, key, value = torch.chunk(qkv, 3, dim=-1)
        # Resize the query, key, and value to (batch_size, num_attention_heads, sequence_length, attention_head_size)
        batch_size, sequence_length, _ = query.size()
        num_attention_heads, attention_head_size = self.num_attention_heads, self.attention_head_size

        beta, bandwidth = self.beta, self.bandwidth
        sphere_radius = self.sphere_radius
        d_intrinsic = attention_head_size

        query = F.normalize(query.view(batch_size, sequence_length, num_attention_heads, attention_head_size).transpose(1, 2), p=2, dim=-1)
        key = F.normalize(key.view(batch_size, sequence_length, num_attention_heads, attention_head_size).transpose(1, 2), p=2, dim=-1)
        value = value.view(batch_size, sequence_length, num_attention_heads, attention_head_size).transpose(1, 2)
        # print(f'query shape: {query.shape}')
        # print(f'key shape: {key.shape}')
        # print(f'value shape: {value.shape}')        

        # geodesic distance on sphere
        eps = 1e-7  # for limiting the divergence from acos
        g_dist = torch.acos(torch.clamp(query @ key.transpose(-2, -1), -1+eps, 1-eps)) * sphere_radius
        
        # Calculate the attention scores
        if beta < 2:
            attn_score = (1 + g_dist/bandwidth**0.5)**(-d_intrinsic-beta)
        else:
            attn_score = torch.exp((-g_dist/bandwidth**0.5)**(beta/(beta-1)))
        D_inv = torch.diag_embed(attn_score.sum(-1)**(-1))  # inverse of degree matrix of attn_score
        K_tilde = D_inv @ attn_score @ D_inv
        K_tilde = K_tilde.masked_fill(attention_mask.expand(-1,self.num_attention_heads,-1,-1)==0, -1e9) # Mask
        attention_probs = F.normalize(K_tilde,p=1,dim=3)  # can do this as the attn weights are always positive
        attention_probs = self.attn_dropout(attention_probs)

        # Calculate the attention output
        attention_output = attention_probs @ value
        # Resize the attention output
        # from (batch_size, num_attention_heads, sequence_length, attention_head_size)
        # To (batch_size, sequence_length, all_head_size)
        attention_output = attention_output.transpose(1, 2) \
                                           .contiguous() \
                                           .view(batch_size, sequence_length, self.all_head_size)
        # Project the attention output back to the hidden size
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        # Return the attention output and the attention probabilities (optional)
        if not output_attentions:
            return (attention_output, None)
        else:
            return (attention_output, attention_probs)


class MLP(nn.Module):
    """
    A multi-layer perceptron module.
    """

    def __init__(self, config):
        super().__init__()
        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])
        self.activation = NewGELUActivation()
        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x


class FNSEncoderBlock(nn.Module):
    """
    A single transformer block.
    """

    def __init__(self, config):
        super().__init__()
        self.use_faster_attention = config.get("use_faster_attention", False)
        if self.use_faster_attention:
            self.attention = FasterFNSMultiHeadAttention(config)
        else:
            self.attention = FNSMultiHeadAttention(config)
        self.layernorm_1 = nn.LayerNorm(config["hidden_size"])
        self.mlp = MLP(config)
        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])

    def forward(self, x, attention_mask=None, output_attentions=False):
        # Self-attention
        attention_output, attention_probs = \
            self.attention(x, attention_mask=attention_mask, output_attentions=output_attentions)
        # Skip connection
        x = self.layernorm_1(x + attention_output)
        # Feed-forward network
        mlp_output = self.mlp(x)
        # Skip connection
        x = self.layernorm_2(x + mlp_output)
        # Return the transformer block's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, attention_probs)
        

class FNSDecoderBlock(nn.Module):
    """
    A single transformer block.
    """

    def __init__(self, config):
        super().__init__()
        self.use_faster_attention = config.get("use_faster_attention", False)
        if self.use_faster_attention:
            self.self_attention = FasterFNSMultiHeadAttention(config)
            self.cross_attention = FasterFNSMultiHeadAttention(config, is_cross_attention=True)
        else:
            self.self_attention = FNSMultiHeadAttention(config)
            self.cross_attention = FNSMultiHeadAttention(config, is_cross_attention=True)
        self.layernorm_1 = nn.LayerNorm(config["hidden_size"])
        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])
        self.mlp = MLP(config)
        self.layernorm_3 = nn.LayerNorm(config["hidden_size"])

    def forward(self, x, encoder_output_states, src_mask=None, trg_mask=None, output_attentions=False):
        # Self-attention
        attention_output, self_attention_probs = \
            self.self_attention(x, attention_mask=src_mask, output_attentions=output_attentions)
        # Skip connection
        x = self.layernorm_1(x + attention_output)
        # Cross-attention
        attention_output, cross_attention_probs = \
            self.cross_attention(x, attention_mask=trg_mask, output_attentions=output_attentions, encoder_output_states=encoder_output_states)
        # Skip connection
        x = self.layernorm_2(x + attention_output)
        # Feed-forward network
        mlp_output = self.mlp(x)
        # Skip connection
        x = self.layernorm_3(x + mlp_output)
        # Return the transformer block's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, self_attention_probs, cross_attention_probs)


class FNSEncoder(nn.Module):
    """
    The transformer encoder module.
    """

    def __init__(self, config):
        super().__init__()
        self.padding_idx = config["src_pad_token_id"]
        # Embeddings
        self.token_embedding = nn.Embedding(
            num_embeddings=config["src_vocab_size"],
            embedding_dim=config["hidden_size"],
            padding_idx=config["src_pad_token_id"],
        )
        self.positional_embedding = nn.Embedding(
            num_embeddings=config["max_length"],
            embedding_dim=config["hidden_size"],
        )
        self.dropout = nn.Dropout(p=config["encoder_dropout_prob"])
        # Create a list of transformer blocks
        self.blocks = nn.ModuleList([])
        for _ in range(config["num_encoder_layers"]):
            block = FNSEncoderBlock(config)
            self.blocks.append(block)

    def forward(self, x, attention_mask=None, output_attentions=False):
        # Create the position ids from the input token ids. Any padded tokens remain padded.
        position_ids = torch.arange(0, x.shape[-1]).to(x.device)
        position_embeddings = self.positional_embedding(position_ids)
        token_embeddings = self.token_embedding(x)
        # Dropout 
        x = self.dropout(position_embeddings + token_embeddings)
        # Calculate the transformer block's output for each block
        all_attentions = []
        for block in self.blocks:
            x, attention_probs = block(x, attention_mask=attention_mask, output_attentions=output_attentions)
            if output_attentions:
                all_attentions.append(attention_probs)
        # Return the encoder's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, all_attentions)    
    
class FNSDecoder(nn.Module):
    """
    The transformer decoder module.
    """

    def __init__(self, config, bias=True):
        super().__init__()
        self.padding_idx = config["trg_pad_token_id"]
        # Embeddings
        self.token_embedding = nn.Embedding(
            num_embeddings=config["trg_vocab_size"],
            embedding_dim=config["hidden_size"],
            padding_idx=config["trg_pad_token_id"],
        )
        self.positional_embedding = nn.Embedding(
            num_embeddings=config["max_length"],
            embedding_dim=config["hidden_size"],
        )
        self.dropout = nn.Dropout(p=config["decoder_dropout_prob"])
        # Create a list of transformer blocks
        self.blocks = nn.ModuleList([])
        for _ in range(config["num_decoder_layers"]):
            block = FNSDecoderBlock(config)
            self.blocks.append(block)
        # Tie output linear weights to input embedding matrix
        self.fc = nn.Linear(config["hidden_size"], config["trg_vocab_size"], bias=bias)
        self.fc.weight = self.token_embedding.weight 
        
    def forward(self, x, embedding_output_states, src_mask=None, trg_mask=None, output_attentions=False):
        # Create the position ids from the input token ids. Any padded tokens remain padded.
        position_ids = torch.arange(0, x.shape[-1]).to(x.device)
        position_embeddings = self.positional_embedding(position_ids)
        token_embeddings = self.token_embedding(x)
        # Dropout 
        x = self.dropout(position_embeddings + token_embeddings)
        # Calculate the transformer block's output for each block
        all_self_attentions = []
        all_cross_attentions = []
        for block in self.blocks:
            x, self_attention_probs, cross_attention_probs = block(x, embedding_output_states, src_mask=src_mask, trg_mask=trg_mask, output_attentions=output_attentions)
            if output_attentions:
                all_self_attentions.append(self_attention_probs)
                all_cross_attentions.append(cross_attention_probs)
        # Linear layer
        x = self.fc(x)
        # Softmax
        x = nn.Softmax(dim=-1)(x)
        # Return logits and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, all_self_attentions, all_cross_attentions)

class FNSForTranslation(nn.Module):
    """
    The seq2seq model for neural machine translation.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        # Create the transformer encoder module
        self.encoder = FNSEncoder(config)
        # Create the transformer decoder module
        self.decoder = FNSDecoder(config)
        # Initialize the weights
        self.apply(self._init_weights)

    def forward(self, x, encoder_mask, decoder_mask, output_attentions=False):
        # Calculate the encoder's output
        encoder_output, encoder_self_attentions = self.encoder(x, attention_mask = encoder_mask, output_attentions=output_attentions)
        # Calculate the decoder's output
        decoder_output, decoder_self_attentions, decoder_cross_attentions = self.decoder(x, encoder_output, src_mask=encoder_mask, trg_mask=decoder_mask, output_attentions=output_attentions)
        # Return the logits and the attention probabilities (optional)
        if not output_attentions:
            return (decoder_output, None, None, None)
        else:
            return (decoder_output, encoder_self_attentions, decoder_self_attentions, decoder_cross_attentions)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config["initializer_range"])
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config["initializer_range"])
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

In [230]:
config = {
    "beta": 1,
    "bandwidth": 1,
    "sphere_radius": 1,
    "hidden_size": 1,
    "num_encoder_layers": 1,
    "num_decoder_layers": 1,
    "num_attention_heads": 1,
    "intermediate_size": 4, # 4 * hidden_size
    "hidden_dropout_prob": 0,
    "encoder_dropout_prob": 0,
    "decoder_dropout_prob": 0,
    "attention_probs_dropout_prob": 0,
    "initializer_range": 0.1,
    "qkv_bias": True,
    "use_faster_attention": True,
    "src_vocab_size": tokenizer_src.vocab_size,
    "src_pad_token_id": tokenizer_src.pad_token_id,
    "trg_vocab_size": tokenizer_trg.vocab_size,
    "trg_pad_token_id": tokenizer_trg.pad_token_id,
    "max_length": 128,
}
model = FNSForTranslation(config)    

In [115]:
model 

FNSForTranslation(
  (encoder): FNSEncoder(
    (token_embedding): Embedding(58101, 1, padding_idx=58100)
    (positional_embedding): Embedding(128, 1)
    (dropout): Dropout(p=0, inplace=False)
    (blocks): ModuleList(
      (0): FNSEncoderBlock(
        (attention): FasterFNSMultiHeadAttention(
          (qkv_projection): Linear(in_features=1, out_features=3, bias=True)
          (attn_dropout): Dropout(p=0, inplace=False)
          (output_projection): Linear(in_features=1, out_features=1, bias=True)
          (output_dropout): Dropout(p=0, inplace=False)
        )
        (layernorm_1): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (dense_1): Linear(in_features=1, out_features=4, bias=True)
          (activation): NewGELUActivation()
          (dense_2): Linear(in_features=4, out_features=1, bias=True)
          (dropout): Dropout(p=0, inplace=False)
        )
        (layernorm_2): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
      )
   

In [116]:
def greedy_decode(model, source, tokenizer_trg, max_len, device):
    bos_idx = tokenizer_trg.bos_token_id
    eos_idx = tokenizer_trg.eos_token_id

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(bos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build causal mask for target 
        decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).type(torch.int)
        decoder_mask = (decoder_mask == 0).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

In [200]:
tokenizer_trg.bos_token_id

50256

In [235]:
x 

tensor([[20011,   589,  6737, 11962,    16,   103,   663,    78,  1052, 34217,
            45,     2,   541, 26079,   518,   745,     0, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100,
         58100, 58100, 58100, 58100, 58100, 58100, 5

In [231]:
model.encoder.token_embedding(x)

IndexError: index out of range in self

In [222]:
model.encoder(x, encoder_mask)

IndexError: index out of range in self

In [215]:
batch_size = 2

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

def get_batch():
    data = testloader
    ix = torch.randint(len(data), (batch_size,))
    x = torch.tensor([testloader.dataset["input_ids"][i] for i in ix]) # (B, N) 
    y = torch.tensor([testloader.dataset["labels"][i] for i in ix]) # (B, N) 
    encoder_mask = torch.tensor([testloader.dataset["attention_mask"][i] for i in ix]) # (B, N) 
    encoder_mask = (encoder_mask.unsqueeze(-1)@encoder_mask.unsqueeze(1)).view(batch_size, 1, config["max_length"], config["max_length"]) # (B,1,N,N)
    decoder_mask = torch.stack([(torch.tensor(testloader.dataset["attention_mask"][i] != 0)).unsqueeze(0).int() & causal_mask(config["max_length"]) for i in ix]) # (B,1,N,N)
    return x, y, encoder_mask, decoder_mask

In [216]:
x, y, encoder_mask, decoder_mask = get_batch()

In [218]:
encoder_mask.shape

torch.Size([2, 1, 128, 128])