# SaRDinE Training

**Fully self-contained notebook - writes and launches training script**

In [None]:
!pip install -q transformers==5.0.0rc1 torch datasets accelerate wandb huggingface_hub bitsandbytes tqdm

In [None]:
import os
if not os.path.exists('srde.py'):
    print("Cloning SRDE repo...")
    !git clone https://github.com/MinimaML/srde-mistral
    %cd srde-mistral/mistral-3-14b-reasoning
else:
    print("SRDE found.")

In [None]:
# Check GPUs
!nvidia-smi -L
import subprocess
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True)
NUM_GPUS = len([l for l in result.stdout.strip().split('\n') if l.startswith('GPU')])
print(f"\nDetected {NUM_GPUS} GPUs")
# Clear any existing stop file
!rm -f STOP_TRAINING

In [None]:
# === WRITE TRAINING SCRIPT ===

TRAIN_SCRIPT = r'''
#!/usr/bin/env python3
"""SaRDinE Multi-GPU Training Script

Launch with: accelerate launch --num_processes=N --mixed_precision=bf16 train.py
"""
import os, gc, random, threading, time
from collections import deque
from pathlib import Path

import torch
import torch.distributed as dist
import wandb
import bitsandbytes as bnb
from datasets import load_dataset
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from torch.utils.data import DataLoader, IterableDataset
from tqdm.auto import tqdm
from huggingface_hub import login
from accelerate import Accelerator

from config import SRDEConfig
from srde import create_srde_model
from muon import Muon

os.environ['WANDB_API_KEY'] = 'abe32d5463fb2265eaea4563a571c07e5a39b7b6'
login(token='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')

MODEL_NAME = 'mistralai/Ministral-3-14B-Reasoning-2512'
SRDE_CFG = SRDEConfig()
CONFIG = {
    'wandb_project': 'sardine-collab', 'model_name': MODEL_NAME, 'max_length': 2048,
    'target_tokens': 20_000_000_000, 'buffer_tokens': 1_000_000, 'max_steps': 1000000,
    'lr': 1e-4, 'muon_lr': 0.02, 'warmup_steps': 1000, 'save_steps': 5000,
    'batch_size': 8, 'grad_accum': 4, 'checkpoint_dir': './checkpoints',
}

DOMAINS = {
    'math': {'expert': 0, 'weight': 0.30}, 'code': {'expert': 1, 'weight': 0.30},
    'science': {'expert': 2, 'weight': 0.15}, 'logic': {'expert': 3, 'weight': 0.10},
    'planning': {'expert': 4, 'weight': 0.08}, 'abstract': {'expert': 5, 'weight': 0.07},
}

DATA_SOURCES = {
    'math': [('nvidia/OpenMathInstruct-2', None), ('meta-math/MetaMathQA', None)],
    'code': [('ise-uiuc/Magicoder-OSS-Instruct-75K', None), ('bigcode/starcoderdata', None)],
    'science': [('HuggingFaceFW/fineweb-edu', 'sample-10BT')],
    'logic': [('Rowan/hellaswag', None)],
    'planning': [('hotpot_qa', 'fullwiki')],
    'abstract': [('deepmind/aqua_rat', 'raw'), ('winogrande', 'winogrande_xl')],
}

def format_sample(sample, source_name):
    try:
        if 'OpenMath' in source_name:
            return f"Problem: {sample.get('problem', '')}\nSolution: {sample.get('generated_solution', '')}"
        if 'MetaMath' in source_name:
            return f"Q: {sample.get('query', '')}\nA: {sample.get('response', '')}"
        if 'Magicoder' in source_name:
            return f"### Instruction:\n{sample.get('problem', '')}\n### Solution:\n{sample.get('solution', '')}"
        if 'starcoder' in source_name:
            return sample.get('content', '')
        if 'fineweb' in source_name:
            return sample.get('text', '')
        if 'hellaswag' in source_name:
            return f"{sample.get('ctx', '')} {sample.get('endings', [''])[int(sample.get('label', 0))]}"
        if 'hotpot' in source_name:
            return f"Q: {sample.get('question', '')}\nA: {sample.get('answer', '')}"
        if 'aqua' in source_name:
            return f"Q: {sample.get('question', '')}\nA: {sample.get('rationale', '')}"
        if 'winogrande' in source_name:
            return sample.get('sentence', '')
        return sample.get('text', str(sample)[:1000])
    except:
        return ''

class BufferedDomainStream:
    def __init__(self, domain, sources, tokenizer, buffer_tokens=1_000_000):
        self.domain, self.sources, self.tokenizer = domain, sources, tokenizer
        self.buffer_tokens, self.buffer, self.buffer_token_count = buffer_tokens, deque(), 0
        self.lock, self.stop_event, self.thread = threading.Lock(), threading.Event(), None
    def start(self):
        self.thread = threading.Thread(target=self._fill_loop, daemon=True)
        self.thread.start()
    def stop(self):
        self.stop_event.set()
        if self.thread: self.thread.join(timeout=1)
    def _fill_loop(self):
        while not self.stop_event.is_set():
            with self.lock:
                if self.buffer_token_count >= self.buffer_tokens:
                    time.sleep(0.1)
                    continue
            for src, sub in self.sources:
                if self.stop_event.is_set(): break
                try:
                    ds = load_dataset(src, sub, split='train', streaming=True, token=True)
                    for s in ds:
                        if self.stop_event.is_set(): break
                        txt = format_sample(s, src)
                        if len(txt) < 50: continue
                        toks = len(self.tokenizer.encode(txt, add_special_tokens=False))
                        with self.lock:
                            self.buffer.append({'text': txt, 'domain': self.domain, 'tokens': toks})
                            self.buffer_token_count += toks
                            if self.buffer_token_count >= self.buffer_tokens * 1.2: break
                except Exception as e:
                    print(f'[{self.domain}] Stream error {src}: {e}')
                    time.sleep(2)
    def get_sample(self):
        with self.lock:
            if self.buffer:
                s = self.buffer.popleft()
                self.buffer_token_count -= s['tokens']
                return s
        return None
    def buffer_status(self):
        with self.lock: return self.buffer_token_count

class BufferedStreamDataset(IterableDataset):
    def __init__(self, domains, tokenizer, buffer_tokens=1_000_000, rank=0):
        self.domains = list(domains.keys())
        self.weights = [domains[d]['weight'] for d in self.domains]
        self.tokenizer, self.buffer_tokens = tokenizer, buffer_tokens
        self.streams, self.started = {}, False
        self.rank = rank
    def start_streams(self):
        if self.started: return
        if self.rank == 0: print('Starting streams...')
        for d in self.domains:
            self.streams[d] = BufferedDomainStream(d, DATA_SOURCES.get(d, []), self.tokenizer, self.buffer_tokens)
            self.streams[d].start()
        self.started = True
        while True:
            m = min(s.buffer_status() for s in self.streams.values())
            if self.rank == 0: print(f'  Buffer: {m/1e6:.2f}M/{self.buffer_tokens/1e6:.0f}M', end='\r')
            if m >= self.buffer_tokens * 0.5:
                if self.rank == 0: print('\nReady!')
                break
            time.sleep(1)
    def stop_streams(self):
        for s in self.streams.values(): s.stop()
    def __iter__(self):
        self.start_streams()
        rng = random.Random(42 + self.rank)
        empty_count = 0
        while True:
            d = rng.choices(self.domains, weights=self.weights, k=1)[0]
            s = self.streams[d].get_sample()
            if s:
                empty_count = 0
                yield s
            else:
                found = False
                for dd in self.domains:
                    s = self.streams[dd].get_sample()
                    if s:
                        empty_count = 0
                        found = True
                        yield s
                        break
                if not found:
                    empty_count += 1
                    if empty_count > 100:
                        time.sleep(0.1)
                        empty_count = 0

def get_loss(out):
    return out.get('loss') if isinstance(out, dict) else getattr(out, 'loss', None)

def get_params_for_muon(p):
    return [x for x in p if x.ndim >= 2], [x for x in p if x.ndim < 2]

def get_gpu_stats():
    s = {}
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            a = torch.cuda.memory_allocated(i) / 1e9
            t = torch.cuda.get_device_properties(i).total_memory / 1e9
            s[f'gpu{i}_gb'] = a
            s[f'gpu{i}_pct'] = (a / t) * 100
        s['gpu_total_gb'] = sum(torch.cuda.memory_allocated(i) for i in range(torch.cuda.device_count())) / 1e9
    return s

def cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    if dist.is_initialized():
        dist.destroy_process_group()

def main():
    # Get local rank from environment (set by accelerate/torchrun)
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    
    # Set CUDA device BEFORE any CUDA operations
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
    
    print(f'[Rank {local_rank}/{world_size}] Starting on cuda:{local_rank}')
    
    acc = Accelerator(mixed_precision='bf16', gradient_accumulation_steps=CONFIG['grad_accum'])
    is_main, device = acc.is_main_process, acc.device
    
    print(f'[Rank {local_rank}] Accelerator device: {device}')
    
    if is_main:
        wandb.init(project=CONFIG['wandb_project'], name=f'sardine-{world_size}gpu', config={**CONFIG, 'world_size': world_size})
    
    tok = AutoTokenizer.from_pretrained(CONFIG['model_name'], trust_remote_code=True)
    tok.pad_token = tok.eos_token
    
    ds = BufferedStreamDataset(DOMAINS, tok, CONFIG['buffer_tokens'], rank=local_rank)
    
    def collate(batch):
        e = tok([b['text'][:8192] for b in batch], truncation=True, max_length=CONFIG['max_length'], padding='max_length', return_tensors='pt')
        return e['input_ids'], e['attention_mask']
    
    dl = DataLoader(ds, batch_size=CONFIG['batch_size'], collate_fn=collate)
    
    # Load model directly to this rank's GPU
    print(f'[Rank {local_rank}] Loading model to cuda:{local_rank}...')
    model = create_srde_model(
        CONFIG['model_name'], 
        torch_dtype=torch.bfloat16, 
        trust_remote_code=True, 
        device_map={'': local_rank}  # Load directly to this GPU
    )
    
    tc = 0
    for n, p in model.named_parameters():
        if 'srde_layers' in n or 'expert' in n or 'router' in n or 'vocabulary' in n:
            p.requires_grad = True
            tc += p.numel()
        else:
            p.requires_grad = False
    print(f'[Rank {local_rank}] Trainable: {tc/1e6:.1f}M')
    
    if hasattr(model, 'base_model') and hasattr(model.base_model, 'gradient_checkpointing_enable'):
        model.base_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
    
    all_p = [p for p in model.parameters() if p.requires_grad]
    muon_p, adam_p = get_params_for_muon(all_p)
    print(f'[Rank {local_rank}] Muon: {sum(p.numel() for p in muon_p)/1e6:.1f}M | Adam: {sum(p.numel() for p in adam_p)/1e6:.1f}M')
    
    opts, scheds = [], []
    if muon_p:
        o = Muon(muon_p, lr=CONFIG['muon_lr'], momentum=0.95)
        opts.append(o)
        scheds.append(get_cosine_schedule_with_warmup(o, CONFIG['warmup_steps'], CONFIG['max_steps']))
    if adam_p:
        o = bnb.optim.AdamW8bit(adam_p, lr=CONFIG['lr'])
        opts.append(o)
        scheds.append(get_cosine_schedule_with_warmup(o, CONFIG['warmup_steps'], CONFIG['max_steps']))
    
    # Prepare with accelerate - wraps with DDP for gradient sync
    model, dl, *rest = acc.prepare(model, dl, *opts, *scheds)
    n = len(opts)
    opts, scheds = list(rest[:n]), list(rest[n:])
    
    print(f'[Rank {local_rank}] Post-prepare GPU: {get_gpu_stats()}')
    
    Path(CONFIG['checkpoint_dir']).mkdir(exist_ok=True, parents=True)
    step, total_tokens = 0, 0
    tpb = CONFIG['batch_size'] * CONFIG['max_length'] * world_size
    
    pbar = tqdm(total=CONFIG['target_tokens'], unit='tok', unit_scale=True, disable=not is_main)
    
    model.train()
    print(f'[Rank {local_rank}] Starting training loop...')
    
    try:
        for ids, mask in dl:
            if total_tokens >= CONFIG['target_tokens']:
                print(f'\nReached {total_tokens/1e9:.1f}B!')
                break
            if os.path.exists('STOP_TRAINING'):
                print('Stop signal received')
                break
            
            with acc.accumulate(model):
                with acc.autocast():
                    out = model(ids, attention_mask=mask, labels=ids)
                    loss = get_loss(out)
                if loss is not None:
                    acc.backward(loss)
                    for o in opts:
                        o.step()
                        o.zero_grad()
                    for s in scheds:
                        s.step()
            
            step += 1
            total_tokens += tpb
            
            if step % 10 == 0 and is_main:
                lr = scheds[0].get_last_lr()[0] if scheds else 0
                lv = loss.item() if loss is not None else 0
                buf = sum(s.buffer_status() for s in ds.streams.values()) / 1e6 if ds.streams else 0
                gs = get_gpu_stats()
                wandb.log({'loss': lv, 'step': step, 'lr': lr, 'tokens_B': total_tokens/1e9, 'buffer_M': buf, **gs})
                pbar.set_postfix({'loss': f'{lv:.4f}', 'gpu': f"{gs.get('gpu_total_gb', 0):.1f}GB"})
                pbar.update(tpb * 10)
            
            if step % CONFIG['save_steps'] == 0 and is_main:
                acc.wait_for_everyone()
                ckpt = Path(CONFIG['checkpoint_dir']) / f'ckpt-{step}'
                ckpt.mkdir(exist_ok=True, parents=True)
                acc.unwrap_model(model).save_srde_weights(str(ckpt / 'weights.pt'))
                print(f'\nSaved: {ckpt}')
    
    except KeyboardInterrupt:
        print('\nInterrupted')
    finally:
        ds.stop_streams()
        if is_main and step > 0:
            acc.wait_for_everyone()
            final = Path(CONFIG['checkpoint_dir']) / f'final-{int(total_tokens/1e9)}B'
            final.mkdir(exist_ok=True, parents=True)
            acc.unwrap_model(model).save_srde_weights(str(final / 'weights.pt'))
            wandb.finish()
        cleanup()

if __name__ == '__main__':
    main()
'''

with open('train.py', 'w') as f:
    f.write(TRAIN_SCRIPT)
print("âœ… train.py written (batch_size=8)")

In [None]:
# === LAUNCH TRAINING ===

import subprocess
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True)
NUM_GPUS = len([l for l in result.stdout.strip().split('\n') if l.startswith('GPU')])

print(f"ðŸš€ Launching on {NUM_GPUS} GPUs (batch_size=8)...")
!accelerate launch --num_processes={NUM_GPUS} --mixed_precision=bf16 train.py

In [None]:
# === STOP TRAINING (uncomment and run when you want to stop) ===
# !touch STOP_TRAINING
# print("Stop signal sent.")