<a href="https://colab.research.google.com/github/Hisernberg/BLEACH---Bangla-Language-Expert-Adaptive-Corpus-Handler/blob/main/bleach_01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## BLEACH

In [None]:
#!/usr/bin/env python3
"""
BLEACH: Bangla Language Expert Adaptive Corpus Handler
ACL-ready sparse MoE for Bangla dialect modeling
"""

import os
import math
import random
import warnings
warnings.filterwarnings('ignore')

from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from collections import defaultdict, Counter
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.cuda.amp import autocast, GradScaler

from transformers import AutoTokenizer, PreTrainedTokenizer
from datasets import load_dataset, Dataset as HFDataset

# Set reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ====================================================
# PORTION 1: DATA PIPELINE & DIALECT-AWARE SAMPLER
# ====================================================

@dataclass
class DataConfig:
    """Data configuration"""
    train_path: str = "/content/cleaned_bangla_train (1).csv"
    val_path: str = "/content/cleaned_bangla_val (1).csv"
    test_path: str = "/content/cleaned_bangla_test (2).csv"

    # Tokenizer settings
    tokenizer_name: str = "csebuetnlp/banglabert"  # Pre-trained Bangla tokenizer
    vocab_size: int = 32000
    max_length: int = 256  # Conservative for Colab

    # Batch settings
    batch_size: int = 8  # Safe for Colab with MoE
    micro_batch_size: int = 2  # For gradient accumulation
    num_workers: int = 2

    # Dialect settings
    dialects: Tuple[str] = ("Chittagong", "Sylhet", "Barisal", "Noakhali", "Mymensingh")
    dialect_to_id: Dict[str, int] = None

    def __post_init__(self):
        self.dialect_to_id = {d: i for i, d in enumerate(self.dialects)}
        self.id_to_dialect = {i: d for d, i in self.dialect_to_id.items()}

    @property
    def gradient_accumulation_steps(self):
        return self.batch_size // self.micro_batch_size

