In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch._dynamo as dynamo
import math
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding, AutoTokenizer, TrainerCallback, pipeline
from datasets import load_dataset, interleave_datasets
from datasets import Dataset as HFDataset
from torch.utils.data import Dataset, Subset
from torch.cuda.amp import autocast
import os
import random
import numpy as np
import re
import zlib
from typing import List, Dict
import wandb
import time
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from itertools import islice
import pickle
import json
import matplotlib.pyplot as plt

In [70]:
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.cuda.set_per_process_memory_fraction(0.95, device=0)
torch.autograd.set_detect_anomaly(True)
#os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
#torch._inductor.config.triton.unique_kernel_names = True
#torch._inductor.config.triton.cudagraphs = True

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x79af5954d2d0>

In [71]:
class RotaryPositionalEmbeddings(nn.Module):
    """
    Rotary Positional Embeddings (RoPE) module.
    Injects position info via complex rotations directly into attention queries/keys.
    Unlike classic sinusoidal embeddings that add position vectors,
    RoPE applies rotations preserving relative position relationships,
    enabling better extrapolation to longer sequences beyond training.
    """
    def __init__(self, dim: int, max_seq_len: int = 4096, base: int = 10000):
        super().__init__()
        
        # Compute inverse frequencies for each pair of dimensions
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) 
        # Positions from 0 to max_seq_len
        t = torch.arange(max_seq_len, dtype=torch.float32) 
        # Outer product gives [pos, dim/2] matrix of phase angles
        freqs = torch.outer(t, inv_freq) 
        # Convert phases to complex rotations: cos(θ) + i·sin(θ)
        self.register_buffer("freqs_cis", torch.polar(torch.ones_like(freqs), freqs)) 

    def forward(self, x: torch.Tensor, start_pos: int = 0):
        # x: [batch, seq_len, num_heads, head_dim]
        seq_len = x.shape[1]
        
        # Convert last dim into complex numbers: [..., d] → [..., d/2] as complex
        x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        
        # Select matching rotary angles for current sequence length and start_pos
        freqs_cis = self.freqs_cis[start_pos : start_pos + seq_len] # [seq_len, d/2]
        freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, d/2]

        # Apply rotation: complex multiplication performs the position encoding
        x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
        return x_out.type_as(x)

In [72]:
class SwiGLU(nn.Module):
    """
    SwiGLU feed-forward module.
    Uses gated linear units with SiLU activation, which improves over ReLU by:
    - Providing smooth, non-zero gradients everywhere for better training stability
    - Enabling multiplicative gating that enhances model capacity and expressiveness
    This leads to faster convergence and better generalization compared to classic ReLU-based FFNs.
    """
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256, bias: bool = False):
        super().__init__()
        # Adjust hidden_dim to be multiple of `multiple_of` for hardware efficiency
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        # Linear layers: input→hidden, input→hidden (gate), hidden→output
        self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
        self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
        self.w3 = nn.Linear(dim, hidden_dim, bias=bias)

    def forward(self, x):
        # Apply gated activation: w2( SiLU(w1(x)) * w3(x) )
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

