In [55]:
import sys

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

sys.path.append('../')  # make sure we can import transformer_lm

# Training a transformer language model

In this notebook, we will learn how to

1. preprocess data for language modeling
2. use `torch.utils.data` to handle batching in an efficient and standard way
3. train a transformer language model

Specifically, we will use the Tiny Shakespeare dataset, which contains the complete works of William Shakespeare, to train a language model. The goal of this notebook is to walk you through the steps of pre-processing the dataset and preparing it for training using the PyTorch DataLoader, creating a language model, training it and using it to generate text.

We will train a character-based langauge model instead of word-based, because:

1. It's faster to train it to the point that it can generate text
2. We don't want to complicate the homework with BPE tokenization
3. We work with a small dataset which might not be enough to train a word-based language model

> Feel free to try training a word-based language model on a larger dataset, such as the WikiText-2 dataset, which is available in the hugginface datasets library.

# Step 1: Load and Explore the Dataset
The first step is to load the dataset and explore it. In this example, we will use the Tiny Shakespeare dataset, which contains the complete works of William Shakespeare. We can download the dataset from the following URL: https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

Feel free to use `wget` to download the dataset or just download the file manually and upload it to your Colab instance.

Here's how you can use `wget` to download the dataset:
```
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O tiny_shakespeare.txt
```

## Coding task 3.1: load the data and take a look

Read the file to a variable named `raw_data` and print the first 1000 characters.

### Grading criteria
**(1 point max)**

1 point if everything works

In [56]:
with open("tiny_shakespeare.txt", "r") as f:
    raw_data = f.read()

