# Implementing Transformer Models
## Practical IX
Carel van Niekerk & Hsien-Chin Lin

9-20.12.2024

---

In this practical we will implement the training script train the transformer model.

### 1. Essentials of the training script

In the training script the prepared dataset is used to train a model instance. Before training the model, the instance should be initialised, a dataloader should be created from the dataset and the loss function, optimiser and learning rate scheduler should be initialised. 

#### 1.1. Dataloaders

The dataloader is a pytorch class that is used to load data from the dataset in batches. The dataloader is initialised with the dataset and the batch size. The dataloader is then used to iterate over the dataset in batches. The dataloader is used in the training loop to load the data for each batch.

The dataloader for the training data is initialised as follows:

```python
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
```

#### 1.2. Training the model

Once the model, dataloader and loss function are initialised, the model can be trained. The training loop iterates over the batches in the dataloader and performs the following steps:

1. Load the data for the batch.
2. Perform a forward pass through the model.
3. Calculate the loss.
4. Perform a backward pass through the model.
5. Update the model parameters.
6. Update the learning rate.
7. Repeat for the next batch.
8. Repeat for the next epoch.

# Exercises

1. Initialise a small version of your transformer model (do not use more than 4 layers and 64 hidden units unless you have access to sufficient compute).
2. Initialise the dataloader using the dataset class from practical 4.
3. Initialise the loss function (cross entropy loss), optimiser and learning rate scheduler.
4. Implement the training loop.
5. Train the model for 5 epochs and ensure that loss decreases for both the training and validation sets of the dataset. You can use a small randomly selected subset of the training data to speed up training.

In [3]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import tqdm

from dataset import get_costum_dataset
from modelling.model import Transformer
from modelling.lr_scheduler import TransformerLRScheduler
from transformers import GPT2Tokenizer

from dataset import get_costum_dataset

In [4]:
model = Transformer(
    vocab_size=10_000, 
    d_model=512, 
    n_heads=8, 
    num_encoder_layers=4, 
    num_decoder_layers=4, 
    dim_feedforward=2048, 
    dropout=0.1, 
    max_len=64
)

loss = nn.CrossEntropyLoss()

"""no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=1e-3)
scheduler = TransformerLRScheduler(optimizer, d_model=512, warmup_steps=100)"""

tokenizer = GPT2Tokenizer.from_pretrained("modelling/bpe_v=10000_l=64")

ds_train = get_costum_dataset(split='train', tokenize=True) # 871_399 samples
ds_val = get_costum_dataset(split='validation', tokenize=True) # 543 samples

print(len(ds_train))
print(len(ds_val))

#build subsets
ds_train = torch.utils.data.Subset(ds_train, range(100_000))


Using the latest cached version of the dataset since wmt17 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'de-en' at /Users/leonmarkwart/.cache/huggingface/datasets/wmt17/de-en/0.0.0/54d3aacfb5429020b9b85b170a677e4bc92f2449 (last modified on Thu Jan 16 01:41:05 2025).
Using the latest cached version of the dataset since wmt17 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'de-en' at /Users/leonmarkwart/.cache/huggingface/datasets/wmt17/de-en/0.0.0/54d3aacfb5429020b9b85b170a677e4bc92f2449 (last modified on Thu Jan 16 01:41:05 2025).


871399
543


In [5]:
sample = ds_val[50]
print(sample.keys())
print(tokenizer.decode(sample['src_input']))
print(tokenizer.decode(sample['tgt_input']))
print(tokenizer.decode(sample['tgt_output']))

def collate_fn(sample):
    src_input = [s['src_input'] for s in sample]
    tgt_input = [s['tgt_input'] for s in sample]
    tgt_output = [s['tgt_output'] for s in sample]
    return {
        'src_input': torch.tensor(src_input),
        'tgt_input': torch.tensor(tgt_input),
        'tgt_output': torch.tensor(tgt_output)
    }
    
train_loader = DataLoader(ds_train, batch_size=64, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader = DataLoader(ds_val, batch_size=64, shuffle=False, collate_fn=collate_fn, drop_last=True)

dict_keys(['translation', 'src_input', 'tgt_input', 'tgt_output'])
Sie würde es erlauben, gegen Urteile Einspruch zu erheben. [EOS] [PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[BOS] This would allow for appeals to be made against judgements. [EOS] [PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
This would allow for appeals to be made against judgements. [EOS] [PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]


In [11]:
from torchshow import show

train_losses = []
val_losses = []

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
model = torch.nn.Transformer(d_model=64, nhead=8, num_encoder_layers=4, num_decoder_layers=4, dim_feedforward=2048, dropout=0.1)

for epoch in range(10):
    model.train()
    #move model to device
    model#.to('mps')
    train_loss = 0
    pbar = tqdm.tqdm(train_loader, desc=f"Training Epoch {epoch}")
    for i, batch in enumerate(pbar):

        src_input = batch['src_input'].float()#.to('mps')
        tgt_input = batch['tgt_input'].float()#.to('mps')
        tgt_output = batch['tgt_output'].float()#.to('mps') # [B, max_len]

        #mask padding
        src_input_mask = (src_input != 0).float()
        tgt_input_mask = (tgt_input != 0).float()


        optimizer.zero_grad()
        output = model(src_input, tgt_input, src_input_mask, tgt_input_mask)
        print(output.shape)
        print(tgt_output.shape)
        loss_step = loss(output, tgt_output)
        #loss_step = loss(output.view(-1, output.size(-1)), tgt_output.view(-1)) # [B, max_len], [B, max_len, vocab_size]

        loss_step.backward()
        optimizer.step()
        #scheduler.step()
        train_loss += loss_step.item()
        pbar.set_postfix({'loss': loss_step.item()})
        pbar.update()
    train_losses.append(train_loss / len(train_loader))

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            src_input = batch['src_input'].float()
            tgt_input = batch['tgt_input'].float()
            tgt_output = batch['tgt_output'].float()

            src_input_mask = (src_input != 0).float()
            tgt_input_mask = (tgt_input != 0).float()

            output = model(src_input, tgt_input, src_input_mask, tgt_input_mask)
            loss_step = loss(output, tgt_output)
            val_loss += loss_step.item()
    val_losses.append(val_loss / len(val_loader))


    print(f"Epoch {epoch}: Train Loss = {train_losses[-1]}, Validation Loss = {val_losses[-1]}")



torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])





torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])


