# 6.4610 Research Project

## Overview
In this file, we implement a transformer model trained on OpenWebText

## Imports

In [None]:
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional, List, Dict
import math
import numpy as np
from dataclasses import dataclass
from transformers import AutoTokenizer
import datasets
from tqdm import tqdm
import os
import json
from itertools import islice

## Preliminaries

In [None]:
TRAIN_DATASETE_SIZE = 200000
TEST_DATASET_SIZE = 10000
TAYLOR_APPROXIMATION = 5
PH_ALPHA = 0.5
PH_SEMANTIC_HEADS = 3

@dataclass
class TransformerConfig:
    """Configuration class for transformer model"""
    vocab_size: int = 50257
    hidden_size: int = 768
    num_attention_heads: int = 12
    num_hidden_layers: int = 12
    intermediate_size: int = 3072
    max_position_embeddings: int = 512
    use_causal_mask: bool = True
    number_diffusion_kernels: int = 4

@dataclass
class TrainingConfig:
    # Model hyperparameters
    vocab_size: int = 50257
    hidden_size: int = 768
    num_attention_heads: int = 12
    num_hidden_layers: int = 12
    intermediate_size: int = 3072
    max_position_embeddings: int = 512
    use_causal_mask: bool = True

    # Training hyperparameters
    batch_size: int = 4
    learning_rate: float = 5e-4
    weight_decay: float = 0.01
    num_epochs: int = 3
    steps_per_epoch: int = 200000
    warmup_steps: int = 1000
    max_grad_norm: float = 1.0
    save_steps: int = 10000
    eval_steps: int = 5000
    train_dataset_size: int = 200000
    test_dataset_size: int = 10000

    # Paths
    output_dir: str = "/transformer"
    log_dir: str = "/logs_transformer"

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        """
        Position-wise feed-forward network

        Args:
            hidden_size: Model dimension
            intermediate_size: Hidden dimension of FFN
            activation_fn: Activation function ('relu', 'gelu', etc.)
        """
        super().__init__()

        self.activation = nn.GELU()
        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

def count_parameters(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Original Model

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        """
        RMS Normalization

        Args:
            hidden_size: The size of the hidden dimension
            eps: Small constant for numerical stability
        """
        super().__init__()
        self.parameter = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Apply RMS normalization

        Args:
            hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size)

        Returns:
            Normalized tensor of shape (batch_size, seq_len, hidden_size)
        """
        rms = torch.sqrt(torch.mean(torch.square(hidden_states), dim=-1, keepdim=True) + self.eps)
        normalized = hidden_states / rms
        return normalized * self.parameter


In [None]:
class AttentionHead(nn.Module):
    def __init__(self, hidden_size: int, head_dim: int):
        """
        Single attention head implementation

        Args:
            hidden_size: Input dimension
            head_dim: Dimension of each attention head
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.WQ = nn.Linear(hidden_size, head_dim, bias=False)
        self.WK = nn.Linear(hidden_size, head_dim, bias=False)
        self.WV = nn.Linear(hidden_size, head_dim, bias=False)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for attention head

        Args:
            x: Input tensor (batch_size, seq_len, hidden_size)
            attn_mask: Attention mask (batch_size, seq_len, seq_len) - 1 for attend, 0 for mask

        Returns:
            attention_output: (batch_size, seq_len, head_dim)
            attention_weights: (batch_size, seq_len, seq_len)
        """
        Q = self.WQ(x)
        V = self.WV(x)
        K = self.WK(x)
        score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
          score = score.masked_fill(attn_mask == 0, -torch.inf)
        attention_weights = torch.softmax(score, dim=-1)
        attention_output = torch.matmul(attention_weights, V)
        return attention_output, attention_weights

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        """
        Multi-head attention implementation

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
        """
        super().__init__()
        assert hidden_size % num_heads == 0, f"The hidden size {hidden_size} is not divisible by the number of heads {num_heads}."
        head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.heads = nn.ModuleList([AttentionHead(hidden_size, head_dim) for _ in range(num_heads)])
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for multi-head attention

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            attention_output: (batch_size, seq_len, hidden_size)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        outputs = [each(hidden_states, attention_mask) for each in self.heads]
        attention_outputs_tuple = tuple(each[0] for each in outputs)
        attention_outputs = torch.stack(attention_outputs_tuple).transpose(0, 1).transpose(1, 2).flatten(2, 3)
        attention_weights_tuple = tuple(each[1] for each in outputs)
        attention_weights = torch.stack(attention_weights_tuple).transpose(0, 1)
        attention_outputs = self.linear(attention_outputs)
        return attention_outputs, attention_weights

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int):
        """
        Complete transformer block with attention and feed-forward

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
            intermediate_size: FFN hidden dimension
        """
        super().__init__()
        self.rms_att = RMSNorm(hidden_size=hidden_size)
        self.rms_ffn = RMSNorm(hidden_size=hidden_size)
        self.mha = MultiHeadAttention(hidden_size=hidden_size, num_heads=num_heads)
        self.ffn = FeedForward(hidden_size=hidden_size, intermediate_size=intermediate_size)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer block

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask

        Returns:
            hidden_states: Output tensor (batch_size, seq_len, hidden_size)
        """
        att_norm = self.rms_att(hidden_states)
        self_att = self.mha(att_norm, attention_mask)[0]
        res_conn_self_att = self_att + hidden_states
        ffn_norm = self.rms_ffn(res_conn_self_att)
        ffn_output = self.ffn(ffn_norm)
        res_conn_ffn = res_conn_self_att + ffn_output
        return res_conn_ffn

