In [None]:
"""
Encoder-only Transformer (BERT-style) — single-file project
Features added:
- Transformer Encoder implementation (PyTorch)
- TokenEmbedding + SegmentEmbedding + PositionEmbedding
- Masked Language Modeling (MLM) head for pre-training
- Classification head (using [CLS])
- Question-Answering heads (start & end logits)
- Utilities to freeze encoder (feature extraction), domain-adaptation, and simple distillation stub
- Training loop placeholders and example usage

Notes for the user:
- Replace the dataset placeholders with your dataset loader/tokenizer.
- This file is intentionally framework-agnostic for tokenization (assumes a tokenizer that maps text -> input_ids, attention_mask, token_type_ids).
- If you prefer TensorFlow, tell me and I'll convert it.

"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

# -----------------------------
# Basic Modules
# -----------------------------
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, hidden_size)

    def forward(self, input_ids):
        return self.emb(input_ids)

class PositionEmbedding(nn.Module):
    def __init__(self, max_len: int, hidden_size: int):
        super().__init__()
        self.emb = nn.Embedding(max_len, hidden_size)

    def forward(self, seq_len):
        # seq_len may be an int or tensor shape
        if isinstance(seq_len, int):
            pos_ids = torch.arange(seq_len, device=self.emb.weight.device).unsqueeze(0)
        else:
            pos_ids = torch.arange(seq_len.size(1), device=self.emb.weight.device).unsqueeze(0)
        return self.emb(pos_ids)

class SegmentEmbedding(nn.Module):
    def __init__(self, type_vocab_size: int, hidden_size: int):
        super().__init__()
        self.emb = nn.Embedding(type_vocab_size, hidden_size)

    def forward(self, token_type_ids):
        return self.emb(token_type_ids)

# -----------------------------
# Transformer Encoder Block
# -----------------------------
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        assert hidden_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.qkv = nn.Linear(hidden_size, 3 * hidden_size)
        self.out = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        B, T, H = x.size()
        qkv = self.qkv(x)  # (B, T, 3H)
        qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        # transpose for attention: (B, heads, T, head_dim)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)
        # scaled dot-product
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            # mask: (B, 1, 1, T) or broadcastable
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = attn @ v  # (B, heads, T, head_dim)
        out = out.permute(0, 2, 1, 3).contiguous().reshape(B, T, H)
        return self.out(out)

class FeedForward(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, intermediate_size)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.fc2(self.act(self.fc1(x))))

class EncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, intermediate_size, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadSelfAttention(hidden_size, num_heads, dropout)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ff = FeedForward(hidden_size, intermediate_size, dropout)
        self.ln2 = nn.LayerNorm(hidden_size)

    def forward(self, x, mask=None):
        attn_out = self.attn(self.ln1(x), mask)
        x = x + attn_out
        ff_out = self.ff(self.ln2(x))
        x = x + ff_out
        return x

class Encoder(nn.Module):
    def __init__(self, num_layers, hidden_size, num_heads, intermediate_size, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(hidden_size, num_heads, intermediate_size, dropout)
            for _ in range(num_layers)
        ])
        self.ln_final = nn.LayerNorm(hidden_size)

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.ln_final(x)

# -----------------------------
# BERT-style Model
# -----------------------------
class BertStyleModel(nn.Module):
    def __init__(self,
                 vocab_size: int = 30522,
                 hidden_size: int = 768,
                 max_position_embeddings: int = 512,
                 type_vocab_size: int = 2,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 intermediate_size: int = 3072,
                 dropout: float = 0.1):
        super().__init__()
        self.token_emb = TokenEmbedding(vocab_size, hidden_size)
        self.pos_emb = PositionEmbedding(max_position_embeddings, hidden_size)
        self.segment_emb = SegmentEmbedding(type_vocab_size, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.encoder = Encoder(num_layers, hidden_size, num_heads, intermediate_size, dropout)
        # Heads
        self.mlm_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, vocab_size)
        )
        self.classifier = nn.Linear(hidden_size, 2)  # default binary
        self.qa_start = nn.Linear(hidden_size, 1)
        self.qa_end = nn.Linear(hidden_size, 1)

    def forward_embeddings(self, input_ids, token_type_ids=None):
        # input_ids: (B, T)
        B, T = input_ids.size()
        tok = self.token_emb(input_ids)
        pos = self.pos_emb(T)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        seg = self.segment_emb(token_type_ids)
        embeddings = tok + pos + seg
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        embeddings = self.forward_embeddings(input_ids, token_type_ids)
        # attention_mask: (B, T) where 1=keep 0=pad
        if attention_mask is not None:
            # convert to shape (B, 1, 1, T) for broadcasting
            mask = attention_mask.unsqueeze(1).unsqueeze(2)
        else:
            mask = None
        encoded = self.encoder(embeddings, mask)
        return encoded

    # ---------- Task-specific helpers ----------
    def mlm(self, input_ids, mask_positions, token_type_ids=None):
        """
        input_ids: (B, T)
        mask_positions: bool tensor (B, T) indicating where [MASK] tokens are
        returns logits for vocab at masked positions
        """
        encoded = self.forward(input_ids, attention_mask=(input_ids!=0).long(), token_type_ids=token_type_ids)
        logits = self.mlm_head(encoded)  # (B, T, V)
        # Optionally gather masked positions outside
        return logits

    def classify(self, input_ids, attention_mask=None, token_type_ids=None, num_labels=2):
        encoded = self.forward(input_ids, attention_mask, token_type_ids)
        cls_embed = encoded[:, 0, :]  # assume [CLS] token at position 0
        logits = self.classifier(cls_embed)
        if num_labels == 1:
            return logits.squeeze(-1)
        return logits

    def qa(self, input_ids, attention_mask=None, token_type_ids=None):
        encoded = self.forward(input_ids, attention_mask, token_type_ids)
        start_logits = self.qa_start(encoded).squeeze(-1)  # (B, T)
        end_logits = self.qa_end(encoded).squeeze(-1)
        return start_logits, end_logits

    # Utility: freeze encoder (feature extraction mode)
    def freeze_encoder(self):
        for p in self.encoder.parameters():
            p.requires_grad = False

    def unfreeze_encoder(self):
        for p in self.encoder.parameters():
            p.requires_grad = True

# -----------------------------
# Training Utilities (skeleton)
# -----------------------------

def compute_mlm_loss(logits, target_ids, mask_positions):
    # logits: (B, T, V), target_ids: (B, T), mask_positions: (B, T) bool
    vocab_size = logits.size(-1)
    logits_flat = logits.view(-1, vocab_size)
    targets_flat = target_ids.view(-1)
    mask_flat = mask_positions.view(-1)
    if mask_flat.sum() == 0:
        return torch.tensor(0.0, device=logits.device)
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    losses = loss_fct(logits_flat, targets_flat)
    masked_losses = losses * mask_flat.float()
    return masked_losses.sum() / mask_flat.sum().float()


def compute_classification_loss(logits, labels):
    if logits.size(-1) == 1:
        # regression / single-label
        loss = F.mse_loss(logits.squeeze(-1), labels.float())
    else:
        loss = F.cross_entropy(logits, labels.long())
    return loss


def compute_qa_loss(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    return (start_loss + end_loss) / 2

# Simple training step skeleton (single-batch)
def train_step_mlm(model: BertStyleModel, batch, optimizer):
    model.train()
    input_ids = batch['input_ids']
    token_type_ids = batch.get('token_type_ids')
    attention_mask = batch.get('attention_mask')
    mask_positions = batch['mask_positions']  # bool
    mlm_labels = batch['mlm_labels']

    logits = model.mlm(input_ids, mask_positions, token_type_ids)
    loss = compute_mlm_loss(logits, mlm_labels, mask_positions)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

# Fine-tuning step for classification
def train_step_classification(model: BertStyleModel, batch, optimizer):
    model.train()
    input_ids = batch['input_ids']
    token_type_ids = batch.get('token_type_ids')
    attention_mask = batch.get('attention_mask')
    labels = batch['labels']

    logits = model.classify(input_ids, attention_mask, token_type_ids)
    loss = compute_classification_loss(logits, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

# Fine-tuning step for QA
def train_step_qa(model: BertStyleModel, batch, optimizer):
    model.train()
    input_ids = batch['input_ids']
    token_type_ids = batch.get('token_type_ids')
    attention_mask = batch.get('attention_mask')
    start_positions = batch['start_positions']
    end_positions = batch['end_positions']

    start_logits, end_logits = model.qa(input_ids, attention_mask, token_type_ids)
    loss = compute_qa_loss(start_logits, end_logits, start_positions, end_positions)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

# -----------------------------
# Distillation stub (teacher -> student)
# -----------------------------
def distill_step(teacher: BertStyleModel, student: BertStyleModel, batch, optimizer, alpha=0.5, temperature=2.0):
    """
    Simple distillation recipe: match logits of masked LM or [CLS] distribution.
    This is a stub — adapt per your target (MLM distillation or task distillation).
    """
    teacher.eval()
    student.train()
    input_ids = batch['input_ids']
    token_type_ids = batch.get('token_type_ids')
    attention_mask = batch.get('attention_mask')
    mask_positions = batch['mask_positions']
    mlm_labels = batch['mlm_labels']

    with torch.no_grad():
        t_logits = teacher.mlm(input_ids, mask_positions, token_type_ids)
    s_logits = student.mlm(input_ids, mask_positions, token_type_ids)
    # focus only on masked tokens
    vocab = s_logits.size(-1)
    s_flat = s_logits.view(-1, vocab)
    t_flat = t_logits.view(-1, vocab)
    mask_flat = mask_positions.view(-1)
    if mask_flat.sum() == 0:
        return None
    # soft targets loss
    t_soft = F.log_softmax(t_flat / temperature, dim=-1)
    s_soft = F.log_softmax(s_flat / temperature, dim=-1)
    kd_loss = F.kl_div(s_soft, t_soft.exp(), reduction='none').sum(dim=1)
    kd_loss = (kd_loss * mask_flat.float()).sum() / mask_flat.sum().float()
    # hard mlm loss for student
    mlm_loss = compute_mlm_loss(s_logits, mlm_labels, mask_positions)
    loss = alpha * kd_loss + (1 - alpha) * mlm_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

# -----------------------------
# Example usage (pseudocode)
# -----------------------------
if __name__ == '__main__':
    # Hyperparameters (tune for your environment)
    VOCAB_SIZE = 30522
    model = BertStyleModel(vocab_size=VOCAB_SIZE,
                           hidden_size=256,     # smaller for quick runs
                           max_position_embeddings=512,
                           num_layers=4,
                           num_heads=8,
                           intermediate_size=1024)

    # Example: prepare a batch (replace with real tokenizer/dataloader)
    # Assume tokenizer returns input_ids, token_type_ids, attention_mask
    # And for MLM: mask_positions boolean mask and mlm_labels with target ids at masked positions
    dummy_batch = {
        'input_ids': torch.randint(0, VOCAB_SIZE, (2, 64)),
        'token_type_ids': torch.zeros((2, 64), dtype=torch.long),
        'attention_mask': torch.ones((2, 64), dtype=torch.long),
        'mask_positions': torch.zeros((2, 64), dtype=torch.bool),
        'mlm_labels': torch.zeros((2, 64), dtype=torch.long),
        'labels': torch.tensor([0, 1]),
        'start_positions': torch.tensor([10, 5]),
        'end_positions': torch.tensor([12, 7])
    }

    optim = torch.optim.AdamW(model.parameters(), lr=5e-5)
    # Train a single step (demo)
    loss = train_step_classification(model, dummy_batch, optim)
    print('demo loss (classification):', loss)

    # Freeze encoder for feature-extraction
    model.freeze_encoder()
    # ...train only heads

    # To resume full fine-tuning
    model.unfreeze_encoder()

    print('Model ready. Replace dummy_batch with your dataloader and tokenizer.')

# End of file
