# DoReMi 

## How it works

- Step 1: Train a small referece model using uniform sampling from each data domain (for a given batch size).
- Step 2: Used the trained reference model from previous steps to train an identical mode, and use its performance to dynamically tune the domain weights.
- Step 3: Save the domain weights in model checkpoint. Calculate the optimal domain weights by averaging the domain weights across all the training steps.
- Step 4: Use the optimized domain weights from previous step to train a larger model. (10x-30x larger)

In [1]:
import os
import sys
sys.path.append("/data/horse/ws/lama722b-nanite-lm/nanite-lm/")

In [2]:
def b():
    try:
        from ipdb import set_trace
    except:
        from pdb import set_trace
    set_trace()

In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np

from codebase.data import (
    DataArgs,
    build_dataloader_from_args,
    init_dataloader_state_from_args
)
from codebase.tokenizer import (
    build_tokenizer,
    TokenizerArgs
)
from codebase.optim import (
    build_optimizer,
    OptimArgs
)

from codebase.transformer import (
    BaseTransformer,
    RMSNorm
)

from experiments.baseline_transformer.transformer import (
    LMTransformerArgs,
    LMTransformer,
    create_causal_mask
)

NUM_TRAIN_STEPS = 500

### Per-Domain Cross Entropy Loss

In [15]:
def cross_entropy(pred, target, **kwargs):
    inp = F.log_softmax(pred.flatten(end_dim=-2).float(), -1)
    return F.nll_loss(
        inp,
        target.flatten(end_dim=-1),
        reduction = "mean",
        **kwargs,
    )

def per_token_cross_entropy(pred, target, **kwargs):
    inp = F.log_softmax(pred.flatten(end_dim=-2).float(), -1)
    return F.nll_loss(
        inp,
        target.flatten(end_dim=-1),
        reduction = "none",
        **kwargs,
    )

### DoReMi Context

- This is used to maintain the doremi weights and thier history

In [5]:
from dataclasses import dataclass, field
from typing import List, TypedDict

class WeightHistory(TypedDict):
    step: int
    weight: torch.Tensor

@dataclass
class DoReMiContext:
    # Note(krotonus): This is the current domain weights
    domain_keys: List[str]
    is_proxy: bool
    step_size: float = 1
    smoothing_param: float = 1e-3
    domain_weight_history: WeightHistory = field(default_factory=list)

    @property
    def num_domain(self) -> int:
        return len(self.domain_keys)

    def get_domain_name(self, domain_idx: int) -> str:
        return self.domain_keys[domain_idx]

    def __post_init__(self):
        self.domain_weights = torch.ones(self.num_domains) / self.num_domains
        self.add_weight_with_history(self.domain_weights, 0)
    
    def add_weight_with_history(self, domain_weights, step):
        self.domain_weight_history.append(WeightHistory(step=step, weight=domain_weights.cpu()))        

In [5]:
tokenizer = build_tokenizer(
        name = "sp",
        path = "/home/lama722b/nanite_lm/tokenizers/gemma/tokenizer.model"
)


def compute_initial_weights(data_args):
    num_samples_per_domain = [len(d) for d in datasets]
    total_samples = sum(num_samples_per_domain)
    weights = torch.tensor([num_sample / total_samples for num_sample in num_samples_per_domain])
    return weights

In [7]:
sources = data_args.sources
n_sources = len(sources)
possible_sources = list(sources.keys())
weights = list(sources.values())
norm_weights = np.array(weights) / np.array(weights).sum()
print(n_sources, possible_sources, weights, norm_weights)

2 ['de_shuffled', 'en_shuffled'] [1, 1] [0.5 0.5]


### Trial by Training

In [5]:
from contextlib import ExitStack
from dataclasses import dataclass

In [20]:
class DoReMiTransformer(BaseTransformer):
    def __init__(self, args):
        super().__init__(args)
        self.weight_tying = args.weight_tying
        self.sliding_window = args.sliding_window

        assert args.vocab_size > 0

        self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)

        self.norm = RMSNorm(args.dim, eps=args.norm_eps)

        if args.weight_tying:
            self.output = TiedLinear(self.tok_embeddings)
        else:
            self.output = nn.Linear(
                args.dim,
                args.vocab_size,
                bias=False,
            )

    def forward(
        self,
        token_values,
        target = None,
        tok_idx = None,
        mask = None,
        attn_impl = "sdpa",
    ):
        bsz, seqlen = token_values.shape

        h = self.tok_embeddings(token_values)

        mask = (
            mask
            if mask is not None
            else create_causal_mask(seqlen, attn_impl, self.sliding_window)
        )

        h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
        logits = self.output(self.norm(h))
        b()
        if target is not None:
            loss = per_token_cross_entropy(logits, target)
            return loss
        else:
            return logits

In [17]:
# Args
tok_args = TokenizerArgs(
    name = "sp",
    path = "/home/lama722b/nanite_lm/tokenizers/gemma/tokenizer.model"
)

tokenizer = build_tokenizer(name=tok_args.name, path=tok_args.path)

data_args = DataArgs(
    root_dir = "/home/lama722b/nanite_lm/data/fineweb",
    sources = {
        "de_shuffled": 1,
        "en_shuffled": 1
    },
    batch_size = 1,
    seq_len=512,
    load_async = False,
    prefetch_size = 2,
    tokenizer = tok_args
)

optim_args = OptimArgs()

model_args = LMTransformerArgs(
    vocab_size = tokenizer.n_words,
    n_heads = 4,
    n_layers = 2,
    dim = 128,
    max_seqlen = data_args.seq_len
)