class DialectAwareSampler(Sampler):
    """Balances dialects in each batch"""

    def __init__(self, dataset, batch_size, dialect_ids, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.dialect_ids = dialect_ids
        self.shuffle = shuffle

        # Group indices by dialect
        self.dialect_groups = defaultdict(list)
        for idx, dialect_id in enumerate(dialect_ids):
            self.dialect_groups[dialect_id].append(idx)

        # Ensure each dialect has same number of samples per epoch
        # Add a check here to prevent ValueError if dialect_groups is empty
        if not self.dialect_groups:
            self.num_samples_per_dialect = 0
        else:
            self.num_samples_per_dialect = max(len(group) for group in self.dialect_groups.values())

    def __iter__(self):
        # Create balanced batches
        indices = []

        # For each dialect, shuffle and extend to match max length
        dialect_cycles = {}
        for dialect_id, group in self.dialect_groups.items():
            if self.shuffle:
                random.shuffle(group)
            # Repeat to fill epoch
            repeated = group * (self.num_samples_per_dialect // len(group) + 1)
            dialect_cycles[dialect_id] = repeated[:self.num_samples_per_dialect]

        # Interleave dialects
        for i in range(self.num_samples_per_dialect):
            for dialect_id in self.dialect_groups.keys():
                indices.append(dialect_cycles[dialect_id][i])

        if self.shuffle:
            # Shuffle batches, not individual samples
            batch_indices = []
            for i in range(0, len(indices), self.batch_size):
                batch = indices[i:i + self.batch_size]
                if len(batch) == self.batch_size:
                    batch_indices.append(batch)
            random.shuffle(batch_indices)
            indices = [idx for batch in batch_indices for idx in batch]

        return iter(indices)

    def __len__(self):
        return self.num_samples_per_dialect * len(self.dialect_groups)

class BanglaDialectDataset(Dataset):
    """PyTorch Dataset for Bangla dialect LM"""

    def __init__(self, file_path: str, tokenizer: PreTrainedTokenizer, config: DataConfig):
        self.config = config
        self.tokenizer = tokenizer

        # Load data
        self.df = pd.read_csv(file_path)

        # Validate columns
        assert "text" in self.df.columns, "Missing 'text' column"
        assert "dialect" in self.df.columns, "Missing 'dialect' column"

        # Clean and filter
        self.df = self.df.dropna(subset=["text", "dialect"])
        self.df = self.df[self.df["dialect"].isin(config.dialects)]

        # Convert dialect to ID
        self.dialect_ids = [
            config.dialect_to_id[d] for d in self.df["dialect"].values
        ]

        print(f"Loaded {len(self)} samples from {file_path}")
        print(f"Dialect distribution: {dict(Counter(self.df['dialect'].values))}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = str(self.df.iloc[idx]["text"])
        dialect_id = self.dialect_ids[idx]

        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.config.max_length,
            padding=False,
            return_tensors=None
        )

        return {
            "input_ids": encoding["input_ids"],
            "attention_mask": encoding.get("attention_mask", [1] * len(encoding["input_ids"])),
            "dialect_id": dialect_id,
            "text": text
        }

def dynamic_padding_collate_fn(batch, tokenizer, max_length=None):
    """Collate function with dynamic padding"""
    input_ids = [item["input_ids"] for item in batch]
    attention_mask = [item["attention_mask"] for item in batch]
    dialect_ids = torch.tensor([item["dialect_id"] for item in batch], dtype=torch.long)

    # Dynamic padding
    padded = tokenizer.pad(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        padding="longest",
        max_length=max_length,
        return_tensors="pt"
    )

    # Shift for language modeling
    labels = padded["input_ids"].clone()

    return {
        "input_ids": padded["input_ids"],
        "attention_mask": padded["attention_mask"],
        "labels": labels,
        "dialect_ids": dialect_ids
    }

def create_dataloaders(config):
    """Create train, val, test dataloaders"""

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token or "[PAD]"

    # Create datasets
    train_dataset = BanglaDialectDataset(config.train_path, tokenizer, config)
    val_dataset = BanglaDialectDataset(config.val_path, tokenizer, config)
    test_dataset = BanglaDialectDataset(config.test_path, tokenizer, config)

    # Check if train_dataset is empty before creating sampler
    if not train_dataset.dialect_ids:
        raise ValueError(f"Training dataset is empty. Please check the file '{config.train_path}' and ensure it contains valid 'text' and 'dialect' columns matching the specified dialects: {config.dialects}")

    # Create samplers
    train_sampler = DialectAwareSampler(
        train_dataset, config.batch_size, train_dataset.dialect_ids, shuffle=True
    )

    # Create dataloaders
    collate_fn = lambda batch: dynamic_padding_collate_fn(batch, tokenizer, config.max_length)

    train_loader = DataLoader(
        train_dataset,
        batch_sampler=None,
        batch_size=config.micro_batch_size,  # Use micro-batch for gradient accumulation
        sampler=train_sampler,
        collate_fn=collate_fn,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.micro_batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=config.num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.micro_batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=config.num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader, tokenizer

# ====================================================
# PORTION 2: BLEACH MoE ARCHITECTURE
# ====================================================

@dataclass
class ModelConfig:
    """Model configuration for BLEACH"""

    # Architecture
    vocab_size: int = 32000
    hidden_dim: int = 768
    num_layers: int = 8
    num_attention_heads: int = 12
    intermediate_dim: int = 3072  # hidden_dim * 4

    # MoE settings
    num_experts: int = 5
    expert_capacity_factor: float = 1.2
    router_jitter_noise: float = 0.01
    expert_dropout: float = 0.1
    router_aux_loss_coef: float = 0.01

    # Routing
    top_k: int = 1  # Top-1 routing
    router_init_range: float = 0.02
    use_load_balancing: bool = True

    # Regularization
    attention_dropout: float = 0.1
    hidden_dropout: float = 0.1
    layer_norm_eps: float = 1e-5

    # Positional embeddings
    max_position_embeddings: int = 512
    rope_theta: float = 10000.0
    use_rope: bool = True
    use_alibi: bool = False  # Can't use both

    # FFN type
    use_swiglu: bool = True

    # LoRA for experts (optional)
    use_lora: bool = True
    lora_rank: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.1

    def __post_init__(self):
        assert not (self.use_rope and self.use_alibi), "Cannot use both RoPE and ALiBi"
        assert self.hidden_dim % self.num_attention_heads == 0, "hidden_dim must be divisible by num_heads"
        self.head_dim = self.hidden_dim // self.num_attention_heads

class LoRALayer(nn.Module):
    """LoRA layer for parameter-efficient fine-tuning of experts"""

    def __init__(self, in_dim, out_dim, rank=8, alpha=16, dropout=0.1):
        super().__init__()
        self.lora_A = nn.Linear(in_dim, rank, bias=False)
        self.lora_B = nn.Linear(rank, out_dim, bias=False)
        self.scaling = alpha / rank
        self.dropout = nn.Dropout(dropout)

        # Initialize
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

    def forward(self, x, base_weight):
        """x: [..., in_dim], base_weight: original weight matrix"""
        lora_out = self.lora_B(self.lora_A(self.dropout(x)))
        return F.linear(x, base_weight) + lora_out * self.scaling

class ExpertFFN(nn.Module):
    """Single expert FFN with SwiGLU"""

    def __init__(self, config: ModelConfig, expert_id: int = 0):
        super().__init__()
        self.expert_id = expert_id
        self.config = config

        # SwiGLU components
        self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
        self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
        self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)

        # Optional LoRA adapters
        self.use_lora = config.use_lora
        if self.use_lora:
            self.lora_gate = LoRALayer(config.hidden_dim, config.intermediate_dim,
                                      config.lora_rank, config.lora_alpha, config.lora_dropout)
            self.lora_up = LoRALayer(config.hidden_dim, config.intermediate_dim,
                                    config.lora_rank, config.lora_alpha, config.lora_dropout)
            self.lora_down = LoRALayer(config.intermediate_dim, config.hidden_dim,
                                      config.lora_rank, config.lora_alpha, config.lora_dropout)

        # Dropout
        self.dropout = nn.Dropout(config.expert_dropout)

        # Initialize
        self._init_weights()

    def _init_weights(self):
        # Expert-specific initialization
        std = self.config.router_init_range
        nn.init.normal_(self.gate_proj.weight, mean=0.0, std=std)
        nn.init.normal_(self.up_proj.weight, mean=0.0, std=std)
        nn.init.normal_(self.down_proj.weight, mean=0.0, std=std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: [batch_size * seq_len, hidden_dim]"""
        if self.use_lora:
            # Use LoRA-enhanced forward pass
            gate = self.lora_gate(x, self.gate_proj.weight)
            up = self.lora_up(x, self.up_proj.weight)
        else:
            gate = self.gate_proj(x)
            up = self.up_proj(x)

        # SwiGLU activation
        hidden = F.silu(gate) * up

        if self.use_lora:
            output = self.lora_down(hidden, self.down_proj.weight)
        else:
            output = self.down_proj(hidden)

        return self.dropout(output)

class Router(nn.Module):
    """Token-level router for MoE"""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts

        # Router linear layer
        self.router = nn.Linear(config.hidden_dim, config.num_experts, bias=False)

        # Initialize router weights
        nn.init.normal_(self.router.weight, mean=0.0, std=config.router_init_range)

        # Jitter noise for exploration
        self.jitter_noise = config.router_jitter_noise

        # Statistics
        self.register_buffer("expert_counts", torch.zeros(config.num_experts))
        self.register_buffer("total_tokens", torch.tensor(0.0))

    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns: gates, indices, router_logits
        hidden_states: [batch_size * seq_len, hidden_dim]
        """
        batch_size_seq_len, hidden_dim = hidden_states.shape

        # Add jitter noise for exploration during training
        if self.training and self.jitter_noise > 0:
            hidden_states = hidden_states + torch.randn_like(hidden_states) * self.jitter_noise

        # Compute router logits
        router_logits = self.router(hidden_states)  # [T, num_experts]

        # Convert to probabilities
        gates = F.softmax(router_logits, dim=-1)  # [T, num_experts]

        # Top-1 selection
        top1_gates, top1_indices = torch.topk(gates, k=1, dim=-1)  # [T, 1]
        top1_indices = top1_indices.squeeze(-1)  # [T]

        # Create mask for selected experts
        expert_mask = F.one_hot(top1_indices, num_classes=self.num_experts).float()  # [T, num_experts]

        return gates, top1_indices, router_logits

    def compute_aux_loss(self, router_logits: torch.Tensor, expert_mask: torch.Tensor) -> torch.Tensor:
        """Compute load balancing auxiliary loss"""
        if not self.config.use_load_balancing:
            return torch.tensor(0.0, device=router_logits.device)

        # router_logits: [T, num_experts]
        # expert_mask: [T, num_experts] (one-hot or gating weights)

        T = router_logits.shape[0]
        E = self.num_experts

        # Importance loss: balance router probabilities
        router_probs = F.softmax(router_logits, dim=-1)  # [T, E]
        importance = router_probs.sum(dim=0)  # [E]
        importance_loss = torch.std(importance) / (torch.mean(importance) + 1e-8)

        # Load loss: balance expert assignments
        load = expert_mask.sum(dim=0)  # [E]
        load_loss = torch.std(load) / (torch.mean(load) + 1e-8)

        # Combined loss
        aux_loss = self.config.router_aux_loss_coef * (importance_loss + load_loss)

        # Update statistics
        if self.training:
            self.expert_counts += load.detach()
            self.total_tokens += T

        return aux_loss

class SparseMoELayer(nn.Module):
    """Sparse MoE layer with Top-1 routing"""

    def __init__(self, config: ModelConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        # Create experts
        self.experts = nn.ModuleList([
            ExpertFFN(config, expert_id=i) for i in range(config.num_experts)
        ])

        # Router
        self.router = Router(config)

        # Capacity factor
        self.capacity_factor = config.expert_capacity_factor

        # Layer norm for router input
        self.layer_norm = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_eps)

        # Dropout
        self.dropout = nn.Dropout(config.hidden_dropout)

        # Statistics
        self.routing_decisions = []  # For analysis

    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        hidden_states: [batch_size, seq_len, hidden_dim]
        Returns: (output, aux_loss)
        """
        batch_size, seq_len, hidden_dim = hidden_states.shape

        # Reshape for token-level routing
        original_shape = hidden_states.shape
        hidden_states_2d = hidden_states.reshape(-1, hidden_dim)  # [T, hidden_dim]
        T = hidden_states_2d.shape[0]  # batch_size * seq_len

        # Normalize before routing
        hidden_states_norm = self.layer_norm(hidden_states_2d)

        # Get routing decisions
        gates, indices, router_logits = self.router(hidden_states_norm)

        # Create expert mask
        expert_mask = F.one_hot(indices, num_classes=self.config.num_experts).float()  # [T, E]

        # Calculate capacity per expert
        capacity = int(self.capacity_factor * T / self.config.num_experts)

        # Dispatch tokens to experts
        expert_outputs = torch.zeros_like(hidden_states_2d)

        for expert_idx in range(self.config.num_experts):
            # Get tokens assigned to this expert
            expert_mask_slice = expert_mask[:, expert_idx]  # [T]
            selected_indices = torch.nonzero(expert_mask_slice, as_tuple=True)[0]

            if len(selected_indices) > 0:
                # Limit to capacity
                if len(selected_indices) > capacity:
                    # Randomly select tokens up to capacity
                    perm = torch.randperm(len(selected_indices), device=hidden_states.device)[:capacity]
                    selected_indices = selected_indices[perm]

                if len(selected_indices) > 0:
                    # Get tokens for this expert
                    expert_input = hidden_states_2d[selected_indices]  # [num_tokens, hidden_dim]

                    # Process through expert
                    expert_output = self.experts[expert_idx](expert_input)  # [num_tokens, hidden_dim]

                    # Scatter back
                    expert_outputs[selected_indices] += expert_output

        # Reshape back
        output = expert_outputs.reshape(original_shape)
        output = self.dropout(output)

        # Compute auxiliary loss
        aux_loss = self.router.compute_aux_loss(router_logits, expert_mask)

        # Store routing decisions for analysis
        if self.training and random.random() < 0.01:  # Sample 1% of batches
            self.routing_decisions.append({
                "indices": indices.detach().cpu(),
                "gates": gates.detach().cpu(),
                "layer": self.layer_idx
            })
            if len(self.routing_decisions) > 1000:  # Limit memory
                self.routing_decisions.pop(0)

        return output, aux_loss

class RotaryPositionEmbedding(nn.Module):
    """RoPE implementation"""

    def __init__(self, dim: int, max_seq_len: int = 512, theta: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.theta = theta

        # Precompute frequencies
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
        """
        x: [batch_size, num_heads, seq_len, head_dim]
        """
        batch_size, num_heads, seq_len, head_dim = x.shape
        device = x.device

        # Create positions
        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)

        # Compute frequencies
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # [seq_len, dim//2]
        # The following line was incorrect and caused a dimension mismatch
        # freqs = torch.cat([freqs, freqs], dim=-1)  # [seq_len, dim]

        # Reshape for broadcasting
        freqs = freqs.unsqueeze(0).unsqueeze(1)  # [1, 1, seq_len, dim//2]

        # Apply rotation
        cos = torch.cos(freqs)
        sin = torch.sin(freqs)

        x1, x2 = x[..., :self.dim//2], x[..., self.dim//2:]
        rotated = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)

        return rotated.type_as(x)

class Attention(nn.Module):
    """Multi-head attention with RoPE"""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # QKV projections
        self.q_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.o_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)

        # Dropout
        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.hidden_dropout)

        # RoPE
        self.use_rope = config.use_rope
        if self.use_rope:
            self.rope = RotaryPositionEmbedding(
                dim=config.head_dim,
                max_seq_len=config.max_position_embeddings,
                theta=config.rope_theta
            )

        # ALiBi (alternative)
        self.use_alibi = config.use_alibi
        if self.use_alibi:
            self.register_buffer(
                "alibi_slopes",
                torch.tensor(self._get_alibi_slopes(config.num_attention_heads))
            )

        # Initialize
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.q_proj.weight)
        nn.init.xavier_uniform_(self.k_proj.weight)
        nn.init.xavier_uniform_(self.v_proj.weight)
        nn.init.xavier_uniform_(self.o_proj.weight)

    def _get_alibi_slopes(self, n_heads):
        """Get ALiBi slopes"""
        def get_slopes_power_of_2(n):
            start = 2**(-2**(-(math.log2(n)-3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n_heads).is_integer():
            return get_slopes_power_of_2(n_heads)

        # Closest power of 2
        nearest_power = 2**math.floor(math.log2(n_heads))
        slopes = get_slopes_power_of_2(nearest_power)

        # Add remaining heads
        extra_slopes = get_slopes_power_of_2(2 * nearest_power)
        slopes.extend(extra_slopes[:n_heads - nearest_power])

        return slopes

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, hidden_dim = x.shape

        # QKV projections
        q = self.q_proj(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.head_dim)

        # Transpose for attention
        q = q.transpose(1, 2)  # [batch, heads, seq_len, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Apply RoPE
        if self.use_rope:
            q = self.rope(q, seq_len)
            k = self.rope(k, seq_len)

        # Attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.config.head_dim)

        # Apply ALiBi if using
        if self.use_alibi:
            positions = torch.arange(seq_len, device=x.device).unsqueeze(0).unsqueeze(0)
            alibi = positions * self.alibi_slopes.view(1, -1, 1, 1)
            attn_scores = attn_scores + alibi

        # Apply attention mask
        if attention_mask is not None:
            attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, seq_len]
            attn_scores = attn_scores.masked_fill(attn_mask == 0, float('-inf'))

        # Softmax
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        # Context
        context = torch.matmul(attn_probs, v)  # [batch, heads, seq_len, head_dim]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_dim)

        # Output projection
        output = self.o_proj(context)
        output = self.resid_dropout(output)

        return output

class TransformerBlock(nn.Module):
    """Single transformer block with optional MoE"""

    def __init__(self, config: ModelConfig, layer_idx: int, is_moe_layer: bool = False):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.is_moe_layer = is_moe_layer

        # Self-attention
        self.attention = Attention(config)
        self.attention_norm = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_eps)

        # FFN (either dense or MoE)
        if is_moe_layer:
            self.moe = SparseMoELayer(config, layer_idx)
            self.ffn_norm = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_eps)
            self.dense_ffn = None
        else:
            self.moe = None
            self.ffn_norm = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_eps)

            # Dense FFN with SwiGLU
            self.dense_ffn = nn.ModuleDict({
                'gate_proj': nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False),
                'up_proj': nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False),
                'down_proj': nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False),
                'dropout': nn.Dropout(config.hidden_dropout)
            })

            # Initialize dense FFN
            nn.init.normal_(self.dense_ffn['gate_proj'].weight, mean=0.0, std=config.router_init_range)
            nn.init.normal_(self.dense_ffn['up_proj'].weight, mean=0.0, std=config.router_init_range)
            nn.init.normal_(self.dense_ffn['down_proj'].weight, mean=0.0, std=config.router_init_range)

    def forward_dense_ffn(self, x: torch.Tensor) -> torch.Tensor:
        """Dense FFN forward"""
        gate = self.dense_ffn['gate_proj'](x)
        up = self.dense_ffn['up_proj'](x)
        hidden = F.silu(gate) * up
        output = self.dense_ffn['down_proj'](hidden)
        return self.dense_ffn['dropout'](output)

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns: (hidden_states, aux_loss)
        """
        # Self-attention with residual
        attn_output = self.attention(self.attention_norm(x), attention_mask)
        x = x + attn_output

        # FFN with residual
        ffn_norm_output = self.ffn_norm(x)

        if self.is_moe_layer and self.moe is not None:
            ffn_output, aux_loss = self.moe(ffn_norm_output)
            x = x + ffn_output
        else:
            ffn_output = self.forward_dense_ffn(ffn_norm_output)
            x = x + ffn_output
            aux_loss = torch.tensor(0.0, device=x.device)

        return x, aux_loss

    def gradient_checkpointing_enable(self):
        """Enable gradient checkpointing for this block"""
        self._gradient_checkpointing = True

    def gradient_checkpointing_disable(self):
        """Disable gradient checkpointing for this block"""
        self._gradient_checkpointing = False

class BLEACHModel(nn.Module):
    """Main BLEACH model"""

    def __init__(self, config: ModelConfig, vocab_size: int):
        super().__init__()
        self.config = config
        self.vocab_size = vocab_size

        # Embeddings
        self.word_embeddings = nn.Embedding(vocab_size, config.hidden_dim)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_dim)

        # Dropout
        self.emb_dropout = nn.Dropout(config.hidden_dropout)

        # Transformer layers
        # Place MoE at layers 2, 4, 6 (3 out of 8 layers)
        self.layers = nn.ModuleList()
        for layer_idx in range(config.num_layers):
            is_moe_layer = layer_idx in [2, 4, 6]  # 3 MoE layers
            self.layers.append(
                TransformerBlock(config, layer_idx, is_moe_layer)
            )

        # Final layer norm
        self.final_norm = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_eps)

        # LM head
        self.lm_head = nn.Linear(config.hidden_dim, vocab_size, bias=False)

        # Tie weights
        self.lm_head.weight = self.word_embeddings.weight

        # Initialize
        self._init_weights()

        # Parameter count
        self.print_parameter_count()

    def _init_weights(self):
        # Embeddings
        nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=0.02)

        # LM head already tied

    def print_parameter_count(self):
        """Print detailed parameter count"""
        total_params = 0
        moe_params = 0
        dense_params = 0

        for name, param in self.named_parameters():
            if param.requires_grad:
                params = param.numel()
                total_params += params

                if 'moe' in name and 'expert' in name:
                    moe_params += params
                else:
                    dense_params += params

        print("=" * 60)
        print("MODEL PARAMETER COUNT")
        print("=" * 60)
        print(f"Total parameters: {total_params:,}")
        print(f"Dense parameters: {dense_params:,}")
        print(f"MoE parameters: {moe_params:,}")
        print(f"Experts: {self.config.num_experts}")
        print(f"MoE layers: {sum(1 for layer in self.layers if layer.is_moe_layer)}")
        print(f"Model size: {total_params * 4 / 1e9:.2f} GB (FP32)")
        print(f"Model size: {total_params * 2 / 1e9:.2f} GB (FP16)")
        print("=" * 60)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: bool = True
    ) -> Dict[str, torch.Tensor]:
        batch_size, seq_len = input_ids.shape

        # Create position IDs
        position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)

        # Get embeddings
        word_embeds = self.word_embeddings(input_ids)
        pos_embeds = self.position_embeddings(position_ids)
        hidden_states = word_embeds + pos_embeds
        hidden_states = self.emb_dropout(hidden_states)

        # Transformer layers
        total_aux_loss = torch.tensor(0.0, device=hidden_states.device)

        for layer in self.layers:
            hidden_states, aux_loss = layer(hidden_states, attention_mask)
            total_aux_loss = total_aux_loss + aux_loss

        # Final norm
        hidden_states = self.final_norm(hidden_states)

        # LM head
        logits = self.lm_head(hidden_states)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Flatten for cross-entropy
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            lm_loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            # Add auxiliary loss
            loss = lm_loss + total_aux_loss

        if not return_dict:
            return (logits, loss) if loss is not None else logits

        return {
            "logits": logits,
            "loss": loss,
            "hidden_states": hidden_states,
            "aux_loss": total_aux_loss
        }

