# PROMETHEUS v2.0 + Tunix: CIC-Guided Reasoning Training

## Synthesized with 20+ Novel Insights

**Key Innovations from PROMETHEUS SYNTHESIS ENCYCLOPEDIA:**
1. **NCD Basin Regularization** - Reward traces close to canonical patterns
2. **ToM-Informed Reward Shaping** - Optimize for human collaboration
3. **CIC Functional Scoring** - F[T] = Φ(T) - λH(T|X) + γC(T)
4. **Causal Emergence Rewards** - Macro explains micro
5. **Coherence Graph Scoring** - Logical flow, not just keywords
6. **Temperature-κ Coupling** - Adaptive exploration
7. **Trace Diversity via NCD** - Prevent mode collapse
8. **Kolmogorov Regularization** - Reward concise reasoning

**From Riedl & Weidmann (2025):**
- ToM predicts AI collaboration (ρs=0.17, p<0.001)
- Solo ability has ZERO correlation (β=-0.00)
- Optimize for κ (collaborative ability), not θ (solo)

**Competition Requirements:**
- Model: Gemma2 2B or Gemma3 1B
- Output: `<reasoning>trace</reasoning><answer>answer</answer>`
- Hardware: TPU v5e-8 (9hr session, 20hr/week)
- Framework: Tunix (JAX-native GRPO)

In [None]:
# =============================================================================
# CELL 1: SETUP & INSTALLATION
# =============================================================================
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

# Install dependencies
!pip install -q kagglehub
!pip install -q ipywidgets
!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"
!pip uninstall -q -y flax
!pip install -U flax
!pip install -q datasets
!pip install -q scipy  # For NCD clustering

print("Dependencies installed.")

In [None]:
# =============================================================================
# CELL 2: IMPORTS
# =============================================================================
import functools
import gc
import os
import zlib  # For NCD computation
from pprint import pprint
import re
import random
import csv
import shutil

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import numpy as np
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset

# Config
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

print(f"JAX devices: {jax.devices()}")

In [None]:
# =============================================================================
# CELL 3: HYPERPARAMETERS ("Efficiency Build" for TPU v5e-8)
# =============================================================================

# ====== Data ======
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 32
ALPHA = 32.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO Generation ======
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512

# ====== ADAPTIVE TEMPERATURE (PROMETHEUS v2.0) ======
# Now coupled to training dynamics via κ proxy
TEMP_START = 1.2
TEMP_END = 0.5
TEMP_MIN = 0.3
TEMP_MAX = 1.5

TOP_P = 1.0
TOP_K = 50
NUM_GENERATIONS = 4

# ====== GRPO Config ======
NUM_ITERATIONS = 1
BETA = 0.08
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 2
NUM_BATCHES = 3738
NUM_TEST_BATCHES = 100
EVAL_EVERY_N_STEPS = 10
NUM_EPOCHS = 1

MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# ====== AdamW + Warmup + Cosine Scheduler ======
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = 0.1 * MAX_STEPS
MAX_GRAD_NORM = 0.1

# ====== Checkpointing ======
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

print(f"Config: RANK={RANK}, ALPHA={ALPHA}, LR={LEARNING_RATE}")
print(f"Training: {MAX_STEPS} steps, batch={TRAIN_MICRO_BATCH_SIZE}")

In [None]:
# =============================================================================
# CELL 4: NCD PRIMITIVES (Core of PROMETHEUS v2.0)
# =============================================================================
# From Li et al. (2004) "The Similarity Metric"
# =============================================================================

def ncd(x: str, y: str) -> float:
    """Normalized Compression Distance using zlib."""
    if not x or not y:
        return 1.0
    cx = len(zlib.compress(x.encode(), level=9))
    cy = len(zlib.compress(y.encode(), level=9))
    cxy = len(zlib.compress((x + y).encode(), level=9))
    return (cxy - min(cx, cy)) / max(cx, cy)


