In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import math

In [2]:
class Embedding(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size,config.d_model)
        
    def embed(self,x,position_ids=None,token_type_ids=None):
        x = self.embedding(x)
        return x
    
    def unembed(self,x):
        return torch.matmul(x,self.embedding.weight.t())

In [3]:
class RMSNorm(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.eps = config.eps 
        self.d_model = config.d_model
        self.weight = nn.Parameter(torch.ones(self.d_model))
        
    def _norm(self, x):
        return x / (torch.sqrt((x ** 2).mean(dim=-1, keepdim=True) + self.eps))
    
    def forward(self, x):
        return self.weight * self._norm(x)

In [4]:
class GatedMLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_ff,bias=False)
        self.fc2 = nn.Linear(config.d_ff, config.d_model,bias=False)
        self.fc3 = nn.Linear(config.d_model, config.d_ff,bias=False)
        self.act = F.silu
        
    def forward(self,x):
        x1 = self.fc1(x)
        x2 = self.fc3(x)
        return self.fc2(self.act(x1) * self.act(x2))

In [5]:
def repeat_kv(x,n_rep):
    batch_size,seq_len,n_heads,head_dim = x.size()
    if n_rep == 1:
        return x
    return x.unsqueeze(-2).expand(batch_size,seq_len,n_rep,n_heads,head_dim).reshape(batch_size,seq_len,n_heads*n_rep,head_dim)

In [6]:
class LinearlyScaledRotaryEmbedding(nn.Module):
    def __init__(self, d_model, scaling_factor=1.0, base=10000.0, device=None):
        super().__init__()
        self.d_model = d_model
        self.device = device
        self._linear_scaling_factor = scaling_factor
        self.inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2, device=device).float() / d_model))
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None
        self.scale = None
    
    def update_cache(self, seq_len, device=None, dtype=None):
        if (seq_len > self._seq_len_cached or 
            self._cos_cached is None or 
            self._cos_cached.device != device or 
            self._cos_cached.dtype != dtype):
            self._seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device, dtype=dtype) / self._linear_scaling_factor
            freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=dtype))
            self._cos_cached = torch.cos(freqs).to(dtype)
            self._sin_cached = torch.sin(freqs).to(dtype)
            
    def forward(self, x):
        # x is assumed to be of shape [B, L, n_heads, head_dim] with head_dim == self.d_model
        seq_len = x.shape[1]
        device = x.device
        dtype = x.dtype
        self.update_cache(seq_len, device=device, dtype=dtype)
        cos, sin = self._cos_cached, self._sin_cached
        
        head_dim = x.size(-1)
        split_dim = head_dim // 2
        
        # Split the last dimension into two halves.
        x1 = x[..., :split_dim]  # shape: [B, L, n_heads, split_dim]
        x2 = x[..., split_dim:]  # shape: [B, L, n_heads, split_dim]
        
        # Cos and sin have shape [L, d_model/2]; slice to [L, split_dim] and add extra dimensions:
        cos = cos[:seq_len, :split_dim].unsqueeze(0).unsqueeze(2)  # [1, L, 1, split_dim]
        sin = sin[:seq_len, :split_dim].unsqueeze(0).unsqueeze(2)  # [1, L, 1, split_dim]
        
        x1_rot = x1 * cos - x2 * sin
        x2_rot = x1 * sin + x2 * cos
        
        return torch.cat([x1_rot, x2_rot], dim=-1)