In [None]:
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """Create a causal (lower triangular) attention mask

    Args:
        seq_len: Sequence length
        device: Device to create the mask on

    Returns:
        Causal mask of shape (1, seq_len, seq_len)
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask.unsqueeze(0)

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, config: TransformerConfig):
        """
        Complete transformer model for causal language modeling
        """
        super().__init__()
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embeddings = nn.Embedding(num_embeddings=self.config.vocab_size, embedding_dim=self.config.hidden_size)
        self.pos_embeddings = nn.Embedding(num_embeddings=self.config.max_position_embeddings, embedding_dim=self.config.hidden_size)
        self.transformer = nn.ModuleList([TransformerBlock(hidden_size=self.config.hidden_size, num_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size)
         for _ in range(self.config.num_hidden_layers)])
        self.norm = RMSNorm(hidden_size=self.config.hidden_size)

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            hidden_states: Final hidden states (batch_size, seq_len, hidden_size)
        """
        positions = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.size(0), -1)
        pos_embeds = self.pos_embeddings(positions)
        token_embeddings = self.embeddings(input_ids) + pos_embeds
        if attention_mask is None and self.config.use_causal_mask:
          attention_mask = create_causal_mask(input_ids.shape[1], token_embeddings.device)
        transf = token_embeddings
        for layer in self.transformer:
          transf = layer(transf, attention_mask=attention_mask)
        output = self.norm(transf)
        return output

In [None]:
class CausalLanguageModel(nn.Module):
    def __init__(self, config: TransformerConfig):
        """Causal language model with transformer backbone"""
        super().__init__()
        self.config = config
        self.transformer = TransformerModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass for language model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            labels: Target labels for loss computation (batch_size, seq_len)

        Returns:
            If labels provided: (loss, logits)
            Else: logits only
        """
        hidden_states = self.transformer(input_ids)
        logits = self.lm_head(hidden_states)
        if labels is not None:
            logits_flat = logits[:, :-1, :].flatten(0, 1)
            labels_flat = labels[:, 1:].flatten(0, 1)
            return self.criterion(logits_flat, labels_flat), logits
        return logits

    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0) -> torch.Tensor:
        """
        Generate text using the language model

        Args:
            input_ids: Starting token IDs (batch_size, seq_len)
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature

        Returns:
            Generated token IDs (batch_size, seq_len + max_new_tokens)
        """
        for _ in range(max_new_tokens):
            logits = self.forward(input_ids)[:, -1, :] / temperature
            probs = nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
        return input_ids


## Laplacian Model

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        """
        RMS Normalization

        Args:
            hidden_size: The size of the hidden dimension
            eps: Small constant for numerical stability
        """
        super().__init__()
        self.parameter = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Apply RMS normalization

        Args:
            hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size)

        Returns:
            Normalized tensor of shape (batch_size, seq_len, hidden_size)
        """
        rms = torch.sqrt(torch.mean(torch.square(hidden_states), dim=-1, keepdim=True) + self.eps)
        normalized = hidden_states / rms
        return normalized * self.parameter