# ====================================================
# PORTION 3: MEMORY-EFFICIENT TRAINING LOOP
# ====================================================

@dataclass
class TrainingConfig:
    """Training configuration"""

    # Optimization
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    betas: Tuple[float, float] = (0.9, 0.999)
    gradient_accumulation_steps: int = 4
    max_grad_norm: float = 1.0

    # Scheduler
    num_warmup_steps: int = 1000
    num_training_steps: int = 10000
    lr_scheduler_type: str = "cosine"

    # Training
    num_epochs: int = 10
    max_steps: int = -1
    logging_steps: int = 50
    eval_steps: int = 500
    save_steps: int = 1000

    # Mixed precision
    use_amp: bool = True
    amp_dtype: torch.dtype = torch.float16

    # Gradient checkpointing
    use_gradient_checkpointing: bool = True

    # Regularization
    r_drop_alpha: float = 0.5
    label_smoothing: float = 0.1

    # MoE specific
    moe_aux_loss_weight: float = 0.01

    # Checkpoints
    output_dir: str = "./bleach_checkpoints"
    save_total_limit: int = 3

class LionOptimizer(torch.optim.Optimizer):
    """
    Lion optimizer (Evolved Sign Momentum)
    Reference: https://arxiv.org/abs/2302.06675
    """

    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('Lion does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                exp_avg = state['exp_avg'] # This is m_{t-1}
                beta1, beta2 = group['betas']
                step_size = group['lr']

                # Weight decay
                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay']) # grad is g_t

                # Calculate m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
                m_t_for_sign = exp_avg.clone().mul_(beta1).add_(grad, alpha=1 - beta1)

                # Get update direction u_t = sign(m_t)
                update_direction = torch.sign(m_t_for_sign)

                # Update the exponential moving average (momentum) for the next step
                # m_{t-1} (new) = beta2 * m_{t-1} (old) + (1 - beta2) * g_t
                exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

                # Apply parameter update p_t = p_{t-1} - eta_t * u_t
                p.add_(update_direction, alpha=-step_size)

        return loss

class RDropLoss:
    """R-Drop regularization"""

    def __init__(self, alpha=0.5):
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
        self.kl_loss = nn.KLDivLoss(reduction='none')

    def compute(self, logits1, logits2, labels, attention_mask=None):
        """
        logits1, logits2: [batch_size, seq_len, vocab_size]
        labels: [batch_size, seq_len]
        attention_mask: [batch_size, seq_len]
        """
        batch_size, seq_len, vocab_size = logits1.shape

        # Shift for language modeling
        shift_logits1 = logits1[..., :-1, :].contiguous()
        shift_logits2 = logits2[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # Flatten
        shift_logits1 = shift_logits1.view(-1, vocab_size)
        shift_logits2 = shift_logits2.view(-1, vocab_size)
        shift_labels = shift_labels.view(-1)

        # Cross entropy losses
        ce_loss1 = self.ce_loss(shift_logits1, shift_labels)
        ce_loss2 = self.ce_loss(shift_logits2, shift_labels)
        ce_loss = (ce_loss1 + ce_loss2) / 2

        # KL divergence between two distributions
        log_prob1 = F.log_softmax(shift_logits1, dim=-1)
        log_prob2 = F.log_softmax(shift_logits2, dim=-1)
        prob1 = F.softmax(shift_logits1, dim=-1)
        prob2 = F.softmax(shift_logits2, dim=-1)

        kl_loss1 = self.kl_loss(log_prob1, prob2).sum(-1)
        kl_loss2 = self.kl_loss(log_prob2, prob1).sum(-1)
        kl_loss = (kl_loss1 + kl_loss2) / 2

        # Apply attention mask if provided
        if attention_mask is not None:
            shift_mask = attention_mask[..., 1:].contiguous().view(-1)
            ce_loss = ce_loss * shift_mask
            kl_loss = kl_loss * shift_mask
            valid_tokens = shift_mask.sum()
            ce_loss = ce_loss.sum() / (valid_tokens + 1e-8)
            kl_loss = kl_loss.sum() / (valid_tokens + 1e-8)
        else:
            ce_loss = ce_loss.mean()
            kl_loss = kl_loss.mean()

        # Combined loss
        total_loss = ce_loss + self.alpha * kl_loss

        return total_loss, ce_loss, kl_loss

class Trainer:
    """Memory-efficient trainer for BLEACH"""

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: TrainingConfig,
        device: torch.device
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device

        # Move model to device
        self.model.to(device)

        # Enable gradient checkpointing
        if config.use_gradient_checkpointing:
            self._enable_gradient_checkpointing()

        # Optimizer
        self.optimizer = LionOptimizer(
            self.model.parameters(),
            lr=config.learning_rate,
            betas=config.betas,
            weight_decay=config.weight_decay
        )

        # Scheduler
        self.scheduler = self._create_scheduler()

        # Mixed precision
        self.scaler = GradScaler(enabled=config.use_amp)

        # R-Drop
        self.r_drop = RDropLoss(alpha=config.r_drop_alpha)

        # Statistics
        self.global_step = 0
        self.best_val_loss = float('inf')

        # Create output directory
        os.makedirs(config.output_dir, exist_ok=True)

        print("Trainer initialized with:")
        print(f"  Device: {device}")
        print(f"  Gradient accumulation steps: {config.gradient_accumulation_steps}")
        print(f"  Mixed precision: {config.use_amp}")
        print(f"  Gradient checkpointing: {config.use_gradient_checkpointing}")

    def _enable_gradient_checkpointing(self):
        """Enable gradient checkpointing for MoE layers"""
        for module in self.model.modules():
            if hasattr(module, 'gradient_checkpointing_enable'):
                module.gradient_checkpointing_enable()
        print("Gradient checkpointing enabled")

    def _create_scheduler(self):
        """Create learning rate scheduler"""
        if self.config.lr_scheduler_type == "cosine":
            return torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=self.config.num_training_steps,
                eta_min=self.config.learning_rate * 0.1
            )
        elif self.config.lr_scheduler_type == "linear":
            return torch.optim.lr_scheduler.LinearLR(
                self.optimizer,
                start_factor=1.0,
                end_factor=0.1,
                total_iters=self.config.num_training_steps
            )
        else:
            return torch.optim.lr_scheduler.LambdaLR(
                self.optimizer,
                lambda step: min(
                    (step + 1) / (self.config.num_warmup_steps + 1),
                    1.0
                )
            )

    def train_epoch(self, epoch: int):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        total_lm_loss = 0.0
        total_aux_loss = 0.0

        progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch}")

        for step, batch in enumerate(progress_bar):
            # Move batch to device
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Forward pass with mixed precision
            with autocast(enabled=self.config.use_amp, dtype=self.config.amp_dtype):
                # First forward pass
                outputs1 = self.model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"]
                )

                # R-Drop: second forward pass with different dropout
                if self.config.r_drop_alpha > 0:
                    self.model.train()  # Ensure dropout is active
                    outputs2 = self.model(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        labels=batch["labels"]
                    )

                    # Compute R-Drop loss
                    rdrop_loss, lm_loss, kl_loss = self.r_drop.compute(
                        outputs1["logits"],
                        outputs2["logits"],
                        batch["labels"],
                        batch["attention_mask"]
                    )

                    # Add auxiliary losses
                    total_aux = outputs1["aux_loss"] + outputs2["aux_loss"]
                    loss = rdrop_loss + total_aux * self.config.moe_aux_loss_weight
                else:
                    # Standard loss
                    lm_loss = outputs1["loss"] - outputs1["aux_loss"]
                    loss = outputs1["loss"]
                    total_aux = outputs1["aux_loss"]

            # Scale loss for gradient accumulation
            loss = loss / self.config.gradient_accumulation_steps

            # Backward pass
            self.scaler.scale(loss).backward()

            # Update statistics
            total_loss += loss.item() * self.config.gradient_accumulation_steps
            total_lm_loss += lm_loss.item() if isinstance(lm_loss, torch.Tensor) else 0
            total_aux_loss += total_aux.item() if isinstance(total_aux, torch.Tensor) else 0

            # Gradient accumulation
            if (step + 1) % self.config.gradient_accumulation_steps == 0:
                # Unscale gradients
                self.scaler.unscale_(self.optimizer)

                # Clip gradients
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config.max_grad_norm
                )

                # Optimizer step
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()

                # Scheduler step
                if self.scheduler is not None:
                    self.scheduler.step()

                self.global_step += 1

                # Logging
                if self.global_step % self.config.logging_steps == 0:
                    avg_loss = total_loss / (step + 1)
                    avg_lm_loss = total_lm_loss / (step + 1)
                    avg_aux_loss = total_aux_loss / (step + 1)

                    progress_bar.set_postfix({
                        "loss": f"{avg_loss:.4f}",
                        "lm_loss": f"{avg_lm_loss:.4f}",
                        "aux_loss": f"{avg_aux_loss:.4f}",
                        "lr": f"{self.optimizer.param_groups[0]['lr']:.2e}"
                    })

                # Evaluation
                if self.global_step % self.config.eval_steps == 0:
                    val_loss = self.evaluate()
                    print(f"Step {self.global_step}: val_loss = {val_loss:.4f}")

                    # Save best model
                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                        self.save_checkpoint("best")

                # Save checkpoint
                if self.global_step % self.config.save_steps == 0:
                    self.save_checkpoint(f"step_{self.global_step}")

            # Early stopping if max steps reached
            if (self.config.max_steps > 0 and
                self.global_step >= self.config.max_steps):
                break

        return total_loss / len(self.train_loader)

    @torch.no_grad()
    def evaluate(self):
        """Evaluate on validation set"""
        self.model.eval()
        total_loss = 0.0
        total_tokens = 0

        for batch in tqdm(self.val_loader, desc="Evaluating"):
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            with autocast(enabled=True, dtype=torch.float16):
                outputs = self.model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"]
                )

                if outputs["loss"] is not None:
                    loss = outputs["loss"]

                    # Count non-padding tokens
                    labels = batch["labels"]
                    non_padding = (labels != -100).sum().item()

                    total_loss += loss.item() * non_padding
                    total_tokens += non_padding

        self.model.train()
        return total_loss / (total_tokens + 1e-8)

    def save_checkpoint(self, name: str):
        """Save model checkpoint"""
        checkpoint_path = os.path.join(self.config.output_dir, f"{name}.pt")

        checkpoint = {
            "global_step": self.global_step,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict() if self.scheduler else None,
            "scaler_state_dict": self.scaler.state_dict(),
            "best_val_loss": self.best_val_loss,
            "config": self.config
        }

        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Limit total checkpoints
        self._cleanup_checkpoints()

    def _cleanup_checkpoints(self):
        """Keep only latest N checkpoints"""
        checkpoints = []
        for f in os.listdir(self.config.output_dir):
            if f.endswith(".pt"):
                checkpoints.append(f)

        if len(checkpoints) > self.config.save_total_limit:
            checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.config.output_dir, x)))
            for f in checkpoints[:-self.config.save_total_limit]:
                os.remove(os.path.join(self.config.output_dir, f))

    def train(self):
        """Main training loop"""
        print("Starting training...")

        for epoch in range(self.config.num_epochs):
            train_loss = self.train_epoch(epoch)
            print(f"Epoch {epoch} - Train loss: {train_loss:.4f}")

            # Evaluate
            val_loss = self.evaluate()
            print(f"Epoch {epoch} - Val loss: {val_loss:.4f}")

            # Save checkpoint
            self.save_checkpoint(f"epoch_{epoch}")

            # Early stopping
            if self.config.max_steps > 0 and self.global_step >= self.config.max_steps:
                print(f"Reached max steps {self.config.max_steps}")
                break

        print("Training completed!")
        print(f"Best validation loss: {self.best_val_loss:.4f}")