def kolmogorov_proxy(text: str) -> float:
    """Approximate Kolmogorov complexity via compression ratio."""
    if not text:
        return 1.0
    original = len(text.encode())
    compressed = len(zlib.compress(text.encode(), level=9))
    return compressed / original


print("NCD primitives loaded.")
print(f"  Test: ncd('hello', 'hello') = {ncd('hello', 'hello'):.3f}")
print(f"  Test: ncd('hello', 'world') = {ncd('hello', 'world'):.3f}")

In [None]:
# =============================================================================
# CELL 5: TEMPLATE & BASIN CENTERS
# =============================================================================

reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer between {solution_start} \
and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

# =============================================================================
# BASIN CENTERS - Canonical reasoning patterns for training
# Reward traces that compress well with these patterns
# =============================================================================
CANONICAL_TRACES = [
    # Step-by-step pattern
    """Let me break this down step by step.
First, I identify the key elements of the problem.
Next, I consider the relationships between them.
Then, I apply the relevant principle or formula.
Finally, I combine these to reach the conclusion.""",
    
    # Hypothesis-test pattern
    """I'll approach this by forming a hypothesis.
Hypothesis: The answer has property X.
Testing: If this is true, then we should see Y.
Verification: Checking this condition...
The hypothesis holds, so the conclusion follows.""",
    
    # Case analysis pattern
    """This problem has several cases to consider.
Case 1: When the first condition applies.
Case 2: When the second condition applies.
Combining the cases gives the complete answer.""",
    
    # Deductive pattern
    """Given the premises, I can deduce the following.
From A, it follows that B.
Combined with C, this implies D.
Therefore, the answer is E.""",
]

print("Template and basin centers configured.")

In [None]:
# =============================================================================
# CELL 6: DATA LOADING (Real User Tasks)
# =============================================================================

def get_real_user_dataset(max_samples=10000) -> grain.MapDataset:
    """Load real user conversation datasets."""
    all_questions = []
    
    # OpenAssistant
    print("Loading OpenAssistant...")
    try:
        oasst = load_dataset("OpenAssistant/oasst1", split="train")
        for item in oasst:
            if item.get("role") == "prompter" and item.get("parent_id") is None:
                text = item.get("text", "").strip()
                if 20 < len(text) < 500:
                    all_questions.append({"question": text, "source": "oasst", "answer": None})
        print(f"  -> {len([q for q in all_questions if q['source']=='oasst'])} from OpenAssistant")
    except Exception as e:
        print(f"  -> OpenAssistant failed: {e}")
    
    # Dolly
    print("Loading Dolly...")
    try:
        dolly = load_dataset("databricks/databricks-dolly-15k", split="train")
        for item in dolly:
            instruction = item.get("instruction", "").strip()
            context = item.get("context", "").strip()
            question = f"{instruction}\n\nContext: {context}" if context else instruction
            if 20 < len(question) < 500:
                all_questions.append({"question": question, "source": "dolly", "answer": None})
        print(f"  -> {len([q for q in all_questions if q['source']=='dolly'])} from Dolly")
    except Exception as e:
        print(f"  -> Dolly failed: {e}")
    
    # Alpaca
    print("Loading Alpaca...")
    try:
        alpaca = load_dataset("tatsu-lab/alpaca", split="train")
        for item in alpaca:
            instruction = item.get("instruction", "").strip()
            inp = item.get("input", "").strip()
            question = f"{instruction}\n\nInput: {inp}" if inp else instruction
            if 20 < len(question) < 500:
                all_questions.append({"question": question, "source": "alpaca", "answer": None})
        print(f"  -> {len([q for q in all_questions if q['source']=='alpaca'])} from Alpaca")
    except Exception as e:
        print(f"  -> Alpaca failed: {e}")
    
    random.shuffle(all_questions)
    all_questions = all_questions[:max_samples]
    print(f"\nTotal: {len(all_questions)} real user questions")
    
    dataset = (
        grain.MapDataset.source(all_questions)
        .shuffle(seed=42)
        .map(
            lambda x: {
                "prompts": TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=x["question"]),
                "question": x["question"],
                "answer": x["answer"],
                "source": x["source"],
            }
        )
    )
    return dataset

