In [1]:
!pip install torch math

[31mERROR: Could not find a version that satisfies the requirement math (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for math[0m[31m
[0m

In [3]:
%%writefile moht_components.py
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

class MHAAttention(nn.Module):
    """Multi-Head Attention: num_q_heads = num_kv_heads"""
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.d_head = config.n_embd // config.n_head
        self.dropout = config.dropout
        self.block_size = config.block_size

        # Q, K, V projections for all heads
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # Flash attention support
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()

        # Calculate Q, K, V for all heads
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)  # (B, nh, T, hs)

        # Attention
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True
            )
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_head))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class GQAAttention(nn.Module):
    """Grouped-Query Attention: num_q_heads > num_kv_heads"""
    def __init__(self, config, num_kv_heads=2):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        assert config.n_head % num_kv_heads == 0, "n_head must be divisible by num_kv_heads"

        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.num_kv_heads = num_kv_heads
        self.d_head = config.n_embd // config.n_head
        self.dropout = config.dropout
        self.block_size = config.block_size

        # Q projection for all heads
        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # K, V projections for fewer heads
        self.k_proj = nn.Linear(config.n_embd, num_kv_heads * self.d_head, bias=config.bias)
        self.v_proj = nn.Linear(config.n_embd, num_kv_heads * self.d_head, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # Flash attention support
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()

        # Q for all heads
        q = self.q_proj(x).view(B, T, self.n_head, self.d_head).transpose(1, 2)  # (B, nh, T, hs)

        # K, V for fewer heads
        k = self.k_proj(x).view(B, T, self.num_kv_heads, self.d_head).transpose(1, 2)  # (B, num_kv_heads, T, hs)
        v = self.v_proj(x).view(B, T, self.num_kv_heads, self.d_head).transpose(1, 2)  # (B, num_kv_heads, T, hs)

        # Repeat K, V to match number of Q heads
        # Each KV head is shared across n_head // num_kv_heads query heads
        k = k.repeat_interleave(self.n_head // self.num_kv_heads, dim=1)  # (B, nh, T, hs)
        v = v.repeat_interleave(self.n_head // self.num_kv_heads, dim=1)  # (B, nh, T, hs)

        # Attention
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True
            )
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_head))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class MQAAttention(nn.Module):
    """Multi-Query Attention: num_kv_heads = 1"""
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.d_head = config.n_embd // config.n_head
        self.dropout = config.dropout
        self.block_size = config.block_size

        # Q projection for all heads
        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # K, V projections for single head
        self.k_proj = nn.Linear(config.n_embd, self.d_head, bias=config.bias)
        self.v_proj = nn.Linear(config.n_embd, self.d_head, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # Flash attention support
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()

        # Q for all heads
        q = self.q_proj(x).view(B, T, self.n_head, self.d_head).transpose(1, 2)  # (B, nh, T, hs)

        # K, V for single head
        k = self.k_proj(x).view(B, T, 1, self.d_head).transpose(1, 2)  # (B, 1, T, hs)
        v = self.v_proj(x).view(B, T, 1, self.d_head).transpose(1, 2)  # (B, 1, T, hs)

        # Repeat K, V to match number of Q heads
        k = k.repeat(1, self.n_head, 1, 1)  # (B, nh, T, hs)
        v = v.repeat(1, self.n_head, 1, 1)  # (B, nh, T, hs)

        # Attention
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True
            )
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_head))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class StaticMoASAttention(nn.Module):
    """Static Mixture of Attention Schemes - averages MHA, GQA, MQA outputs"""
    def __init__(self, config):
        super().__init__()
        self.mha = MHAAttention(config)
        self.gqa = GQAAttention(config, num_kv_heads=2)
        self.mqa = MQAAttention(config)

    def forward(self, x):
        o_mha = self.mha(x)
        o_gqa = self.gqa(x)
        o_mqa = self.mqa(x)
        return (o_mha + o_gqa + o_mqa) / 3.0