# ====================================================
# PORTION 4: EVALUATION, ROUTING ANALYSIS & VISUALIZATION
# ====================================================

class Evaluator:
    """Comprehensive evaluator for BLEACH"""

    def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer, device: torch.device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    def compute_perplexity(self, dataloader: DataLoader) -> float:
        """Compute perplexity on dataset"""
        total_loss = 0.0
        total_tokens = 0

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Computing perplexity"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

                with autocast(enabled=True, dtype=torch.float16):
                    outputs = self.model(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        labels=batch["labels"]
                    )

                if outputs["loss"] is not None:
                    loss = outputs["loss"]
                    labels = batch["labels"]
                    non_padding = (labels != -100).sum().item()

                    total_loss += loss.item() * non_padding
                    total_tokens += non_padding

        avg_loss = total_loss / total_tokens
        perplexity = math.exp(avg_loss)

        return perplexity, avg_loss

    def compute_bleu(self, dataloader: DataLoader, max_samples: int = 100) -> float:
        """Compute BLEU score for generated text"""
        from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

        smoothie = SmoothingFunction().method4
        total_bleu = 0.0
        count = 0

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Computing BLEU"):
                if count >= max_samples:
                    break

                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)

                # Generate text
                generated = self.generate_text(
                    input_ids[:, :10],  # Use first 10 tokens as prompt
                    max_length=50,
                    temperature=0.7
                )

                # Decode
                references = [self.tokenizer.decode(ids, skip_special_tokens=True)
                            for ids in input_ids]
                candidates = [self.tokenizer.decode(ids, skip_special_tokens=True)
                            for ids in generated]

                # Compute BLEU for each sample
                for ref, cand in zip(references, candidates):
                    ref_tokens = ref.split()
                    cand_tokens = cand.split()

                    if len(ref_tokens) > 0 and len(cand_tokens) > 0:
                        bleu = sentence_bleu(
                            [ref_tokens],
                            cand_tokens,
                            smoothing_function=smoothie
                        )
                        total_bleu += bleu
                        count += 1

        return total_bleu / count if count > 0 else 0.0

    def generate_text(self, input_ids: torch.Tensor, max_length: int = 50,
                     temperature: float = 1.0, top_p: float = 0.9) -> torch.Tensor:
        """Generate text using nucleus sampling"""
        generated = input_ids.clone()

        for _ in range(max_length):
            with torch.no_grad():
                outputs = self.model(
                    input_ids=generated,
                    attention_mask=torch.ones_like(generated)
                )

            logits = outputs["logits"][:, -1, :] / temperature

            # Nucleus sampling
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[:, indices_to_remove] = float('-inf')

            # Sample
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            generated = torch.cat([generated, next_token], dim=1)

            # Stop if EOS token
            if (next_token == self.tokenizer.eos_token_id).any():
                break

        return generated

    def analyze_routing(self, dataloader: DataLoader) -> Dict:
        """Analyze routing patterns and expert utilization"""
        routing_stats = {
            "expert_counts": torch.zeros(self.model.config.num_experts),
            "dialect_expert_matrix": torch.zeros(
                len(self.model.config.dialects),
                self.model.config.num_experts
            ),
            "layer_routing": defaultdict(list),
            "gate_entropy": []
        }

        # Hook to capture routing decisions
        def routing_hook(module, input, output):
            if hasattr(module, 'routing_decisions') and module.routing_decisions:
                routing_stats["layer_routing"][module.layer_idx].extend(
                    module.routing_decisions[-1:]  # Take latest
                )

        # Register hooks on MoE layers
        hooks = []
        for name, module in self.model.named_modules():
            if isinstance(module, SparseMoELayer):
                hook = module.register_forward_hook(routing_hook)
                hooks.append(hook)

        # Collect statistics
        dialect_map = self.model.config.dialect_to_id

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Analyzing routing"):
                batch_size = len(batch["dialect_ids"])

                # Forward pass
                _ = self.model(
                    input_ids=batch["input_ids"].to(self.device),
                    attention_mask=batch["attention_mask"].to(self.device)
                )

                # Process captured routing
                for layer_idx, decisions in routing_stats["layer_routing"].items():
                    for decision in decisions:
                        indices = decision["indices"]
                        gates = decision["gates"]

                        # Expert counts
                        for expert_idx in range(self.model.config.num_experts):
                            count = (indices == expert_idx).sum().item()
                            routing_stats["expert_counts"][expert_idx] += count

                        # Gate entropy (diversity measure)
                        entropy = -torch.sum(gates * torch.log(gates + 1e-8), dim=-1).mean()
                        routing_stats["gate_entropy"].append(entropy.item())

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Normalize counts
        routing_stats["expert_counts"] = routing_stats["expert_counts"] / routing_stats["expert_counts"].sum()
        routing_stats["gate_entropy"] = np.mean(routing_stats["gate_entropy"]) if routing_stats["gate_entropy"] else 0

        return routing_stats

    def plot_routing_heatmap(self, routing_stats: Dict, save_path: str = None):
        """Plot expert utilization heatmap"""
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))

        # Expert utilization bar chart
        expert_counts = routing_stats["expert_counts"].cpu().numpy()
        axes[0].bar(range(len(expert_counts)), expert_counts)
        axes[0].set_xlabel("Expert ID")
        axes[0].set_ylabel("Utilization Fraction")
        axes[0].set_title("Expert Utilization")
        axes[0].axhline(y=1/len(expert_counts), color='r', linestyle='--',
                       label=f"Target ({1/len(expert_counts):.2f})")
        axes[0].legend()

        # Dialect-expert matrix (if available)
        if "dialect_expert_matrix" in routing_stats:
            matrix = routing_stats["dialect_expert_matrix"].cpu().numpy()
            im = axes[1].imshow(matrix, cmap='YlOrRd', aspect='auto')
            axes[1].set_xlabel("Expert ID")
            axes[1].set_ylabel("Dialect ID")
            axes[1].set_title("Dialect-Expert Alignment")
            plt.colorbar(im, ax=axes[1])

            # Add text annotations
            for i in range(matrix.shape[0]):
                for j in range(matrix.shape[1]):
                    axes[1].text(j, i, f"{matrix[i, j]:.2f}",
                               ha="center", va="center", color="black")

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.show()

    def plot_training_curves(self, log_file: str, save_path: str = None):
        """Plot training curves from log file"""
        if not os.path.exists(log_file):
            print(f"Log file {log_file} not found")
            return

        # Parse log file (simplified)
        steps, losses, lrs = [], [], []

        with open(log_file, 'r') as f:
            for line in f:
                if "loss" in line and "lm_loss" in line:
                    # Extract values
                    import re
                    step_match = re.search(r"Step (\d+)", line)
                    loss_match = re.search(r"loss=([\d\.]+)", line)
                    lr_match = re.search(r"lr=([\d\.e-]+)", line)

                    if step_match and loss_match:
                        steps.append(int(step_match.group(1)))
                        losses.append(float(loss_match.group(1)))
                        if lr_match:
                            lrs.append(float(lr_match.group(1)))

        if not steps:
            print("No training data found in log file")
            return

        fig, axes = plt.subplots(1, 2, figsize=(12, 4))

        # Loss curve
        axes[0].plot(steps, losses, 'b-', linewidth=2)
        axes[0].set_xlabel("Training Steps")
        axes[0].set_ylabel("Loss")
        axes[0].set_title("Training Loss")
        axes[0].grid(True, alpha=0.3)

        # Learning rate curve
        if lrs:
            axes[1].plot(steps[:len(lrs)], lrs, 'r-', linewidth=2)
            axes[1].set_xlabel("Training Steps")
            axes[1].set_ylabel("Learning Rate")
            axes[1].set_title("Learning Rate Schedule")
            axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.show()