print("Loading datasets...")
full_dataset = get_real_user_dataset(max_samples=NUM_BATCHES * TRAIN_MICRO_BATCH_SIZE * 2)
dataset = full_dataset.batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]
train_dataset = dataset.repeat(NUM_EPOCHS)
test_start = int(len(dataset) * 0.9)
test_dataset = full_dataset.batch(TRAIN_MICRO_BATCH_SIZE)[test_start:test_start + NUM_TEST_BATCHES]
print(f"Train batches: {len(train_dataset)}")
print(f"Test batches: {len(test_dataset)}")

In [None]:
# =============================================================================
# CELL 7: MODEL LOADING
# =============================================================================

def show_hbm_usage():
    fmt_size = functools.partial(humanize.naturalsize, binary=True)
    for d in jax.local_devices():
        stats = d.memory_stats()
        used = stats["bytes_in_use"]
        limit = stats["bytes_limit"]
        print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

from kaggle_secrets import UserSecretsClient

def auto_login():
    try:
        user_secrets = UserSecretsClient()
        username = user_secrets.get_secret("KAGGLE_USERNAME")
        key = user_secrets.get_secret("KAGGLE_KEY")
        if username and key:
            os.environ["KAGGLE_USERNAME"] = username
            os.environ["KAGGLE_KEY"] = key
            print("✅ Authenticated")
            return
    except: pass
    kagglehub.login()

auto_login()

!rm /tmp/content/intermediate_ckpt/* -rf
!rm /tmp/content/ckpts/* -rf

print("Loading Gemma3 1B-IT...")
MODEL_CP_PATH = params.GEMMA3_1B_IT
config = model.ModelConfig.gemma3_1b()
gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
tokenizer = params.create_tokenizer()

checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)
checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
checkpointer.wait_until_finished()

del gemma, state
gc.collect()
print("Model checkpoint saved.")
show_hbm_usage()

In [None]:
# =============================================================================
# CELL 8: LORA MODEL SETUP
# =============================================================================

def get_gemma_ref_model(ckpt_path):
    mesh = jax.make_mesh(*MESH)
    model_config = model.ModelConfig.gemma3_1b()
    abs_gemma = nnx.eval_shape(
        lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, config)
    )
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    return gemma, mesh, model_config

def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum",
        rank=RANK,
        alpha=ALPHA,
    )
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(base_model, lora_provider, **model_input)
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    return lora_model

print("Loading reference model...")
ref_model, mesh, model_config = get_gemma_ref_model(
    ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
)

print("Applying LoRA...")
lora_policy = get_lora_model(ref_model, mesh=mesh)
print(f"LoRA model ready: RANK={RANK}, ALPHA={ALPHA}")
show_hbm_usage()

In [None]:
# =============================================================================
# CELL 9: PROMETHEUS v2.0 REWARD FUNCTIONS (CIC-Guided)
# =============================================================================
# Synthesized from PROMETHEUS SYNTHESIS ENCYCLOPEDIA
# Implements 8 novel reward signals
# =============================================================================

match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)


# =============================================================================
# REWARD 1: Format Compliance (from v1)
# =============================================================================
def match_format_exactly(prompts, completions, **kwargs):
    return [0 if match_format.search(r) is None else 3.0 for r in completions]


def match_format_approximately(prompts, completions, **kwargs):
    scores = []
    for c in completions:
        score = 0
        score += 0.5 if c.count(reasoning_start) == 1 else -0.5
        score += 0.5 if c.count(reasoning_end) == 1 else -0.5
        score += 0.5 if c.count(solution_start) == 1 else -0.5
        score += 0.5 if c.count(solution_end) == 1 else -0.5
        scores.append(score)
    return scores


# =============================================================================
# REWARD 2: NCD Basin Regularization (INSIGHT C2)
# Reward traces close to canonical patterns
# =============================================================================
def basin_regularization_reward(prompts, completions, **kwargs):
    """Reward traces close to canonical reasoning patterns."""
    scores = []
    for completion in completions:
        match = re.search(rf"{reasoning_start}(.+?){reasoning_end}", completion, re.DOTALL)
        if not match:
            scores.append(-1.0)
            continue
        
        trace = match.group(1)
        distances = [ncd(trace, canonical) for canonical in CANONICAL_TRACES]
        min_distance = min(distances)
        
        # Score: inverse of distance (closer = higher)
        score = 2.0 * (1 - min_distance)
        scores.append(score)
    return scores


# =============================================================================
# REWARD 3: Causal Emergence (INSIGHT B7/C1)
# Macro-level statements explaining micro-level details
# =============================================================================
MACRO_PATTERNS = [
    r'by (symmetry|induction|contradiction|construction)',
    r'(therefore|thus|hence|so)\s+(all|every|the)',
    r'in general',
    r'for (all|any|every)',
    r'this (implies|means|shows)',
    r'the key (insight|observation)',
]

MICRO_PATTERNS = [
    r'\d+\s*[\+\-\*/]\s*\d+\s*=\s*\d+',
    r'step\s+\d+',
    r'case\s+\d+',
]