In [73]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention (GQA) module with Rotary Positional Embeddings (RoPE) and KV caching.
    Uses multiple query heads grouped to share a smaller number of key and value heads.
    This reduces memory and compute cost while maintaining better quality than Multi-Query Attention (MQA).
    Each group of query heads attends to the same key/value head, enabling faster inference with improved flexibility.
    """
    def __init__(self, d_model: int, n_head: int, num_kv_heads: int, rope: RotaryPositionalEmbeddings | None, bias: bool = False):
        super().__init__()
        assert n_head % num_kv_heads == 0, "n_head must be divisible by num_kv_heads"
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        self.n_head = n_head
        self.kv_head = num_kv_heads
        self.d_model = d_model
        self.head_dim = d_model // n_head
        self.rope = rope

        # Query projection: one per head (shape: [d_model, n_head*head_dim])
        self.wq = nn.Linear(d_model, n_head * self.head_dim, bias=bias)
        
        # Single key and value projections shared across heads (shape: [d_model, head_dim])
        self.wk = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=bias)
        self.wv = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=bias)

        # Output projection to restore original dimension
        self.wo = nn.Linear(n_head * self.head_dim, d_model, bias=bias)

    def forward(
        self, 
        x_q: torch.Tensor, 
        x_kv: torch.Tensor | None = None, 
        mask: torch.Tensor | None = None, 
        kv_cache: tuple | None = None, 
        is_causal: bool = False
    ):
        # In self-attention, K and V are derived from the same input as Q. 
        is_self_attention = x_kv is None
        if is_self_attention:
            x_kv = x_q

        batch_size, seq_len_q, _ = x_q.shape
        _, seq_len_kv, _ = x_kv.shape
        
        # Compute Q, K, V projections
        xq, xk, xv = self.wq(x_q), self.wk(x_kv), self.wv(x_kv)
        
        # Reshape Q: [batch, seq_len, n_head * head_dim] → [batch, seq_len, n_head, head_dim]
        xq = xq.view(batch_size, seq_len_q, self.n_head, self.head_dim)
        # Reshape K, V: [batch, seq_len, head_dim] → [batch, seq_len, num_kv_heads, head_dim]
        xk = xk.view(batch_size, seq_len_kv, self.kv_head, self.head_dim)
        xv = xv.view(batch_size, seq_len_kv, self.kv_head, self.head_dim)
        
        # Apply Rotary Positional Embeddings to Q and K only if rope handed
        if self.rope is not None:
            start_pos = 0
            if kv_cache is not None and kv_cache[0] is not None:
                start_pos = kv_cache[0].shape[1]
            xq = self.rope(xq, start_pos)
            xk = self.rope(xk, start_pos)
            

        # Key-Value caching for efficient autoregressive decoding
        if is_self_attention and kv_cache is not None and kv_cache[0] is not None:
            cached_k, cached_v = kv_cache
            # Append new keys and values to cached tensors along sequence dimension
            xk = torch.cat([cached_k, xk], dim=1)
            xv = torch.cat([cached_v, xv], dim=1)
        
        # Update cache to return
        updated_kv_cache = (xk.clone().detach(), xv.clone().detach())

        # enable_gqa=True in scaled_dot_product_attention will do this:
        # Repeat K and V across all heads to match Q's head dimension
        # Shape: [batch, seq_len, num_kv_heads, head_dim] → [batch, seq_len, n_head, head_dim]
        #n_repeats = self.n_head // self.kv_head
        keys=xk#keys = xk.repeat_interleave(n_repeats, dim=2)
        values=xv#values = xv.repeat_interleave(n_repeats, dim=2)

        # Prepare tensors for scaled dot-product attention
        # Transpose to shape [batch, n_head, seq_len, head_dim]
        xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)

        # Perform attention with optional causal masking
        if is_causal:
            device = xq.device
            full_seq_len_kv = xk.shape[1]
            causal_mask = torch.tril(torch.ones((1, 1, seq_len_q, full_seq_len_kv), device=device, dtype=torch.bool), diagonal=0)
            if mask is not None:
                mask = mask & causal_mask
            else:
                mask = causal_mask

        if seq_len_q == 1:
            mask = None
        output = F.scaled_dot_product_attention(xq, keys, values, attn_mask=mask, enable_gqa=True)
        
        # Reshape back to [batch, seq_len, n_head * head_dim]
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1)

        # Final linear projection to output dimension
        return self.wo(output), updated_kv_cache

In [74]:
class ModernEncoderLayer(nn.Module):
    """
    Modern Transformer Encoder Layer employing:
    - Pre-LayerNorm for better training stability and gradient flow,
      replacing the original post-LN design.
    - Dropout strategically placed after attention and feed-forward blocks
      to prevent overfitting.
      """
    def __init__(self, d_model: int, n_head: int, num_kv_heads: int, rope: RotaryPositionalEmbeddings, dropout: float = 0.1):
        super().__init__()
        self.self_attn = GroupedQueryAttention(d_model, n_head, num_kv_heads, rope)
        self.feed_forward = SwiGLU(dim=d_model, hidden_dim=None)
        self.attention_norm = nn.LayerNorm(d_model)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.attn_dropout = nn.Dropout(dropout)
        self.ffn_dropout = nn.Dropout(dropout)
    def forward(self, x: torch.Tensor, mask: torch.Tensor | None):
        # Pre-LN
        x = self.attention_norm(x)
        h, _ = self.self_attn(x_q=x, mask=mask)
        x = x + self.attn_dropout(h)
        out = x + self.ffn_dropout(self.feed_forward(self.ffn_norm(x)))
        return out

In [75]:
class ModernDecoderLayer(nn.Module):
    """
    Modern Transformer Decoder Layer combining:
    - Pre-LayerNorm applied before each sub-layer for stable and efficient training.
    - Residual connections and dropout after each sub-layer for regularization.
    """
    def __init__(self, d_model: int, n_head: int, num_kv_heads: int, rope: RotaryPositionalEmbeddings, dropout: float = 0.1):
        super().__init__()
        #self.self_attn = nn.MultiheadAttention(d_model, n_head, batch_first=True, bias=False)
        #self.cross_attn = nn.MultiheadAttention(d_model, n_head, batch_first=True, bias=False)
        self.self_attn = GroupedQueryAttention(d_model, n_head, num_kv_heads, rope=rope)
        self.cross_attn = GroupedQueryAttention(d_model, n_head, num_kv_heads, rope=None)
        self.feed_forward = SwiGLU(dim=d_model, hidden_dim=None)
        
        self.sa_norm = nn.LayerNorm(d_model)
        self.ca_norm = nn.LayerNorm(d_model)
        self.ffn_norm = nn.LayerNorm(d_model)

        self.sa_dropout = nn.Dropout(dropout)
        self.ca_dropout = nn.Dropout(dropout)
        self.ffn_dropout = nn.Dropout(dropout)

    def forward(
        self, 
        x: torch.Tensor, 
        memory: torch.Tensor,
        src_mask: torch.Tensor | None,
        tgt_mask: torch.Tensor | None,
        self_attn_kv_cache: tuple | None = None,
    ):
        # Self-Attention: Pre-LN, causal mask, KV cache, residual + dropout
        x_sa = self.sa_norm(x)
        h, updated_sa_kv = self.self_attn(x_q=x_sa, mask=tgt_mask, kv_cache=self_attn_kv_cache, is_causal=True)
        x = x + self.sa_dropout(h)
        
        # Cross-Attention: Pre-LN, attends encoder output with padding mask, residual + dropout
        if self.cross_attn is not None:
            x_ca = self.ca_norm(x)
            h, _ = self.cross_attn(x_q=x_ca, x_kv=memory, mask=src_mask)
            x = x + self.ca_dropout(h)
        
        # Feed-Forward: Pre-LN, SwiGLU, residual + dropout
        out = x + self.ffn_dropout(self.feed_forward(self.ffn_norm(x)))
        
        return out, updated_sa_kv

In [76]:
class ModernEncoderDecoder(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_head: int, num_kv_heads: int, num_encoder_layers: int, num_decoder_layers: int, dropout: float = 0.1, max_seq_len: int = 1024, label_smothing=0.08):
        super().__init__()
        self.d_model = d_model
        self.config = lambda: None
        self.config.vocab_size = vocab_size

        self.embedding = nn.Embedding(vocab_size, d_model)
        rope = RotaryPositionalEmbeddings(d_model // n_head, max_seq_len=max_seq_len)

        self.encoder_layers = nn.ModuleList([
            ModernEncoderLayer(d_model, n_head, num_kv_heads, rope, dropout) for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            ModernDecoderLayer(d_model, n_head, num_kv_heads, rope, dropout) for _ in range(num_decoder_layers)
        ])
        self.embedding_dropout = nn.Dropout(dropout)
        self.encoder_norm = nn.LayerNorm(d_model, eps=5e-5)
        self.decoder_norm = nn.LayerNorm(d_model, eps=5e-5)
        self.fc_out = nn.Linear(d_model, vocab_size, bias=False)

        self.apply(self._init_weights)
        self.apply(self._apply_eps)

        self.label_smoothing=label_smothing

    def forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels=None):
        # Encoder
        h = self.embedding(input_ids) 
        h = self.embedding_dropout(h)
        
        # Padding mask
        # [batch_size, seq_len]
        src_mask = (attention_mask == 1)#.unsqueeze(1).unsqueeze(2)
        tgt_mask = (decoder_attention_mask == 1)#.unsqueeze(1).unsqueeze(2)

        # [B, S] -> [B, 1, S] | [B, S, 1] -> [B, S, S]
        src_mask = src_mask[:, None, :] | src_mask[:, :, None]
        tgt_mask = tgt_mask[:, None, :] | tgt_mask[:, :, None]

        # [B, S, S] -> [B, 1, S, S]
        src_mask = src_mask.unsqueeze(1)
        tgt_mask = tgt_mask.unsqueeze(1)
        
        for layer in self.encoder_layers:
            h = layer(h, src_mask)
        h = self.encoder_norm(h)
        memory = h

        # Decoder
        h = self.embedding(decoder_input_ids)
        h = self.embedding_dropout(h)
        for layer in self.decoder_layers:
            h = layer(h, memory, src_mask=src_mask, tgt_mask=tgt_mask)[0]
        h = self.decoder_norm(h)
        
        logits = self.fc_out(h)
        
        loss = None
        if labels is not None:
            logits_flat = logits.view(-1, self.config.vocab_size)
            labels_flat = labels.view(-1)
            
            loss = F.cross_entropy(
                logits_flat, 
                labels_flat, 
                ignore_index=0, 
                label_smoothing=self.label_smoothing
            )
        
        return loss, logits

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
            # Attenuating init for output layers
            # Key technique from GPT-2 paper for residual connections stabilization
            if isinstance(module, (GroupedQueryAttention, SwiGLU)):
                 # Finding output projections: wo in GQA and w2 in SwiGLU
                if hasattr(module, 'wo'):
                     torch.nn.init.normal_(module.wo.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_decoder_layers))
                if hasattr(module, 'w2'):
                     torch.nn.init.normal_(module.w2.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_decoder_layers))
    
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
                
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
        elif isinstance(module, nn.LayerNorm):
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def _apply_eps(self, module):
        if isinstance(module, nn.LayerNorm):
            module.eps = 1e-5

    @torch.no_grad()
    def generate(
        self, 
        input_ids, 
        decoder_input_ids=None,
        max_new_tokens=50, 
        eos_token_id=4,
        temperature=0.8,
        top_k=40,
        top_p=0.9,
        do_sample=True
    ):
        """
        Generates a sequence using autoregressive decoding with KV caching.
        """
        self.eval()
        device = input_ids.device
    
        # Encoder pass (once)
        h = self.embedding(input_ids)
            
        for layer in self.encoder_layers:
            h = layer(h, mask=None)
        memory = self.encoder_norm(h)
    
        # Prepare for generation
        kv_cache = [(None, None) for _ in range(len(self.decoder_layers))]

        if decoder_input_ids is None:
            current_tokens = input_ids[:, :1]
        else:
            current_tokens = decoder_input_ids.to(device)

        # "Warm-up" the KV cache if a prompt is provided
        if current_tokens.size(1) > 1:
            prompt_tokens = current_tokens[:, :-1]
            h = self.embedding(prompt_tokens)
            prompt_mask = torch.zeros(prompt_tokens.shape, dtype=torch.bool, device=device)
            for i, layer in enumerate(self.decoder_layers):
                h, updated_sa_kv = layer(h, memory, src_mask=None, tgt_mask=None, self_attn_kv_cache=None)
                kv_cache[i] = updated_sa_kv
            
            next_token_to_process = current_tokens[:, -1].unsqueeze(-1)
        else:
            next_token_to_process = current_tokens.clone()
        
        generated_ids = [current_tokens]
        # Autoregressive generation loop
        for _ in range(max_new_tokens):
            h = self.embedding(next_token_to_process)

            next_kv_cache = []
            for i, layer in enumerate(self.decoder_layers):
                sa_kv = kv_cache[i]
                h, updated_sa_kv = layer(h, memory, src_mask=None, tgt_mask=None, self_attn_kv_cache=sa_kv)
                next_kv_cache.append(updated_sa_kv)
            kv_cache = next_kv_cache
            
            h = self.decoder_norm(h)
            logits = self.fc_out(h) # shape [b, 1, vocab_size]
            
            # Sample the next token
            next_token_logits = logits[:, -1, :]
    
            if do_sample:
                next_token_logits = next_token_logits / temperature
                if top_k > 0:
                    top_k_values, _ = torch.topk(next_token_logits, top_k)
                    k_th_value = top_k_values[:, -1].unsqueeze(-1)
                    indices_to_remove = next_token_logits < k_th_value
                    next_token_logits[indices_to_remove] = -float('Inf')
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    next_token_logits[indices_to_remove] = -float('inf')
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
    
            generated_ids.append(next_token)
            next_token_to_process = next_token.clone()
    
            if (next_token == eos_token_id).all():
                break
        
        return torch.cat(generated_ids, dim=1)

In [77]:
tokenizer = Tokenizer.from_file("tokenizer_clean.json")
PAD_TOKEN_ID = 0
UNK_TOKEN_ID = 1
EN_TOKEN_ID = 2
DE_TOKEN_ID = 3
EOS_TOKEN_ID = 4
MASK_TOKEN_ID = 5

BATCH_SIZE = 26

In [78]:
model_config = {
    "vocab_size": 30000,
    "d_model": 640,
    "n_head": 8,
    "num_kv_heads": 2,
    "num_encoder_layers": 10,
    "num_decoder_layers": 12,
}
model = ModernEncoderDecoder(**model_config)
print(f"Model initialized. Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model initialized. Number of parameters: 148,984,320


In [79]:
# Loading en and de versions of oscar
en = load_dataset(
    'oscar',
    name='unshuffled_deduplicated_en',
    split='train',
    streaming=True
)

de = load_dataset(
    'oscar',
    name='unshuffled_deduplicated_de',
    split='train',
    streaming=True
)

# Labeling data by language
en = en.map(lambda ex: {"text": ex["text"], "lang": "en"})
de = de.map(lambda ex: {"text": ex["text"], "lang": "de"})

# Shuffling data
buffer_size = 10000 
seed = 42
en = en.shuffle(seed=seed, buffer_size=buffer_size)
de = de.shuffle(seed=seed, buffer_size=buffer_size)

# interleaving data into one dataset
streaming_dataset = interleave_datasets(
    [en, de],
    probabilities=[0.5, 0.5],
    stopping_strategy="first_exhausted",
    seed=seed
)

# stop words and paterns
GENERAL_BAD_PATTERNS = re.compile(
    r'''
    \b(
        casino|gambling|poker|betting|slots?|roulette|blackjack|baccarat|craps|freespins|bonus|jackpot|wager|no deposit|ohne einzahlung|kostenlos spielen|echtes geld|spielautomaten|spielhalle|spielbank|willkommensbonus|startguthaben|casinospiele|
        porn|porno|escort|erotic|hookup|onlyfans|nudes?|camgirls?|sexkontakte|erotik|sexchat|live sex|stripchat|webcamsex|geschlechtsverkehr|selbstbefriedigung|masturbation|pornos|pornhub|xvideos|xnxx|vibrators?|dicks?|cums?|
        fast cash|bad credits?|zinsfrei|geld leihen|kredit aufnehmen|ratenzahlung|schnellkredit|binary options?|payday loans?|payday advance|cash advance|short-term loans?|no credit check|guaranteed loan|Kurzzeitkredite?|Minikredite?|Sofortkredite?|Kredit ohne Schufa|schnelles Geld|Geld sofort|
        tinder|badoo|parship|elitepartner|lovoo|flirt|verlieben|
        test answers?|cheat sheet|homework help|buy answers?|buy exam|abitur lösung|prüfung antworten|examen lösung|
        bitcoin|ethereum|blockchain|nft|ico|airdrop|pump and dump|binance|coinbase|kraken|crypto trading|krypto|kryptowährung|
        privacy policy|terms of use|terms and conditions|all rights reserved|copyright|impressum|datenschutz|nutzungsbedingungen|alle rechte vorbehalten|cookie policy|agb|rechtliche hinweise|haftungsausschluss|
        viagra|levitra|cialis|penis|enlargement|erection|erektionsstörung|potenzmittel|libido|sexualstörung|
        weight loss|fat burning|diet pills|appetite suppressant|abnehmen|diätpillen|fettverbrennung|schnell abnehmen|
        make money online|side hustle|get rich quick|passives einkommen|geld verdienen|heimarbeit|schnell reich werden|
        click here|buy now|order now|free trial|limited offer|jetzt kaufen|hier klicken|jetzt abonnieren|kostenlos testen|nur heute
    )\b
    |
    -{3,}|={3,}|\*{3,}|
    (?:(?:\w+\s*,\s*){10,}\w+)
    ''',
    re.IGNORECASE | re.VERBOSE
)

# stop header words
BOILERPLATE_HEADER_PATTERNS = re.compile(
    r'^(?:\s*)'
    r'(you are not logged in|you do not have permission|access this page|'
    r'terms of use|privacy policy|cookies?|all rights reserved|'
    r'sign in|log in|register|create an account|register|'
    r'skip to content|main navigation|toggle navigation|'
    r'select language|choose your region|'
    r'cheap|discounts?|easy|billige|rabatte|einfach|'
    r'sie sind nicht angemeldet|kein zugriff|'
    r'zur hauptnavigation|navigation überspringen|'
    r'anmelden|einloggen|registrieren|konto erstellen|'
    r'nutzungsbedingungen|datenschutz|cookies?|'
    r'alle rechte vorbehalten|sprache auswählen|region wählen|'
    r'günstig|rabatt|einfach|schnell|kostenlos|angebot|aktionen|'
    r'jetzt anmelden|mehr erfahren|hier klicken|'
    r'help|hilfe|assist|unterstützen|call|anrufen|send|senden|respond|antworten|fill|ausfüllen)',

    re.IGNORECASE | re.MULTILINE
)

# One of the many tested filters. No useful patterns were found.
def entropy(text):
    freqs = {}
    for char in text:
        freqs[char] = freqs.get(char, 0) + 1

    total = sum(freqs.values())
    probs = [count / total for count in freqs.values()]
    return -sum(p * math.log2(p) for p in probs if p > 0)


def filter_texts(example: dict,
                 min_num_of_words = 100,
                 max_digit_ratio=0.18, 
                 min_alpha_word_ratio=0.75, 
                 max_symbol_ratio=0.1, 
                 header_check_length=200, 
                 allowed_uppercase_ratio=0.07, 
                 logging=False) -> bool:
    text = example["text"]
    lang = example["lang"]
    
    # Base len filter
    if not text or len(text) < 384:
        return False

    # Stop patterns filter
    if GENERAL_BAD_PATTERNS.search(text):
        if logging:
            print("general bad pattern")
        return False

    # Header stop words filter
    text_header = text[:header_check_length]
    if BOILERPLATE_HEADER_PATTERNS.search(text_header):
        if logging:
            print("header bad pattern")
        return False

    # Statistical filters:
    num_digits = 0
    num_alpha_words = 0
    words = text.lower().split()
    num_words = len(words)
    
    # Filter by number of total words
    if num_words < min_num_of_words:
        if logging:
            print("num words")
        return False

    # Filter ttr
    cleaned_words = [word.strip(".,!?;:`'\"") for word in words]
    ttr = len(cleaned_words) / num_words
    if num_words < 500:
        ttr_threshold = 0.4
    elif num_words < 2000:
        ttr_threshold = 0.35
    else:
        ttr_threshold = 0.3
        
    if ttr < ttr_threshold:
        if logging:
            print("unique words")
        return False
        
    # Filter by numerical digits and alpha words
    for word in words:
        num_digits += sum(c.isdigit() for c in word)
        if word.replace("`", "").replace("'", "").isalpha():
            num_alpha_words += 1

    if (num_digits / len(text)) > max_digit_ratio:
        if logging:
            print("digit words")
        return False

    if (num_alpha_words / num_words) < min_alpha_word_ratio:
        if logging:
            print("alpha ratio")
        return False

    # Filter by mean word lean
    mean_word_len = sum(len(w) for w in words) / num_words
    if not (3 < mean_word_len < (15 if lang == 'de' else 12)):
        if logging:
            print("word len")
        return False

    # Filter by uppercase ratio
    uppercase_chars_ratio = sum(1 for c in text if c.isupper()) / len(text)
    if uppercase_chars_ratio > allowed_uppercase_ratio:
        if logging:
            print("uppercase ratio")
        return False
    return True

def tokenizer_test(batch, max_len=320, min_len=256, max_avg_token_id=7000, min_avg_token_len=3.2, max_avg_token_len=5, max_unk_count=3):
    # Tokenize and filter
    texts = []
    # Text standardization
    replacements = {
        "\n": "[NL]",
        "“": '"',
        "”": '"',
        "„": '"',
        "’": "'",
        "—": "-",
        "…": "...",
        "`": "'",
        "''": '"',
        "$": "dollars",
        "€": "euros",
        "½": "1/2"
    }
    for i in range(len(batch["text"])):
        text = batch["text"][i]
        for old, new in replacements.items():
            text = text.replace(old, new)
        texts.append(text)
        
    # Get encodings
    encodings = tokenizer.encode_batch(texts, add_special_tokens=False)
    batch_test = [] 
    batch_token_ids = [None]*len(batch["text"])
    
    for i, enc in enumerate(encodings):
        token_ids = enc.ids
        max_tokens_per_text = max_len - 1
        # Trim text by max_len and filter by min_len
        if len(token_ids) > max_tokens_per_text:
            token_ids = token_ids[:max_tokens_per_text]
        if len(token_ids) < min_len:
            batch_test.append(False)
            continue

        # Filter by unc tokens
        unk_count = sum(1 for t in token_ids if t == UNK_TOKEN_ID)
        if unk_count >= max_unk_count:
            batch_test.append(False)
            continue 

        # Filter by average token id
        avg_token_id = np.mean(token_ids)
        if avg_token_id > max_avg_token_id:
            batch_test.append(False)
            continue 

        # Filter by strange tokens
        tokens_as_strings = enc.tokens
        if not tokens_as_strings:
            batch_test.append(False)
            continue 

        # Filter by average token len
        avg_token_len = sum(len(t) for t in tokens_as_strings) / len(tokens_as_strings)
        if max_avg_token_len < avg_token_len < min_avg_token_len:
            batch_test.append(False)
            continue
        # Label text as correct and add batch id
        batch_test.append(True)
        batch_token_ids[i] = token_ids
        
    return {
        "tokenizer_test": batch_test,
        "token_ids": batch_token_ids
    }

def filter_by_tokenizer_test(example):
    # Separated from tokenizer_test because it allows tokenizing texts in batches and returning token_ids.
    return example["tokenizer_test"]

streaming_dataset = streaming_dataset.filter(filter_texts)
streaming_dataset = streaming_dataset.map(tokenizer_test, batched=True, batch_size=BATCH_SIZE*8,)
streaming_dataset = streaming_dataset.filter(filter_by_tokenizer_test)

In [42]:
# Span parametrs
MEAN_SPAN_LENGTH = 12.0
MIN_SPAN_LENGTH = 2
MAX_SPAN_LENGTH = 40
p = 1.0 / MEAN_SPAN_LENGTH
def generate_span_length():
    # Geometric distribution of spans
    length = np.random.geometric(p=p)
    length = max(length, MIN_SPAN_LENGTH)
    length = min(length, MAX_SPAN_LENGTH)
    return length

class DataCollator:
    def __init__(self, mask_prob, copy_prob, en_token_id=2, de_token_id=3, eos_token_id=4, pad_token_id=0, mask_token_id=5,):
        self.mask_prob = mask_prob
        self.copy_prob = copy_prob # tested feature

        self.EN_TOKEN_ID = en_token_id
        self.DE_TOKEN_ID = de_token_id
        self.EOS_TOKEN_ID = eos_token_id
        self.PAD_TOKEN_ID = pad_token_id
        self.MASK_TOKEN_ID = mask_token_id

        self.all_ids = set()

    def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
        batch_input_ids = []
        batch_attention_mask = []
        batch_labels = []
        batch_decoder_input_ids = []
        batch_decoder_attention_mask = []

        # copy_prob may stabilize training process, but it didn`t need as practice had shown
        actual_mask_prob = 0.0 if random.random() <= self.copy_prob else self.mask_prob
        # Dynamic batch size can speed up training if none of the batch items len == max_len
        batch_len = 0
        for item in batch:
            if len(item["token_ids"]) > batch_len:
                batch_len = len(item["token_ids"])

        #print(f'Batch len: {batch_len}. Ids: {item["token_ids"][0:5]}...')
        for item in batch:
            token_ids = item["token_ids"]
            lang = item["lang"]
            
            # process token ids into encoder input, decoder input and output.
            lang_token = self.EN_TOKEN_ID if lang == "en" else self.DE_TOKEN_ID

            input_ids = [lang_token] + token_ids + [self.EOS_TOKEN_ID]
            decoder_ids = [lang_token] + token_ids
            labels = token_ids + [self.EOS_TOKEN_ID]
            
            padding_length = batch_len - len(token_ids)
            input_ids.extend([self.PAD_TOKEN_ID] * (padding_length-1))
            decoder_ids.extend([self.PAD_TOKEN_ID] * padding_length)
            labels.extend([self.PAD_TOKEN_ID] * padding_length)
            
            decoder_attention_mask = [1 if tid != self.PAD_TOKEN_ID else 0 for tid in decoder_ids]
            attention_mask = [1 if tid != self.PAD_TOKEN_ID else 0 for tid in input_ids]

            # Add processed items to batch
            batch_input_ids.append(input_ids)
            batch_attention_mask.append(attention_mask)
            batch_labels.append(labels)
            batch_decoder_input_ids.append(decoder_ids)
            batch_decoder_attention_mask.append(decoder_attention_mask)

        # Maksing spans
        input_ids_np = np.array(batch_input_ids, dtype=np.int32)
        rand = np.random.rand(*input_ids_np.shape)
        can_be_masked = (input_ids_np != self.PAD_TOKEN_ID) & \
                        (input_ids_np != self.EN_TOKEN_ID) & \
                        (input_ids_np != self.DE_TOKEN_ID) & \
                        (input_ids_np != self.EOS_TOKEN_ID)

        mask_selection = (rand < actual_mask_prob)
        force = 0
        for i in range(len(mask_selection)):
            for j in range(len(mask_selection[i])):
                if force>0:
                    mask_selection[i][j]=True
                    force-=1
                else:
                    if mask_selection[i][j]==True:
                        force = generate_span_length()
                
        mask_selection = mask_selection & can_be_masked
        mask_sum = sum([1 if i else 0 for i in mask_selection[0]])
        #print(f"tok_len = {batch_len}, mask_sum = {mask_sum}, mask% = {1/(batch_len/mask_sum)}")
        # Unite several nearby mask tokens ([1], [MASK], [MASK], [MASK], [1] -> [1], [MASK], [1])
        input_ids_np[mask_selection] = self.MASK_TOKEN_ID
        batch_input_ids = input_ids_np.tolist()
        fixed_batch_input_ids = []
        fixed_batch_attention_mask = []
        for ids, mask in zip(batch_input_ids, batch_attention_mask):
            fixed_ids = []
            pads = []
            fixed_mask = []
            span_sterted = False
            for i in ids:
                if span_sterted:
                    if i != self.MASK_TOKEN_ID:
                        span_sterted = False
                        fixed_ids.append(i)
                    else:
                        pads.append(self.PAD_TOKEN_ID)
                elif i == self.MASK_TOKEN_ID:
                    fixed_ids.append(i)
                    span_sterted = True
                else:
                    fixed_ids.append(i)
            fixed_mask = mask[len(mask)-len(pads):] = [0]*len(pads)
            fixed_ids = fixed_ids + pads
            fixed_batch_input_ids.append(fixed_ids)
            fixed_batch_attention_mask.append(fixed_mask)
                
        return {
            "input_ids": torch.tensor(fixed_batch_input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
            "decoder_input_ids": torch.tensor(batch_decoder_input_ids, dtype=torch.long),
            "decoder_attention_mask": torch.tensor(batch_decoder_attention_mask, dtype=torch.long),
            "labels": torch.tensor(batch_labels, dtype=torch.long),
        }

In [45]:
# Realization of dynamic masking. Needed to encourage decoder.
def step_scheduler(start_pos, end_pos, max_steps):
    x = np.linspace(0, np.pi, max_steps*2)
    cosine_decay = 0.5 * (np.cos(x) + 1)
    cosine_curve = end_pos + (start_pos - end_pos) * cosine_decay[max_steps:]*2
    return cosine_curve
    
class DataCollatorDynamic:
    def __init__(self, start_prob, end_prob, steps, en_token_id=2, de_token_id=3, eos_token_id=4, pad_token_id=0, mask_token_id=5,):
        self.schedule = step_scheduler(start_prob, end_prob, steps)
        
        self.EN_TOKEN_ID = en_token_id
        self.DE_TOKEN_ID = de_token_id
        self.EOS_TOKEN_ID = eos_token_id
        self.PAD_TOKEN_ID = pad_token_id
        self.MASK_TOKEN_ID = mask_token_id

        self.step = 0

        self.all_ids = set()

    def __call__(self, batch: List[Dict], step: int | None = None) -> Dict[str, torch.Tensor]:
        batch_input_ids = []
        batch_attention_mask = []
        batch_labels = []
        batch_decoder_input_ids = []
        batch_decoder_attention_mask = []

        # Identify current step and mask_prob
        try:
            if not step:
                step = self.step
                self.step += 1
            actual_mask_prob = self.schedule[step]
        except IndexError:
            actual_mask_prob = self.schedule[-1]

        #if step % 36 == 0:
            #print(f"step: {step}, actual_mask_prob: {actual_mask_prob}")
            
        # Dynamic batch_len
        batch_len = 0
        for item in batch:
            if len(item["token_ids"]) > batch_len:
                batch_len = len(item["token_ids"])

        #print(f'Batch len: {batch_len}. Ids: {item["token_ids"][0:5]}...')
        for item in batch:
            token_ids = item["token_ids"]
            lang = item["lang"]
            
            # process token ids into encoder input, decoder input and output.
            lang_token = self.EN_TOKEN_ID if lang == "en" else self.DE_TOKEN_ID

            input_ids = [lang_token] + token_ids + [self.EOS_TOKEN_ID]
            decoder_ids = [lang_token] + token_ids
            labels = token_ids + [self.EOS_TOKEN_ID]
            
            padding_length = batch_len - len(token_ids)
            input_ids.extend([self.PAD_TOKEN_ID] * (padding_length-1))
            decoder_ids.extend([self.PAD_TOKEN_ID] * padding_length)
            labels.extend([self.PAD_TOKEN_ID] * padding_length)
            
            decoder_attention_mask = [1 if tid != self.PAD_TOKEN_ID else 0 for tid in decoder_ids]
            attention_mask = [1 if tid != self.PAD_TOKEN_ID else 0 for tid in input_ids]

            # Add processed items to batch
            batch_input_ids.append(input_ids)
            batch_attention_mask.append(attention_mask)
            batch_labels.append(labels)
            batch_decoder_input_ids.append(decoder_ids)
            batch_decoder_attention_mask.append(decoder_attention_mask)

        # Maksing spans
        input_ids_np = np.array(batch_input_ids, dtype=np.int32)
        rand = np.random.rand(*input_ids_np.shape)
        can_be_masked = (input_ids_np != self.PAD_TOKEN_ID) & \
                        (input_ids_np != self.EN_TOKEN_ID) & \
                        (input_ids_np != self.DE_TOKEN_ID) & \
                        (input_ids_np != self.EOS_TOKEN_ID)

        mask_selection = (rand < actual_mask_prob)
        force = 0
        for i in range(len(mask_selection)):
            for j in range(len(mask_selection[i])):
                if force>0:
                    mask_selection[i][j]=True
                    force-=1
                else:
                    if mask_selection[i][j]==True:
                        force = generate_span_length()
                
        mask_selection = mask_selection & can_be_masked
        mask_sum = sum([1 if i else 0 for i in mask_selection[0]])
        #print(f"tok_len = {batch_len}, mask_sum = {mask_sum}, mask% = {1/(batch_len/mask_sum)}")
        
        # Unite several nearby mask tokens ([1], [MASK], [MASK], [MASK], [1] -> [1], [MASK], [1])
        input_ids_np[mask_selection] = self.MASK_TOKEN_ID
        batch_input_ids = input_ids_np.tolist()
        fixed_batch_input_ids = []
        fixed_batch_attention_mask = []
        for ids, mask in zip(batch_input_ids, batch_attention_mask):
            fixed_ids = []
            pads = []
            fixed_mask = []
            span_sterted = False
            for i in ids:
                if span_sterted:
                    if i != self.MASK_TOKEN_ID:
                        span_sterted = False
                        fixed_ids.append(i)
                    else:
                        pads.append(self.PAD_TOKEN_ID)
                elif i == self.MASK_TOKEN_ID:
                    fixed_ids.append(i)
                    span_sterted = True
                else:
                    fixed_ids.append(i)
            fixed_mask = mask[len(mask)-len(pads):] = [0]*len(pads)
            fixed_ids = fixed_ids + pads
            fixed_batch_input_ids.append(fixed_ids)
            fixed_batch_attention_mask.append(fixed_mask)
                
        return {
            "input_ids": torch.tensor(fixed_batch_input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
            "decoder_input_ids": torch.tensor(batch_decoder_input_ids, dtype=torch.long),
            "decoder_attention_mask": torch.tensor(batch_decoder_attention_mask, dtype=torch.long),
            "labels": torch.tensor(batch_labels, dtype=torch.long),
        }

In [46]:
data_collator = DataCollator(mask_prob=0.023,#0.023, 
                             copy_prob=0)

data_collator_dynamic = DataCollatorDynamic(
                             start_prob=0.08,
                             end_prob=0.023,#0.023,
                             steps=6000)

In [38]:
# Usefull logs
class AdvancedDiagnosticsCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        model = kwargs.get("model")
        tokenizer = kwargs.get("tokenizer")
        inputs = kwargs.get("inputs")
        if model is None or wandb.run is None:
            return

        step = state.global_step

        # Loss & perplexity
        if logs and "loss" in logs:
            loss_val = logs["loss"]
            perplexity = torch.exp(torch.tensor(loss_val)).item()
            run.log({
                "loss": loss_val,
                "perplexity": perplexity,
            }, step=step)

        # Learning rate
        try:
            lr = kwargs["optimizer"].param_groups[0]["lr"]
            run.log({"learning_rate": lr}, step=step)
        except Exception:
            pass

        # Weight stats
        weights = model.fc_out.weight.detach().cpu().float()
        run.log({
            "weights/fc_out_mean": weights.mean().item(),
            "weights/fc_out_std": weights.std().item(),
            "weights/fc_out_hist": wandb.Histogram(weights.numpy())
        }, step=step)

        # Sparsity
        zero_ratio = (weights.abs() < 1e-5).float().mean().item()
        run.log({"weights/fc_out_sparsity": zero_ratio}, step=step)
        
diag_callback = AdvancedDiagnosticsCallback()

In [16]:
training_args = TrainingArguments(
    output_dir="./results",
    max_steps=3500,
    #num_train_epochs=1000,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_steps=300,
    save_strategy="steps",
    save_steps=300,
    save_total_limit=8,
    logging_strategy="steps",
    logging_steps=50,
    report_to="none",
    remove_unused_columns=False,
    bf16=True,
    fp16=False,
    torch_compile=True,
    torch_compile_backend="aot_eager",
    save_safetensors = False,
    dataloader_num_workers=2,
    weight_decay=0.01,
    max_grad_norm=1.0,
)

In [39]:
# I was playing with encoder during experiments
for layer in model.encoder_layers:
    for param in layer.parameters():
        param.requires_grad = True

In [22]:
run = wandb.init(project="EDT-lm", name="Pretrain")
model.train()
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=streaming_dataset,
    data_collator=data_collator_dynamic,
    callbacks=[diag_callback],
)

In [66]:
history = trainer.train()#resume_from_checkpoint="results/checkpoint-900")

In [46]:
torch.save(model.state_dict(), "model_weights.pth")

In [18]:
state_dict = torch.load("model_weights.pth", map_location='cuda')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [28]:
state_dict = torch.load("checkpoints/checkpoint-3315-deen-pretrain/pytorch_model.bin", map_location='cuda')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [14]:
state_dict = torch.load("results_translate/checkpoint-4242/pytorch_model.bin", map_location='cuda')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [80]:
# Classsic top_p sampling
@torch.no_grad()
def top_p_decode(
    model, tokenizer, input_ids, start_tokens, attention_mask=None, max_new_tokens=50,
    eos_token_id=4, top_p=0.3, temperature=1.0,
):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.eval()
    model.to(device)
    input_ids = input_ids.to(device)

    # Encoder pass
    encoder_outputs = model.embedding(input_ids)
    if attention_mask is not None:
        attn_mask = (attention_mask == 1)
        attn_mask = attn_mask[:, None, :] | attn_mask[:, :, None]
    else:
        attn_mask = None

    for layer in model.encoder_layers:
        encoder_outputs = layer(encoder_outputs, mask=attn_mask)
    memory = model.encoder_norm(encoder_outputs)

    decoder_input_ids = torch.tensor([start_tokens], device=device)

    raw_probabilities = []
    tempered_probabilities = []

    for _ in range(max_new_tokens):
        tgt_embeddings = model.embedding(decoder_input_ids)

        for layer in model.decoder_layers:
            tgt_embeddings, _ = layer(
                tgt_embeddings,
                memory,
                src_mask=None,
                tgt_mask=None,
                self_attn_kv_cache=None
            )

        h = model.decoder_norm(tgt_embeddings)
        logits = model.fc_out(h)  # [batch, seq_len, vocab]

        raw_next_token_logits = logits[:, -1, :]
        raw_probs_dist = F.softmax(raw_next_token_logits, dim=-1)

        next_token_logits = raw_next_token_logits / temperature

        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        
        sorted_mask = cumulative_probs > top_p
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
        sorted_mask[..., 0] = 0

        indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
        next_token_logits[indices_to_remove] = -float('inf')

        final_probs_dist = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(final_probs_dist, num_samples=1)

        chosen_raw_prob = raw_probs_dist[0, next_token.item()].item()
        raw_probabilities.append(chosen_raw_prob)

        chosen_final_prob = final_probs_dist[0, next_token.item()].item()
        tempered_probabilities.append(chosen_final_prob)

        decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)

        if eos_token_id is not None and (next_token == eos_token_id).all():
            break
            
    return decoder_input_ids, raw_probabilities, tempered_probabilities

In [16]:
iterator = iter(streaming_dataset)

In [49]:
data_collator = DataCollatorDynamic(start_prob=0.3, end_prob=0.023, steps=40000)
data_collator = DataCollator(mask_prob=0.023, copy_prob=0)

In [50]:
raw_item = next(iterator)

In [52]:
#text = "I have a cat. His name is Leo. He is a big, fluffy cat with bright green eyes. Leo loves to sleep in the sun. He also likes to play with a small red ball. Every morning, he waits by his bowl for his favorite food. When I pet him, he starts to purr loudly. He is a very good and gentle friend."
#raw_item["token_ids"] = tokenizer.encode(text).ids

In [53]:
item = data_collator([raw_item])

In [27]:
print("Base sequence:")
print(tokenizer.decode(item["input_ids"].tolist()[0], skip_special_tokens=False).replace("[NL]", "\n"))

output_ids, raw_probs, probs = top_p_decode(
    model,
    tokenizer,
    input_ids=item["input_ids"][0].unsqueeze(0),
    #attention_mask=item.get("attention_mask"),
    max_new_tokens=300,
    eos_token_id=4,
    top_p=0.8, 
    temperature=1,
    start_tokens=tokenizer.encode("[SOSE]").ids
)

print()
print("Denoised sequence (top-p autoregresive search model1)")
out_list = output_ids[0].tolist()
decoded_text = tokenizer.decode(out_list, skip_special_tokens=False).replace("[NL]", "\n")
print(decoded_text)

Base sequence:
[SOSE] We are here to help you. If you have need of our services, please call us, day or night. Or, if you prefer, you can fill out the form on the right.
 Are you thinking about pre-planning your funeral? Pre-planning is the best way to choose how you're remembered, to ease the emotional and financial burden on your loved ones, to protect yourself from rising funeral costs, and to let your family know your final wishes.
 If you are looking for information[MASK] for a loved[MASK] been entrusted to our care, you can use the form below to narrow down your search.
 Ordering flowers from our site ensures that your order will reach us or the family in a timely manner, and your gesture of support will remain acknowledged in the Book of Memories[UNK] for future generations. We only work with local florists so we can maintain the sense of urgency and quality of your selections. We thank you for helping to support the family during their time of need, and will fondly remember you

In [67]:
#for prob, raw_prob, token in zip(probs, raw_probs, out_list):
#   print(prob, raw_prob, token, tokenizer.decode([token], skip_special_tokens=False))

## fine tuning

In [56]:
# Loading data
with open("translations.json", "r", encoding="utf-8") as f:
    data_part1 = json.load(f)
with open("translations2.json", "r", encoding="utf-8") as f:
    data_part2 = json.load(f)
with open("translations3.json", "r", encoding="utf-8") as f:
    data_part3 = json.load(f)
with open("translations4.json", "r", encoding="utf-8") as f:
    data_part4 = json.load(f)
with open("translations5.json", "r", encoding="utf-8") as f:
    data_part5 = json.load(f)
with open("translations6.json", "r", encoding="utf-8") as f:
    data_part6 = json.load(f)
    
text_pairs = {
    "en": data_part1["en"] + data_part2["en"] + data_part3["en"] + data_part4["en"] + data_part5["en"] + data_part6["en"],
    "de": data_part1["de"] + data_part2["de"] + data_part3["de"] + data_part4["de"] + data_part5["de"] + data_part6["de"]
}

In [57]:
len(text_pairs["en"])

79446

In [54]:
class DataCollatorTranslate:
    def __init__(self, translate_from, pad_token_id=0):
        self.translate_from = translate_from
        self.PAD_TOKEN_ID = pad_token_id

    def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
        batch_input_ids = []
        batch_attention_mask = []
        batch_labels = []
        batch_decoder_input_ids = []
        batch_decoder_attention_mask = []

        batch_len = 0
        for item in batch:
            if len(item['en_ids']) > batch_len:
                batch_len = len(item['en_ids'])
            if len(item['de_ids']) > batch_len:
                batch_len = len(item['de_ids'])

        #print(f'Batch len: {batch_len}')
        for item in batch:
            translate_to = ["de", "en"]
            if self.translate_from:
                translate_from = self.translate_from
            else:
                translate_from = random.choice(translate_to)
            translate_to.remove(translate_from)
            
            from_ids = item[f"{translate_from}_ids"]
            to_ids = item[f"{translate_to[0]}_ids"]

            labels = to_ids[1:]
            decoder_ids = to_ids[:-1]

            pad_length_from = batch_len - len(from_ids)
            pad_length_to = batch_len - len(decoder_ids)

            from_ids.extend([self.PAD_TOKEN_ID] * pad_length_from)
            decoder_ids.extend([self.PAD_TOKEN_ID] * pad_length_to)
            labels.extend([self.PAD_TOKEN_ID] * pad_length_to)

            attention_mask = [1 if token != self.PAD_TOKEN_ID else 0 for token in from_ids]
            decoder_attention_mask = [1 if token != self.PAD_TOKEN_ID else 0 for token in decoder_ids]

            batch_input_ids.append(from_ids)
            batch_attention_mask.append(attention_mask)
            batch_labels.append(labels)
            batch_decoder_input_ids.append(decoder_ids)
            batch_decoder_attention_mask.append(decoder_attention_mask)

        return {
            "input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
            "decoder_input_ids": torch.tensor(batch_decoder_input_ids, dtype=torch.long),
            "decoder_attention_mask": torch.tensor(batch_decoder_attention_mask, dtype=torch.long),
            "labels": torch.tensor(batch_labels, dtype=torch.long),
        }

In [55]:
data_collator_translate = DataCollatorTranslate("de")

In [60]:
training_args_translate = TrainingArguments(
    output_dir="./results_translate",
    num_train_epochs=2,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    learning_rate=5e-6,
    lr_scheduler_type="cosine",
    warmup_steps=45,
    save_strategy="epoch",
    #save_steps=300,
    save_total_limit=8,
    logging_strategy="epoch",
    #logging_steps=100,
    report_to="none",
    eval_strategy="epoch",
    #eval_steps=100,
    remove_unused_columns=False,
    bf16=True,
    fp16=False,
    torch_compile=True,
    torch_compile_backend="aot_eager",
    save_safetensors = False,
    dataloader_num_workers=8,
    weight_decay=0.01,
    max_grad_norm=1.0,
)

In [62]:
def tokenize(batch, tokenizer=tokenizer):
    texts_en = []
    texts_de = []
    # Texts standartization
    replacements = {
        "\n": "[NL]",
        "“": '"',
        "”": '"',
        "„": '"',
        "’": "'",
        "‘": "'",
        "—": "-",
        "–": "-",
        "…": "...",
        "`": "'",
        "''": '"',
        "$": "(dollars)",
        "€": "(euros)",
        "½": "1(backslash)2",
        "²": "(caret)2",
        "/": "(backslash)",
        "^": "(caret)",
        "″": '"',
        "%": "(percent)",
        "=": "(equals)",
    }
    for i in range(len(batch["en"])):
        text_en = batch["en"][i]
        text_de = batch["de"][i]
        for old, new in replacements.items():
            text_en = text_en.replace(old, new)
            text_de = text_de.replace(old, new)
        texts_en.append(text_en)
        texts_de.append(text_de)

    # Tokenize
    en_encodings = tokenizer.encode_batch(texts_en)
    de_encodings = tokenizer.encode_batch(texts_de)

    en_ids = [[2] + encoding.ids + [4] for encoding in en_encodings]
    de_ids = [[3] + encoding.ids + [4] for encoding in de_encodings]

    
    return {
        "en_ids": en_ids,
        "de_ids": de_ids
    }
        
def filter_translation(item):
    # min len in tokens of each text
    min_len = 6
    if len(item["en_ids"]) < min_len or len(item["de_ids"]) < min_len:
        return False

    # max len in tokens of each text
    max_len = 320
    if len(item["en_ids"]) > max_len or len(item["de_ids"]) > max_len:
        return False

    # ratio of tokens in pair
    min_ratio = 0.85
    if min(len(item["en_ids"]), len(item["de_ids"])) / max(len(item["en_ids"]), len(item["de_ids"])) < min_ratio:
        return False

    # allowed unk tokens
    unk_tok_id = 1
    max_unk_tok = 1
    if item["en_ids"].count(unk_tok_id) > max_unk_tok or item["de_ids"].count(unk_tok_id) > max_unk_tok:
        return False
        
    return True
    
translate_dataset = HFDataset.from_dict(text_pairs).shuffle(seed=42).map(tokenize, batched=True, batch_size=BATCH_SIZE).filter(filter_translation)

Map:   0%|          | 0/79446 [00:00<?, ? examples/s]

Filter:   0%|          | 0/79446 [00:00<?, ? examples/s]

In [63]:
datasets = translate_dataset.train_test_split(test_size=0.05, seed=42)
datasets

DatasetDict({
    train: Dataset({
        features: ['en', 'de', 'en_ids', 'de_ids'],
        num_rows: 62974
    })
    test: Dataset({
        features: ['en', 'de', 'en_ids', 'de_ids'],
        num_rows: 3315
    })
})

In [64]:
train_translate_dataset = datasets["train"]
test_translate_dataset = datasets["test"]

In [None]:
#run = wandb.init(project="EDT-lm", name="FineTune")
trainer_fine_tune = Trainer(
    model=model,
    args=training_args_translate,
    train_dataset=train_translate_dataset,
    eval_dataset=test_translate_dataset,
    data_collator=data_collator_translate,
    #callbacks=[diag_callback],
)

In [None]:
trainer_fine_tune.train()

In [23]:
iterator = iter(test_translate_dataset)

In [24]:
item = next(iterator)
item = data_collator_translate([item])

In [82]:
item["input_ids"]

tensor([[    3,   233,    16,  2987,   826,   870, 23438,    10,  7599,    65,
           874,    17,  1359,   669,     7,    42,   111,    15, 22777,    17,
          1012,   390, 12025,  9056,  3001,  5271,  5566,    17,     7,    10,
           332,    45,    86, 11673,    11,  3757,   200,   458,    14,    25,
           790, 10725,    11,    24,  4126,   149,   633,  9600,     7,   141,
            93,    15, 11806,  5326,     7,    15,   870,    62, 16710,     7,
         16400,  6576,    10,  2120,   208,    65,   115,     7,   214, 17508,
             9,    16,    62,  7214,    11, 15944,    18,    24,  4892,    12,
            15,   423,    31,   717, 14766,    24, 10230,     8,    10,     4]])

In [100]:
state_dict = torch.load("results_translate/checkpoint-4848/pytorch_model.bin", map_location='cuda')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [95]:
text_a1 = "[SOSD] Ich bin Anna. Ich fahre nach Berlin. Ich buche ein Hotel. Das Hotel ist gut. Berlin ist groß. Ich mag Berlin. [EOS]"
text_a2 = "[SOSD] Hallo, mein Name ist Anna. Nächste Woche fliege ich nach Berlin, weil ich die Stadt besuchen will. Ich habe schon ein Hotelzimmer gebucht. Ich möchte die Sehenswürdigkeiten sehen und vielleicht in einem Restaurant essen. Kannst du mir helfen? [EOS]"

text_b1 = "[SOSD] Ich bereite gerade meine Reise nach Berlin vor und freue mich schon sehr darauf. Obwohl das Wetter nicht immer perfekt ist, hoffe ich, dass wir viele Parks erkunden können. Mir wurde gesagt, dass man unbedingt das Brandenburger Tor besuchen sollte. Falls du Zeit hast, könnten wir uns vielleicht treffen, um gemeinsam die Stadt zu entdecken. [EOS]"
text_b2 = "[SOSD] Bei der Planung meiner Berlin-Reise lege ich besonderen Wert darauf, nicht nur die touristischen Hauptattraktionen abzuklappern, sondern auch einen authentischen Eindruck vom Leben in den verschiedenen Stadtteilen zu bekommen. Es wäre schade, wenn man die kulturelle Vielfalt, die Berlin auszeichnet, verpassen würde. Deshalb habe ich vor, mich abseits der ausgetretenen Pfade zu bewegen und die Stadt auf eigene Faust zu erkunden. [EOS]"

text_c1 = "[SOSD] Die Faszination Berlins ergibt sich aus dem ständigen Spannungsfeld zwischen seiner turbulenten Vergangenheit und seiner dynamischen Gegenwart. Um die Stadt wirklich zu begreifen, genügt es nicht, die Überreste der Mauer zu betrachten; man muss sich mit den vielschichtigen Narrativen auseinandersetzen, die das heutige Stadtbild prägen. Der Facettenreichtum der Metropole offenbart sich erst bei der eingehenden Beschäftigung mit ihren subkulturellen Strömungen und der permanenten urbanen Transformation. [EOS]"
text_c2 = "[SOSD] Die epistemologische Auseinandersetzung mit Berlin als Palimpsest der deutschen Geschichte erfordert eine hohe Ambiguitätstoleranz seitens des Betrachters. Die diskursive Hegemonie etablierter Gedenkorte wird zunehmend durch partizipative Geschichtsnarrative dekonstruiert, was die per se prekäre Verfasstheit kollektiver Identität perpetuiert und zu einer fortwährenden Neuaushandlung des städtischen Selbstverständnisses führt. [EOS]"

text_free = "[SOSD] Hey, was geht ab? Ich hab mir überlegt, wir könnten heute Abend ins Kino gehen. Der neue Film soll mega krass sein, oder? Ich hab aber keine Ahnung, worum es geht. Soll halt gut sein. Bei dir läuft's ja eh. Also, sag Bescheid! [EOS]"

item["input_ids"] = torch.tensor([tokenizer.encode(text_b1).ids])

In [101]:
print("Base sequence:")
print(tokenizer.decode(item["input_ids"].tolist()[0], skip_special_tokens=False).replace("[NL]", "\n"))

output_ids, raw_probs, probs = top_p_decode(
    model,
    tokenizer,
    input_ids=item["input_ids"][0].unsqueeze(0),
    start_tokens=tokenizer.encode("[SOSE]").ids,
    max_new_tokens=200,
    top_p=0.2,
    temperature=1
)

torch.cuda.synchronize()
start_time_custom = time.perf_counter()

output_ids, raw_probs, probs = top_p_decode(
    model,
    tokenizer,
    input_ids=item["input_ids"][0].unsqueeze(0),
    start_tokens=tokenizer.encode("[SOSE]").ids,
    max_new_tokens=200,
    top_p=0.2,
    temperature=1
)

torch.cuda.synchronize()
end_time_custom = time.perf_counter()
duration_custom = end_time_custom - start_time_custom

print()
print(f"Tokens generated: {len(output_ids[0])}")
print(f"Execution time: {duration_custom:.4f} second")

generated_ids_creative = model.generate(
    input_ids=item['input_ids'].to("cuda"),
    #attention_mask=item['attention_mask'].to("cuda"),
    decoder_input_ids=torch.tensor([2]).unsqueeze(0).to("cuda"),
    max_new_tokens=5,
    temperature=1,
    top_k=50,
    top_p=0.2,
    do_sample=True,
)

torch.cuda.synchronize()
start_time_generate = time.perf_counter()

generated_ids_creative = model.generate(
    input_ids=item['input_ids'].to("cuda"),
    #attention_mask=item['attention_mask'].to("cuda"),
    decoder_input_ids=torch.tensor(tokenizer.encode("[SOSE]").ids).unsqueeze(0).to("cuda"),
    max_new_tokens=200,
    temperature=1,
    top_k=3,
    top_p=0.25,
    do_sample=True,
)

torch.cuda.synchronize()
end_time_generate = time.perf_counter()
duration_generate = end_time_generate - start_time_generate

print(f"Tokens generated: {len(generated_ids_creative[0])}")
print(f"Execution time: {duration_generate:.4f} second")

print()
print("Denoised sequence (top-p autoregresive search)")
out_list = generated_ids_creative.tolist()
#print(out_list)
decoded_text = tokenizer.decode(out_list[0], skip_special_tokens=False).replace("[NL]", "\n")
print(decoded_text)

Base sequence:
[SOSD] Ich bereite gerade meine Reise nach Berlin vor und freue mich schon sehr darauf. Obwohl das Wetter nicht immer perfekt ist, hoffe ich, dass wir viele Parks erkunden können. Mir wurde gesagt, dass man unbedingt das Brandenburger Tor besuchen sollte. Falls du Zeit hast, könnten wir uns vielleicht treffen, um gemeinsam die Stadt zu entdecken. [EOS]

Tokens generated: 60
Execution time: 1.2906 second
Tokens generated: 60
Execution time: 1.2155 second

Denoised sequence (top-p autoregresive search)
[SOSE] I am currently planning my trip to Berlin and am already very excited about it. Although the weather isn't always perfect, I hope we can explore many parks. I was told that you should definitely visit the Brandenburg Gate. If you have time, we might meet together to discover the city together.[EOS]


In [65]:
for prob, raw_prob, token in zip(probs, raw_probs, out_list):
    print(prob, raw_prob, token, tokenizer.decode([token], skip_special_tokens=False))

In [None]:
#comparison with a professional model of similar size
model_name = "Helsinki-NLP/opus-mt-de-en"
translator = pipeline("translation", model=model_name)
translated_output = translator(text_b2)

print("\n--- Результат перекладу ---")
print(f"Original (DE): {german_text_b2.strip()}")
print(f"Translation (EN): {translated_output[0]['translation_text']}")


translated_output_c2 = translator(text_c2)
print("\n--- Перевірка на тексті C2 ---")
print(f"Original (DE): {german_text_c2.strip()}")
print(f"Translation (EN): {translated_output_c2[0]['translation_text']}")