In [22]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DefaultDataCollator, DataCollatorWithPadding
from torch.utils.data import DataLoader
from transformer.transformer import Transformer


In [29]:
wmt14 =  load_dataset("wmt/wmt14", "de-en")
train_subset = wmt14['train'].select(range(10000))
de_tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
en_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

special_tokens_dict = {'bos_token': '<s>', 'eos_token': '</s>'}
en_tokenizer.add_special_tokens(special_tokens_dict)



2

In [31]:
len(en_tokenizer)

30524

In [39]:
en_tokenizer.convert_tokens_to_ids("lol")

100

In [None]:
def preprocess_function(examples):
    de_texts = [translation['de'] for translation in examples['translation']]
    en_texts = [translation['en'] for translation in examples['translation']]
    
    de_inputs = de_tokenizer(de_texts, truncation=True, max_length=128, padding="max_length")
    en_inputs = en_tokenizer(en_texts, truncation=True, max_length=128, padding="max_length")
    
    model_inputs = {
        "input_ids": de_inputs["input_ids"],
        "attention_mask": de_inputs["attention_mask"],
        "decoder_input_ids": en_inputs["input_ids"],
        "decoder_attention_mask": en_inputs["attention_mask"],
        "labels": en_inputs["input_ids"].copy()
    }
    
    return model_inputs

In [40]:
en_tokenizer.special_tokens_map

{'bos_token': '<s>',
 'eos_token': '</s>',
 'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [20]:
tokenized_dataset = train_subset.map(preprocess_function, batched=True, batch_size=1000)
#tokenized_dataset.set_format("torch")
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"])

data_collator = DefaultDataCollator()
train_dataloader = DataLoader(
    tokenized_dataset, shuffle=True, batch_size=8, collate_fn=data_collator
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [53]:
model = Transformer(
    len(de_tokenizer),
    len(en_tokenizer),
    de_tokenizer.pad_token_id,
    en_tokenizer.pad_token_id,
    forward_dim=2048,
    emb_dim=512,
    num_heads=8,
    num_layers=6,
    max_len=128,
    dropout_rate=0.1
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move model to device
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=en_tokenizer.pad_token_id)

for i, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    
    batch = {k: v.to(device) for k, v in batch.items()}
    
    # Add BOS token for the decoder input
    bos_tokens = torch.full((batch['decoder_input_ids'].shape[0], 1), 
                            en_tokenizer.bos_token_id, 
                            device=device)
    decoder_in = torch.cat([bos_tokens, batch['decoder_input_ids']], dim=1)
    
    # Select up to second-last token for teacher forcing
    decoder_in = decoder_in[:, :-1]

    # Forward pass - shape: batch_size x seq_length x vocab_size
    logits = model(batch["input_ids"], decoder_in)

    # Reshape for loss calculation. Loss_fn expects batch to be flattened
    batch_size, seq_len, vocab_size = logits.shape
    logits_reshaped = logits.contiguous().view(-1, vocab_size)
    labels_reshaped = batch['labels'].contiguous().view(-1)

    loss = loss_fn(logits_reshaped, labels_reshaped)
    loss.backward()
    optimizer.step()
    
    if i % 10 == 0:
        print(f"Batch {i}, Loss: {loss.item()}")

Batch 0, Loss: 10.600179672241211
Batch 10, Loss: 8.862262725830078
Batch 20, Loss: 7.864902973175049
Batch 30, Loss: 7.641119956970215
Batch 40, Loss: 6.921629428863525
Batch 50, Loss: 7.2100830078125
Batch 60, Loss: 6.524007320404053
Batch 70, Loss: 6.3765363693237305
Batch 80, Loss: 6.521219253540039
Batch 90, Loss: 6.109145164489746
Batch 100, Loss: 6.268571853637695
Batch 110, Loss: 5.916914939880371
Batch 120, Loss: 6.040268421173096
Batch 130, Loss: 6.002480506896973
Batch 140, Loss: 5.708379745483398
Batch 150, Loss: 5.76394510269165
Batch 160, Loss: 5.843137741088867
Batch 170, Loss: 6.137154579162598
Batch 180, Loss: 5.7472100257873535
Batch 190, Loss: 5.356274604797363
Batch 200, Loss: 5.383860111236572
Batch 210, Loss: 5.734013557434082
Batch 220, Loss: 5.699577808380127
Batch 230, Loss: 5.559024333953857
Batch 240, Loss: 5.112642765045166


KeyboardInterrupt: 

In [55]:
total_params = sum(p.numel() for p in model.parameters())

In [56]:
print(total_params)

91480892


In [None]:
from torchsummary import summary


In [64]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Print model architecture
print(model)

# Print total parameter count
total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params:,}")

# Print parameter count for each layer/module
for name, parameter in model.named_parameters():
    if parameter.requires_grad:
        print(f"{name}: {parameter.numel():,}")

Transformer(
  (encoder): Encoder(
    (embedding): Embedding(31102, 512)
    (pos_embedding): Embedding(128, 512)
    (dropout): Dropout(p=0.1, inplace=False)
    (transformer_blocks): ModuleList(
      (0-5): 6 x TransformerBlock(
        (norm_1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (norm_2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (mha): MultiHeadAttention(
          (Q): Linear(in_features=512, out_features=512, bias=True)
          (K): Linear(in_features=512, out_features=512, bias=True)
          (V): Linear(in_features=512, out_features=512, bias=True)
          (w0): Linear(in_features=512, out_features=512, bias=True)
        )
        (ffnn): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
    )
  )
  (decoder): Decoder(
    (embedding): E