print(raw_data[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



## Inline question 3.1: raw text preprocessing
**(1 point max, 1 extra point for creative ideas)**

Think about how you can pre-process the data (in terms of modifying the text). Provde three ideas and explain why you think they are useful or not. Think about the size of the data, tokenization method (we will use character-level language model), your computational resources, and what kind of text you want to generate. Make this answer as extensive as possible.

***Your answer:*** 
1. Convert all characters into lowercase. This can reduce the size of the vocabulary by 26 without too much loss of information. Capitalization has two primary use cases here: start of the sentence capitalization and proper nouns. The first function can be replaced by other punctuation or special characters, and the second function doesn't seem to be necessary since we can recognize the names and places without capitalization.
2. Remove newline characters and replace them with a special token when necessary. We don't need a newline character after every ":" and dialogue line. So instead of  
   "\<person1\>: <span style="color:red">\n</span>  
   \<some words\><span style="color:red">\n</span>  
   \<some other words\><span style="color:red">\n</span>  
   <span style="color:red">\n</span>  
   \<person2\>:<span style="color:red">\n</span>  
   ...,"  
we can do 
"\<person1\>: \<some words\>\<some other words\><span style="color:red">\<special token\></span> \<person2\>: \<some words\>\<some other words\>..."
   This not only reduces the size of the corpus but also preserves the separation between dialogue lines.
3. The text contains special symbols, such as "$" and "&", but the uses don't make much sense. For example, I found the following in the text:  
   a. "Now stops thy spring; my sea sha$l suck them dry,"  
   b. "England and France, and lord of Ireland, &c."  
The single occurrence of "$" seems to be a typo for the letter 'l', so we can just replace "$" with "l." The symbol '&' appears only three times, and every occurrence is followed by a letter "c." "&c" stands for "et cetera," so we can replace "&c" with "etc." By performing those two substitutions, we further reduce the vocabulary size by 2.

# Step 2: preparing the data for the model

## Coding task 3.2
Similar to previous homeworks, where we made a vocabualry of words, we will make a vocabulary of characters.

1. Make a vocabulary of all characters
2. Make `char2idx`
3. Make a class `Tokenizer` that stores `char2idx` and has two methods: `encode` and `decode` that encode and decode text using `char2idx` and `idx2char` dictionaries.
   * You might find it useful to create `idx2char` dictionary inside the `__init__` method of the `Tokenizer` class.
4. Create a `Tokenizer` object
5. Convert the text to a list of integers using `char2idx`, assign it to a variable named `data`
6. Print the first 100 items of `data`

It's useful to have a function that converts a sequence of indices to a string. You will need it to convert the output of the model to a text when you will be generating text, but is it also very useful for **debugging** your pre-processing code.

### Grading criteria
**(2 points max)**

1. 1 point for `char2idx` dictionary
2. 1 point for `Tokenizer` class that passes the tests below

In [57]:
# YOUR CODE STARTS HERE (our implementation is about 4 lines using comprehensions, but it's alright if yours is longer)
char2idx = {ch: i for i, ch in enumerate(set(raw_data))}
class Tokenizer:
    def __init__(self, char2idx):
        self.char_to_idx = char2idx
        self.idx_to_char = {i: ch for ch, i in char2idx.items()}

    def encode(self, string):
        return [self.char_to_idx[ch] for ch in string]

    def decode(self, tokens):
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.flatten().tolist()
        return ''.join([self.idx_to_char[i] for i in tokens])
# YOUR CODE ENDS HERE

In [58]:
_tokenizer = Tokenizer(char2idx)

_token_ids = _tokenizer.encode("hello")
_text = _tokenizer.decode(_token_ids)

assert isinstance(_token_ids, list), "token_ids should be a list"
assert isinstance(_token_ids[0], int), "token_ids should be a list of integers"
assert _text == "hello", "decode should work correctly and return the original text"

del _tokenizer, _token_ids, _text

# Chunk the data

Our data is too long to be processed in one go. We will split it into chunks of length 128. We will use the first 128 characters to predict the next character. This is a decent length for a sequence, but you can play with it if you want.

## Coding task 3.3

1. Create a list of sequences of length `MAX_LEN + 1`. Each sequence should be a list of integers. You'll see why we need `+ 1` in a minute.
   * You might need to get rid of your last example if it's shorter than `MAX_LEN + 1` characters. We need all data to be of the same length to simplify batching.
   * In the next homework we will implement batchihg for sequences of different lengths and you are probably not going to enjoy it, it's a bit tricky.
2. Split the data into training and validation sets. Use 90% of the data for training and 10% for validation.
3. Make x and y pairs for your data. Remember that we want to use the first 128 characters to predict the next character. So, `x` should be the first 128 characters and `y` should be a shifted version of the same sequence, so it's the last 128 characters. Name them `train_x` and `train_y` for the training set and `val_x` and `val_y` for the validation set.
4. Print an example from the training set. You should see that the first 128 characters are the same as the first 128 characters of the original text, and the last 128 characters are the same as the last 128 characters of the original text, shifted by one character.

You can just stride using `data[i:i+128]` for each `i` in `range(0, len(data), 128)`, no need to do anything fancy. You can figure out more complex ways to do it, just do this after all the homework is done. You receive no extra points if your homework is not finished.

### Grading criteria

1. 1 point for `data_chunks` list and train-test split
2. 1 point for dataset and dataloader objects
3. Extra point for a more interesting way to chunk the text
4. Extra point for implementing a custom dataset class

In [59]:
MAX_LEN = 128

# YOUR CODE STARTS HERE (our implementation is about 13 lines, but it's alright if yours is different)
data_chunks = [raw_data[i:i+MAX_LEN + 1] for i in range(0, len(raw_data), MAX_LEN)]
data_chunks = [(chunk[0: MAX_LEN], chunk[1:]) for chunk in data_chunks if len(chunk) == MAX_LEN + 1]

train_data = data_chunks[:int(len(data_chunks) * 0.8)]
val_data = data_chunks[int(len(data_chunks) * 0.8):]
train_x, train_y = zip(*train_data)
val_x, val_y = zip(*val_data)

print(f"First Train Example:\nx:\"{train_x[0]}\"\n->\ny:\"{train_y[0]}\"")
# YOUR CODE ENDS HERE

First Train Example:
x:"First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to "
->
y:"irst Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to d"


# Using `torch.utils.data`

We will use `torch.utils.data.Dataset` to create a dataset object that will be used to create a `torch.utils.data.DataLoader` object. The `DataLoader` object will be used to create batches of data.

## Coding task 3.4

Your task is to learn how to use `torch.utils.data.Dataset` and `torch.utils.data.DataLoader` classes and to apply them to our data.

1. Convert your data to tensors of type long
1. Create a `torch.utils.data.Dataset` object for each train and test data. Name them `train_dataset` and `val_dataset`. You can use the `TensorDataset` class for this or make a new class that inherits from `torch.utils.data.Dataset` and implements the `__getitem__` and `__len__` methods.
2. Try indexing `train_dataset` to get a single example and decode it using `tokenizer.decode()`. What does it contain? Use tokenizer to decode one example (both x and y). Does it look like a valid text? Are the targets shifted by one character?
1. Use the `DataLoader` class to create `train_loader` and `val_loader` objects. It will shuffle and batch data for you. You can use the following parameters:
   * `dataset` - the dataset object you created in the previous step
   * `batch_size` - your choice!
   * `shuffle` - True for training data, False for validation data
   * `num_workers` - 8, number of CPU cores to use for batch preparation
3. Try iterating over `train_loader` and print the shapes of the batches.
    * You can use `break` to stop the loop after the first iteration.
4. Try decoding a batch that you get from `train_loader`. Does it look like a valid text? Are the targets shifted by one character?

Learn more about data feeding in pytorch here: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html


**NOTE:**
1. `TensorDataset` returns a tuple of tensors. Usually these are `(x, y)` pairs, where `x` is the input and `y` is the target. In our case, `x` is the input sequence and `y` is the same sequence shifted by one character. This is how we will train our language model. We will use the first 128 characters to predict the next character.
1. You need to convert your pytorch tensor into a python list in order to use `tokenizer.decode()`. Feel free to do it in-place or modify the `decode` method of the `Tokenizer` class to accept **BOTH** python lists and pytorch tensors. You can check what datatype you have using `isinstance()` function.
2. Printing might look a bit weird because you have a lot of `\n` in the data. It is alright, just be careful when you are verifying that your data is correct.

### Grading criteria

* 1 point for `train_dataset` and `val_dataset` objects
* 1 point if each test is written and passed:
  * train dataset element is correctly processed and x and y corespond to the correct characters
  * printed the shapes of the items that you get from `train_loader`
  * decoded a batch from `train_loader` and printed the decoded text and it is correct

In [60]:
BATCH_SIZE = 3  # think about a better batch size for training, this is just a placeholder

# YOUR CODE STARTS HERE (our implementation is about 13 lines)
BATCH_SIZE = 160
def encode_dataset(dataset, tokenizer):
    return [tokenizer.encode(s) for s in dataset]

tokenizer = Tokenizer(char2idx)
train_dataset = torch.utils.data.TensorDataset(
    torch.tensor(encode_dataset(train_x, tokenizer), dtype=torch.long),
    torch.tensor(encode_dataset(train_y, tokenizer), dtype=torch.long))
val_dataset = torch.utils.data.TensorDataset(
    torch.tensor(encode_dataset(val_x, tokenizer), dtype=torch.long),
    torch.tensor(encode_dataset(val_y, tokenizer), dtype=torch.long))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

for x, y in train_loader:
    print(x.shape, y.shape)
    print(tokenizer.decode(x[0]), y[0])
    break


# YOUR CODE ENDS HERE

torch.Size([160, 128]) torch.Size([160, 128])
GLOUCESTER:
Else wherefore breathe I in a Christian land?

BUCKINGHAM:
Then know, it is your fault that you resign
The supreme s tensor([54, 18, 22,  1, 57, 55, 14, 57, 21, 39, 56, 57, 51, 44, 31, 61, 15,  8,
        31,  2, 31, 53, 34,  2, 31, 61,  7,  2, 31,  6,  3,  8, 31, 61, 35, 61,
        38, 58, 61,  6, 61,  1,  8,  2, 38, 44,  3, 38,  6, 58, 61, 51,  6, 58,
        60, 47, 56, 56, 40, 22,  1, 59, 35, 10, 48, 32, 52,  5, 39, 56, 14,  8,
        31, 58, 61, 63, 58, 34, 15, 41, 61, 38,  3, 61, 38, 44, 61,  0, 34, 46,
         2, 61, 53,  6, 46, 51,  3, 61,  3,  8,  6,  3, 61,  0, 34, 46, 61,  2,
        31, 44, 38, 20, 58, 56, 14,  8, 31, 61, 44, 46, 43,  2, 31, 28, 31, 61,
        44, 31])


In [61]:
print(len(train_loader), len(val_loader))

44 11


# Train a Transformer model

Import your `TransformerLM` model from `modeling_transormer` file and train it on the data you prepared above.
You know the drill: define a model, an optimizer, and a training loop, log everything to wandb.
You can also save your model using `TransformerLM.save_pretrained()` method and load it using `TransformerLM.from_pretrained()` method in case you want to.

### Tricky part

In PyTorch, `F.cross_entropy` expects the logits to be of shape `(batch_size, num_classes)` and the targets to be of shape `(batch_size,)` containing the class indices. In our case, the logits tensor has the shape `(batch_size, seq_len, num_classes)` and the targets are of shape `(batch_size, seq_len)`. We need to reshape the input and the targets to make them compatible with `F.cross_entropy`. You can do it like this:

```python
bs, seq_len, num_classes = logits.shape
logits = logits.reshape(bs * seq_len, num_classes)
targets = targets.reshape(bs * seq_len)
```

or, equivalently, like this:

```python
logits = logits.view(-1, num_classes)
targets = targets.view(-1)
```

Try monitoring your GPU consumption and max it out. The more efficient your code is, the faster your model will train.
During training log your loss and and accuracy. You can only log accuracy every 100 batches or so, because it is a bit slow to compute. You can also log the learning rate.
During evlauation you just need to log the perplexity, the loss, and accuracy. Perplexity is just `exp(loss)`.
Accuracy is not the most standard metric for language models, but it is very intererpretable and easy to compute. Don't expect it to be high, though.
Be mindful how frequenly you evaluate your model. You don't want to evaluate it too often, because it will slow down your training loop.

> You can also log the number of batches you process in one second (throughput) as a measure of efficiency. It is not required, but it is a good idea to monitor it.

## Coding task 3.5

Make a training loop and train your model.

### Grading criteria
**(5 points + extra points)**

* 2 points for trainig loop
* 1 point for using the GPU
* 1 point for evaluation loop (we recommend to make it into a separate function to make your code more readable)
* 1 point for wandb logging of train loss, eval loss, train accuracy, eval accuracy, eval perplexity. You can also log the learning rate, but it is not required.
* -1 point if forget to zero your gradients between batches
* -1 point if your forget to put your model to evaluation mode during evaluation and back to training mode during training
* Extra point for using a learning rate scheduler
* Extra point for any other improvements to the training loop


In [62]:
# YOUR CODE STARTS HERE
from transformer_lm.modeling_transformer import TransformerLM
import wandb

def evaluate(model, val_loader, device=torch.device("cuda")):
    """Evaluate the model on the validation set. Return the loss, accuracy, and perplexity.

    Args:
    model: The model to evaluate.
    val_loader: DataLoader for the validation set.
    
    Returns:
    loss: The average loss on the validation set.
    accuracy: The accuracy on the validation set.
    perplexity: The perplexity on the validation set.
    """
    with torch.no_grad():
        model.eval()
        total_loss = 0
        correct_preds = 0
        total_preds = 0
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            logits = logits.view(-1, len(char2idx))
            targets = y.view(-1)
            loss = F.cross_entropy(logits, targets)
            total_loss += loss.item()
            correct_preds += (logits.argmax(1) == targets).sum().item()
            total_preds += targets.shape[0]
            
        loss = total_loss / len(val_loader)
        accuracy = correct_preds / total_preds
        perplexity = np.exp(loss)
        model.train()
        return loss, accuracy, perplexity

def noam_schedule(scale_factor, hidden_size, step, warmup_steps=1500):
    """Learning rate schedule from the "Attention is All You Need" paper.

    Args:
    scale_factor: The scale factor to apply to the learning rate.
    hidden_size: The hidden size of the model.
    step: The current step in training.
    warmup_steps: The number of warmup steps (default is 4000).
    
    Returns:
    lr: The learning rate for the current step.
    """
    step = max(step, 1)
    return scale_factor * (hidden_size ** (-0.5) * min(step ** (-0.5), step * (warmup_steps ** (-1.5))))

def train(model, train_loader, val_loader, optimizer, config=None, log_to_wandb=False, device=torch.device("cuda")):
    """Train the model on the training set and evaluate it on the validation set.
    
    Args:
    model: The model to train.
    train_loader: DataLoader for the training set.
    val_loader: DataLoader for the validation set.
    optimizer: The optimizer to use for training.
    config: A dictionary containing hyperparameters for training.
    log_to_wandb: Whether to log training and validation metrics to Weights & Biases.
    """
    # log hyperparameters
    if log_to_wandb:
        wandb.login()
        wandb.init(project=config["project"], config=config)

    model.train()
    batch_count = 0
    best_val_acc = -1
    for _ in range(config["num_epochs"]):
        total_preds = 0
        correct_preds = 0
        for x, y in train_loader:
            batch_count += 1
            
            # learning rate schedule
            lr = noam_schedule(config["scale_factor"], config["hidden"], batch_count)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
            # training step
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            logits = logits.view(-1, len(char2idx))
            targets = y.view(-1)
            loss = F.cross_entropy(logits, targets)
            loss.backward()
            optimizer.step()
            
            # metrics update
            total_preds += targets.shape[0]
            correct_preds += (logits.argmax(1) == targets).sum().item()

            # log metrics
            if batch_count % 10 == 0:
                wandb.log({"learning_rate": lr}, step=batch_count)

                # train metrics
                train_acc = correct_preds / total_preds
                loss = loss.item()
                print(f"Batch {batch_count}, Train accuracy: {train_acc:.4f}, Train Loss: {loss:.4f}")
                wandb.log({"train_loss": loss, "train_accuracy": train_acc}, step=batch_count)
                
                # val metrics
                eval_loss, eval_accuracy, eval_perplexity = evaluate(model, val_loader)
                model.train()
                print(f"Val Loss: {eval_loss:.4f}, Val Accuracy: {eval_accuracy:.4f}, Val Perplexity: {eval_perplexity:.4f}")
                wandb.log({"val_loss": eval_loss, "val_accuracy": eval_accuracy, "val_perplexity": eval_perplexity}, step=batch_count)
                if best_val_acc < eval_accuracy:
                    best_val_acc = eval_accuracy
    # end training loop
    if log_to_wandb:
        wandb.run.summary["best_val_accuracy"] = best_val_acc



    


config = {
    "project": "hw-5",
    "num_layers": 4,
    "hidden": 64,
    "num_heads": 8,
    "fcn_hidden": 64,
    "vocab_size": len(char2idx),
    "max_seq_len": MAX_LEN,
    "num_epochs": 50,
    "lr_schedule": "Noam",
    "scale_factor": 1,
    "optimizer": "adam",
    
}

device = torch.device("cuda")
model = TransformerLM(
    num_layers=config["num_layers"],
    hidden=config["hidden"],
    num_heads=config["num_heads"],
    fcn_hidden=config["fcn_hidden"],
    vocab_size=config["vocab_size"],
    max_seq_len=config["max_seq_len"],
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config["scale_factor"])
train(model, train_loader, val_loader, optimizer, config, True)
wandb.finish()

# YOUR CODE ENDS HERE

Batch 10, Train accuracy: 0.0132, Train Loss: 4.2713
Val Loss: 4.2616, Val Accuracy: 0.0130, Val Perplexity: 70.9226
Batch 20, Train accuracy: 0.0144, Train Loss: 4.2427
Val Loss: 4.2276, Val Accuracy: 0.0159, Val Perplexity: 68.5552
Batch 30, Train accuracy: 0.0163, Train Loss: 4.1908
Val Loss: 4.1720, Val Accuracy: 0.0247, Val Perplexity: 64.8422
Batch 40, Train accuracy: 0.0203, Train Loss: 4.1075
Val Loss: 4.0896, Val Accuracy: 0.0500, Val Perplexity: 59.7149
Batch 50, Train accuracy: 0.0703, Train Loss: 3.9926
Val Loss: 3.9677, Val Accuracy: 0.0976, Val Perplexity: 52.8620
Batch 60, Train accuracy: 0.0943, Train Loss: 3.8267
Val Loss: 3.7934, Val Accuracy: 0.1387, Val Perplexity: 44.4084
Batch 70, Train accuracy: 0.1126, Train Loss: 3.6154
Val Loss: 3.5916, Val Accuracy: 0.1499, Val Perplexity: 36.2926
Batch 80, Train accuracy: 0.1230, Train Loss: 3.4635
Val Loss: 3.4398, Val Accuracy: 0.1511, Val Perplexity: 31.1804
Batch 90, Train accuracy: 0.1478, Train Loss: 3.3587
Val Loss: 3

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
learning_rate,▁▁▁▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇███████▇▇▇▇▇▇▇▇
train_accuracy,▁▂▃▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇██████████████
train_loss,█▇▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▃▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██████████████
val_loss,█▇▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_perplexity,█▅▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
best_val_accuracy,0.45903
learning_rate,0.00267
train_accuracy,0.46947
train_loss,1.76145
val_accuracy,0.45871
val_loss,1.83996
val_perplexity,6.2963


# Generate text using your model

Now it's time to see what this model can do. Implement a generation function.
The idea is to start with some prefix text, predict the next character, append it to the prefix, and repeat the process.
You can stop generating text when you reach MAX_LEN tokens.

Use `torch.no_grad()` context manager to make sure that you don't compute gradients during generation, or it will blow up your GPU memory.

## Coding task 3.6

Implement a generation function that accepts a prefix text and generates the next tokens up to MAX_LEN.

### Grading criteria
**(2 points)**

* 2 points for generation function
* -1 point if you forget to put your model to evaluation mode during generation and back to training mode after generation or if you forget to use `torch.no_grad()` context manager, or if you are not using the GPU.

In [65]:
# YOUR CODE STARTS HERE (our implementation is about 10 lines)

import torch
import torch.nn.functional as F
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
def generate_text(model, start_text):
    model.eval()
    with torch.no_grad():
        x = torch.tensor([tokenizer.encode(start_text)], dtype=torch.long).to(device)
        for _ in range(128 - len(start_text)):
            logits = model(x)
            logits = logits[0, -1, :]
            next_token = torch.argmax(logits).reshape(1,1)
            x = torch.cat([x, next_token], dim=1)
    model.train()
    return tokenizer.decode(x)
    

print(generate_text(model, "LADY"))
# YOUR CODE ENDS HERE

LADY ANNE:
Therefore the common of the world,
The world of the world of the world of the world.

GLOUCESTER:
I will not thee wit


# Exploring hyperparameters and understanding Transformers

Train at least 10 models with different hyperparameters and compare them using wandb. Write a short report (500-1000 words).


### Grading criteria
**(5 points max + extra points)**

* 4 points for training 10+ models. (5-9 models = 2 points, 1-4 models = 1 point)
* 1 point for training report that describes what you did and what you learned about the hyperparameters and efficient training.