In [7]:
class GroupedMHA(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads_q = config.n_heads_q
        self.head_dim = self.d_model // self.n_heads_q
        self.n_rep = self.n_heads_q // config.n_heads_kv
        
        self.wq = nn.Linear(config.d_model,config.d_model,bias=False)
        self.wk = nn.Linear(config.d_model,config.d_model,bias=False)
        self.wv = nn.Linear(config.d_model,config.d_model,bias=False)
        self.wo = nn.Linear(config.d_model,config.d_model,bias=False)
        
        self.rotary_embed = LinearlyScaledRotaryEmbedding(self.head_dim,base=10000.0,device=None)
        
    def forward(self,x,cache=None):
        batch_size,seq_len,d_model = x.size()
        q = self.wq(x).view(batch_size,seq_len,self.n_heads_q,self.head_dim)
        k = self.wk(x).view(batch_size,seq_len,self.n_heads_q,self.head_dim)
        v = self.wv(x).view(batch_size,seq_len,self.n_heads_q,self.head_dim)
        
        xq = self.rotary_embed(q)
        xk = self.rotary_embed(k)
        
        if cache is not None:
            if "k" in cache and cache["k"] is not None:
                k = torch.cat([cache["k"], k], dim=1)  
                v = torch.cat([cache["v"], v], dim=1)
            cache["k"] = k
            cache["v"] = v
        
        xq = repeat_kv(xq,self.n_rep)
        xk = repeat_kv(xk,self.n_rep)
        
        xq = xq.transpose(1,2)
        xk = xk.transpose(1,2)
        v = v.transpose(1,2)
        
        attn = (xq @ xk.transpose(-2,-1)) / math.sqrt(self.head_dim)
        attn = F.softmax(attn,dim=-1)
        x = attn @ v
        x = x.transpose(1,2).reshape(batch_size,seq_len,self.d_model)
        return self.wo(x)

In [8]:
class AttentionBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.mha = GroupedMHA(config)
        self.pre_norm = RMSNorm(config)
        self.post_norm = RMSNorm(config)
        self.mlp = GatedMLP(config)
        
    def forward(self, x, padding_mask=None):
        if isinstance(padding_mask, torch.Tensor):
            x = x * padding_mask[..., None]
        attn_out = self.mha(self.pre_norm(x))
        x = x + attn_out
        if isinstance(padding_mask, torch.Tensor):
            x = x * padding_mask[..., None]
        mlp_out = self.mlp(self.post_norm(x))
        return x + mlp_out, None

In [9]:
class HyenaCascade(nn.Module):
    def __init__(self, config, hyena_filter_group=None, fir_inner_filter_length=None):
        super().__init__()
        self.d_model = config.d_model
        self.short_filter_length = config.short_filter_length
        # Use the projected channel dimension (which is 3*d_model) if not explicitly provided.
        self.hyena_filter_group = hyena_filter_group or config.d_model  
        self.state_size = config.state_size
        
        self.short_filter_weight = nn.Parameter(torch.randn(3 * config.d_model, 1, config.short_filter_length))
        self.short_filter_bias = (nn.Parameter(torch.randn(3 * config.d_model))
                                  if config.short_filter_bias else None)
        
        self.log_poles = nn.Parameter(torch.randn(self.hyena_filter_group, config.state_size)*0.01)
        self.residues = nn.Parameter(torch.randn(self.hyena_filter_group, config.state_size))
        self.D = nn.Parameter(torch.zeros(config.d_model))
        self.h = None
        self.t = None
        
    def update_time(self, seq_len, device):
        if self.t is None or self.t.shape[-1] != seq_len:
            self.t = torch.arange(seq_len, device=device).unsqueeze(0).unsqueeze(0)
        else:
            self.t = self.t[..., :seq_len]
            
    def compute_filter(self, seq_len, device):
        self.update_time(seq_len, device)
        h = (self.residues.unsqueeze(-1) * (self.log_poles.unsqueeze(-1) * self.t).exp()).sum(dim=1)
        h = h.mean(dim=0, keepdim=True)
        return h  
    
    def forward(self, x, inference_params=None, padding_mask=None):
        batch_size, seq_len, in_channels = x.size()
        x_t = x.transpose(1, 2)  
        pad = (self.short_filter_length - 1) // 2
        conv_out = F.conv1d(x_t, self.short_filter_weight, bias=self.short_filter_bias,
                            padding=pad, groups=in_channels)
        conv_out = conv_out.transpose(1, 2)  
        conv_out = conv_out[..., :self.d_model]
        
        if self.h is None or self.h.shape[1] < seq_len:
            self.h = self.compute_filter(seq_len, device=x.device).detach()  
        h_exp = self.h.unsqueeze(-1)
        x_val = x[..., :self.d_model]
        y = conv_out * h_exp + x_val * self.D
        if padding_mask is not None:
            y = y * padding_mask.unsqueeze(-1)
        return y, inference_params

In [10]:
class GatedConvBlock(nn.Module):
    def __init__(self,config,hyena_filter_group=None,fir_inner_filter_length=None):
        super().__init__()
        self.pre_norm = RMSNorm(config)
        self.post_norm = RMSNorm(config)
        self.hyena = HyenaCascade(config,hyena_filter_group=hyena_filter_group,fir_inner_filter_length=fir_inner_filter_length)
        self.mlp = GatedMLP(config)
        self.proj = nn.Linear(config.d_model,3 * config.d_model,bias=True)
        self.out_filter_dense = nn.Linear(config.d_model,config.d_model,bias=True)
        
    def proj_norm(self,x):
        return self.proj(self.pre_norm(x))
    
    def res_mlp_norm(self,x):
        return self.mlp(self.post_norm(x)) + x
    
    def forward(self,x,inference_params=None,padding_mask=None):
        x = self.proj_norm(x)
        if isinstance(padding_mask,torch.Tensor):
            x = x * padding_mask[...,None]
        x,inference_params = self.hyena(x,inference_params,padding_mask)
        if isinstance(padding_mask,torch.Tensor):
            x = x * padding_mask[...,None]
        x = self.res_mlp_norm(x)
        return self.out_filter_dense(x),inference_params

In [11]:
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.mha = AttentionBlock(config)
        self.pre_norm = RMSNorm(config)
        self.post_norm = RMSNorm(config)
        self.gated_conv = GatedConvBlock(config)
        
    def forward(self,x,padding_mask=None):
        x = self.pre_norm(x)
        x,_ = self.mha(x,padding_mask=padding_mask)
        x = self.post_norm(x)
        x,_ = self.gated_conv(x,padding_mask=padding_mask)
        return x,None

In [12]:
class StripedHyena(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed = Embedding(config)
        self.norm = RMSNorm(config)
        self.unembed = self.embed if config.get("tie_embed") else Embedding(config)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
        
    def forward(self,x,padding_mask=None):
        x = self.embed.embed(x)
        if padding_mask is not None:
            x = x * padding_mask.unsqueeze(-1)
        for block in self.blocks:
            x,_ = block(x,padding_mask=padding_mask)
        x = self.norm(x)
        x = self.unembed.unembed(x)
        return x

In [13]:
'''if __name__ == "__main__":
    # Define a simple configuration class.
    class Config:
        d_model = 128
        n_heads_q = 8
        n_heads_kv = 4
        vocab_size = 10000
        eps = 1e-5
        short_filter_length = 21
        state_size = 16
        n_layers = 4
        use_gated_conv = True
        tie_embed = False
        n_heads_kv = 8
        use_flashfft = False
        short_filter_bias = True
        device = "cuda"  # Change to "cuda" if GPU is available.
        inner_size_multiple_of = 64
        d_ff = 512  # Feed-forward dimension for GatedMLP
        
        def get(self, key, default=None):
            return getattr(self, key, default)
    
    config = Config()
    
    # Create dummy input: batch size 2, sequence length 50.
    dummy_input = torch.randint(0, config.vocab_size, (32, 128))
    # Create a padding mask (all ones means no padding).
    padding_mask = torch.ones(32, 128, dtype=torch.bool)
    
    # Instantiate the model.
    model = StripedHyena(config)
    
    # Forward pass.
    output = model(dummy_input, padding_mask=padding_mask)
    print("Output shape:", output.shape)'''

'if __name__ == "__main__":\n    # Define a simple configuration class.\n    class Config:\n        d_model = 128\n        n_heads_q = 8\n        n_heads_kv = 4\n        vocab_size = 10000\n        eps = 1e-5\n        short_filter_length = 21\n        state_size = 16\n        n_layers = 4\n        use_gated_conv = True\n        tie_embed = False\n        n_heads_kv = 8\n        use_flashfft = False\n        short_filter_bias = True\n        device = "cuda"  # Change to "cuda" if GPU is available.\n        inner_size_multiple_of = 64\n        d_ff = 512  # Feed-forward dimension for GatedMLP\n        \n        def get(self, key, default=None):\n            return getattr(self, key, default)\n    \n    config = Config()\n    \n    # Create dummy input: batch size 2, sequence length 50.\n    dummy_input = torch.randint(0, config.vocab_size, (32, 128))\n    # Create a padding mask (all ones means no padding).\n    padding_mask = torch.ones(32, 128, dtype=torch.bool)\n    \n    # Instantiat

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import math
from torch.utils.data import DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
from tqdm import tqdm

# -------------------------------
# Custom Tokenizer (Simple Example)
# -------------------------------
class CustomTokenizer:
    def __init__(self, texts, vocab_size=200000, pad_token="<pad>", unk_token="<unk>"):
        self.pad_token = pad_token
        self.unk_token = unk_token
        tokens = []
        for text in texts:
            tokens.extend(text.split())
        freq = {}
        for token in tokens:
            freq[token] = freq.get(token, 0) + 1
        sorted_tokens = sorted(freq.items(), key=lambda x: x[1], reverse=True)[:vocab_size-2]
        self.vocab = {pad_token: 0, unk_token: 1}
        idx = 2
        for token, _ in sorted_tokens:
            self.vocab[token] = idx
            idx += 1
        self.vocab_size = len(self.vocab)
    
    def encode(self, text, max_length=128, padding="max_length", truncation=True):
        tokens = text.split()
        token_ids = [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
        if truncation and len(token_ids) > max_length:
            token_ids = token_ids[:max_length]
        if padding == "max_length":
            pad_len = max_length - len(token_ids)
            token_ids = token_ids + [self.vocab[self.pad_token]] * pad_len
            attention_mask = [1]*min(len(tokens), max_length) + [0]*pad_len
        else:
            attention_mask = [1]*len(token_ids)
        return {"input_ids": token_ids, "attention_mask": attention_mask}
    
    def decode(self, token_ids):
        inv_vocab = {v: k for k, v in self.vocab.items()}
        # Stop at the pad token if found.
        tokens = []
        for t in token_ids:
            if t == self.vocab[self.pad_token]:
                break
            tokens.append(inv_vocab.get(t, self.unk_token))
        return " ".join(tokens)

class Config:
    d_model = 64              # Reduced hidden dimension
    n_heads_q = 4             # Reduced number of heads
    n_heads_kv = 4  
    vocab_size = 200000      # Will be updated after building vocabulary
    eps = 1e-5
    short_filter_length = 15  # Smaller filter length
    state_size = 8            # Smaller state size
    n_layers = 6              # Fewer layers
    use_gated_conv = True
    tie_embed = False
    use_flashfft = False
    short_filter_bias = True
    device = "cuda" if torch.cuda.is_available() else "cpu"
    inner_size_multiple_of = 32
    d_ff = 256               # Reduced feed-forward dimension
    
    def get(self, key, default=None):
        return getattr(self, key, default)

config = Config()

dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train[:1%]")
print("Loaded dataset with", len(dataset), "samples")

all_texts = [sample["text"] for sample in dataset if sample["text"] is not None]
tokenizer = CustomTokenizer(all_texts, vocab_size=200000)  # Smaller vocab for demonstration
print("Custom vocabulary size:", tokenizer.vocab_size)
config.vocab_size = tokenizer.vocab_size

def tokenize_function(examples):
    outputs = {"input_ids": [], "attention_mask": []}
    for text in examples["text"]:
        encoded = tokenizer.encode(text, max_length=128, padding="max_length", truncation=True)
        outputs["input_ids"].append(encoded["input_ids"])
        outputs["attention_mask"].append(encoded["attention_mask"])
    return outputs

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

train_loader = DataLoader(tokenized_dataset, batch_size=32, shuffle=True)

# -------------------------------
# Sample Generation Function
# -------------------------------
def generate_text(model, tokenizer, prompt, config, max_length=128, temperature=1.0, top_k=4, top_p=0.0):
    model.eval()
    with torch.no_grad():
        encoded = tokenizer.encode(prompt, max_length=128, padding="max_length", truncation=True)
        input_ids = torch.tensor(encoded["input_ids"]).unsqueeze(0).to(config.device)
        for _ in range(max_length - input_ids.size(1)):
            logits = model(input_ids)[..., :]  # [1, L, vocab_size] if unembedding projects to vocab
            next_logits = logits[:, -1, :] / temperature  # Apply temperature scaling
            
            # Optionally, apply top-k or top-p filtering here.
            # For greedy sampling, simply choose the highest probability token:
            next_token = torch.argmax(next_logits, dim=-1).unsqueeze(0)
            input_ids = torch.cat([input_ids, next_token], dim=1)
        model.train()
        return tokenizer.decode(input_ids.squeeze(0).tolist())


# -------------------------------
# Initialize Model and Optimizer
# -------------------------------
model = StripedHyena(config)
model.to(config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Total trainable parameters:", count_parameters(model))

# -------------------------------
# Training Loop (Language Modeling)
# -------------------------------
model.train()
num_epochs = 100
loss_history = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    epoch_losses = []
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(config.device)
        attention_mask = batch["attention_mask"].to(config.device)
        
        optimizer.zero_grad()
        logits = model(input_ids, padding_mask=attention_mask)
        
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        
        loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss.backward()
        optimizer.step()
        
        epoch_losses.append(loss.item())
        torch.cuda.empty_cache()
    
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")

    scheduler.step(avg_loss)
    # Sample generation after each epoch.
    prompt = "The meaning of life is"
    generated_text = generate_text(model, tokenizer, prompt, config, max_length=128)
    print("Sample generated text:", generated_text)
    
    # Save the model checkpoint.
    torch.save(model.state_dict(), f"striped_hyena_epoch_{epoch+1}.pt")

# Plot the loss history.
plt.figure(figsize=(8, 6))
plt.plot(range(1, num_epochs+1), loss_history, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.title("Training Loss per Epoch")
plt.grid(True)
plt.savefig("training_loss.png")
plt.show()


Loaded dataset with 18014 samples
Custom vocabulary size: 50972




Total trainable parameters: 7339712
Epoch 1


Epoch 1: 100%|████████████████████████████████| 563/563 [01:42<00:00,  5.48it/s]


Epoch 1 average loss: 29.4444
Sample generated text: The meaning of life is
Epoch 2


Epoch 2: 100%|████████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 2 average loss: 24.9200
Sample generated text: The meaning of life is
Epoch 3


Epoch 3: 100%|████████████████████████████████| 563/563 [01:42<00:00,  5.47it/s]


Epoch 3 average loss: 21.2122
Sample generated text: The meaning of life is
Epoch 4


Epoch 4: 100%|████████████████████████████████| 563/563 [01:43<00:00,  5.46it/s]


Epoch 4 average loss: 17.7561
Sample generated text: The meaning of life is
Epoch 5


Epoch 5: 100%|████████████████████████████████| 563/563 [01:43<00:00,  5.43it/s]


Epoch 5 average loss: 14.5567
Sample generated text: The meaning of life is
Epoch 6


Epoch 6: 100%|████████████████████████████████| 563/563 [01:42<00:00,  5.47it/s]


Epoch 6 average loss: 11.4847
Sample generated text: The meaning of life is
Epoch 7


Epoch 7: 100%|████████████████████████████████| 563/563 [01:42<00:00,  5.47it/s]


Epoch 7 average loss: 8.5223
Sample generated text: The meaning of life is
Epoch 8


Epoch 8: 100%|████████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 8 average loss: 6.4583
Sample generated text: The meaning of life is
Epoch 9


Epoch 9: 100%|████████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 9 average loss: 5.8095
Sample generated text: The meaning of life is
Epoch 10


Epoch 10: 100%|███████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 10 average loss: 5.5103
Sample generated text: The meaning of life is
Epoch 11


Epoch 11: 100%|███████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 11 average loss: 5.2804
Sample generated text: The meaning of life is
Epoch 12


Epoch 12: 100%|███████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 12 average loss: 5.0876
Sample generated text: The meaning of life is
Epoch 13


Epoch 13: 100%|███████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 13 average loss: 4.9167
Sample generated text: The meaning of life is
Epoch 14


Epoch 14: 100%|███████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 14 average loss: 4.7623
Sample generated text: The meaning of life is
Epoch 15


Epoch 15: 100%|███████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 15 average loss: 4.6187
Sample generated text: The meaning of life is
Epoch 16


Epoch 16: 100%|███████████████████████████████| 563/563 [01:42<00:00,  5.49it/s]


Epoch 16 average loss: 4.4840
Sample generated text: The meaning of life is
Epoch 17


Epoch 17:  22%|██████▋                        | 122/563 [00:22<01:21,  5.40it/s]


KeyboardInterrupt: 