In [None]:
def matrix_exponential(laplacian):
  """
  Approximation of the exponential of a matrix.

  Args:
    laplacian: Input tensor (batch_size, number_diffusion_kernels, seq_len, seq_len)
  """
  batch_size, number_diffusion_kernels, seq_len, _ = laplacian.shape
  device = laplacian.device

  identity_matrix = torch.eye(seq_len, device=device)
  laplacian_power = identity_matrix.unsqueeze(0).unsqueeze(0).repeat(batch_size, number_diffusion_kernels, 1, 1)

  taylor_sum = torch.zeros_like(laplacian)

  for l in range(TAYLOR_APPROXIMATION):
      l_factorial_inv = 1.0 / math.factorial(l)
      taylor_sum += laplacian_power * l_factorial_inv
      laplacian_power = torch.matmul(laplacian_power, laplacian)

  return taylor_sum

In [None]:
class AttentionHeadL(nn.Module):
    def __init__(self, hidden_size: int, head_dim: int, number_diffusion_kernels: int):
        """
        Single attention head implementation

        Args:
            hidden_size: Input dimension
            head_dim: Dimension of each attention head
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.number_diffusion_kernels = number_diffusion_kernels
        self.WQ = nn.Linear(hidden_size, head_dim, bias=False)
        self.WK = nn.Linear(hidden_size, head_dim, bias=False)
        self.WV = nn.Linear(hidden_size, head_dim, bias=False)
        self.WR = nn.Linear(hidden_size, head_dim, bias=False)

        self.log_beta = nn.Parameter(torch.randn(number_diffusion_kernels))
        self.weights = nn.Parameter(torch.rand(number_diffusion_kernels))
        self.ratio_r = nn.Parameter(torch.randn(number_diffusion_kernels))

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for attention head

        Args:
            x: Input tensor (batch_size, seq_len, hidden_size)
            attn_mask: Attention mask (batch_size, seq_len, seq_len) - 1 for attend, 0 for mask

        Returns:
            attention_output: (batch_size, seq_len, head_dim)
            attention_weights: (batch_size, seq_len, seq_len)
        """
        seq_len = x.shape[1]
        device = x.device

        R = self.WR(x) # (batch_size, seq_len, head_dim)

        dots = torch.matmul(R, R.transpose(-2, -1)) # (batch_size, seq_len, seq_len)
        R_norms_sq = torch.sum(R * R, dim=-1, keepdim=True) # (batch_size, seq_len, 1)
        squared_distance = R_norms_sq + R_norms_sq.transpose(-2, -1) - 2 * dots
        distance = torch.sqrt(torch.relu(squared_distance)) # (batch_size, seq_len, seq_len)

        self_loop_mask = (1.0 - torch.eye(seq_len, device=device)).bool()
        identity_matrix = torch.eye(seq_len, device=device)

        max_r = torch.amax(distance, (1, 2))
        min_r = torch.amin(torch.where(self_loop_mask, distance, torch.inf), (1, 2))

        ratio_r_expanded = self.ratio_r.view(1, self.number_diffusion_kernels)
        restricted_ratio_r = torch.sigmoid(ratio_r_expanded)
        range_r = (max_r - min_r).unsqueeze(-1)
        min_r_expanded = min_r.unsqueeze(-1)
        values_r = min_r_expanded + range_r * restricted_ratio_r

        distance_expanded = distance.unsqueeze(1)
        values_r_expanded = values_r.unsqueeze(-1).unsqueeze(-1)

        mask = (distance_expanded < values_r_expanded) # (batch_size, number_diffusion_kernels, seq_len, seq_len)
        adjacency = mask.float() * self_loop_mask.float()

        degree_sums = torch.sum(adjacency, dim=-1).unsqueeze(-1)
        laplacian = degree_sums * identity_matrix - adjacency

        positive_beta = torch.nn.functional.softplus(self.log_beta)
        beta_broadcastable = positive_beta.view(1, self.number_diffusion_kernels, 1, 1)
        weighted_laplacian = -laplacian * beta_broadcastable
        kernels = matrix_exponential(weighted_laplacian)

        weights_broadcastable= self.weights.view(1, self.number_diffusion_kernels, 1, 1)
        final_combined_kernel = torch.sum(kernels * weights_broadcastable, dim=1) # (batch_size, seq_len, seq_len)

        Q = self.WQ(x)
        V = self.WV(x)
        K = self.WK(x)
        score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) + final_combined_kernel
        if attn_mask is not None:
            score = score.masked_fill(attn_mask == 0, -torch.inf)
        attention_weights = torch.softmax(score, dim=-1)
        attention_output = torch.matmul(attention_weights, V)

        return attention_output, attention_weights