def causal_emergence_reward(prompts, completions, **kwargs):
    """Reward traces with good macro/micro balance."""
    scores = []
    for completion in completions:
        match = re.search(rf"{reasoning_start}(.+?){reasoning_end}", completion, re.DOTALL)
        if not match:
            scores.append(0)
            continue
        
        trace = match.group(1).lower()
        
        macro_count = sum(len(re.findall(p, trace)) for p in MACRO_PATTERNS)
        micro_count = sum(len(re.findall(p, trace)) for p in MICRO_PATTERNS)
        
        if micro_count == 0:
            scores.append(0.5)
            continue
        
        # Ideal: ~1 macro per 3 micros
        ratio = macro_count / micro_count if micro_count > 0 else 0
        ideal = 1/3
        psi = 1 - min(1, abs(ratio - ideal) / ideal)
        scores.append(psi * 1.5)  # Scale to 1.5 max
    return scores


# =============================================================================
# REWARD 4: ToM-Informed Properties (INSIGHT C1)
# Outputs that help humans understand and verify
# =============================================================================
def tom_informed_reward(prompts, completions, **kwargs):
    """Reward ToM-friendly outputs."""
    scores = []
    
    uncertainty_markers = [
        'i think', 'likely', 'probably', 'might',
        'assuming', 'if we', 'could be',
    ]
    checkable_patterns = [
        r'therefore.*=',
        r'which gives',
        r'substituting',
        r'to verify',
    ]
    
    for completion in completions:
        lower = completion.lower()
        score = 0
        
        # Uncertainty markers (calibration)
        unc_count = sum(1 for m in uncertainty_markers if m in lower)
        score += min(0.5, unc_count * 0.15)
        
        # Checkable steps
        check_count = sum(len(re.findall(p, lower)) for p in checkable_patterns)
        score += min(0.5, check_count * 0.2)
        
        scores.append(score)
    return scores