# ====================================================
# PORTION 5: ABLATIONS, SCALING & EXPERIMENTS
# ====================================================

@dataclass
class ExperimentConfig:
    """Configuration for ablation experiments"""

    # Base model config
    base_config: ModelConfig = None

    # Ablation variants
    use_moe: bool = True
    num_experts: int = 5
    top_k: int = 1  # 1 or 2
    use_lora: bool = True
    use_r_drop: bool = True
    expert_dropout: float = 0.1

    # Scaling variants
    hidden_dim: int = 768  # 512, 768, 1024
    num_layers: int = 8  # 6, 8, 10
    num_moe_layers: int = 3  # 2, 3, 4

    # Training
    learning_rate: float = 3e-4
    batch_size: int = 8
    num_steps: int = 5000

    # Output
    experiment_name: str = "baseline"
    save_dir: str = "./experiments"

class ExperimentRunner:
    """Runner for ablation experiments"""

    def __init__(self, data_config: DataConfig):
        self.data_config = data_config
        self.results = []
        self.experiment_dir = "./experiments"
        os.makedirs(self.experiment_dir, exist_ok=True)

    def run_experiment(self, exp_config: ExperimentConfig):
        """Run a single experiment"""
        print(f"\n{'='*60}")
        print(f"Running experiment: {exp_config.experiment_name}")
        print(f"{'='*60}")

        # Create model config
        model_config = ModelConfig(
            hidden_dim=exp_config.hidden_dim,
            num_layers=exp_config.num_layers,
            num_experts=exp_config.num_experts if exp_config.use_moe else 1,
            use_lora=exp_config.use_lora,
            top_k=exp_config.top_k
        )

        # Adjust MoE layers
        if not exp_config.use_moe:
            model_config.num_experts = 1

        # Create dataloaders
        train_loader, val_loader, test_loader, tokenizer = create_dataloaders(self.data_config)

        # Create model
        model = BLEACHModel(model_config, vocab_size=tokenizer.vocab_size)

        # Training config
        train_config = TrainingConfig(
            learning_rate=exp_config.learning_rate,
            num_training_steps=exp_config.num_steps,
            r_drop_alpha=0.5 if exp_config.use_r_drop else 0.0,
            output_dir=os.path.join(self.experiment_dir, exp_config.experiment_name)
        )

        # Train
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        trainer = Trainer(model, train_loader, val_loader, train_config, device)

        try:
            trainer.train()

            # Evaluate
            evaluator = Evaluator(model, tokenizer, device)
            perplexity, loss = evaluator.compute_perplexity(val_loader)
            routing_stats = evaluator.analyze_routing(val_loader)

            # Save results
            result = {
                "experiment": exp_config.experiment_name,
                "perplexity": perplexity,
                "loss": loss,
                "expert_utilization": routing_stats["expert_counts"].cpu().numpy().tolist(),
                "gate_entropy": routing_stats.get("gate_entropy", 0),
                "parameters": sum(p.numel() for p in model.parameters() if p.requires_grad),
                "config": exp_config.__dict__
            }

            self.results.append(result)

            # Save to file
            self.save_results()

            print(f"\nExperiment {exp_config.experiment_name} completed:")
            print(f"  Perplexity: {perplexity:.2f}")
            print(f"  Loss: {loss:.4f}")
            print(f"  Expert utilization: {result['expert_utilization']}")

            return result

        except Exception as e:
            print(f"Experiment failed: {e}")
            import traceback
            traceback.print_exc()
            return None

    def run_ablation_study(self):
        """Run comprehensive ablation study"""
        experiments = []

        # Baseline: MoE with LoRA, Top-1
        experiments.append(ExperimentConfig(
            experiment_name="baseline_moe_lora_top1",
            use_moe=True,
            use_lora=True,
            top_k=1,
            use_r_drop=True
        ))

        # Ablation: Dense model (no MoE)
        experiments.append(ExperimentConfig(
            experiment_name="dense_baseline",
            use_moe=False,
            use_lora=False,
            use_r_drop=True
        ))

        # Ablation: MoE without LoRA
        experiments.append(ExperimentConfig(
            experiment_name="moe_no_lora",
            use_moe=True,
            use_lora=False,
            top_k=1,
            use_r_drop=True
        ))

        # Ablation: Top-2 routing
        experiments.append(ExperimentConfig(
            experiment_name="moe_lora_top2",
            use_moe=True,
            use_lora=True,
            top_k=2,
            use_r_drop=True
        ))

        # Ablation: No R-Drop
        experiments.append(ExperimentConfig(
            experiment_name="moe_lora_no_rdrop",
            use_moe=True,
            use_lora=True,
            top_k=1,
            use_r_drop=False
        ))

        # Scaling: Larger hidden dimension
        experiments.append(ExperimentConfig(
            experiment_name="moe_large_hidden",
            use_moe=True,
            use_lora=True,
            hidden_dim=1024,
            num_layers=6,  # Fewer layers to keep similar param count
            use_r_drop=True
        ))

        # Run all experiments
        for exp_config in experiments:
            self.run_experiment(exp_config)

        # Analyze results
        self.analyze_results()

    def save_results(self):
        """Save experiment results to CSV"""
        if not self.results:
            return

        df_data = []
        for result in self.results:
            row = {
                "experiment": result["experiment"],
                "perplexity": result["perplexity"],
                "loss": result["loss"],
                "gate_entropy": result["gate_entropy"],
                "parameters": result["parameters"]
            }

            # Add expert utilization
            for i, util in enumerate(result["expert_utilization"]):
                row[f"expert_{i}_util"] = util

            df_data.append(row)

        df = pd.DataFrame(df_data)
        df.to_csv(os.path.join(self.experiment_dir, "results.csv"), index=False)
        print(f"\nResults saved to {os.path.join(self.experiment_dir, 'results.csv')}")

    def analyze_results(self):
        """Analyze and visualize experiment results"""
        if not self.results:
            print("No results to analyze")
            return

        # Load results from CSV
        results_file = os.path.join(self.experiment_dir, "results.csv")
        if os.path.exists(results_file):
            df = pd.read_csv(results_file)
        else:
            # Convert results to DataFrame
            df_data = []
            for result in self.results:
                row = {
                    "experiment": result["experiment"],
                    "perplexity": result["perplexity"],
                    "loss": result["loss"],
                    "gate_entropy": result["gate_entropy"],
                    "parameters": result["parameters"]
                }

                for i, util in enumerate(result["expert_utilization"]):
                    row[f"expert_{i}_util"] = util

                df_data.append(row)

            df = pd.DataFrame(df_data)

        # Plot comparison
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))

        # 1. Perplexity comparison
        axes[0, 0].bar(range(len(df)), df["perplexity"])
        axes[0, 0].set_xticks(range(len(df)))
        axes[0, 0].set_xticklabels(df["experiment"], rotation=45, ha='right')
        axes[0, 0].set_ylabel("Perplexity (lower is better)")
        axes[0, 0].set_title("Perplexity Comparison")
        axes[0, 0].grid(True, alpha=0.3, axis='y')

        # 2. Parameter efficiency
        axes[0, 1].scatter(df["parameters"] / 1e6, df["perplexity"])
        axes[0, 1].set_xlabel("Parameters (millions)")
        axes[0, 1].set_ylabel("Perplexity")
        axes[0, 1].set_title("Parameter Efficiency")
        axes[0, 1].grid(True, alpha=0.3)

        # Add labels to points
        for i, row in df.iterrows():
            axes[0, 1].annotate(row["experiment"],
                              (row["parameters"]/1e6, row["perplexity"]),
                              fontsize=8, alpha=0.7)

        # 3. Routing entropy
        axes[0, 2].bar(range(len(df)), df["gate_entropy"])
        axes[0, 2].set_xticks(range(len(df)))
        axes[0, 2].set_xticklabels(df["experiment"], rotation=45, ha='right')
        axes[0, 2].set_ylabel("Routing Entropy")
        axes[0, 2].set_title("Routing Diversity")
        axes[0, 2].grid(True, alpha=0.3, axis='y')

        # 4. Expert utilization heatmap
        expert_cols = [col for col in df.columns if col.startswith("expert_")]
        if expert_cols:
            util_matrix = df[expert_cols].values
            im = axes[1, 0].imshow(util_matrix.T, cmap='YlOrRd', aspect='auto')
            axes[1, 0].set_xlabel("Experiment")
            axes[1, 0].set_ylabel("Expert")
            axes[1, 0].set_title("Expert Utilization")
            axes[1, 0].set_xticks(range(len(df)))
            axes[1, 0].set_xticklabels(df["experiment"], rotation=45, ha='right')
            axes[1, 0].set_yticks(range(len(expert_cols)))
            axes[1, 0].set_yticklabels([f"Exp{i}" for i in range(len(expert_cols))])
            plt.colorbar(im, ax=axes[1, 0])

        # 5. MoE vs Dense comparison
        moe_mask = df["experiment"].str.contains("moe")
        dense_mask = df["experiment"].str.contains("dense")

        if moe_mask.any() and dense_mask.any():
            moe_ppl = df[moe_mask]["perplexity"].mean()
            dense_ppl = df[dense_mask]["perplexity"].mean()

            axes[1, 1].bar(["MoE", "Dense"], [moe_ppl, dense_ppl])
            axes[1, 1].set_ylabel("Perplexity")
            axes[1, 1].set_title("MoE vs Dense")
            axes[1, 1].grid(True, alpha=0.3, axis='y')

            # Add value labels
            for i, v in enumerate([moe_ppl, dense_ppl]):
                axes[1, 1].text(i, v + 0.1, f"{v:.2f}", ha='center')

        # 6. Top-1 vs Top-2 comparison
        top1_mask = df["experiment"].str.contains("top1")
        top2_mask = df["experiment"].str.contains("top2")

        if top1_mask.any() and top2_mask.any():
            top1_ppl = df[top1_mask]["perplexity"].mean()
            top2_ppl = df[top2_mask]["perplexity"].mean()

            axes[1, 2].bar(["Top-1", "Top-2"], [top1_ppl, top2_ppl])
            axes[1, 2].set_ylabel("Perplexity")
            axes[1, 2].set_title("Top-1 vs Top-2 Routing")
            axes[1, 2].grid(True, alpha=0.3, axis='y')

            for i, v in enumerate([top1_ppl, top2_ppl]):
                axes[1, 2].text(i, v + 0.1, f"{v:.2f}", ha='center')

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.show()

        # Print summary table
        print("\n" + "="*80)
        print("EXPERIMENT SUMMARY")
        print("="*80)
        print(df.to_string(index=False))
        print("\nKey Insights:")
        print("1. Lower perplexity is better")
        print("2. Higher routing entropy indicates more diverse expert usage")
        print("3. Balanced expert utilization (close to 0.2 for 5 experts) is optimal")
        print("="*80)