In [None]:
class MultiHeadAttentionL(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, number_diffusion_kernels: int):
        """
        Multi-head attention implementation

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
        """
        super().__init__()
        assert hidden_size % num_heads == 0, f"The hidden size {hidden_size} is not divisible by the number of heads {num_heads}."
        head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.number_diffusion_kernels = number_diffusion_kernels
        self.heads = nn.ModuleList([AttentionHeadL(hidden_size, head_dim, number_diffusion_kernels) for _ in range(num_heads)])
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for multi-head attention

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            attention_output: (batch_size, seq_len, hidden_size)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        outputs = [each(hidden_states, attention_mask) for each in self.heads]
        attention_outputs_tuple = tuple(each[0] for each in outputs)
        attention_outputs = torch.stack(attention_outputs_tuple).transpose(0, 1).transpose(1, 2).flatten(2, 3)
        attention_weights_tuple = tuple(each[1] for each in outputs)
        attention_weights = torch.stack(attention_weights_tuple).transpose(0, 1)
        attention_outputs = self.linear(attention_outputs)
        return attention_outputs, attention_weights

In [None]:
class TransformerBlockL(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int, number_diffusion_kernels):
        """
        Complete transformer block with attention and feed-forward

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
            intermediate_size: FFN hidden dimension
        """
        super().__init__()
        self.rms_att = RMSNorm(hidden_size=hidden_size)
        self.rms_ffn = RMSNorm(hidden_size=hidden_size)
        self.mha = MultiHeadAttentionL(hidden_size=hidden_size, num_heads=num_heads, number_diffusion_kernels=number_diffusion_kernels)
        self.ffn = FeedForward(hidden_size=hidden_size, intermediate_size=intermediate_size)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer block

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask

        Returns:
            hidden_states: Output tensor (batch_size, seq_len, hidden_size)
        """
        att_norm = self.rms_att(hidden_states)
        self_att = self.mha(att_norm, attention_mask)[0]
        res_conn_self_att = self_att + hidden_states
        ffn_norm = self.rms_ffn(res_conn_self_att)
        ffn_output = self.ffn(ffn_norm)
        res_conn_ffn = res_conn_self_att + ffn_output
        return res_conn_ffn

In [None]:
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """Create a causal (lower triangular) attention mask

    Args:
        seq_len: Sequence length
        device: Device to create the mask on

    Returns:
        Causal mask of shape (1, seq_len, seq_len)
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask.unsqueeze(0)

