# Семинар "Natural Language Processing. Часть 1"

На этом семинаре мы попробуем реализовать основные составляющие части архитектуры трансформер: токенайзер, SDPA, Encoder/Decoder Blocks; попробуем обучить получившуюся топологию на wikitext и посчитать Perplexity. Далее возьмем предобученный BERT и сравним качество.

In [None]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np
import torch
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

### Токенизация

- Byte-Pair Encoding (BPE)
Алгоритм обучается, начиная с базового словаря символов, и итеративно объединяет наиболее часто встречающиеся пары символов или токенов. Эффективно обрабатывает редкие и неизвестные слова. Используется в GPT-2, GPT-3, RoBERTa и LLaMA.

- WordPiece
Алгоритм обучения похож на BPE, но для выбора пар к слиянию использует не частоту, а следующую вероятностную оценку
$$score=\frac{\text{freq_of_pair}}{\text{freq_of_first_element}} \times \text{freq_of_second_element}$$
Кроме того, элементы объединяются не с начала слова, а с конца. WordPiece применяется в BERT, DistilBERT и MobileBERT.

- Unigram
В отличие от BPE и WordPiece, этот метод начинает обучение с большого словаря и постепенно удаляет из него токены, которые меньше всего уменьшают правдоподобие корпуса. Используется в T5 и ALBERT.

Рассмотрим BPE как наиболее популярный и простой для понимания алгоритм

1. Нормализация (опционально) и пре-токенизация
2. Инициализируем словарь всеми ASCII символами (либо всеми символами, встречающимися в корпусе)
3. Повторяем пока не достигли ограничения на размер словаря  
3.1 Вычисляем частоты всех пар идущих подряд токенов  
3.2 Назначаем новым токеном объединение двух существующих токенов, которое встречается чаще других пар в корпусе