class MoASAttention(nn.Module):
    """Mixture of Attention Schemes with learned per-token routing"""
    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd

        # Three attention branches
        self.mha = MHAAttention(config)
        self.gqa = GQAAttention(config, num_kv_heads=2)
        self.mqa = MQAAttention(config)

        # Router: 2-layer MLP
        router_hidden = config.n_embd // 4
        self.router = nn.Sequential(
            nn.Linear(config.n_embd, router_hidden, bias=config.bias),
            nn.GELU(),
            nn.Linear(router_hidden, 3, bias=config.bias)  # 3 attention types
        )

        self.gate_dropout = nn.Dropout(config.dropout)

    def forward(self, x, return_gate_stats=False):
        B, T, C = x.size()

        # Compute all attention outputs
        o_mha = self.mha(x)  # (B, T, C)
        o_gqa = self.gqa(x)  # (B, T, C)
        o_mqa = self.mqa(x)  # (B, T, C)

        # Stack outputs: (B, T, 3, C)
        outputs = torch.stack([o_mha, o_gqa, o_mqa], dim=2)

        # Compute routing logits for each token
        router_logits = self.router(x)  # (B, T, 3)
        gates = F.softmax(router_logits, dim=-1)  # (B, T, 3)
        gates = self.gate_dropout(gates)

        # Mix outputs per token: (B, T, 3, 1) * (B, T, 3, C) -> (B, T, 3, C) -> (B, T, C)
        y = (gates.unsqueeze(-1) * outputs).sum(dim=2)

        if return_gate_stats:
            # Return average gate values for logging
            avg_gates = gates.mean(dim=(0, 1))  # (3,)
            return y, avg_gates

        return y

    def get_load_balancing_loss(self, x):
        """Compute load balancing loss to encourage using all attention types"""
        B, T, C = x.size()

        # Compute gates
        router_logits = self.router(x)  # (B, T, 3)
        gates = F.softmax(router_logits, dim=-1)  # (B, T, 3)

        # Average gate per type across all tokens
        avg_gates = gates.mean(dim=(0, 1))  # (3,)

        # Target: uniform distribution (1/3 for each type)
        target = torch.ones_like(avg_gates) / 3.0

        # MSE loss
        loss = F.mse_loss(avg_gates, target)

        return loss


Writing moht_components.py


In [4]:
import os, sys, importlib

# 1) Go to the folder where the file was written (Colab default)
%cd /content

# 2) Verify the file exists
print("cwd:", os.getcwd())
print("has file:", os.path.exists("moht_components.py"))
!ls -l moht_components.py

# 3) Ensure /content is on Python path (usually already is)
if "/content" not in sys.path:
    sys.path.insert(0, "/content")

# 4) Import + reload (useful if you edited the file and re-ran %%writefile)
import moht_components
importlib.reload(moht_components)

from moht_components import StaticMoASAttention, MoASAttention
print("Imported OK ✅")


/content
cwd: /content
has file: True
-rw-r--r-- 1 root root 10421 Dec 16 07:51 moht_components.py
Imported OK ✅


In [6]:
%%writefile moht_gpt.py
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass
from moht_components import StaticMoASAttention, MoASAttention

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True
    attention_type: str = 'baseline'  # 'baseline', 'static_moas', 'moas'

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)

        # Choose attention type
        if config.attention_type == 'baseline':
            self.attn = CausalSelfAttention(config)
        elif config.attention_type == 'static_moas':
            self.attn = StaticMoASAttention(config)
        elif config.attention_type == 'moas':
            self.attn = MoASAttention(config)
        else:
            raise ValueError(f"Unknown attention type: {config.attention_type}")

        self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

    def get_load_balancing_loss(self, idx):
        """Compute load balancing loss for MoAS attention"""
        if self.config.attention_type != 'moas':
            return torch.tensor(0.0, device=idx.device)

        device = idx.device
        b, t = idx.size()
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        # Forward to get embeddings
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)

        # Accumulate load balancing loss from all layers
        total_lb_loss = 0.0
        for block in self.transformer.h:
            x_norm = block.ln_1(x)
            lb_loss = block.attn.get_load_balancing_loss(x_norm)
            total_lb_loss += lb_loss
            x = x + block.attn(x_norm)
            x = x + block.mlp(block.ln_2(x))

        return total_lb_loss / len(self.transformer.h)

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in torch.optim.AdamW.__init__.__code__.co_varnames
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer


Writing moht_gpt.py


In [7]:
%%writefile train.py
import os
import time
import math
import pickle
import contextlib
import numpy as np
import torch
import tiktoken
from moht_gpt import GPT, GPTConfig

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
out_dir = 'out'
eval_interval = 200
log_interval = 10
eval_iters = 20
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume'
# data
dataset = 'wikitext-2'
gradient_accumulation_steps = 1 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 256 # context of up to 256 tokens
# model
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2
bias = False # do we use bias inside LayerNorm and Linear layers?
attention_type = 'baseline' # 'baseline', 'static_moas', 'moas'
load_balance_weight = 0.01 # weight for load balancing loss (only for 'moas')
# adamw optimizer
learning_rate = 1e-3 # max learning rate
max_iters = 2000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.99
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 100 # how many steps to warm up for
lr_decay_iters = 2000 # should be ~= max_iters per Chinchilla
min_lr = 1e-4 # minimum learning rate, should be ~= learning_rate/10
# system
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = False # use PyTorch 2.0 to compile the model to be faster

# -----------------------------------------------------------------------------
# Data Loading
# -----------------------------------------------------------------------------
def get_data():
    # simple data loader for wikitext-2
    # we will download it if it doesn't exist
    data_dir = os.path.join('data', dataset)
    os.makedirs(data_dir, exist_ok=True)
    input_file_path = os.path.join(data_dir, 'input.txt')
    if not os.path.exists(input_file_path):
        import requests
        print("Downloading WikiText-2...")
        data_url = 'https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/train.txt'
        try:
            with open(input_file_path, 'w', encoding='utf-8') as f:
                f.write(requests.get(data_url).text)
        except Exception as e:
            print(f"Failed to download data: {e}")
            print("Creating dummy data instead.")
            with open(input_file_path, 'w', encoding='utf-8') as f:
                f.write("Hello world " * 10000)

    with open(input_file_path, 'r', encoding='utf-8') as f:
        data = f.read()

    # tokenize (character level)
    chars = sorted(list(set(data)))
    vocab_size = len(chars)
    print(f"Vocab size: {vocab_size}")
    stoi = { ch:i for i,ch in enumerate(chars) }
    itos = { i:ch for i,ch in enumerate(chars) }
    encode = lambda s: [stoi[c] for c in s]

    train_ids = encode(data)
    n = len(train_ids)
    train_data = np.array(train_ids[:int(n*0.9)], dtype=np.uint16)
    val_data = np.array(train_ids[int(n*0.9):], dtype=np.uint16)
    return train_data, val_data, vocab_size

train_data, val_data, vocab_size = get_data()

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# -----------------------------------------------------------------------------
# Training Setup
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
# exec(open('configurator.py').read()) # overrides from command line or config file
# config = {k: globals()[k] for k in config_keys} # will be useful for logging

torch.manual_seed(1337)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = contextlib.nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=vocab_size, dropout=dropout, attention_type=attention_type)
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.to(device)

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    # TODO: implement resume
    pass

# compile the model
if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model) # requires PyTorch 2.0

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

# -----------------------------------------------------------------------------
# Training Loop
# -----------------------------------------------------------------------------
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model # unwrap DDP container if needed
running_mfu = -1.0

print(f"Training on {device}...")

while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(local_iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if local_iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {local_iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if always_save_checkpoint:
            if local_iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': local_iter_num,
                    'best_val_loss': losses['val'],
                    'config': config_keys,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))

    if local_iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y)

            # Add load balancing loss for MoAS
            if attention_type == 'moas':
                lb_loss = model.get_load_balancing_loss(X)
                loss = loss + load_balance_weight * lb_loss

            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        # scaler.scale(loss).backward()
        loss.backward()

    # clip the gradient
    if grad_clip != 0.0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

    # step the optimizer and scaler if training in fp16
    optimizer.step()
    # scaler.update()
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if local_iter_num % log_interval == 0:
        lossf = loss.item() * gradient_accumulation_steps
        print(f"iter {local_iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")

    local_iter_num += 1

    # termination conditions
    if local_iter_num > max_iters:
        break

print("Training finished!")


Writing train.py


In [8]:
!python train.py

Downloading WikiText-2...
Vocab size: 283
  self.setter(val)
number of parameters: 10.73M
num decayed parameter tensors: 26 with 10,823,808 parameters
num non-decayed parameter tensors: 13 with 4,992 parameters
using fused AdamW: True
Training on cuda...
step 0: train loss 5.7300, val loss 5.7355
iter 0: loss 5.7324, time 2219.06ms
iter 10: loss 4.0670, time 93.62ms
iter 20: loss 3.3951, time 94.22ms
iter 30: loss 2.8527, time 94.16ms
iter 40: loss 2.6811, time 94.81ms
iter 50: loss 2.5449, time 94.47ms
iter 60: loss 2.5387, time 95.35ms
iter 70: loss 2.4884, time 94.39ms
iter 80: loss 2.5052, time 96.19ms
iter 90: loss 2.4781, time 96.51ms
iter 100: loss 2.4915, time 95.68ms
iter 110: loss 2.4626, time 96.09ms
iter 120: loss 2.4603, time 97.39ms
iter 130: loss 2.5099, time 96.21ms
iter 140: loss 2.4814, time 100.34ms
iter 150: loss 2.4404, time 97.87ms
iter 160: loss 2.4058, time 96.47ms
iter 170: loss 2.4540, time 95.88ms
iter 180: loss 2.4378, time 96.24ms
iter 190: loss 2.4072, tim

In [9]:
%%writefile train_static_moas.py
import os
import time
import math
import pickle
import contextlib
import numpy as np
import torch
import tiktoken
from moht_gpt import GPT, GPTConfig

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
out_dir = 'out'
eval_interval = 200
log_interval = 10
eval_iters = 20
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume'
# data
dataset = 'wikitext-2'
gradient_accumulation_steps = 1 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 256 # context of up to 256 tokens
# model
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2
bias = False # do we use bias inside LayerNorm and Linear layers?
attention_type = 'static_moas' # 'baseline', 'static_moas', 'moas'
load_balance_weight = 0.01 # weight for load balancing loss (only for 'moas')
# adamw optimizer
learning_rate = 1e-3 # max learning rate
max_iters = 2000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.99
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 100 # how many steps to warm up for
lr_decay_iters = 2000 # should be ~= max_iters per Chinchilla
min_lr = 1e-4 # minimum learning rate, should be ~= learning_rate/10
# system
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = False # use PyTorch 2.0 to compile the model to be faster

# -----------------------------------------------------------------------------
# Data Loading
# -----------------------------------------------------------------------------
def get_data():
    # simple data loader for wikitext-2
    # we will download it if it doesn't exist
    data_dir = os.path.join('data', dataset)
    os.makedirs(data_dir, exist_ok=True)
    input_file_path = os.path.join(data_dir, 'input.txt')
    if not os.path.exists(input_file_path):
        import requests
        print("Downloading WikiText-2...")
        data_url = 'https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/train.txt'
        try:
            with open(input_file_path, 'w', encoding='utf-8') as f:
                f.write(requests.get(data_url).text)
        except Exception as e:
            print(f"Failed to download data: {e}")
            print("Creating dummy data instead.")
            with open(input_file_path, 'w', encoding='utf-8') as f:
                f.write("Hello world " * 10000)

    with open(input_file_path, 'r', encoding='utf-8') as f:
        data = f.read()

    # tokenize (character level)
    chars = sorted(list(set(data)))
    vocab_size = len(chars)
    print(f"Vocab size: {vocab_size}")
    stoi = { ch:i for i,ch in enumerate(chars) }
    itos = { i:ch for i,ch in enumerate(chars) }
    encode = lambda s: [stoi[c] for c in s]

    train_ids = encode(data)
    n = len(train_ids)
    train_data = np.array(train_ids[:int(n*0.9)], dtype=np.uint16)
    val_data = np.array(train_ids[int(n*0.9):], dtype=np.uint16)
    return train_data, val_data, vocab_size

train_data, val_data, vocab_size = get_data()

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# -----------------------------------------------------------------------------
# Training Setup
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
# exec(open('configurator.py').read()) # overrides from command line or config file
# config = {k: globals()[k] for k in config_keys} # will be useful for logging

torch.manual_seed(1337)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = contextlib.nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=vocab_size, dropout=dropout, attention_type=attention_type)
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.to(device)

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    # TODO: implement resume
    pass

# compile the model
if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model) # requires PyTorch 2.0

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

# -----------------------------------------------------------------------------
# Training Loop
# -----------------------------------------------------------------------------
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model # unwrap DDP container if needed
running_mfu = -1.0

print(f"Training on {device}...")

while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(local_iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if local_iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {local_iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if always_save_checkpoint:
            if local_iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': local_iter_num,
                    'best_val_loss': losses['val'],
                    'config': config_keys,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))

    if local_iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y)

            # Add load balancing loss for MoAS
            if attention_type == 'moas':
                lb_loss = model.get_load_balancing_loss(X)
                loss = loss + load_balance_weight * lb_loss

            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        # scaler.scale(loss).backward()
        loss.backward()

    # clip the gradient
    if grad_clip != 0.0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

    # step the optimizer and scaler if training in fp16
    optimizer.step()
    # scaler.update()
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if local_iter_num % log_interval == 0:
        lossf = loss.item() * gradient_accumulation_steps
        print(f"iter {local_iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")

    local_iter_num += 1

    # termination conditions
    if local_iter_num > max_iters:
        break

print("Training finished!")


Writing train_static_moas.py


In [10]:
!python train_static_moas.py

Vocab size: 283
  self.setter(val)
number of parameters: 15.15M
num decayed parameter tensors: 74 with 15,247,488 parameters
num non-decayed parameter tensors: 13 with 4,992 parameters
using fused AdamW: True
Training on cuda...
step 0: train loss 5.7434, val loss 5.7429
iter 0: loss 5.7337, time 2952.34ms
iter 10: loss 3.9412, time 180.81ms
iter 20: loss 3.2909, time 181.95ms
iter 30: loss 2.8385, time 181.82ms
iter 40: loss 2.6259, time 182.42ms
iter 50: loss 2.5571, time 182.66ms
iter 60: loss 2.5139, time 184.81ms
iter 70: loss 2.5352, time 185.40ms
iter 80: loss 2.4910, time 184.85ms
iter 90: loss 2.4766, time 187.46ms
iter 100: loss 2.4954, time 185.71ms
iter 110: loss 2.5384, time 189.98ms
iter 120: loss 2.4409, time 190.31ms
iter 130: loss 2.4294, time 188.19ms
iter 140: loss 2.4617, time 188.84ms
iter 150: loss 2.4251, time 188.59ms
iter 160: loss 2.4296, time 193.15ms
iter 170: loss 2.4381, time 191.08ms
iter 180: loss 2.4169, time 191.86ms
iter 190: loss 2.5015, time 193.51m

In [11]:
%%writefile train_moas.py
import os
import time
import math
import pickle
import contextlib
import numpy as np
import torch
import tiktoken
from moht_gpt import GPT, GPTConfig

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
out_dir = 'out'
eval_interval = 200
log_interval = 10
eval_iters = 20
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume'
# data
dataset = 'wikitext-2'
gradient_accumulation_steps = 1 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 256 # context of up to 256 tokens
# model
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2
bias = False # do we use bias inside LayerNorm and Linear layers?
attention_type = 'moas' # 'baseline', 'static_moas', 'moas'
load_balance_weight = 0.01 # weight for load balancing loss (only for 'moas')
# adamw optimizer
learning_rate = 1e-3 # max learning rate
max_iters = 2000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.99
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 100 # how many steps to warm up for
lr_decay_iters = 2000 # should be ~= max_iters per Chinchilla
min_lr = 1e-4 # minimum learning rate, should be ~= learning_rate/10
# system
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = False # use PyTorch 2.0 to compile the model to be faster

# -----------------------------------------------------------------------------
# Data Loading
# -----------------------------------------------------------------------------
def get_data():
    # simple data loader for wikitext-2
    # we will download it if it doesn't exist
    data_dir = os.path.join('data', dataset)
    os.makedirs(data_dir, exist_ok=True)
    input_file_path = os.path.join(data_dir, 'input.txt')
    if not os.path.exists(input_file_path):
        import requests
        print("Downloading WikiText-2...")
        data_url = 'https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/train.txt'
        try:
            with open(input_file_path, 'w', encoding='utf-8') as f:
                f.write(requests.get(data_url).text)
        except Exception as e:
            print(f"Failed to download data: {e}")
            print("Creating dummy data instead.")
            with open(input_file_path, 'w', encoding='utf-8') as f:
                f.write("Hello world " * 10000)

    with open(input_file_path, 'r', encoding='utf-8') as f:
        data = f.read()

    # tokenize (character level)
    chars = sorted(list(set(data)))
    vocab_size = len(chars)
    print(f"Vocab size: {vocab_size}")
    stoi = { ch:i for i,ch in enumerate(chars) }
    itos = { i:ch for i,ch in enumerate(chars) }
    encode = lambda s: [stoi[c] for c in s]

    train_ids = encode(data)
    n = len(train_ids)
    train_data = np.array(train_ids[:int(n*0.9)], dtype=np.uint16)
    val_data = np.array(train_ids[int(n*0.9):], dtype=np.uint16)
    return train_data, val_data, vocab_size

train_data, val_data, vocab_size = get_data()

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# -----------------------------------------------------------------------------
# Training Setup
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
# exec(open('configurator.py').read()) # overrides from command line or config file
# config = {k: globals()[k] for k in config_keys} # will be useful for logging

torch.manual_seed(1337)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = contextlib.nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=vocab_size, dropout=dropout, attention_type=attention_type)
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.to(device)

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    # TODO: implement resume
    pass

# compile the model
if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model) # requires PyTorch 2.0

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

# -----------------------------------------------------------------------------
# Training Loop
# -----------------------------------------------------------------------------
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model # unwrap DDP container if needed
running_mfu = -1.0

print(f"Training on {device}...")

while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(local_iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if local_iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {local_iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if always_save_checkpoint:
            if local_iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': local_iter_num,
                    'best_val_loss': losses['val'],
                    'config': config_keys,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))

    if local_iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y)

            # Add load balancing loss for MoAS
            if attention_type == 'moas':
                lb_loss = model.get_load_balancing_loss(X)
                loss = loss + load_balance_weight * lb_loss

            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        # scaler.scale(loss).backward()
        loss.backward()

    # clip the gradient
    if grad_clip != 0.0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

    # step the optimizer and scaler if training in fp16
    optimizer.step()
    # scaler.update()
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if local_iter_num % log_interval == 0:
        lossf = loss.item() * gradient_accumulation_steps
        print(f"iter {local_iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")

    local_iter_num += 1

    # termination conditions
    if local_iter_num > max_iters:
        break

print("Training finished!")


Writing train_moas.py


In [12]:
!python train_moas.py

Vocab size: 283
  self.setter(val)
number of parameters: 15.38M
num decayed parameter tensors: 86 with 15,470,400 parameters
num non-decayed parameter tensors: 13 with 4,992 parameters
using fused AdamW: True
Training on cuda...
step 0: train loss 5.6328, val loss 5.6301
iter 0: loss 5.6384, time 3363.34ms
iter 10: loss 3.8911, time 365.71ms
iter 20: loss 3.2655, time 369.08ms
iter 30: loss 2.8054, time 372.64ms
iter 40: loss 2.5900, time 375.00ms
iter 50: loss 2.5962, time 374.98ms
iter 60: loss 2.5543, time 377.77ms
iter 70: loss 2.5553, time 382.30ms
iter 80: loss 2.5149, time 383.75ms
iter 90: loss 2.4737, time 386.36ms
iter 100: loss 2.4753, time 390.71ms
iter 110: loss 2.4623, time 393.08ms
iter 120: loss 2.4999, time 398.33ms
iter 130: loss 2.4667, time 393.36ms
iter 140: loss 2.4477, time 399.29ms
iter 150: loss 2.4524, time 392.15ms
iter 160: loss 2.3879, time 389.80ms
iter 170: loss 2.4681, time 408.02ms
iter 180: loss 2.4510, time 387.26ms
iter 190: loss 2.4377, time 389.79m

In [13]:
%%writefile compare_moas.py
"""
Compare training runs for Baseline, Static MoAS, and Dynamic MoAS
"""
import os
import time
import math
import numpy as np
import torch
from moht_gpt import GPT, GPTConfig

# Configuration
out_dir = 'out_comparison'
os.makedirs(out_dir, exist_ok=True)

# Data
dataset = 'wikitext-2'
batch_size = 12
block_size = 256

# Model
n_layer = 4
n_head = 6
n_embd = 384
dropout = 0.1
bias = False

# Training
learning_rate = 3e-4
max_iters = 500
eval_interval = 50
eval_iters = 20
log_interval = 10

# MoAS specific
load_balance_weight = 0.01

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load data
def get_data():
    data_dir = os.path.join('data', dataset)
    input_file_path = os.path.join(data_dir, 'input.txt')

    with open(input_file_path, 'r', encoding='utf-8') as f:
        data = f.read()

    # Character-level tokenization
    chars = sorted(list(set(data)))
    vocab_size = len(chars)
    stoi = {ch: i for i, ch in enumerate(chars)}
    encode = lambda s: [stoi[c] for c in s]

    train_ids = encode(data)
    n = len(train_ids)
    train_data = np.array(train_ids[:int(n*0.9)], dtype=np.uint16)
    val_data = np.array(train_ids[int(n*0.9):], dtype=np.uint16)
    return train_data, val_data, vocab_size

train_data, val_data, vocab_size = get_data()

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def train_model(attention_type, name):
    print(f"\n{'='*70}")
    print(f"Training {name} ({attention_type})")
    print('='*70)

    # Create model
    config = GPTConfig(
        block_size=block_size,
        vocab_size=vocab_size,
        n_layer=n_layer,
        n_head=n_head,
        n_embd=n_embd,
        dropout=dropout,
        bias=bias,
        attention_type=attention_type
    )

    model = GPT(config)
    model.to(device)

    # Optimizer
    optimizer = model.configure_optimizers(
        weight_decay=0.1,
        learning_rate=learning_rate,
        betas=(0.9, 0.99),
        device_type=device
    )

    # Training loop
    results = {
        'iters': [],
        'train_loss': [],
        'val_loss': []
    }

    X, Y = get_batch('train')
    t0 = time.time()

    for iter_num in range(max_iters + 1):
        # Evaluation
        if iter_num % eval_interval == 0:
            losses = estimate_loss(model)
            print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            results['iters'].append(iter_num)
            results['train_loss'].append(losses['train'].item())
            results['val_loss'].append(losses['val'].item())

        if iter_num == max_iters:
            break

        # Forward
        logits, loss = model(X, Y)

        # Add load balancing loss for MoAS
        if attention_type == 'moas':
            lb_loss = model.get_load_balancing_loss(X)
            loss = loss + load_balance_weight * lb_loss

        # Backward
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Fetch next batch
        X, Y = get_batch('train')

        # Logging
        if iter_num % log_interval == 0:
            t1 = time.time()
            dt = t1 - t0
            t0 = t1
            lossf = loss.item()
            print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")

    print(f"\n✓ {name} training completed!")
    return results

# Train all three variants
results_baseline = train_model('baseline', 'Baseline MHA')
results_static = train_model('static_moas', 'Static MoAS')
results_moas = train_model('moas', 'Dynamic MoAS (Routed)')

# Save results
import pickle
with open(os.path.join(out_dir, 'comparison_results.pkl'), 'wb') as f:
    pickle.dump({
        'baseline': results_baseline,
        'static_moas': results_static,
        'moas': results_moas
    }, f)

print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
print(f"Baseline MHA    - Final val loss: {results_baseline['val_loss'][-1]:.4f}")
print(f"Static MoAS     - Final val loss: {results_static['val_loss'][-1]:.4f}")
print(f"Dynamic MoAS    - Final val loss: {results_moas['val_loss'][-1]:.4f}")
print("="*70)


Writing compare_moas.py


In [14]:
!python compare_moas.py

Using device: cuda

Training Baseline MHA (baseline)
number of parameters: 7.19M
num decayed parameter tensors: 18 with 7,284,864 parameters
num non-decayed parameter tensors: 9 with 3,456 parameters
using fused AdamW: True
step 0: train loss 5.7558, val loss 5.7592
iter 0: loss 5.7438, time 858.68ms
iter 10: loss 3.3046, time 474.47ms
iter 20: loss 2.9041, time 493.30ms
iter 30: loss 2.7083, time 483.85ms
iter 40: loss 2.6108, time 488.50ms
step 50: train loss 2.5482, val loss 2.5545
iter 50: loss 2.5501, time 1119.60ms
iter 60: loss 2.5305, time 488.09ms
iter 70: loss 2.5076, time 493.11ms
iter 80: loss 2.4992, time 492.51ms
iter 90: loss 2.4949, time 491.61ms
step 100: train loss 2.4658, val loss 2.4768
iter 100: loss 2.4760, time 1136.17ms
iter 110: loss 2.4665, time 494.08ms
iter 120: loss 2.4422, time 503.45ms
iter 130: loss 2.4589, time 498.65ms
iter 140: loss 2.4145, time 496.45ms
step 150: train loss 2.4288, val loss 2.4397
iter 150: loss 2.4478, time 1147.28ms
iter 160: loss 