# Building a Transformer from Scratch

Reference: ([link text](https://medium.com/towards-data-science/a-complete-guide-to-write-your-own-transformers-29e23f371ddd))

In [1]:
import torch
import torch.nn as nn
import math

## Multi-head attention

In [2]:
# Create the multi-head attention class for the transformer model

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim=256, num_heads=4):
        """
        input_dim: Dimensionality of the input.
        num_heads: The number of attention heads to split the input into.
        """

        super(MultiHeadAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        # check if the hidden dimensions are divisible by the number of heads
        assert hidden_dim % num_heads == 0, "Hidden Dimensions must be divisible by number of heads"
        self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part
        self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part
        self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Output layer

    def check_sdpa_inputs(self, x):
        assert x.size(1) == self.num_heads, f"Expected size of x to be ({-1, self.num_heads, -1, self.hidden_dim // self.num_heads}), got {x.size()}"
        assert x.size(3) == self.hidden_dim // self.num_heads

    def scaled_dot_product_attention(self, query, key, value, attention_mask=None, key_padding_mask=None):
        """
        query: tensor of shape (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads)
        key : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
        value : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
        attention_mask : tensor of shape (query_sequence_length, key_sequence_length)
        key_padding_mask : tensor of shape (sequence_length, key_sequence_length)

        """

        self.check_sdpa_inputs(query)
        self.check_sdpa_inputs(key)
        self.check_sdpa_inputs(value)

        d_k = query.size(-1)
        target_len, source_len = query.size(-2), key.size(-2)

        # logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        # Creating the attention mask here
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                assert attention_mask.size() == (target_len, source_len)
                attention_mask = attention_mask.unsqueeze(0)
                logits = logits + attention_mask

            else:
                raise ValueError(f"Attention mask size {attention_mask.size()}")

        # Creating Key mask here
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # Broadcast over batch size, number of heads
            logits = logits + key_padding_mask

        attention = torch.softmax(logits, dim=-1)
        output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dims)

        return output, attention

    def split_into_heads(self, x, num_heads):
        batch_size, seq_length, hidden_dim = x.size()
        x = x.view(batch_size, seq_length, num_heads, hidden_dim // num_heads)

        return x.transpose(1, 2) # Final dimension will be (batch_size, num_heads, seq_length, hidden_dim // num_heads)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, head_hidden_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, num_heads * head_hidden_dim)

    def forward(self, query, key, value, attention_mask=None, key_padding_mask=None):
        """
        query : tensor of shape (batch_size, query_sequence_length, hidden_dim)
        key : tensor of shape (batch_size, key_sequence_length, hidden_dim)
        value : tensor of shape (batch_size, key_sequence_length, hidden_dim)
        attention_mask : tensor of shape (query_sequence_length, key_sequence_length)
        key_padding_mask : tensor of shape (sequence_length, key_sequence_length)

        """
        query = self.Wq(query)
        key = self.Wk(key)
        value = self.Wv(value)

        query = self.split_into_heads(query, self.num_heads)
        key = self.split_into_heads(key, self.num_heads)
        value = self.split_into_heads(value, self.num_heads)

        attention_values, attention_weights = self.scaled_dot_product_attention(
            query=query,
            key=key,
            value=value,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask
        )

        grouped = self.combine_heads(attention_values)
        output = self.Wo(grouped)

        self.attention_weights = attention_weights

        return output



## Encoder Class

When receiving and treating an input, a transformer has no sense of order as it looks at the sequence as a whole, in opposition to what RNNs do. We therefore need to add a hint of temporal order so that the transformer can learn dependencies.



### Positional Encoding part of the Encoder

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """

        x = x + self.pe[:, :x.size(1), :]
        return x



### Feed Forward part of the Encoder

In [4]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super(PositionWiseFeedForward, self).__init__()

        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))



### Encoder block part of the Encoder

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, n_dim: int, dropout: float, n_heads: int):
        super(EncoderBlock, self).__init__()

        self.MHA = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
        self.norm1 = nn.LayerNorm(n_dim)
        self.ff = PositionWiseFeedForward(n_dim, n_dim)
        self.norm2 = nn.LayerNorm(n_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_padding_mask=None):
        assert x.ndim==3, f"Expected input to be 3-Dimensional; got {x.ndim}"
        attention_output = self.MHA(x, x, x, key_padding_mask=src_padding_mask)

        ff_output = self.ff(x)
        output = x + self.norm2(ff_output)

        return output


### Encoder Class

In [17]:
class Encoder(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            n_dim: int,
            dropout: float,
            n_encoder_blocks: int,
            n_heads: int
        ):

        super(Encoder, self).__init__()
        self.n_dim = n_dim
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_dim)
        self.positional_encoding = PositionalEncoding(d_model=n_dim, dropout=dropout)
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock(n_dim, dropout, n_heads) for _ in range(n_encoder_blocks)
        ])

    def forward(self,x, padding_mask=None):
        x = self.embedding(x) * math.sqrt(self.n_dim)
        x = self.positional_encoding(x)
        for block in self.encoder_blocks:
            x = block(x=x, src_padding_mask=padding_mask)
        return x

## Decoder Class

In [7]:
class DecoderBlock(nn.Module):
    def __init__(self, n_dim: int, dropout: float, n_heads: int):
        super(DecoderBlock, self).__init__()

        # The first Multi-Head Attention has a mask to avoid looking at the future
        self.self_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
        self.norm1 = nn.LayerNorm(n_dim)

        # The second Multi-Head Attention will take inputs from the encoder as key/value inputs
        self.cross_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
        self.norm2 = nn.LayerNorm(n_dim)

        self.ff = PositionWiseFeedForward(n_dim, n_dim)
        self.norm3 = nn.LayerNorm(n_dim)

    def forward(self, target, memory, target_mask=None, target_padding_mask=None, memory_padding_mask=None):
        masked_attention_output = self.self_attention(
            query = target,
            key = target,
            value = target,
            attention_mask = target_mask,
            key_padding_mask = target_padding_mask
        )
        x1 = target + self.norm1(masked_attention_output)

        cross_attention_output = self.cross_attention(
            query = x1,
            key = memory,
            value = memory,
            attention_mask = None,
            key_padding_mask = memory_padding_mask
        )

        ff_output = self.ff(x2)
        output = x2 + self.norm3(ff_output)

        return output


In [8]:
class Decoder(nn.Module):
    def __init__(self, vocab_size: int, n_dim: int, dropout: float, max_seq_len: int, n_decoder_blocks: int, n_heads: int):
        super(Decoder, self).__init__()

        self.embedding = nn.Embedding(num_embeddings = vocab_size, embedding_dim = n_dim)
        self.positional_encoding = PositionalEncoding(
            d_model = n_dim,
            dropout = dropout
        )
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(n_dim, dropout, n_heads) for _ in range(n_decoder_blocks)
        ])

    def forward(self, target, memory, target_mask=None, target_padding_mask=None, memory_padding_mask=None):
        x = self.embedding(target)
        x = self.positional_encoding(x)

        for block in self.decoder_blocks:
            x = block(
                x,
                memory,
                target_mask = target_mask,
                target_padding_mask = target_padding_mask,
                memory_padding_mask= memory_padding_mask
            )
        return x

In [9]:
class SimpleTransformer(nn.Module):
    def __init__(self, **kwargs):
        super(SimpleTransformer, self).__init__()

        for key, value in kwargs.items():
            print(f" * {key}={value}")

        self.vocab_size = kwargs.get('vocab_size')
        self.model_dim = kwargs.get('model_dim')
        self.dropout = kwargs.get('dropout')
        self.n_encoder_layers = kwargs.get('n_encoder_layers')
        self.n_decoder_layers = kwargs.get('n_decoder_layers')
        self.n_heads = kwargs.get('n_heads')
        self.batch_size = kwargs.get('batch_size')
        self.PAD_IDX = kwargs.get('pad_idx', 0)

        self.encoder = Encoder(
            self.vocab_size,
            self.model_dim,
            self.dropout,
            self.n_encoder_layers,
            self.n_heads
        )

        self.decoder = Decoder(
            self.vocab_size,
            self.model_dim,
            self.dropout,
            self.n_decoder_layers,
            self.n_heads
        )

        self.fc = nn.Linear(self.model_dim, self.vocab_size)


    @staticmethod
    def generate_square_subsequent_mask(size: int):
        # Generate a Triangular (size, size) mask

        mask = (1 - torch.triu(torch.ones(size, size), diagonal = 1)).bool()
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

        return mask


    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input
            x: (B, S) with elements in (0, C) where C is num_classes
        Output
            (B, S, E) embedding
        """

        mask = (x == self.PAD_IDX).float()
        encoder_padding_mask = mask.masked_fill(mask == 1, float('-inf'))

        # (B, S, E)
        encoder_output = self.encoder(
            x,
            padding_mask = encoder_padding_mask
        )

        return encoder_output, encoder_padding_mask


    def decode(
            self,
            target: torch.Tensor,
            memory: torch.Tensor,
            memory_padding_mask=None
        ) -> torch.Tensor:
        """
        B = Batch size
        S = Source sequence length
        L = Target sequence length
        E = Model dimension

        Input
            encoded_x: (B, S, E)
            y: (B, L) with elements in (0, C) where C is num_classes
        Output
            (B, L, C) logits
        """

        mask = (target == self.PAD_IDX).float()
        target_padding_mask = mask.masked_fill(mask == 1, float('-inf'))

        decoder_output = self.decoder(
            target=target,
            memory=memory,
            tgt_mask=self.generate_square_subsequent_mask(target.size(1)),
            target_padding_mask=target_padding_mask,
            memory_padding_mask=memory_padding_mask,

        )
        output = self.fc(decoder_output)  # shape (B, L, C)
        return output


    def forward(
        self,
        x: torch.Tensor,
        y: torch.Tensor
    ) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (B, L, C) logits
        """

        # Encoder output shape (B, S, E)
        encoder_output, encoder_padding_mask = self.encode(x)

        # Decoder output shape (B, L, C)
        decoder_output = self.decode(
            target=y,
            memory=encoder_output,
            memory_padding_mask=encoder_padding_mask,
        )

        return decoder_output


    def predict(
        self,
        x: torch.Tensor,
        sos_idx: int=1,
        eos_idx: int=2,
        max_length: int=None
    ) -> torch.Tensor:
        """
        Method to use at inference time. Predict y from x one token at a time. This method is greedy
        decoding. Beam search can be used instead for a potential accuracy boost.

        Input
            x: str
        Output
            (B, L, C) logits
        """

        # Pad the tokens with beginning and end of sentence tokens
        x = torch.cat([
            torch.tensor([sos_idx]),
            x,
            torch.tensor([eos_idx])]
        ).unsqueeze(0)

        encoder_output, mask = self.transformer.encode(x) # (B, S, E)

        if not max_length:
            max_length = x.size(1)

        outputs = torch.ones((x.size()[0], max_length)).type_as(x).long() * sos_idx
        for step in range(1, max_length):
            y = outputs[:, :step]
            probs = self.transformer.decode(y, encoder_output)
            output = torch.argmax(probs, dim=-1)

            # Uncomment if you want to see step by step predicitons
            # print(f"Knowing {y} we output {output[:, -1]}")

            if output[:, -1].detach().numpy() in (eos_idx, sos_idx):
                break
            outputs[:, step] = output[:, -1]


        return outputs

## Defining a Toy dataset

In [12]:
import numpy as np
import torch
from torch.utils.data import Dataset


np.random.seed(0)

def generate_random_string():
    len = np.random.randint(10, 20)
    return "".join([chr(x) for x in np.random.randint(97, 97+26, len)])

class ReverseDataset(Dataset):
    def __init__(self, n_samples, pad_idx, sos_idx, eos_idx):
        super(ReverseDataset, self).__init__()
        self.pad_idx = pad_idx
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx
        self.values = [generate_random_string() for _ in range(n_samples)]
        self.labels = [x[::-1] for x in self.values]

    def __len__(self):
        return len(self.values)  # number of samples in the dataset

    def __getitem__(self, index):
        return self.text_transform(self.values[index].rstrip("\n")), \
            self.text_transform(self.labels[index].rstrip("\n"))

    def text_transform(self, x):
        return torch.tensor([self.sos_idx] + [ord(z)-97+3 for z in x] + [self.eos_idx])

## Define Training and Evaluation Steps

In [13]:
import tqdm
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2

def train(model, optimizer, loader, loss_fn, epoch):
    model.train()
    losses = 0
    acc = 0
    history_loss = []
    history_acc = []

    with tqdm(loader, position=0, leave=True) as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            optimizer.zero_grad()
            logits = model(x, y[:, :-1])
            loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y[:, 1:].contiguous().view(-1))
            loss.backward()
            optimizer.step()
            losses += loss.item()

            preds = logits.argmax(dim=-1)
            masked_pred = preds * (y[:, 1:]!=PAD_IDX)
            accuracy = (masked_pred == y[:, 1:]).float().mean()
            acc += accuracy.item()

            history_loss.append(loss.item())
            history_acc.append(accuracy.item())
            tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy.item())

    return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc



In [14]:
def evaluate(model, loader, loss_fn):
    model.eval()
    losses = 0
    acc = 0
    history_loss = []
    history_acc = []

    for x, y in tqdm(loader, position=0, leave=True):

        logits = model(x, y[:, :-1])
        loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y[:, 1:].contiguous().view(-1))
        losses += loss.item()

        preds = logits.argmax(dim=-1)
        masked_pred = preds * (y[:, 1:]!=PAD_IDX)
        accuracy = (masked_pred == y[:, 1:]).float().mean()
        acc += accuracy.item()

        history_loss.append(loss.item())
        history_acc.append(accuracy.item())

    return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc

## Train!

In [15]:
def collate_fn(batch):
    """
    This function pads inputs with PAD_IDX to have batches of equal length
    """
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

In [18]:
# Model hyperparameters
args = {
    'vocab_size': 128,
    'model_dim': 128,
    'dropout': 0.1,
    'n_encoder_layers': 1,
    'n_decoder_layers': 1,
    'n_heads': 4
}

# Define model here
model = SimpleTransformer(**args)

 * vocab_size=128
 * model_dim=128
 * dropout=0.1
 * n_encoder_layers=1
 * n_decoder_layers=1
 * n_heads=4


TypeError: Decoder.__init__() missing 1 required positional argument: 'n_heads'

In [None]:
# Instantiate datasets
train_iter = ReverseDataset(50000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)
eval_iter = ReverseDataset(10000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)
dataloader_train = DataLoader(train_iter, batch_size=256, collate_fn=collate_fn)
dataloader_val = DataLoader(eval_iter, batch_size=256, collate_fn=collate_fn)

# During debugging, we ensure sources and targets are indeed reversed
# s, t = next(iter(dataloader_train))
# print(s[:4, ...])
# print(t[:4, ...])
# print(s.size())

# Initialize model parameters
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# Define loss function : we ignore logits which are padding tokens
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

# Save history to dictionnary
history = {
    'train_loss': [],
    'eval_loss': [],
    'train_acc': [],
    'eval_acc': []
}

In [None]:
# Main loop
for epoch in range(1, 4):
    start_time = time.time()
    train_loss, train_acc, hist_loss, hist_acc = train(model, optimizer, dataloader_train, loss_fn, epoch)
    history['train_loss'] += hist_loss
    history['train_acc'] += hist_acc
    end_time = time.time()
    val_loss, val_acc, hist_loss, hist_acc = evaluate(model, dataloader_val, loss_fn)
    history['eval_loss'] += hist_loss
    history['eval_acc'] += hist_acc
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Train acc: {train_acc:.3f}, Val loss: {val_loss:.3f}, Val acc: {val_acc:.3f} "f"Epoch time = {(end_time - start_time):.3f}s"))

## Testing our model!

In [None]:
class Translator(nn.Module):
    def __init__(self, transformer):
        super(Translator, self).__init__()
        self.transformer = transformer

    @staticmethod
    def str_to_tokens(s):
        return [ord(z)-97+3 for z in s]

    @staticmethod
    def tokens_to_str(tokens):
        return "".join([chr(x+94) for x in tokens])

    def __call__(self, sentence, max_length=None, pad=False):

        x = torch.tensor(self.str_to_tokens(sentence))

        outputs = self.transformer.predict(sentence)

        return self.tokens_to_str(outputs[0])

In [None]:
translator = Translator(model)

sentence = "Hello World!"
print(sentence)

output = translator(sentence)
print(output: output)