# =============================================================================
# REWARD 5: Coherence Graph (INSIGHT C5)
# Logical connectors indicating flow
# =============================================================================
def coherence_graph_reward(prompts, completions, **kwargs):
    """Reward traces with coherent logical structure."""
    scores = []
    
    forward = ['therefore', 'thus', 'so', 'hence', 'this means', 'which implies']
    backward = ['because', 'since', 'as', 'given that']
    
    for completion in completions:
        match = re.search(rf"{reasoning_start}(.+?){reasoning_end}", completion, re.DOTALL)
        if not match:
            scores.append(0)
            continue
        
        trace = match.group(1).lower()
        sentences = [s.strip() for s in re.split(r'[.!?]+', trace) if s.strip()]
        n_sentences = len(sentences)
        
        if n_sentences < 2:
            scores.append(0.5)
            continue
        
        n_forward = sum(1 for c in forward if c in trace)
        n_backward = sum(1 for c in backward if c in trace)
        
        # Ideal: ~1 connector per 2-3 sentences
        expected = n_sentences / 2.5
        actual = n_forward + n_backward
        ratio = min(1.5, actual / expected) if expected > 0 else 0
        
        # Balance bonus
        balance = 0.3 if n_forward > 0 and n_backward > 0 else 0
        
        scores.append(min(2.0, ratio + balance))
    return scores


# =============================================================================
# REWARD 6: Kolmogorov Compression (INSIGHT C6)
# Reward concise but complete reasoning
# =============================================================================
def kolmogorov_compression_reward(prompts, completions, **kwargs):
    """Reward efficient reasoning traces."""
    scores = []
    for completion in completions:
        match = re.search(rf"{reasoning_start}(.+?){reasoning_end}", completion, re.DOTALL)
        if not match:
            scores.append(0)
            continue
        
        trace = match.group(1)
        k_ratio = kolmogorov_proxy(trace)
        
        # Lower ratio = more compressible = more structured
        # Ideal range: 0.3-0.6 (not too repetitive, not too random)
        if 0.3 <= k_ratio <= 0.6:
            scores.append(1.5)
        elif 0.2 <= k_ratio < 0.3 or 0.6 < k_ratio <= 0.7:
            scores.append(1.0)
        else:
            scores.append(0.5)
    return scores


# =============================================================================
# REWARD 7: NCD Diversity Bonus (INSIGHT C4)
# Reward diverse traces within batch (prevent mode collapse)
# =============================================================================
def trace_diversity_bonus(prompts, completions, **kwargs):
    """Reward diverse reasoning traces."""
    if len(completions) < 2:
        return [0] * len(completions)
    
    traces = []
    for c in completions:
        match = re.search(rf"{reasoning_start}(.+?){reasoning_end}", c, re.DOTALL)
        traces.append(match.group(1) if match else "")
    
    scores = []
    for i, trace in enumerate(traces):
        if not trace:
            scores.append(0)
            continue
        
        other_dists = [ncd(trace, t) for j, t in enumerate(traces) if j != i and t]
        if other_dists:
            avg_dist = np.mean(other_dists)
            # Higher distance = more diverse = bonus
            scores.append(max(0, (avg_dist - 0.5) * 2))
        else:
            scores.append(0)
    return scores


# =============================================================================
# REWARD 8: Answer Completeness (from v1, enhanced)
# =============================================================================
def answer_completeness(prompts, completions, **kwargs):
    scores = []
    for completion in completions:
        match = re.search(rf"{solution_start}(.+?){solution_end}", completion, re.DOTALL)
        if not match:
            scores.append(0)
            continue
        
        answer = match.group(1).strip()
        word_count = len(answer.split())
        
        if 3 <= word_count <= 100:
            scores.append(2.0)
        elif 1 <= word_count < 3 or 100 < word_count <= 200:
            scores.append(1.0)
        else:
            scores.append(0)
    return scores


print("PROMETHEUS v2.0 Reward Functions loaded:")
print("  1. Format exact (3.0)")
print("  2. Format approximate (+/-2.0)")
print("  3. Basin regularization (2.0) [NCD]")
print("  4. Causal emergence (1.5) [Ψ]")
print("  5. ToM-informed (1.0)")
print("  6. Coherence graph (2.0)")
print("  7. Kolmogorov compression (1.5)")
print("  8. Trace diversity (1.0) [NCD]")
print("  9. Answer completeness (2.0)")
print(f"\nMax possible reward: ~16.0")