# ====================================================
# MAIN EXECUTION
# ====================================================

def main():
    """Main execution function"""
    print("BLEACH: Bangla Language Expert Adaptive Corpus Handler")
    print("=" * 60)

    # Check CUDA
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if device.type == "cuda":
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Configuration
    data_config = DataConfig()
    model_config = ModelConfig()
    train_config = TrainingConfig()

    # Step 1: Create dataloaders
    print("\n[Step 1] Creating dataloaders...")
    train_loader, val_loader, test_loader, tokenizer = create_dataloaders(data_config)

    # Step 2: Create model
    print("\n[Step 2] Creating BLEACH model...")
    model = BLEACHModel(model_config, vocab_size=tokenizer.vocab_size)

    # Step 3: Training
    print("\n[Step 3] Starting training...")
    trainer = Trainer(model, train_loader, val_loader, train_config, device)
    trainer.train()

    # Step 4: Evaluation
    print("\n[Step 4] Evaluating model...")
    evaluator = Evaluator(model, tokenizer, device)

    # Compute metrics
    perplexity, loss = evaluator.compute_perplexity(val_loader)
    print(f"Validation Perplexity: {perplexity:.2f}")
    print(f"Validation Loss: {loss:.4f}")

    # Analyze routing
    print("\n[Step 5] Analyzing routing patterns...")
    routing_stats = evaluator.analyze_routing(val_loader)
    evaluator.plot_routing_heatmap(routing_stats, save_path="routing_heatmap.png")

    print(f"Expert utilization: {routing_stats['expert_counts'].cpu().numpy()}")
    print(f"Routing entropy: {routing_stats.get('gate_entropy', 0):.4f}")

    # Step 5: Run ablation study (optional)
    print("\n[Step 6] Running ablation study...")
    runner = ExperimentRunner(data_config)
    runner.run_ablation_study()

    print("\n" + "="*60)
    print("EXPERIMENT COMPLETED SUCCESSFULLY")
    print("="*60)
    print("Next steps:")
    print("1. Check routing_heatmap.png for expert utilization")
    print("2. Check experiments/results.csv for ablation results")
    print("3. Check bleach_checkpoints/ for saved models")
    print("="*60)