@dataclass
class TrainArgs:
    data: DataArgs
    model: LMTransformerArgs
    optim: OptimArgs
    steps: int

args = TrainArgs(
    data = data_args,
    model = model_args,
    optim = optim_args,
    steps = NUM_TRAIN_STEPS
)

In [18]:
print(args)

TrainArgs(data=DataArgs(root_dir='/home/lama722b/nanite_lm/data/fineweb', sources={'de_shuffled': 1, 'en_shuffled': 1}, batch_size=1, seq_len=512, n_views=2, seed=42, add_bos=True, add_eos=True, load_async=False, prefetch_size=2, tokenizer=TokenizerArgs(name='sp', path='/home/lama722b/nanite_lm/tokenizers/gemma/tokenizer.model')), model=LMTransformerArgs(dim=128, n_layers=2, head_dim=None, n_heads=4, n_kv_heads=None, ffn_dim_multiplier=None, multiple_of=256, norm_eps=1e-05, rope_theta=10000.0, init_base_std=None, init_std_factor='disabled', max_seqlen=512, seed=42, vocab_size=262144, weight_tying=False, sliding_window=None), optim=OptimArgs(lr=0.0003, weight_decay=0.1, epsilon=1e-08, beta1=0.9, beta2=0.95, clip=1.0, scheduler='cosine', warmup=2000, lr_min_ratio=0.1, cycle_length=1.0, cosine_theta=1.0, annealing_step=1000, decay_fraction=0.1, exp_factor=0.5), steps=500)


In [None]:
with ExitStack() as context_stack:
    data_loader_state = init_dataloader_state_from_args(
            args.data, rank=0, world_size=1 # Using dummy rank/degree for non-distributed
        )
    data_loader = context_stack.enter_context(
                build_dataloader_from_args(
                    args.data,
                    state=data_loader_state,
                )
            )
    model = DoReMiTransformer(args.model)
    optimizer, scheduler = build_optimizer(model, args.optim, args.steps)

    model.train()
    print(f"Training the model for {args.steps} iterations")
    for n in range(args.steps):
        batch, _ = next(data_loader)
        batch = torch.tensor(batch)
        input_ids = batch[:, :, 0]
        labels = batch[:, :, 1]
        loss = model(input_ids, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        print(f"Train step {n}; Loss {loss.item()}")

Training the model for 500 iterations
--Return--
None
> [32m/tmp/ipykernel_835939/3523260694.py[39m([92m6[39m)[36mb[39m[34m()[39m
[32m      4[39m     [38;5;28;01mexcept[39;00m:
[32m      5[39m         [38;5;28;01mfrom[39;00m pdb [38;5;28;01mimport[39;00m set_trace
[32m----> 6[39m     set_trace()



ipdb>  n


> [32m/tmp/ipykernel_835939/3187652246.py[39m([92m43[39m)[36mforward[39m[34m()[39m
[32m     42[39m         b()
[32m---> 43[39m         [38;5;28;01mif[39;00m target [38;5;28;01mis[39;00m [38;5;28;01mnot[39;00m [38;5;28;01mNone[39;00m:
[32m     44[39m             loss = per_token_cross_entropy(logits, target)



ipdb>  target is None


False


ipdb>  l1 = cross_entropy(logits, target)
ipdb>  l2 = per_token_cross_entropy(logits, target)
ipdb>  l1.shape


torch.Size([])


ipdb>  l1


tensor(12.6476, grad_fn=<NllLossBackward0>)


ipdb>  l2.shape


torch.Size([512])


ipdb>  l2


tensor([13.4736, 12.9166, 12.3519, 12.5545, 12.1786, 12.3585, 12.6308, 12.1326,
        13.0769, 12.6334, 12.7345, 12.1455, 12.5079, 12.4826, 13.2066, 13.3506,
        12.6688, 12.0890, 11.9188, 12.3264, 12.9263, 13.2898, 12.4452, 12.8935,
        12.2390, 12.4342, 13.4504, 12.6023, 11.3511, 13.3773, 13.7091, 12.4095,
        13.2928, 12.7126, 11.4589, 13.6477, 12.5901, 12.9323, 12.1094, 11.8345,
        12.0005, 12.8073, 12.3614, 12.3090, 12.6927, 12.6569, 12.3294, 12.9176,
        13.7227, 12.9877, 12.7442, 12.1720, 12.5801, 12.8247, 12.2506, 12.5086,
        12.8647, 12.6438, 12.6678, 13.0545, 13.0623, 12.2964, 13.5722, 12.4683,
        13.3576, 12.5081, 13.1599, 12.3724, 13.3186, 12.3439, 12.8573, 12.8510,
        11.7582, 11.9754, 12.5636, 12.2200, 12.5077, 12.5504, 13.0285, 12.6178,
        13.2060, 13.4127, 12.9866, 12.8926, 13.1230, 12.7973, 13.2263, 12.8165,
        12.8925, 12.3842, 12.6197, 11.7091, 14.1016, 11.7366, 13.2627, 13.0226,
        12.6179, 13.2870, 12.1634, 12.06

In [13]:
reference_model = train_model(args)

--Return--
None
> [32m/tmp/ipykernel_834270/2138924361.py[39m([92m7[39m)[36mb[39m[34m()[39m
[32m      5[39m         [38;5;28;01mfrom[39;00m pdb [38;5;28;01mimport[39;00m set_trace
[32m      6[39m 
[32m----> 7[39m     set_trace()



ipdb>  c


Training the model for 500 iterations
Train step 1; Loss 12.68013858795166
Train step 499; Loss 10.261428833007812


In [18]:
model = LMTransformer(args.model)
model.weights

AttributeError: 'LMTransformer' object has no attribute 'weights'