# MobileLLM Scaling Laws Study

**Run on Google Colab with T4 GPU**

## 1. Setup

In [None]:
!pip install -q transformers datasets accelerate sentencepiece
!pip install -q scipy scikit-learn matplotlib seaborn pandas tqdm

In [None]:
import torch
print(f"GPU: {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "No GPU!")

In [None]:
from huggingface_hub import login
login()

In [None]:
import json, time
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

RESULTS_DIR = Path('results')
RESULTS_DIR.mkdir(exist_ok=True)

## 2. Config

In [None]:
MODELS = {
    "MobileLLM-125M": {"hf_name": "facebook/MobileLLM-125M", "params": 124.6e6},
    "MobileLLM-350M": {"hf_name": "facebook/MobileLLM-350M", "params": 345.3e6},
    "MobileLLM-600M": {"hf_name": "facebook/MobileLLM-600M", "params": 603.1e6},
    "MobileLLM-1B": {"hf_name": "facebook/MobileLLM-1B", "params": 1.01e9},
}

@dataclass
class ModelResults:
    model_name: str
    num_params: int
    perplexity: float
    tokens_per_second: float
    peak_memory_mb: float
    wall_clock_seconds: float
    def to_dict(self): return asdict(self)

## 3. Functions

In [None]:
def compute_perplexity(model, tokenizer, max_length=1024, stride=512, device="cuda"):
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    text = "\n\n".join(dataset["text"])
    encodings = tokenizer(text, return_tensors="pt")
    seq_len = encodings.input_ids.size(1)
    print(f"Tokens: {seq_len:,}")
    
    nlls, prev_end = [], 0
    model.eval()
    t0 = time.time()
    
    with torch.no_grad():
        for begin in tqdm(range(0, seq_len, stride), desc="Perplexity"):
            end = min(begin + max_length, seq_len)
            trg_len = end - prev_end
            ids = encodings.input_ids[:, begin:end].to(device)
            tgt = ids.clone()
            tgt[:, :-trg_len] = -100
            out = model(ids, labels=tgt)
            nlls.append(out.loss * trg_len)
            prev_end = end
            if end == seq_len: break
    
    ppl = torch.exp(torch.stack(nlls).sum() / seq_len).item()
    return {"perplexity": ppl, "tokens": seq_len, "time": time.time() - t0}

In [None]:
def measure_speed(model, tokenizer, device="cuda"):
    model.eval()
    inputs = tokenizer("The quick brown fox", return_tensors="pt").to(device)
    with torch.no_grad(): model.generate(**inputs, max_new_tokens=10, use_cache=False)
    
    times = []
    for _ in range(5):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.no_grad():
            model.generate(**inputs, max_new_tokens=50, do_sample=False, use_cache=False, pad_token_id=tokenizer.eos_token_id)
        torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    return 50 / np.mean(times)

In [None]:
def run_model(name, config):
    print(f"\n{'='*50}\n{name}\n{'='*50}")
    t0 = time.time()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    hf = config["hf_name"]
    tok = AutoTokenizer.from_pretrained(hf, use_fast=False)
    tok.add_special_tokens({"eos_token": "</s>", "bos_token": "<s>", "unk_token": "<unk>"})
    model = AutoModelForCausalLM.from_pretrained(hf, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
    
    params = sum(p.numel() for p in model.parameters())
    print(f"Params: {params:,}")
    
    ppl = compute_perplexity(model, tok)
    print(f"PPL: {ppl['perplexity']:.2f}")
    
    speed = measure_speed(model, tok)
    print(f"Speed: {speed:.1f} tok/s")
    
    mem = torch.cuda.max_memory_allocated() / 1e6
    result = ModelResults(name, params, ppl["perplexity"], speed, mem, time.time() - t0)
    
    del model, tok
    torch.cuda.empty_cache()
    return result

## 4. Run All Models

In [None]:
all_results = []
for name, cfg in MODELS.items():
    try:
        all_results.append(run_model(name, cfg))
    except Exception as e:
        print(f"FAILED {name}: {e}")
print(f"\nDone: {len(all_results)}/{len(MODELS)}")

## 5. Analysis

In [None]:
df = pd.DataFrame([r.to_dict() for r in all_results])
df

In [None]:
# Fit scaling law
log_n = np.log10(df["num_params"])
log_ppl = np.log10(df["perplexity"])
slope, intercept, r, p, se = stats.linregress(log_n, log_ppl)

print(f"\nSCALING LAW: PPL = {10**intercept:.0f} * N^({slope:.4f})")
print(f"R-squared = {r**2:.4f}")

ppl_fit = {"slope": slope, "intercept": intercept, "r_squared": r**2}

## 6. Plots

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Perplexity scaling
ax1.scatter(df["num_params"], df["perplexity"], s=150, c='steelblue', edgecolors='black', zorder=5)
for _, row in df.iterrows():
    ax1.annotate(row["model_name"].replace("MobileLLM-",""), 
                 (row["num_params"], row["perplexity"]), xytext=(8,0), 
                 textcoords="offset points", fontsize=11, fontweight='bold')

x_fit = np.logspace(np.log10(df["num_params"].min()*0.7), np.log10(df["num_params"].max()*1.3), 100)
y_fit = 10**intercept * x_fit**slope
ax1.plot(x_fit, y_fit, 'r--', lw=2, label=f"PPL = {10**intercept:.0f} x N^({slope:.3f})\nR² = {r**2:.4f}")

ax1.set_xscale("log")
ax1.set_yscale("log")
ax1.set_xlabel("Parameters", fontsize=12)
ax1.set_ylabel("Perplexity", fontsize=12)
ax1.set_title("MobileLLM Perplexity Scaling", fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Speed
colors = plt.cm.Blues(np.linspace(0.4, 0.9, len(df)))
bars = ax2.bar(df["model_name"].str.replace("MobileLLM-",""), df["tokens_per_second"], color=colors, edgecolor='black')
for bar, val in zip(bars, df["tokens_per_second"]):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height()+0.5, f'{val:.1f}', ha='center', fontweight='bold')
ax2.set_ylabel("Tokens/sec", fontsize=12)
ax2.set_xlabel("Model", fontsize=12)
ax2.set_title("Inference Speed (T4)", fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(RESULTS_DIR / "scaling.png", dpi=150)
plt.show()

## 7. Save Results

In [None]:
df.to_csv(RESULTS_DIR / "results.csv", index=False)
with open(RESULTS_DIR / "scaling_fit.json", "w") as f:
    json.dump(ppl_fit, f, indent=2)
print("Saved to results/")

In [None]:
report = f"""# MobileLLM Scaling Laws Report

## Scaling Law
```
PPL = {10**intercept:.0f} x N^({slope:.4f})
R² = {r**2:.4f}
```

## Results
| Model | Params | PPL | Tok/s | Memory |
|-------|--------|-----|-------|--------|
"""
for _, row in df.iterrows():
    report += f"| {row['model_name']} | {row['num_params']/1e6:.0f}M | {row['perplexity']:.2f} | {row['tokens_per_second']:.1f} | {row['peak_memory_mb']/1000:.1f}GB |\n"

with open(RESULTS_DIR / "report.md", "w") as f:
    f.write(report)
print(report)

In [None]:
print("="*50)
print(f"RESULT: PPL = {10**intercept:.0f} x N^({slope:.4f})")
print(f"R² = {r**2:.4f}")
print(f"2x params -> {100*(1-2**slope):.1f}% lower perplexity")
print("="*50)