In [None]:
class Tokenizer:
    def __init__(self) -> None:
        self.vocab = {} # {<token>: <id>}
        self.merges = {}
        self.inverse_vocab = {} # {<id>: <token>}
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""")
        self.special_tokens = {}
        self.unk_token = "<unk>" # unknown
        self.pad_token = "<pad>" # padding

    def train(
        self,
        text: str,
        vocab_size: int,
        verbose: bool = False
    ) -> Dict:
        self._add_special_tokens()

        words = self._pre_tokenize(text)

        word_freqs = defaultdict(int)
        for word in words:
            word_freqs[word] += 1

        alphabet = set()
        for word in word_freqs.keys():
            for char in word:
                alphabet.add(char)
        alphabet = sorted(list(alphabet))

        self.vocab = {**self.special_tokens}
        for idx, char in enumerate(alphabet, start=len(self.special_tokens)):
            self.vocab[char] = idx

        splits = {word: [c for c in word] for word in word_freqs.keys()}

        with tqdm(total=vocab_size) as pbar:
            pbar.update(len(self.vocab))
            while len(self.vocab) < vocab_size:
                pair_freqs = self._compute_pair_freqs(splits, word_freqs)

                if not pair_freqs:
                    break

                best_pair = max(pair_freqs, key=pair_freqs.get)

                self.merges[best_pair] = len(self.vocab)
                new_token = best_pair[0] + best_pair[1]
                self.vocab[new_token] = len(self.vocab)

                splits = self._merge_pair(best_pair, splits)

                if verbose:
                    print(f"Merged {best_pair} -> {new_token}, vocab size: {len(self.vocab)}")

                pbar.update(1)

        self._build_inverse_vocab()
        return self.vocab

    def _add_special_tokens(self) -> None:
        self.special_tokens = {
            self.unk_token: 0,
            self.pad_token: 1,
            "<bos>": 2, # begin of sentence
            "<eos>": 3  # end of sentence
        }

    def _pre_tokenize(self, text: str) -> List:
        words = re.findall(self.pat, text)
        return words

    def _compute_pair_freqs(self, splits, word_freqs) -> Dict:
        pair_freqs = defaultdict(int)
        for word, freq in word_freqs.items():
            split = splits[word]
            if len(split) < 2:
                continue
            for i in range(len(split) - 1):
                pair = (split[i], split[i + 1])
                pair_freqs[pair] += freq

        return pair_freqs

    def _merge_pair(self, pair: Dict, splits: Dict) -> Dict:
        new_splits = {}
        for word, split in splits.items():
            new_split = []
            i = 0
            while i < len(split):
                if i < len(split) - 1 and (split[i], split[i + 1]) == pair:
                    new_split.append(pair[0] + pair[1])
                    i += 2
                else:
                    new_split.append(split[i])
                    i += 1
            new_splits[word] = new_split
        return new_splits

    def _build_inverse_vocab(self) -> None:
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}

    def encode(self, text: str) -> List:
        words = self._pre_tokenize(text)
        encoded = []
        for word in words:
            tokens = [c for c in word]
            while len(tokens) > 1:
                pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
                merge_candidate = None
                for pair in pairs:
                    if pair in self.merges:
                        merge_candidate = pair
                        break

                if not merge_candidate:
                    break

                new_tokens = []
                i = 0
                while i < len(tokens):
                    if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == merge_candidate:
                        new_tokens.append(merge_candidate[0] + merge_candidate[1])
                        i += 2
                    else:
                        new_tokens.append(tokens[i])
                        i += 1
                tokens = new_tokens

            for token in tokens:
                if token in self.vocab:
                    encoded.append(self.vocab[token])
                else:
                    encoded.append(self.vocab[self.unk_token])

        return encoded

    def decode(self, token_ids):
        tokens = [self.inverse_vocab.get(token_id, self.unk_token) for token_id in token_ids]
        text = "".join(tokens)
        return text

In [None]:
def test_tokenizer_train():
    text = "hello world hello there hey"
    tokenizer = Tokenizer()
    vocab = tokenizer.train(text, vocab_size=20, verbose=True)
    assert len(vocab.keys()) > 0
    assert list(vocab.keys()) == [
        '<unk>', '<pad>', '<bos>', '<eos>', ' ', 'd', 'e', 'h', 'l', 'o', 'r',
        't', 'w', 'y', 'he', 'hel', 'hell', 'hello', ' w', ' wo'
    ]

def test_tokenizer_encode():
    text = "hello world hello there hey"
    tokenizer = Tokenizer()
    vocab = tokenizer.train(text, vocab_size=20, verbose=False)

    source_str = "hello world"
    encoded = tokenizer.encode(source_str)
    assert all([isinstance(token_id, int) for token_id in encoded])
    assert len(encoded) > 0

    decoded = tokenizer.decode(encoded)
    assert decoded == source_str

    assert tokenizer.unk_token in tokenizer.vocab

def test_tokenizer_all():
    test_tokenizer_train()
    test_tokenizer_encode()
    print("All tokenizer tests passed!")

In [None]:
test_tokenizer_all()

  0%|          | 0/20 [00:00<?, ?it/s]

Merged ('h', 'e') -> he, vocab size: 15
Merged ('he', 'l') -> hel, vocab size: 16
Merged ('hel', 'l') -> hell, vocab size: 17
Merged ('hell', 'o') -> hello, vocab size: 18
Merged (' ', 'w') ->  w, vocab size: 19
Merged (' w', 'o') ->  wo, vocab size: 20


  0%|          | 0/20 [00:00<?, ?it/s]

All tokenizer tests passed!


### Transformer

__Scaled Dot-Product Attention__
$$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt(d_k)})V$$
$$MultiHead(Q, K, V) = Concat(\text{head}_1, \text{head}_2, ..., \text{head}_h)W^O$$
$$where \; \text{head}_i = Attention(QW_i^Q, KW_i^K, VW_i^V)$$

__Cross-Attention__

![pic](https://media.geeksforgeeks.org/wp-content/uploads/20250319173029489747/cross_attention_.webp)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None, is_causal=False):
    d_k = Q.size(-1)

    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if is_causal:
        L, S = scores.size(-2), scores.size(-1)
        causal_mask = torch.ones(L, S, dtype=torch.bool, device=Q.device).tril(diagonal=0)
        scores = scores.masked_fill(~causal_mask, float('-inf'))

    if mask is not None:
        if mask.dtype == torch.bool:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        else:
            scores = scores + mask

    attention_weights = F.softmax(scores, dim=-1)

    output = torch.matmul(attention_weights, V)

    return output, attention_weights

def cross_attention(Q, K, V, mask=None):
    return scaled_dot_product_attention(Q, K, V, mask=mask)

$$MultiHead(Q, K, V) = Concat(\text{head}_1, \text{head}_2, ..., \text{head}_h)W^O$$
$$where \; \text{head}_i = Attention(QW_i^Q, KW_i^K, VW_i^V)$$

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None, is_causal=False):
        batch_size = Q.size(0)

        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        attn_output, attn_weights = scaled_dot_product_attention(
            Q, K, V, mask=mask, is_causal=is_causal
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        output = self.W_o(attn_output)

        return output, attn_weights

![pic](https://quantdare.com/wp-content/uploads/2021/11/transformer_arch.png)

In [None]:
class PositionWiseFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFFN(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, _ = self.self_attention(x, x, x, mask=mask)
        x = self.norm1(x + self.dropout(attn_output))

        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))

        return x

class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.cross_attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFFN(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None, is_causal=True):
        self_attn_output, _ = self.self_attention(
            x, x, x, mask=tgt_mask, is_causal=is_causal
        )
        x = self.norm1(x + self.dropout(self_attn_output))

        cross_attn_output, _ = self.cross_attention(
            x, encoder_output, encoder_output, mask=src_mask
        )
        x = self.norm2(x + self.dropout(cross_attn_output))

        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_output))

        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 num_encoder_blocks=6, num_decoder_blocks=6, d_ff=2048, max_seq_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = self._create_pos_encoding(max_seq_len, d_model)

        self.encoder = nn.ModuleList([
            EncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_blocks)
        ])
        self.decoder = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_blocks)
        ])

        self.output_layer = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def _create_pos_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, is_causal=True):
        src_embedded = self.src_embedding(src) * math.sqrt(self.d_model)
        src_embedded += self.pos_encoding[:, :src.size(1), :].to(src.device)
        src_embedded = self.dropout(src_embedded)

        encoder_output = src_embedded
        for encoder_layer in self.encoder:
            encoder_output = encoder_layer(encoder_output, src_mask)

        tgt_embedded = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_embedded += self.pos_encoding[:, :tgt.size(1), :].to(tgt.device)
        tgt_embedded = self.dropout(tgt_embedded)

        decoder_output = tgt_embedded
        for decoder_layer in self.decoder:
            decoder_output = decoder_layer(
                decoder_output, encoder_output, src_mask, tgt_mask, is_causal
            )

        output = self.output_layer(decoder_output)

        return output

In [None]:
def test_attention():
    batch_size, seq_len, d_model = 2, 10, 64
    num_heads = 8

    Q = torch.randn(batch_size, seq_len, d_model)
    K = torch.randn(batch_size, seq_len, d_model)
    V = torch.randn(batch_size, seq_len, d_model)

    output, weights = scaled_dot_product_attention(Q, K, V)
    assert output.shape == (batch_size, seq_len, d_model), "Attention output shape wrong"
    assert weights.shape == (batch_size, seq_len, seq_len), "Attention weights shape wrong"

    causal_output, causal_weights = scaled_dot_product_attention(Q, K, V, is_causal=True)
    assert not torch.allclose(output, causal_output), "Causal masking not working"

    mha = MultiHeadAttention(d_model, num_heads)
    mha_output, mha_weights = mha(Q, K, V)
    assert mha_output.shape == (batch_size, seq_len, d_model), "MHA output shape wrong"

    print("All attention tests passed!")

def test_transformer():
    src_vocab_size, tgt_vocab_size = 1000, 1000
    batch_size, src_len, tgt_len = 2, 10, 8

    transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8)

    src = torch.randint(0, src_vocab_size, (batch_size, src_len))
    tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))

    output = transformer(src, tgt)
    assert output.shape == (batch_size, tgt_len, tgt_vocab_size), "Transformer output shape wrong"

    print("All transformer tests passed!")

In [None]:
test_attention()
test_transformer()

All attention tests passed!
All transformer tests passed!


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from torch.optim import Adam
import math

class WikiTextDataset(Dataset):
    def __init__(self, tokenizer, split="train", max_samples=2500, max_length=512):
        self.dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=f"{split}[:{max_samples}]")
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.texts = [text for text in self.dataset['text'] if text.strip()]

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        source_text = text
        target_text = text

        source_ids = self.tokenizer.encode(source_text)[:self.max_length]
        target_ids = self.tokenizer.encode(target_text)[:self.max_length]

        if len(source_ids) < 2:
            source_ids = [self.tokenizer.vocab[self.tokenizer.pad_token]] * 2
        if len(target_ids) < 2:
            target_ids = [self.tokenizer.vocab[self.tokenizer.pad_token]] * 2

        return torch.tensor(source_ids), torch.tensor(target_ids)

def _pad_collate(batch, pad_token_id):
    src_batch, tgt_batch = zip(*batch)

    max_src_len = max(len(src) for src in src_batch)
    max_tgt_len = max(len(tgt) for tgt in tgt_batch)

    padded_src = torch.stack([
        torch.cat([src, torch.full((max_src_len - len(src),), pad_token_id, dtype=torch.long)])
        for src in src_batch
    ])

    padded_tgt = torch.stack([
        torch.cat([tgt, torch.full((max_tgt_len - len(tgt),), pad_token_id, dtype=torch.long)])
        for tgt in tgt_batch
    ])

    return padded_src, padded_tgt

$$Perplexity = exp(CrossEntropy)$$
Чем меньше - тем лучше

In [None]:
def calculate_perplexity(
    model: Transformer,
    dataset: Dataset,
    tokenizer: Tokenizer,
    device: str,
    max_length: int = 512
) -> float:
    model.eval()
    total_loss = 0
    total_tokens = 0

    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)

            seq_len = tgt.size(1)
            if seq_len > max_length:
                tgt = tgt[:, :max_length]
                seq_len = max_length

            if seq_len <= 1:
                continue

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            output = model(src, tgt_input)

            loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.vocab[tokenizer.pad_token], reduction='sum')
            loss = loss_fn(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))

            num_tokens = (tgt_output != tokenizer.vocab[tokenizer.pad_token]).sum().item()

            total_loss += loss.item()
            total_tokens += num_tokens

    if total_tokens == 0:
        return float('inf')

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)

    return perplexity

In [None]:
def train(
    epochs: int = 3,
    batch_size: int = 8,
    num_encoder_blocks: int = 4,
    num_decoder_blocks: int = 4,
    num_heads: int = 8,
    d_ff: int = 1024
) -> Tuple[Transformer, Tokenizer]:
    tokenizer = Tokenizer()

    print("Loading Salesforce/wikitext dataset for tokenizer training...")
    small_dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train[:1000]")

    all_text = " ".join([text for text in small_dataset['text'] if text.strip()])

    print("Training tokenizer on Salesforce/wikitext data...")
    vocab_size = 5000
    tokenizer.train(all_text, vocab_size=vocab_size)

    train_dataset = WikiTextDataset(tokenizer, "train", max_samples=2500)
    eval_dataset = WikiTextDataset(tokenizer, "validation", max_samples=500)

    transformer = Transformer(
        src_vocab_size=vocab_size,
        tgt_vocab_size=vocab_size,
        d_model=256,
        num_heads=num_heads,
        num_encoder_blocks=num_encoder_blocks,
        num_decoder_blocks=num_decoder_blocks,
        d_ff=d_ff,
        dropout=0.1
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transformer.to(device)

    optimizer = Adam(transformer.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.vocab[tokenizer.pad_token])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=lambda batch: _pad_collate(batch, tokenizer.vocab[tokenizer.pad_token]))

    for epoch in range(epochs):
        transformer.train()
        total_loss = 0

        for batch_idx, (src, tgt) in enumerate(train_loader):
            src, tgt = src.to(device), tgt.to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            optimizer.zero_grad()

            output = transformer(src, tgt_input)

            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}. Average Loss: {avg_loss:.4f}")

        perplexity = calculate_perplexity(transformer, eval_dataset, tokenizer, device)
        print(f"Epoch {epoch+1} Perplexity: {perplexity:.2f}")

    return transformer, tokenizer

In [None]:
transformer_model, bpe_tokenizer = train()

torch.save({
    'model_state_dict': transformer_model.state_dict(),
    'tokenizer_vocab': bpe_tokenizer.vocab,
    'tokenizer_merges': bpe_tokenizer.merges
}, 'transformer_wikitext.pth')


Loading Salesforce/wikitext dataset for tokenizer training...


README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Training tokenizer on Salesforce/wikitext data...


  0%|          | 0/5000 [00:00<?, ?it/s]

Epoch 1, Batch 0, Loss: 8.6789
Epoch 1, Batch 50, Loss: 6.6234
Epoch 1, Batch 100, Loss: 5.8077
Epoch 1, Batch 150, Loss: 5.6281
Epoch 1, Batch 200, Loss: 5.5021
Epoch 1. Average Loss: 6.2151
Epoch 1 Perplexity: 242.38
Epoch 2, Batch 0, Loss: 5.4134
Epoch 2, Batch 50, Loss: 5.2919
Epoch 2, Batch 100, Loss: 5.0738
Epoch 2, Batch 150, Loss: 5.0843
Epoch 2, Batch 200, Loss: 5.2055
Epoch 2. Average Loss: 5.1875
Epoch 2 Perplexity: 160.48
Epoch 3, Batch 0, Loss: 4.9049
Epoch 3, Batch 50, Loss: 4.7686
Epoch 3, Batch 100, Loss: 4.7907
Epoch 3, Batch 150, Loss: 4.6506
Epoch 3, Batch 200, Loss: 4.6257
Epoch 3. Average Loss: 4.8083
Epoch 3 Perplexity: 124.30


In [None]:
from transformers import BertForMaskedLM, BertTokenizer

class WikiTextDatasetForBERT(Dataset):
    def __init__(self, tokenizer, split="train", max_samples=2500, max_length=512):
        self.dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=f"{split}[:{max_samples}]")
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.texts = [text for text in self.dataset['text'] if text.strip()]

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze()
        }

model_name = "bert-base-uncased"
bert_tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model.to(device)
bert_model.eval()

eval_dataset = WikiTextDatasetForBERT(bert_tokenizer, "validation", max_samples=500)
dataloader = DataLoader(eval_dataset, batch_size=8, shuffle=False)
total_loss = 0
total_tokens = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = bert_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids
        )

        loss = outputs.loss
        total_loss += loss.item()

        batch_tokens = attention_mask.sum().item()
        total_tokens += batch_tokens

avg_loss = total_loss / len(eval_dataset)
perplexity = math.exp(avg_loss)

print(f"BERT Perplexity: {perplexity:.2f}")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BERT Perplexity: 8.88


In [None]:
print(transformer_model)

Transformer(
  (src_embedding): Embedding(5000, 256)
  (tgt_embedding): Embedding(5000, 256)
  (encoder): ModuleList(
    (0-3): 4 x EncoderBlock(
      (self_attention): MultiHeadAttention(
        (W_q): Linear(in_features=256, out_features=256, bias=True)
        (W_k): Linear(in_features=256, out_features=256, bias=True)
        (W_v): Linear(in_features=256, out_features=256, bias=True)
        (W_o): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): PositionWiseFFN(
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation): ReLU()
      )
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder): ModuleList(
    (0-3): 4 x DecoderBlock(
      (self_attent

In [None]:
print(bert_model)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi