In [51]:
import os
import sys
sys.path.append("/data/horse/ws/lama722b-nanite-lm/nanite-lm")
from pprint import pprint as print
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import AdamW, lr_scheduler
from xformers.ops import fmha, AttentionBias
from torch.nn.attention.flex_attention import (
    BlockMask,
    flex_attention,
    _mask_mod_signature,
)
from enum import Enum
from typing import Optional, Union, Tuple

In [54]:
from codebase.transformer import (
    BaseTransformerArgs,
    BaseTransformer,
    TransformerBlock,
    Attention,
    FeedForward,
    flex_attention_comp,
    RMSNorm,
    cross_entropy,
    apply_rotary_emb,
    reshape_for_broadcast,
    repeat_kv,
    InitStdFactor,
)
from codebase.optim import (
    OptimArgs,
    build_lr_fn
)

In [3]:
def create_causal_mask(seqlen, attn_impl, sliding_window):
    if sliding_window is not None and attn_impl == "fmha":
        return fmha.attn_bias.LocalAttentionFromBottomRightMask(
            window_left=sliding_window - 1, window_right=0
        )
    elif attn_impl == "fmha":
        return fmha.attn_bias.LowerTriangularMask()
    elif attn_impl == "sdpa":
        return "causal"
    elif attn_impl == "flex_attention":
        return create_block_mask(causal_mask, None, None, seqlen, seqlen)
    else:
        raise NotImplementedError(
            f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
        )

## Mup Enabled Attention Layer

In [21]:
class MupAttention(Attention):
    def __init__(
        self,
        dim: int,
        head_dim: int,
        n_heads: int,
        n_kv_heads: int,
        rope_theta: float,
        scaling_factor: float,
    ):
        super().__init__(
            dim,
            head_dim,
            n_heads,
            n_kv_heads,
            rope_theta,
        )
        self.scaling_factor = scaling_factor
        
    def forward(
        self,
        x: torch.Tensor,
        freq_cis: torch.Tensor,
        tok_idx: Optional[torch.Tensor] = None,
        mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
        attn_impl: str = "sdpa",
    ):
        # B S D
        bsz, seq_len, dim = x.shape
        xq = self.wq(x.view_as(x))
        xk = self.wk(x.view_as(x))
        xv = self.wv(x.view_as(x))

        output_shape = xq.shape
        # B S D -> B S H D
        xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])

        # This condition helps us be easily compatible
        # with inference by adding a pluggable KVCache
        if hasattr(self, "kv_cache"):
            xk, xv = self.kv_cache.update(xk, xv, tok_idx)

        xk = repeat_kv(xk, self.heads_per_group, dim=2)
        xv = repeat_kv(xv, self.heads_per_group, dim=2)

        attention_scaling_factor = 1.0 / self.n_heads

        if attn_impl == "flex_attention":
            assert mask is None or isinstance(mask, BlockMask)
            xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
            output = flex_attention_comp(xq, xk, xv, block_mask=mask, scale=attention_scaling_factor)
            output = output.transpose(1, 2).contiguous()  # B H S D -> B S H D

        elif attn_impl == "fmha":
            assert mask is None or isinstance(mask, AttentionBias)
            output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask, scale=attention_scaling_factor)
            # This uses B S H D instead of B H S D of pytorch

        elif attn_impl == "sdpa":
            xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
            assert mask is None or isinstance(mask, (str, torch.Tensor))
            is_causal = (mask == "causal") if isinstance(mask, str) else False
            mask = mask if isinstance(mask, torch.Tensor) else None
            output = F.scaled_dot_product_attention(
                xq,
                xk,
                xv,
                is_causal=is_causal,
                attn_mask=mask,
                scale=attention_scaling_factor
            )
            output = output.transpose(1, 2).contiguous()  # B H S D -> B S H D
        else:
            raise NotImplementedError(
                f"Attention implementation {attn_impl} not supported"
            )

        output = self.wo(output.reshape(output_shape))

        return output

    def reset_parameters(self, init_std=None, out_proj_factor=1.0):
        init_std = init_std or (self.dim ** (-0.5))

        for w in [self.wq, self.wk, self.wv]:
            nn.init.trunc_normal_(
                w.weight,
                mean=0.0,
                std=init_std / math.sqrt(self.scaling_factor),
                a=-3 * init_std,
                b=3 * init_std,
            )

        nn.init.trunc_normal_(
            self.wo.weight,
            mean=0.0,
            std=init_std / out_proj_factor,
            a=-3 * init_std,
            b=3 * init_std,
        )

class MupFeedForward(FeedForward):
    def __init__(
        self,
        dim,
        hidden_dim,
        multiple_of,
        ffn_dim_multiplier,
        scaling_factor: float = 1.0,
        mp_size: int = 1,
    ):
        super().__init__(
            dim=dim,
            hidden_dim=hidden_dim,
            multiple_of=multiple_of,
            ffn_dim_multiplier=ffn_dim_multiplier,
            mp_size = mp_size,
        )
        self.scaling_factor = scaling_factor

    def reset_parameters(
        self,
        init_std=None,
        factor=1.0,
    ):
        in_init_std = init_std or (self.dim ** (-0.5))
        out_init_std = init_std or (self.hidden_dim ** (-0.5))
        in_init_std = in_init_std / math.sqrt(self.scaling_factor)
        out_init_std = out_init_std / factor
        for w in [self.w1, self.w3]:
            nn.init.trunc_normal_(
                w.weight,
                mean=0.0,
                std=in_init_std,
                a=-3 * in_init_std,
                b=3 * in_init_std,
            )
        nn.init.trunc_normal_(
            self.w2.weight,
            mean=0.0,
            std=out_init_std,
            a=-3 * out_init_std,
            b=3 * out_init_std,
        )
        

In [22]:
class MupTransformerBlock(TransformerBlock):
    def __init__(self, args):
        super().__init__(args)
        self.scaling_factor = args.scaling_factor
        self.attention = MupAttention(
            dim = args.dim,
            head_dim = self.head_dim,
            n_heads = self.n_heads,
            n_kv_heads = self.n_kv_heads,
            rope_theta = args.rope_theta,
            scaling_factor = self.scaling_factor
        )
        self.feed_forward = MupFeedForward(
            dim=args.dim,
            hidden_dim=4*args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
            scaling_factor=args.scaling_factor
        )

In [37]:
class MupTransformerArgs(BaseTransformerArgs):
    seed = 42
    vocab_size = -1
    weight_tying = False
    sliding_window = None
    input_alpha = 1.0
    output_alpha = 1.0
    scaling_factor: float = None

class MupOptimArgs(OptimArgs):
    scaling_factor: float = None


class MupTransformer(BaseTransformer):
    def __init__(
        self,
        args,
    ):
        super().__init__(args)
        self.input_alpha = args.input_alpha
        self.output_alpha = args.output_alpha
        self.weight_tying = args.weight_tying
        self.sliding_window = args.sliding_window
        self.scaling_factor = args.scaling_factor
        

        assert args.vocab_size > 0
        assert args.scaling_factor >= 1, "You need to set this!!!"

        self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
        # This is Post-Layer Norm layer or Post-Norm
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)

        if args.weight_tying:
            self.output = TiedLinear(self.tok_embeddings)
        else:
            self.output = nn.Linear(
                args.dim,
                args.vocab_size,
                bias=False
            )
        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(MupTransformerBlock(args))

    def forward(
        self,
        token_values: torch.Tensor,
        target = None,
        tok_idx = None,
        mask = None,
        attn_impl = "sdpa",
    ):
        bsz, seqlen = token_values.shape
        mask = (
            mask
            if mask is not None
            else create_causal_mask(seqlen, attn_impl, self.sliding_window)
        )
        # (krotonus) NOTE: Embedding FWD MUP
        h = self.input_alpha * self.tok_embeddings(token_values)

        freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)

        for i, layer in enumerate(self.layers):
            h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)

        # (krotonus) NOTE: Output Logit FWD. MUP
        logits = (self.output(self.norm(h)) * self.output_alpha) / self.scaling_factor
        if target is not None:
            return cross_entropy(logits, target)
        else:
            return logits

    def init_weights(self):
        self.reset_parameters()
        out_proj_factor =  math.sqrt(2 * args.n_layers * args.scaling_factor)
        for depth, layer in enumerate(self.layers):
            layer.init_weights(self.init_base_std, out_proj_factor)

In [48]:
inp = torch.randint(low=0, high=256, size=(1, 10))

args = MupTransformerArgs()
args.n_heads = 4
args.n_layers = 2
args.dim = 128
args.max_seqlen = 128
args.vocab_size = 256
args.scaling_factor = 1.0
print(args)
print("-"*100)
optim_args = MupOptimArgs()
optim_args.scaling_factor = 1.0
model = MupTransformer(args)
print(model)

MupTransformerArgs(dim=128,
                   n_layers=2,
                   head_dim=None,
                   n_heads=4,
                   n_kv_heads=None,
                   ffn_dim_multiplier=None,
                   multiple_of=256,
                   norm_eps=1e-05,
                   rope_theta=10000.0,
                   init_base_std=None,
                   init_std_factor='disabled',
                   max_seqlen=128)
'----------------------------------------------------------------------------------------------------'
MupTransformer(
  (rope_embeddings): RotaryEmbedding()
  (layers): ModuleList(
    (0-1): 2 x MupTransformerBlock(
      (attention): MupAttention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=128, bias=False)
        (wv): Linear(in_features=128, out_features=128, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): MupFeedForwar

In [28]:
model.init_weights()

In [29]:
inp.size(), inp.shape

(torch.Size([1, 10]), torch.Size([1, 10]))

In [30]:
model(inp)

tensor([[[-2.1746e-01,  3.7759e-02, -4.3266e-04,  ..., -3.6871e-01,
           8.3709e-01, -1.3619e-01],
         [ 4.0747e-01,  2.7429e-01, -7.2218e-01,  ...,  9.1531e-01,
           6.4633e-01, -1.0130e+00],
         [ 4.1656e-01,  3.1421e-01, -4.0535e-01,  ..., -1.5261e-01,
           1.6991e-01, -5.8452e-01],
         ...,
         [-2.8328e-01, -2.4091e-01,  5.4915e-01,  ...,  1.3069e+00,
           4.3992e-01, -5.0319e-01],
         [-2.3307e-01,  9.5362e-01, -1.7055e-01,  ...,  8.9563e-02,
          -7.0343e-01,  3.6659e-01],
         [ 3.4711e-01,  4.3249e-01, -9.6958e-01,  ...,  8.5858e-01,
           1.9173e+00, -2.0373e-01]]], grad_fn=<DivBackward0>)

In [31]:
for name, param in model.named_parameters():
    print(f"{name} : {param.shape}")

'layers.0.attention.wq.weight : torch.Size([128, 128])'
'layers.0.attention.wk.weight : torch.Size([128, 128])'
'layers.0.attention.wv.weight : torch.Size([128, 128])'
'layers.0.attention.wo.weight : torch.Size([128, 128])'
'layers.0.feed_forward.w1.weight : torch.Size([512, 128])'
'layers.0.feed_forward.w3.weight : torch.Size([512, 128])'
'layers.0.feed_forward.w2.weight : torch.Size([128, 512])'
'layers.0.attention_norm.weight : torch.Size([128])'
'layers.0.ffn_norm.weight : torch.Size([128])'
'layers.1.attention.wq.weight : torch.Size([128, 128])'
'layers.1.attention.wk.weight : torch.Size([128, 128])'
'layers.1.attention.wv.weight : torch.Size([128, 128])'
'layers.1.attention.wo.weight : torch.Size([128, 128])'
'layers.1.feed_forward.w1.weight : torch.Size([512, 128])'
'layers.1.feed_forward.w3.weight : torch.Size([512, 128])'
'layers.1.feed_forward.w2.weight : torch.Size([128, 512])'
'layers.1.attention_norm.weight : torch.Size([128])'
'layers.1.ffn_norm.weight : torch.Size([128])

### Optimizer Configuration

In [60]:
def build_mup_optimizer(
    model: nn.Module,
    args: OptimArgs,
    n_steps: int,
):
    mup_decay_params = []
    decay_params = []
    nodecay_params = []
    for n, p in model.named_parameters():
        if p.dim() >= 2:
            if (
            n.endswith('wq.weight') or
            n.endswith('wk.weight') or
            n.endswith('wv.weight') or
            n.endswith('wo.weight') or
            n.endswith('w1.weight') or
            n.endswith('w2.weight') or
            n.endswith('w3.weight')
            ):
                print(f"Added {n} to mup_decay_list")
                mup_decay_params.append(p)
            else:
                decay_params.append(p)
        else:
            nodecay_params.append(p)
    optim_groups = [
        {'params': mup_decay_params, 'weight_decay': args.weight_decay, 'lr_scale': (1/args.scaling_factor)},
        {'params': decay_params, 'weight_decay': args.weight_decay, 'lr_scale': 1},
        {'params': nodecay_params, 'weight_decay': 0.0, 'lr_scale': 1}
    ]
    num_mup_decay_params = sum(p.numel() for p in mup_decay_params)
    num_decay_params = sum(p.numel() for p in decay_params)
    num_nodecay_params = sum(p.numel() for p in nodecay_params)
    print(f"num mup decayed parameter tensors: {len(mup_decay_params)}, with {num_mup_decay_params:,} parameters")
    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
    print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")

    optimizer = AdamW(
        optim_groups,
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        eps=args.epsilon,
        fused=True,  # Faster optim.step but can throw errors
    )

    # scheduler
    lr_fn = build_lr_fn(args, n_steps)
    scheduler = lr_scheduler.LambdaLR(
        optimizer, lr_fn
    )  # lr_scheduler.LambdaLR(optimizer, lr_fn)

    return optimizer, scheduler

In [61]:
optimizer, scheduler = build_mup_optimizer(model, optim_args, 10)

'Added layers.0.attention.wq.weight to mup_decay_list'
'Added layers.0.attention.wk.weight to mup_decay_list'
'Added layers.0.attention.wv.weight to mup_decay_list'
'Added layers.0.attention.wo.weight to mup_decay_list'
'Added layers.0.feed_forward.w1.weight to mup_decay_list'
'Added layers.0.feed_forward.w3.weight to mup_decay_list'
'Added layers.0.feed_forward.w2.weight to mup_decay_list'
'Added layers.1.attention.wq.weight to mup_decay_list'
'Added layers.1.attention.wk.weight to mup_decay_list'
'Added layers.1.attention.wv.weight to mup_decay_list'
'Added layers.1.attention.wo.weight to mup_decay_list'
'Added layers.1.feed_forward.w1.weight to mup_decay_list'
'Added layers.1.feed_forward.w3.weight to mup_decay_list'
'Added layers.1.feed_forward.w2.weight to mup_decay_list'
'num mup decayed parameter tensors: 14, with 524,288 parameters'
'num decayed parameter tensors: 2, with 65,536 parameters'
'num non-decayed parameter tensors: 5, with 640 parameters'


In [62]:
for name, module in model.named_modules():
    print(f"{name}")

''
'rope_embeddings'
'layers'
'layers.0'
'layers.0.attention'
'layers.0.attention.wq'
'layers.0.attention.wk'
'layers.0.attention.wv'
'layers.0.attention.wo'
'layers.0.feed_forward'
'layers.0.feed_forward.w1'
'layers.0.feed_forward.w3'
'layers.0.feed_forward.w2'
'layers.0.attention_norm'
'layers.0.ffn_norm'
'layers.1'
'layers.1.attention'
'layers.1.attention.wq'
'layers.1.attention.wk'
'layers.1.attention.wv'
'layers.1.attention.wo'
'layers.1.feed_forward'
'layers.1.feed_forward.w1'
'layers.1.feed_forward.w3'
'layers.1.feed_forward.w2'
'layers.1.attention_norm'
'layers.1.ffn_norm'
'tok_embeddings'
'norm'
'output'