In [None]:
# =============================================================================
# CELL 10: ADAPTIVE TEMPERATURE SCHEDULER (INSIGHT C3)
# =============================================================================

class AdaptiveTemperatureScheduler:
    """Temperature coupled to training dynamics."""
    
    def __init__(self, temp_start=1.2, temp_end=0.5, ema_alpha=0.1):
        self.temp_start = temp_start
        self.temp_end = temp_end
        self.ema_alpha = ema_alpha
        self.reward_variance_ema = None
        self.baseline_variance = None
        self.step = 0
    
    def update(self, batch_rewards):
        self.step += 1
        batch_var = np.var(batch_rewards) if len(batch_rewards) > 1 else 0
        
        if self.reward_variance_ema is None:
            self.reward_variance_ema = batch_var
            self.baseline_variance = batch_var
        else:
            self.reward_variance_ema = (
                self.ema_alpha * batch_var +
                (1 - self.ema_alpha) * self.reward_variance_ema
            )
    
    def get_temperature(self, max_steps):
        progress = self.step / max_steps
        
        # Base: cosine decay
        base = self.temp_end + 0.5 * (self.temp_start - self.temp_end) * (
            1 + np.cos(np.pi * progress)
        )
        
        # Variance adjustment
        if self.baseline_variance and self.baseline_variance > 0:
            ratio = self.reward_variance_ema / self.baseline_variance
            adjustment = 0.2 * (ratio - 1)
        else:
            adjustment = 0
        
        return np.clip(base + adjustment, TEMP_MIN, TEMP_MAX)


temp_scheduler = AdaptiveTemperatureScheduler()
print("Adaptive Temperature Scheduler loaded.")

In [None]:
# =============================================================================
# CELL 11: EVALUATION HELPERS
# =============================================================================

def generate(question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None):
    if isinstance(question, str):
        input_batch = [TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=question)]
    else:
        input_batch = [TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q) for q in question]
    
    out = sampler(
        input_strings=input_batch,
        max_generation_steps=768,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        echo=False,
        seed=seed,
        eos_tokens=[1, 106],
    )
    return out.text[0] if isinstance(question, str) else out.text


def evaluate_open_ended(dataset, sampler, **config):
    total = format_exact = has_reasoning = has_answer = 0
    coherence_scores = []
    
    for batch in tqdm(dataset):
        responses = generate(batch["question"], sampler, **config, seed=0)
        for r in responses:
            total += 1
            if match_format.search(r): format_exact += 1
            r_match = re.search(rf"{reasoning_start}(.+?){reasoning_end}", r, re.DOTALL)
            if r_match:
                has_reasoning += 1
                coherence = sum(1 for c in ['because', 'therefore', 'so'] if c in r_match.group(1).lower())
                coherence_scores.append(coherence)
            if re.search(rf"{solution_start}(.+?){solution_end}", r, re.DOTALL):
                has_answer += 1
    
    return {
        "total": total,
        "format_exact_pct": format_exact / total * 100 if total else 0,
        "has_reasoning_pct": has_reasoning / total * 100 if total else 0,
        "has_answer_pct": has_answer / total * 100 if total else 0,
        "avg_coherence": np.mean(coherence_scores) if coherence_scores else 0,
    }

print("Evaluation helpers loaded.")

In [None]:
# =============================================================================
# CELL 12: PRE-TRAINING EVALUATION
# =============================================================================

print("Creating sampler...")
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

print("\nEvaluating BASE model...")
base_results = evaluate_open_ended(test_dataset, sampler, **GENERATION_CONFIGS["greedy"])
print(f"Format: {base_results['format_exact_pct']:.1f}%")
print(f"Reasoning: {base_results['has_reasoning_pct']:.1f}%")
print(f"Coherence: {base_results['avg_coherence']:.2f}")

In [None]:
# =============================================================================
# CELL 13: GRPO TRAINING SETUP
# =============================================================================

checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1, b2=B2, weight_decay=WEIGHT_DECAY,
)