In [None]:
class TransformerModelL(nn.Module):
    def __init__(self, config: TransformerConfig):
        """
        Complete transformer model for causal language modeling
        """
        super().__init__()
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embeddings = nn.Embedding(num_embeddings=self.config.vocab_size, embedding_dim=self.config.hidden_size)
        self.pos_embeddings = nn.Embedding(num_embeddings=self.config.max_position_embeddings, embedding_dim=self.config.hidden_size)
        self.transformer = nn.ModuleList([TransformerBlockL(hidden_size=self.config.hidden_size, num_heads=self.config.num_attention_heads,
                                                           intermediate_size=self.config.intermediate_size, number_diffusion_kernels=self.config.number_diffusion_kernels)
         for _ in range(self.config.num_hidden_layers)])
        self.norm = RMSNorm(hidden_size=self.config.hidden_size)

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            hidden_states: Final hidden states (batch_size, seq_len, hidden_size)
        """
        positions = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.size(0), -1)
        pos_embeds = self.pos_embeddings(positions)
        token_embeddings = self.embeddings(input_ids) + pos_embeds
        if attention_mask is None and self.config.use_causal_mask:
          attention_mask = create_causal_mask(input_ids.shape[1], token_embeddings.device)
        transf = token_embeddings
        for layer in self.transformer:
          transf = layer(transf, attention_mask=attention_mask)
        output = self.norm(transf)
        return output

In [None]:
class CausalLanguageModelL(nn.Module):
    def __init__(self, config: TransformerConfig):
        """Causal language model with transformer backbone"""
        super().__init__()
        self.config = config
        self.transformer = TransformerModelL(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass for language model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            labels: Target labels for loss computation (batch_size, seq_len)

        Returns:
            If labels provided: (loss, logits)
            Else: logits only
        """
        hidden_states = self.transformer(input_ids)
        logits = self.lm_head(hidden_states)
        if labels is not None:
            logits_flat = logits[:, :-1, :].flatten(0, 1)
            labels_flat = labels[:, 1:].flatten(0, 1)
            return self.criterion(logits_flat, labels_flat), logits
        return logits

    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0) -> torch.Tensor:
        """
        Generate text using the language model

        Args:
            input_ids: Starting token IDs (batch_size, seq_len)
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature

        Returns:
            Generated token IDs (batch_size, seq_len + max_new_tokens)
        """
        for _ in range(max_new_tokens):
            logits = self.forward(input_ids)[:, -1, :] / temperature
            probs = nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
        return input_ids


##PH Model


In [None]:
class PHDataset:
    def __init__(self, dataset, tokenizer, max_length=512, max_samples=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_samples = max_samples

    def __iter__(self):
        count = 0

        for each in self.dataset:
            text = each['text']

            # Tokenize on the fly
            encoded = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'  # Return tensors for direct use
            )

            # Create labels (same as input_ids for causal language modeling)
            labels = encoded['input_ids'].clone()

            yield {
                'input_ids': encoded['input_ids'].squeeze(0),  # Remove batch dimension
                'labels': labels.squeeze(0)  # Remove batch dimension
            }

            count += 1
            if self.max_samples is not None and count >= self.max_samples:
                break

