# SaRDinE Training

**Fully self-contained notebook**

In [None]:
# === API KEYS ===
import os
os.environ['WANDB_API_KEY'] = 'abe32d5463fb2265eaea4563a571c07e5a39b7b6'
os.environ['HF_TOKEN'] = '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'

!pip install -q transformers==5.0.0rc1 torch datasets accelerate wandb huggingface_hub bitsandbytes tqdm

In [None]:
import os
if not os.path.exists('srde.py'):
    print("Cloning SRDE repo...")
    !git clone https://github.com/MinimaML/srde-mistral
    %cd srde-mistral/mistral-3-14b-reasoning
else:
    print("SRDE found.")

# Patch srde.py for multi-GPU and aux loss
print("Applying patches...")
with open('srde.py', 'r') as f:
    content = f.read()

# ===== PATCH 1: Multi-GPU lm_head device fix =====
old_device = '''                # Get device from language model
                device = next(language_model.parameters()).device
                
                lm_head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=torch_dtype)'''

new_device = '''                # Get target device from device_map
                if isinstance(device_map, dict) and '' in device_map:
                    target_device = f"cuda:{device_map['']}" if isinstance(device_map[''], int) else device_map['']
                elif isinstance(device_map, str) and device_map.startswith('cuda'):
                    target_device = device_map
                else:
                    target_device = next(language_model.parameters()).device
                
                print(f"[SRDE] Creating lm_head on: {target_device}")
                lm_head = nn.Linear(hidden_size, vocab_size, bias=False, device=target_device, dtype=torch_dtype)'''

if old_device in content:
    content = content.replace(old_device, new_device)
    content = content.replace(
        'lm_head.weight.copy_(language_model.embed_tokens.weight)',
        'lm_head.weight.copy_(language_model.embed_tokens.weight.to(target_device))'
    )
    content = content.replace(
        'print(f"[SRDE] Initialized lm_head from embed_tokens (trainable copy, device={device})")',
        'print(f"[SRDE] lm_head initialized on {target_device}")'
    )
    print("  âœ… Multi-GPU lm_head fix")

# ===== PATCH 2: Fix aux loss computation (root cause of 'Error computing aux loss') =====
old_aux = '''            router_probs = F.softmax(router_logits, dim=-1)
            avg_probs = router_probs.mean(dim=list(range(router_probs.dim() - 1)))'''

new_aux = '''            # router_logits shape: [batch, seq, num_experts] or [batch*seq, num_experts]
            router_probs = F.softmax(router_logits, dim=-1)
            
            # Flatten all dims except last (expert dim) and compute mean
            if router_probs.dim() > 1:
                flat_probs = router_probs.view(-1, router_probs.size(-1))
                avg_probs = flat_probs.mean(dim=0)
            else:
                avg_probs = router_probs'''

if old_aux in content:
    content = content.replace(old_aux, new_aux)
    print("  âœ… Aux loss computation fix")

# ===== PATCH 3: Spam fix =====
old_spam = '''        except Exception as e:
            logger.warning(f"Error computing aux loss: {e}")
            return torch.tensor(0.0, device=router_logits.device)'''

new_spam = '''        except Exception as e:
            if not hasattr(self, '_aux_warned'):
                logger.warning(f"Aux loss skipped: {type(e).__name__}: {e}")
                self._aux_warned = True
            return torch.tensor(0.0, device=router_logits.device if router_logits is not None else 'cpu', requires_grad=True)'''

if old_spam in content:
    content = content.replace(old_spam, new_spam)
    print("  âœ… Aux loss spam fix")

# ===== PATCH 4: Null check =====
old_null = '''    def _compute_aux_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
        """Compute load balancing auxiliary loss."""
        try:'''

new_null = '''    def _compute_aux_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
        """Compute load balancing auxiliary loss."""
        if router_logits is None:
            return torch.tensor(0.0, device='cuda' if torch.cuda.is_available() else 'cpu', requires_grad=True)
        try:'''

if old_null in content and 'if router_logits is None:' not in content:
    content = content.replace(old_null, new_null)
    print("  âœ… Null check added")

with open('srde.py', 'w') as f:
    f.write(content)
print("Done!")

In [None]:
!nvidia-smi -L
import subprocess
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True)
NUM_GPUS = len([l for l in result.stdout.strip().split('\n') if l.startswith('GPU')])
print(f"\nDetected {NUM_GPUS} GPUs")
!rm -f STOP_TRAINING