Training Epoch 0:   7%|▋         | 113/1562 [00:03<00:52, 27.79it/s, loss=6.72e+4][A

torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])




torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])


Training Epoch 0:   7%|▋         | 107/1562 [00:04<00:57, 25.13it/s, loss=7.15e+4]


torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])


KeyboardInterrupt: 

In [16]:
import pytorch_lightning as pl

class TransformerModel(pl.LightningModule):


    def __init__(self):
        super().__init__()
        self.model = Transformer(
            vocab_size=10_000, 
            d_model=64, 
            n_heads=8, 
            num_encoder_layers=4, 
            num_decoder_layers=4, 
            dim_feedforward=2048, 
            dropout=0.1, 
            max_len=64
        )

        self.loss = torch.nn.CrossEntropyLoss()
        #self.scheduler = scheduler
        self.lr = 1e-3

        #self.last_feed_forward = self.model.head[0].weight.detach().clone()

    def forward(self, src_input, tgt_input):
        src_input_mask = (src_input != 0)#.float()
        tgt_input_mask = (tgt_input != 0)#.float()
        return self.model(src_input, tgt_input, src_input_mask, tgt_input_mask)
    
    def training_step(self, batch, batch_idx):
        src_input = batch['src_input']#.float()
        tgt_input = batch['tgt_input']#.float()
        tgt_output = batch['tgt_output']#.float()

        output = self(src_input, tgt_input)

        print(output.shape)
        print(tgt_output.shape)
        loss_step = self.loss(output.long(), tgt_output.long())  # ((B, max_len), (B, max_len))
        self.log('train_loss', loss_step, prog_bar=True, on_epoch=False, on_step=True, logger=True)
        self.log('lr', self.lr, prog_bar=True, on_epoch=False, on_step=True, logger=True)
        

        return loss_step
    
    '''def on_after_backward(self):
        global_step = self.global_step
        for name, param in self.model.named_parameters():
            self.logger.experiment.add_histogram(name, param, global_step)
            if param.requires_grad:
                self.logger.experiment.add_histogram(f"{name}_grad", param.grad, global_step)
            else:
                print(f"NO GRADIENTS FOR {name}")'''

    def validation_step(self, batch, batch_idx):
        src_input = batch['src_input']#.float()
        tgt_input = batch['tgt_input']#.float()
        tgt_output = batch['tgt_output']#.float()

        output = self(src_input, tgt_input)

        loss_step = self.loss(output, tgt_output)
        self.log('val_loss', loss_step, prog_bar=True, on_epoch=True, on_step=False, logger=True)
        return loss_step
    
    def configure_optimizers(self):
        """no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]"""
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=0.0)
        return optimizer
    
    def train_dataloader(self):
        return train_loader
    
    def val_dataloader(self):
        return val_loader
    
trainer = pl.Trainer(
    overfit_batches=100, 
    enable_checkpointing=False,
    log_every_n_steps=10,
    callbacks=[pl.callbacks.LearningRateMonitor(logging_interval='step')],
    
    )
transformer = TransformerModel()

#tuner = pl.tuner.tuning.Tuner(trainer)
#lr_finder = tuner.lr_find(transformer)
#fig = lr_finder.plot(suggest=True)
#plt.savefig('lr_finder.png')

trainer.fit(transformer)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/leonmarkwart/miniconda3/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | Transformer      | 4.2 M  | train
1 | loss  | CrossEntropyLoss | 0      | train
---------------------------------------------------
4.2 M     Trainable params
0         Non-trainable params
4.2 M     Total params
16.933    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/leonmarkwart/miniconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


RuntimeError: Expected target size [64, 10000], got [64, 64]

In [None]:
src_input = tokenizer.encode("Hallo, ich heiße Leon und bin 27 Jahre alt.[EOS]", max_length=64, truncation=True, padding='max_length')
tgt_input = tokenizer.encode("Hello, my name is Leon and i am 27", max_length=64, truncation=True, padding='max_length')
tgt_output = tokenizer.encode("[BOS]Hello, my name is Leon and i am 27 years", max_length=64, truncation=True, padding='max_length')
print(src_input)
print(tokenizer.decode(src_input))
print(tgt_input)
print(tokenizer.decode(tgt_input))
print(tgt_output)
print(tokenizer.decode(tgt_output))

src_input = torch.tensor(src_input).unsqueeze(0).to('mps')
tgt_input = torch.tensor(tgt_input).unsqueeze(0).to('mps')
tgt_output = torch.tensor(tgt_output).unsqueeze(0).to('mps')

output = model(src_input, tgt_input)
output = output.argmax(dim=-1)
print(output)