In [None]:
class PHRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        """
        RMS Normalization

        Args:
            hidden_size: The size of the hidden dimension
            eps: Small constant for numerical stability
        """
        super().__init__()
        self.parameter = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Apply RMS normalization

        Args:
            hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size)

        Returns:
            Normalized tensor of shape (batch_size, seq_len, hidden_size)
        """
        rms = torch.sqrt(torch.mean(torch.square(hidden_states), dim=-1, keepdim=True) + self.eps)
        normalized = hidden_states / rms
        return normalized * self.parameter

In [None]:
class PHFeedForward(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        """
        Position-wise feed-forward network

        Args:
            hidden_size: Model dimension
            intermediate_size: Hidden dimension of FFN
            activation_fn: Activation function ('relu', 'gelu', etc.)
        """
        super().__init__()

        self.activation = nn.GELU()
        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

In [None]:
class PHAttentionHead(nn.Module):
    def __init__(self, hidden_size: int,
                 head_dim: int,
                 repn_dim: int = 64,  # Dimension of representation vectors
                 k: int = 32,         # K-nearest neighbors (GLOBAL over batch)
                 alpha: float = PH_ALPHA,  # Bias importance
                 sim_mode: str = "dot",  # "dot" or "l2"
                 semantic: bool = True    # False = normal head
                 ):
        super().__init__()
        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.semantic = semantic

        self.WQ = nn.Linear(hidden_size, head_dim, bias=False)
        self.WK = nn.Linear(hidden_size, head_dim, bias=False)
        self.WV = nn.Linear(hidden_size, head_dim, bias=False)

        if semantic:
            self.WR = nn.Linear(hidden_size, repn_dim, bias=False)
            self.k = k
            self.r = min(self.k, 32)  # small projection rank
            self.U_raw = nn.Parameter(torch.randn(self.k, self.r) / math.sqrt(self.k))
            self.alpha = alpha
            self.sim_mode = sim_mode


    def _mst_h0_lifetimes(self, dmats: torch.Tensor) -> torch.Tensor:
        """
        Batched Prim's algorithm to get H0 VR lifetimes (merge heights) from
        distance matrices. Exact for H0. Vectorized over batch of neighborhoods.

        dmats: (N, m, m) symmetric with 0 diag
        returns: (N, m-1) lifetimes (MST edge weights)
        """
        N, m, _ = dmats.shape
        device = dmats.device

        visited = torch.zeros(N, m, dtype=torch.bool, device=device)
        visited[:, 0] = True

        d_to_tree = dmats[:, 0, :].clone()  # (N, m)
        d_to_tree[visited] = float('inf')

        lifetimes = []
        arangeN = torch.arange(N, device=device)

        for _ in range(m - 1):
            best_val, best_idx = d_to_tree.min(dim=1)   # (N,), (N,)
            lifetimes.append(best_val)
            visited[arangeN, best_idx] = True
            new_row = dmats[arangeN, best_idx, :]       # (N, m)
            d_to_tree = torch.minimum(d_to_tree, new_row)
            d_to_tree[visited] = float('inf')

        return torch.stack(lifetimes, dim=1)            # (N, m-1)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
               ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x: (B, L, hidden_size)
        attn_mask: (B, L, L)
        """
        B, L, _ = x.shape

        # standard attention
        Q = self.WQ(x)  # (B, L, head_dim)
        K = self.WK(x)  # (B, L, head_dim)
        V = self.WV(x)  # (B, L, head_dim)
        score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, L, L)

        if self.semantic:
            # PH features (H0 via MST) computed over the WHOLE BATCH
            repn = self.WR(x)                                 # (B, L, repn_dim)
            N = B * L
            repn_flat = repn.reshape(N, -1).contiguous()      # (N, repn_dim)

            with torch.no_grad():
                # GLOBAL kNN across all tokens in the batch
                D_all = torch.cdist(repn_flat, repn_flat, p=2)             # (N, N)
                k_eff = min(self.k, max(1, N - 1))
                knn_idx = D_all.topk(k_eff + 1, largest=False).indices     # (N, k+1), includes self

                # neighborhoods and their pairwise distances
                m = knn_idx.size(1)                                         # m = k_eff + 1
                row = knn_idx.unsqueeze(-1).expand(-1, m, m)               # (N, m, m)
                col = knn_idx.unsqueeze(-2).expand(-1, m, m)               # (N, m, m)
                dmats = D_all[row, col].contiguous()

            # H0 lifetimes via batched MST
            lifetimes = self._mst_h0_lifetimes(dmats)  # (N, m-1) = (N, k_eff)

            # pad to fixed length k and normalize
            k_eff_actual = m - 1
            if k_eff_actual < self.k:
                pad = torch.zeros(N, self.k - k_eff_actual,
                                   device=x.device, dtype=lifetimes.dtype)
                lifetimes = torch.cat([lifetimes, pad], dim=1)  # (N, k)

            Phi = torch.nn.functional.normalize(lifetimes, p=2, dim=-1).view(B, L, self.k)  # (B, L, k)

            Phi_c = Phi - Phi.mean(dim=1, keepdim=True)                     # (B, L, k)
            U = self.U_raw / (self.U_raw.norm(p='fro') + 1e-8)              # (k, r), scale-normalized
            Z = torch.matmul(Phi_c, U)                                       # (B, L, r)
            Z = torch.nn.functional.normalize(Z, p=2, dim=-1)               # (B, L, r)

            # bias from Z
            if self.sim_mode == "dot":
                ph_bias = torch.bmm(Z, Z.transpose(1, 2))                   # (B, L, L)
            else:  # "l2"
                d = torch.cdist(Z, Z, p=2)
                ph_bias = -(d ** 2)

            score = score + self.alpha * ph_bias

        if attn_mask is not None:
            score = score.masked_fill(attn_mask == 0, float('-inf'))
        attention_weights = torch.softmax(score, dim=-1)
        attention_output = torch.matmul(attention_weights, V)
        return attention_output, attention_weights