if __name__ == "__main__":
    main()

BLEACH: Bangla Language Expert Adaptive Corpus Handler
Using device: cuda
GPU: Tesla T4
Memory: 15.83 GB

[Step 1] Creating dataloaders...
Loaded 17630 samples from /content/cleaned_bangla_train (1).csv
Dialect distribution: {'Noakhali': 1845, 'Chittagong': 7550, 'Barisal': 3037, 'Sylhet': 2844, 'Mymensingh': 2354}
Loaded 3779 samples from /content/cleaned_bangla_val (1).csv
Dialect distribution: {'Sylhet': 609, 'Chittagong': 1618, 'Noakhali': 396, 'Barisal': 651, 'Mymensingh': 505}
Loaded 3779 samples from /content/cleaned_bangla_test (2).csv
Dialect distribution: {'Noakhali': 396, 'Chittagong': 1618, 'Barisal': 651, 'Sylhet': 609, 'Mymensingh': 505}

[Step 2] Creating BLEACH model...
MODEL PARAMETER COUNT
Total parameters: 186,825,984
Dense parameters: 79,275,264
MoE parameters: 107,550,720
Experts: 5
MoE layers: 3
Model size: 0.75 GB (FP32)
Model size: 0.37 GB (FP16)

[Step 3] Starting training...
Gradient checkpointing enabled
Trainer initialized with:
  Device: cuda
  Gradient acc

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch 0:   0%|          | 0/18875 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 500: val_loss = 6.7923
Checkpoint saved to ./bleach_checkpoints/best.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 1000: val_loss = 5.6824
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_1000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 1500: val_loss = 5.5230
Checkpoint saved to ./bleach_checkpoints/best.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 2000: val_loss = 5.7210
Checkpoint saved to ./bleach_checkpoints/step_2000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 2500: val_loss = 5.8575


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 3000: val_loss = 5.6268
Checkpoint saved to ./bleach_checkpoints/step_3000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 3500: val_loss = 5.3226
Checkpoint saved to ./bleach_checkpoints/best.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 4000: val_loss = 5.1142
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_4000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 4500: val_loss = 4.9949
Checkpoint saved to ./bleach_checkpoints/best.pt
Epoch 0 - Train loss: 4.3078


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch 0 - Val loss: 4.9629
Checkpoint saved to ./bleach_checkpoints/epoch_0.pt