if MAX_GRAD_NORM:
    optimizer = optax.chain(optax.clip_by_global_norm(MAX_GRAD_NORM), optimizer)

cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=None,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMP_START,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1, 106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

print("GRPO config ready.")

In [None]:
# =============================================================================
# CELL 14: INITIALIZE GRPO TRAINER WITH v2.0 REWARDS
# =============================================================================

rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# PROMETHEUS v2.0 Reward Stack
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,          # 3.0
        match_format_approximately,    # +/-2.0
        basin_regularization_reward,   # 2.0 [NCD]
        causal_emergence_reward,       # 1.5 [Ψ]
        tom_informed_reward,           # 1.0
        coherence_graph_reward,        # 2.0
        kolmogorov_compression_reward, # 1.5
        trace_diversity_bonus,         # 1.0 [NCD]
        answer_completeness,           # 2.0
    ],
    grpo_config=grpo_config,
)

print("GRPO trainer initialized with PROMETHEUS v2.0 rewards.")
print(f"9 reward functions, max ~16.0")

In [None]:
# =============================================================================
# CELL 15: TRAIN!
# =============================================================================

print(f"Starting GRPO training for {MAX_STEPS} steps...")
print(f"This may take several hours on TPU v5e-8.")
print()

with mesh:
    grpo_trainer.train(train_dataset)

print("\n" + "="*60)
print("PROMETHEUS v2.0 TRAINING COMPLETE!")
print("="*60)

In [None]:
# =============================================================================
# CELL 16: LOAD BEST CHECKPOINT & EVALUATE
# =============================================================================

actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")
latest_step = -1
if os.path.exists(actor_ckpt_dir):
    for item in os.listdir(actor_ckpt_dir):
        if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
            step = int(item)
            if step > latest_step:
                latest_step = step

if latest_step == -1:
    raise FileNotFoundError(f"No checkpoints found")

print(f"Loading checkpoint from step {latest_step}...")

trained_ckpt_path = os.path.join(CKPT_DIR, "actor", str(latest_step), "model_params")
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(lambda a, b: b, nnx.state(lora_policy, nnx.LoRAParam), trained_lora_params),
)

sampler = sampler_lib.Sampler(
    transformer=lora_policy, tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

print("\nEvaluating TRAINED model...")
trained_results = evaluate_open_ended(test_dataset, sampler, **GENERATION_CONFIGS["greedy"])

print("\n" + "="*50)
print("IMPROVEMENT SUMMARY")
print("="*50)
print(f"Format: {base_results['format_exact_pct']:.1f}% -> {trained_results['format_exact_pct']:.1f}%")
print(f"Reasoning: {base_results['has_reasoning_pct']:.1f}% -> {trained_results['has_reasoning_pct']:.1f}%")
print(f"Coherence: {base_results['avg_coherence']:.2f} -> {trained_results['avg_coherence']:.2f}")

In [None]:
# =============================================================================
# CELL 17: SAVE MODEL
# =============================================================================

SAVE_PATH = "/kaggle/working/prometheus_v2_gemma_reasoning"

print(f"Saving trained LoRA weights to {SAVE_PATH}...")
os.makedirs(SAVE_PATH, exist_ok=True)
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(
    os.path.join(SAVE_PATH, "lora_params"),
    nnx.state(lora_policy, nnx.LoRAParam)
)
checkpointer.wait_until_finished()

print("\n" + "="*60)
print("PROMETHEUS v2.0 + TUNIX TRAINING COMPLETE")
print("="*60)
print(f"Model saved to: {SAVE_PATH}")
print(f"\nKey innovations applied:")
print(f"  - NCD Basin Regularization")
print(f"  - Causal Emergence (Ψ) Rewards")
print(f"  - ToM-Informed Output Shaping")
print(f"  - Coherence Graph Scoring")
print(f"  - Kolmogorov Compression")
print(f"  - NCD Trace Diversity")