In [None]:
class PHMultiHeadAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        """
        Multi-head attention implementation

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
        """
        super().__init__()
        assert hidden_size % num_heads == 0, f"The hidden size {hidden_size} is not divisible by the number of heads {num_heads}."
        head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.heads = nn.ModuleList([PHAttentionHead(hidden_size, head_dim, semantic=True) for __ in range(PH_SEMANTIC_HEADS)]+[PHAttentionHead(hidden_size, head_dim, semantic=False) for _ in range(num_heads-PH_SEMANTIC_HEADS)])
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for multi-head attention

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            attention_output: (batch_size, seq_len, hidden_size)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        outputs = [each(hidden_states, attention_mask) for each in self.heads]
        attention_outputs_tuple = tuple(each[0] for each in outputs)
        attention_outputs = torch.stack(attention_outputs_tuple).transpose(0, 1).transpose(1, 2).flatten(2, 3)
        attention_weights_tuple = tuple(each[1] for each in outputs)
        attention_weights = torch.stack(attention_weights_tuple).transpose(0, 1)
        attention_outputs = self.linear(attention_outputs)
        return attention_outputs, attention_weights

In [None]:
class PHTransformerBlock(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int):
        """
        Complete transformer block with attention and feed-forward

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
            intermediate_size: FFN hidden dimension
        """
        super().__init__()
        self.rms_att = PHRMSNorm(hidden_size=hidden_size)
        self.rms_ffn = PHRMSNorm(hidden_size=hidden_size)
        self.mha = PHMultiHeadAttention(hidden_size=hidden_size, num_heads=num_heads)
        self.ffn = PHFeedForward(hidden_size=hidden_size, intermediate_size=intermediate_size)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer block

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask

        Returns:
            hidden_states: Output tensor (batch_size, seq_len, hidden_size)
        """
        att_norm = self.rms_att(hidden_states)
        self_att = self.mha(att_norm, attention_mask)[0]
        res_conn_self_att = self_att + hidden_states
        ffn_norm = self.rms_ffn(res_conn_self_att)
        ffn_output = self.ffn(ffn_norm)
        res_conn_ffn = res_conn_self_att + ffn_output
        return res_conn_ffn

In [None]:
class TransformerModelPH(nn.Module):
    def __init__(self, config: TransformerConfig):
        """
        Complete transformer model for causal language modeling
        """
        super().__init__()
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embeddings = nn.Embedding(num_embeddings=self.config.vocab_size, embedding_dim=self.config.hidden_size)
        self.pos_embeddings = nn.Embedding(num_embeddings=self.config.max_position_embeddings, embedding_dim=self.config.hidden_size)
        self.transformer = nn.ModuleList([PHTransformerBlock(hidden_size=self.config.hidden_size, num_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size)
         for _ in range(self.config.num_hidden_layers)])
        self.norm = PHRMSNorm(hidden_size=self.config.hidden_size)

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            hidden_states: Final hidden states (batch_size, seq_len, hidden_size)
        """
        positions = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.size(0), -1)
        pos_embeds = self.pos_embeddings(positions)
        token_embeddings = self.embeddings(input_ids) + pos_embeds
        if attention_mask is None and self.config.use_causal_mask:
          attention_mask = create_causal_mask(input_ids.shape[1], token_embeddings.device)
        transf = token_embeddings
        for layer in self.transformer:
          transf = layer(transf, attention_mask=attention_mask)
        output = self.norm(transf)
        return output

In [None]:
class CausalLanguageModelPH(nn.Module):
    def __init__(self, config: TransformerConfig):
        """Causal language model with transformer backbone"""
        super().__init__()
        self.config = config
        self.transformer = TransformerModelPH(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass for language model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            labels: Target labels for loss computation (batch_size, seq_len)

        Returns:
            If labels provided: (loss, logits)
            Else: logits only
        """
        hidden_states = self.transformer(input_ids)
        logits = self.lm_head(hidden_states)
        if labels is not None:
            logits_flat = logits[:, :-1, :].flatten(0, 1)
            labels_flat = labels[:, 1:].flatten(0, 1)
            return self.criterion(logits_flat, labels_flat), logits
        return logits

    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0) -> torch.Tensor:
        """
        Generate text using the language model

        Args:
            input_ids: Starting token IDs (batch_size, seq_len)
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature

        Returns:
            Generated token IDs (batch_size, seq_len + max_new_tokens)
        """
        for _ in range(max_new_tokens):
            logits = self.forward(input_ids)[:, -1, :] / temperature
            probs = nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
        return input_ids
