# Attempt at supervised transformer
Trying to implement a transformer as per the [("Attention is All You Need", Wasrani et al. (2016))](https://arxiv.org/abs/1706.03762) paper and the **week 5 notebook**.

## Setup

### All your imports are belong to us!

In [None]:
import torch

### Constant definition and other setup

In [3]:
# define the device to use
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Device: {DEVICE}")

Device: cpu


## Loss implementation

In [None]:
EPOCHS = 5
num_steps = 5 # 5_000
step = 0
epoch = 0

for epoch in range(len(EPOCHS)):
    for batch in train_loader:
        # concatenate the `token_ids``
        batch_token_ids = make_batch(batch)
        batch_token_ids = batch_token_ids.to(DEVICE)

        # forward through the model
        optimizer.zero_grad()
        batch_logits = rnn2(batch_token_ids)

        # compute the loss (negative log-likelihood)
        p_ws = torch.distributions.Categorical(logits=batch_logits) 

        # Exercise: write the loss of the RNN language model
        # hint: check the doc https://pytorch.org/docs/stable/distributions.html#categorical
        # NB: even with the right loss, training is slow and the generated samples won't be very good.
        #
        # NOTE:
        # We need to find the negative log-likelihood, which we do by utilising the logarithmic_probabilities, function
        # of a Categorical object. By summing we take the logarithmic probabilities down to one-dimension, then we 
        # elect to scale down to a scalar by finding the mean of this one-dimensional vector:
        loss = -torch.sum(p_ws.log_prob(batch_token_ids), dim=1).mean()

        # backward and optimize
        loss.backward()
        optimizer.step()
        step += 1
        pbar.update(1)

        # Report
        if step % 5 ==0 :
            loss = loss.detach().cpu()
            pbar.set_description(f"epoch={epoch}, step={step}, loss={loss:.1f}")

        # save checkpoint
        if step % 50 ==0 :
            torch.save(rnn.state_dict(), checkpoint_file)
        if step >= num_steps:
            break
    epoch += 1