In [None]:
# === WRITE TRAINING SCRIPT ===

TRAIN_SCRIPT = r'''
#!/usr/bin/env python3
"""SaRDinE Multi-GPU Training"""
import os, gc, random, time
from pathlib import Path

import torch
import torch.distributed as dist
import wandb
import bitsandbytes as bnb
from datasets import load_dataset, interleave_datasets
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from huggingface_hub import login
from accelerate import Accelerator

from config import SRDEConfig
from srde import create_srde_model
from muon import Muon

if os.environ.get('HF_TOKEN'):
    login(token=os.environ['HF_TOKEN'])

MODEL_NAME = 'mistralai/Ministral-3-14B-Reasoning-2512'
CONFIG = {
    'wandb_project': 'sardine-collab', 'model_name': MODEL_NAME, 'max_length': 2048,
    'target_tokens': 20_000_000_000, 'max_steps': 1000000,
    'lr': 1e-4, 'muon_lr': 0.02, 'warmup_steps': 1000, 'save_steps': 5000,
    'batch_size': 8, 'grad_accum': 4, 'checkpoint_dir': './checkpoints',
}

DATA_SOURCES = [
    ('nvidia/OpenMathInstruct-2', None, 0.25),
    ('meta-math/MetaMathQA', None, 0.15),
    ('ise-uiuc/Magicoder-OSS-Instruct-75K', None, 0.20),
    ('Rowan/hellaswag', None, 0.15),
    ('deepmind/aqua_rat', 'raw', 0.10),
    ('winogrande', 'winogrande_xl', 0.15),
]

def format_sample(sample):
    for key in ['text', 'content', 'problem', 'question', 'sentence', 'ctx']:
        if key in sample and sample[key]:
            text = str(sample[key])
            for ans_key in ['answer', 'solution', 'response', 'generated_solution', 'rationale']:
                if ans_key in sample and sample[ans_key]:
                    text += f"\n{sample[ans_key]}"
                    break
            return {'text': text}
    return {'text': str(sample)[:2000]}

def get_loss(out):
    return out.get('loss') if isinstance(out, dict) else getattr(out, 'loss', None)

def get_params_for_muon(p):
    return [x for x in p if x.ndim >= 2], [x for x in p if x.ndim < 2]

def cleanup():
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    if dist.is_initialized(): dist.destroy_process_group()

def main():
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    if torch.cuda.is_available(): torch.cuda.set_device(local_rank)
    
    print(f'[R{local_rank}/{world_size}] Starting')
    
    acc = Accelerator(mixed_precision='bf16', gradient_accumulation_steps=CONFIG['grad_accum'])
    is_main = acc.is_main_process
    
    if is_main:
        wandb.init(project=CONFIG['wandb_project'], name=f'sardine-{world_size}gpu', config=CONFIG)
    
    tok = AutoTokenizer.from_pretrained(CONFIG['model_name'], trust_remote_code=True)
    tok.pad_token = tok.eos_token
    
    print(f'[R{local_rank}] Loading datasets...')
    datasets = []
    probs = []
    for name, config, prob in DATA_SOURCES:
        try:
            ds = load_dataset(name, config, split='train', streaming=True, token=True)
            ds = ds.map(format_sample, remove_columns=ds.column_names)
            datasets.append(ds)
            probs.append(prob)
            if local_rank == 0: print(f'  âœ“ {name}')
        except Exception as e:
            if local_rank == 0: print(f'  âœ— {name}: {e}')
    
    if not datasets:
        print('No datasets!')
        return
    
    total = sum(probs)
    probs = [p/total for p in probs]
    combined = interleave_datasets(datasets, probabilities=probs, stopping_strategy='all_exhausted')
    
    def collate(batch):
        texts = [b['text'][:8192] for b in batch]
        e = tok(texts, truncation=True, max_length=CONFIG['max_length'], padding='max_length', return_tensors='pt')
        return {'input_ids': e['input_ids'], 'attention_mask': e['attention_mask']}
    
    dl = DataLoader(combined, batch_size=CONFIG['batch_size'], collate_fn=collate)
    
    print(f'[R{local_rank}] Loading model...')
    model = create_srde_model(CONFIG['model_name'], torch_dtype=torch.bfloat16, trust_remote_code=True, device_map={'': local_rank})
    
    tc = 0
    for n, p in model.named_parameters():
        if 'srde_layers' in n or 'expert' in n or 'router' in n or 'vocabulary' in n:
            p.requires_grad = True
            tc += p.numel()
        else:
            p.requires_grad = False
    print(f'[R{local_rank}] Trainable: {tc/1e6:.1f}M')
    
    if hasattr(model, 'base_model') and hasattr(model.base_model, 'gradient_checkpointing_enable'):
        model.base_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
    
    all_p = [p for p in model.parameters() if p.requires_grad]
    muon_p, adam_p = get_params_for_muon(all_p)
    
    opts, scheds = [], []
    if muon_p:
        o = Muon(muon_p, lr=CONFIG['muon_lr'], momentum=0.95)
        opts.append(o)
        scheds.append(get_cosine_schedule_with_warmup(o, CONFIG['warmup_steps'], CONFIG['max_steps']))
    if adam_p:
        o = bnb.optim.AdamW8bit(adam_p, lr=CONFIG['lr'])
        opts.append(o)
        scheds.append(get_cosine_schedule_with_warmup(o, CONFIG['warmup_steps'], CONFIG['max_steps']))
    
    model, dl, *rest = acc.prepare(model, dl, *opts, *scheds)
    n = len(opts)
    opts, scheds = list(rest[:n]), list(rest[n:])
    
    print(f'[R{local_rank}] Ready!')
    
    Path(CONFIG['checkpoint_dir']).mkdir(exist_ok=True, parents=True)
    step, total_tokens = 0, 0
    tpb = CONFIG['batch_size'] * CONFIG['max_length'] * world_size
    pbar = tqdm(total=CONFIG['target_tokens'], unit='tok', unit_scale=True, disable=not is_main)
    model.train()
    
    try:
        for batch in dl:
            if os.path.exists('STOP_TRAINING') or total_tokens >= CONFIG['target_tokens']:
                break
            
            ids = batch['input_ids']
            mask = batch['attention_mask']
            
            with acc.accumulate(model):
                with acc.autocast():
                    out = model(ids, attention_mask=mask, labels=ids)
                    loss = get_loss(out)
                if loss is not None:
                    acc.backward(loss)
                    for o in opts: o.step(); o.zero_grad()
                    for s in scheds: s.step()
            
            step += 1
            total_tokens += tpb
            
            if step % 10 == 0 and is_main:
                lv = loss.item() if loss else 0
                mem = torch.cuda.memory_allocated(local_rank) / 1e9
                wandb.log({'loss': lv, 'step': step, 'tokens_B': total_tokens/1e9, 'gpu_gb': mem})
                pbar.set_postfix({'loss': f'{lv:.4f}', 'mem': f'{mem:.1f}GB'})
                pbar.update(tpb * 10)
            
            if step % CONFIG['save_steps'] == 0 and is_main:
                acc.wait_for_everyone()
                ckpt = Path(CONFIG['checkpoint_dir']) / f'ckpt-{step}'
                ckpt.mkdir(exist_ok=True, parents=True)
                acc.unwrap_model(model).save_srde_weights(str(ckpt / 'weights.pt'))
                print(f'\nSaved {ckpt}')
    except Exception as e:
        print(f'[R{local_rank}] Error: {e}')
        import traceback
        traceback.print_exc()
    finally:
        if is_main and step > 0:
            acc.wait_for_everyone()
            final = Path(CONFIG['checkpoint_dir']) / f'final-{int(total_tokens/1e9)}B'
            final.mkdir(exist_ok=True, parents=True)
            acc.unwrap_model(model).save_srde_weights(str(final / 'weights.pt'))
            wandb.finish()
        cleanup()

if __name__ == '__main__':
    main()
'''

with open('train.py', 'w') as f:
    f.write(TRAIN_SCRIPT)
print("âœ… train.py written (batch_size=8)")

In [None]:
# === LAUNCH TRAINING ===
import subprocess
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True)
NUM_GPUS = len([l for l in result.stdout.strip().split('\n') if l.startswith('GPU')])

print(f"ðŸš€ Launching on {NUM_GPUS} GPUs...")
!accelerate launch --num_processes={NUM_GPUS} --mixed_precision=bf16 train.py

In [None]:
# === STOP TRAINING ===
# !touch STOP_TRAINING