Epoch 1:   0%|          | 0/18875 [00:10<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenize

Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 5000: val_loss = 4.9678
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_5000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 5500: val_loss = 4.8824
Checkpoint saved to ./bleach_checkpoints/best.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 6000: val_loss = 4.8297
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_6000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 6500: val_loss = 4.7714
Checkpoint saved to ./bleach_checkpoints/best.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 7000: val_loss = 4.7157
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_7000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 7500: val_loss = 4.6981
Checkpoint saved to ./bleach_checkpoints/best.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 8000: val_loss = 4.6444
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_8000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 8500: val_loss = 4.6698


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 9000: val_loss = 4.6381
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_9000.pt
Epoch 1 - Train loss: 3.0566


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch 1 - Val loss: 4.6538
Checkpoint saved to ./bleach_checkpoints/epoch_1.pt


Epoch 2:   0%|          | 0/18875 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenize

Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 9500: val_loss = 4.6608


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 10000: val_loss = 4.6505
Checkpoint saved to ./bleach_checkpoints/step_10000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 10500: val_loss = 4.6411


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 11000: val_loss = 4.6158
Checkpoint saved to ./bleach_checkpoints/best.pt
Checkpoint saved to ./bleach_checkpoints/step_11000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 11500: val_loss = 4.6050
Checkpoint saved to ./bleach_checkpoints/best.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 12000: val_loss = 4.6172
Checkpoint saved to ./bleach_checkpoints/step_12000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 12500: val_loss = 4.6521


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 13000: val_loss = 4.6833
Checkpoint saved to ./bleach_checkpoints/step_13000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 13500: val_loss = 4.6742


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 14000: val_loss = 4.6606
Checkpoint saved to ./bleach_checkpoints/step_14000.pt
Epoch 2 - Train loss: 2.8917


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch 2 - Val loss: 4.6538
Checkpoint saved to ./bleach_checkpoints/epoch_2.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch 3:   0%|          | 0/18875 [00:10<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 14500: val_loss = 4.6527


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 15000: val_loss = 4.7006
Checkpoint saved to ./bleach_checkpoints/step_15000.pt


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 15500: val_loss = 4.7573


You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Evaluating:   0%|          | 0/1890 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step 16000: val_loss = 4.7648
Checkpoint saved to ./bleach_checkpoints/step_16000.pt
