# Attempt at supervised transformer
Trying to implement a transformer as per the [("Attention is All You Need", Wasrani et al. (2016))](https://arxiv.org/abs/1706.03762) paper and the **week 5 notebook**.

## Setup

### All your imports are belong to us!

In [2]:
import torch
import math
from torch import nn
import copy

### Constant definition and other setup

In [3]:
### define the device to use
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Device: {DEVICE}")

Device: cpu


## Model implementation

### subsequent mask
This is a mask for the first attention block of the **decoder**. Because that attention block needs to not factor in future output. It can only use output generated so far

In [None]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

### Decoder layer

In [None]:
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

### Singular attention *implementation*

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

### Multi Headed Attention

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)

### MLP
Both the decoder and encoder conatin a FFN block, that contains of two linear layers
and a dropout layer

In [None]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.w_1(x)
        x = x.relu()
        x = self.dropout(x)
        x = self.w_2(x)
        return x

### Embeddings

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

### Positional encoding **(Maybe this shouldn't be in Model implementation)**

In [None]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        ### Some weird constant..?
        LOG_CONSTANT = math.log(10000)
        
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(LOG_CONSTANT / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

### Make the model

In [None]:
def make_model(
    src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1
):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    # NOTE:
    # The forward function of the EncoderDecoder class is what is
    # called when the model is called.
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab),
    )

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

## Loss implementation

In [None]:
EPOCHS = 5
num_steps = 5 # 5_000
step = 0
epoch = 0

for epoch in range(len(EPOCHS)):
    for batch in train_loader:
        # concatenate the `token_ids``
        batch_token_ids = make_batch(batch)
        batch_token_ids = batch_token_ids.to(DEVICE)

        # forward through the model
        optimizer.zero_grad()
        batch_logits = rnn2(batch_token_ids)

        # compute the loss (negative log-likelihood)
        p_ws = torch.distributions.Categorical(logits=batch_logits) 

        # Exercise: write the loss of the RNN language model
        # hint: check the doc https://pytorch.org/docs/stable/distributions.html#categorical
        # NB: even with the right loss, training is slow and the generated samples won't be very good.
        #
        # NOTE:
        # We need to find the negative log-likelihood, which we do by utilising the logarithmic_probabilities, function
        # of a Categorical object. By summing we take the logarithmic probabilities down to one-dimension, then we 
        # elect to scale down to a scalar by finding the mean of this one-dimensional vector:
        loss = -torch.sum(p_ws.log_prob(batch_token_ids), dim=1).mean()

        # backward and optimize
        loss.backward()
        optimizer.step()
        step += 1
        pbar.update(1)

        # Report
        if step % 5 ==0 :
            loss = loss.detach().cpu()
            pbar.set_description(f"epoch={epoch}, step={step}, loss={loss:.1f}")

        # save checkpoint
        if step % 50 ==0 :
            torch.save(rnn.state_dict(), checkpoint_file)
        if step >= num_steps:
            break
    epoch += 1