# Notebook 2: Collect Hidden States from Multi-Agent Debate

Runs multi-agent debate on 500 training problems and collects
concatenated hidden states H_t from all agents.

**Checkpoints every 50 examples to Google Drive for session recovery.**

Estimated time:
- Qwen-3-0.6B: ~2-4 hours on T4
- Phi-4-mini: ~4-8 hours on A100

In [None]:
# Setup
!git clone https://github.com/AUMEZAK/thoughtcomm.git 2>/dev/null || echo 'Already cloned'
%cd thoughtcomm
!pip install -e . -q

In [None]:
# Mount Google Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

SAVE_DIR = '/content/drive/MyDrive/thoughtcomm_checkpoints/'
import os
os.makedirs(SAVE_DIR, exist_ok=True)

In [None]:
import torch
from configs.config import ThoughtCommConfig
from models.model_utils import load_model_and_tokenizer
from pipeline.debate import MultiAgentDebate
from pipeline.hidden_state_collector import HiddenStateCollector
from data.math_data import load_math_dataset
from data.gsm8k_data import load_gsm8k_dataset
from utils.memory import print_memory_stats

print_memory_stats('Initial: ')

In [None]:
# Choose model (uncomment one)
config = ThoughtCommConfig.for_qwen_0_6b()
# config = ThoughtCommConfig.for_phi4_mini()

MODEL_TAG = config.model_name.split('/')[-1]
print(f'Model: {config.model_name}')
print(f'Hidden size: {config.hidden_size}')
print(f'n_h (3 agents): {config.n_h}')

In [None]:
# Load model
model, tokenizer = load_model_and_tokenizer(
    config.model_name, dtype=config.dtype
)
print_memory_stats('After model load: ')

In [None]:
# Load datasets
math_train, math_eval = load_math_dataset(
    num_train=config.num_train, num_eval=config.num_eval, level=config.math_level
)
gsm8k_train, gsm8k_eval = load_gsm8k_dataset(
    num_train=config.num_train, num_eval=config.num_eval
)
print(f'MATH train: {len(math_train)}, eval: {len(math_eval)}')
print(f'GSM8K train: {len(gsm8k_train)}, eval: {len(gsm8k_eval)}')

In [None]:
# Quick test: run debate on 1 problem
debate = MultiAgentDebate(model, tokenizer, config)
test_q = math_train[0]['question']
print(f'Test question: {test_q[:100]}...')

responses, hidden = debate.run_debate(test_q, extract_hidden=True)
print(f'\nRound 0 responses (first 200 chars each):')
for i, r in enumerate(responses[0]):
    print(f'  Agent {i}: {r[:200]}...')

print(f'\nHidden state shapes:')
for r in range(len(hidden)):
    for a in range(len(hidden[r])):
        print(f'  Round {r}, Agent {a}: {hidden[r][a].shape}')

print_memory_stats('After test: ')

In [None]:
# Collect hidden states for MATH training set
collector = HiddenStateCollector(debate, config)

math_save_dir = os.path.join(SAVE_DIR, f'{MODEL_TAG}_math')
H_math, meta_math = collector.collect(
    math_train, save_dir=math_save_dir, checkpoint_every=50
)
print(f'MATH hidden states: {H_math.shape}')

In [None]:
# Collect hidden states for GSM8K training set
gsm8k_save_dir = os.path.join(SAVE_DIR, f'{MODEL_TAG}_gsm8k')
H_gsm8k, meta_gsm8k = collector.collect(
    gsm8k_train, save_dir=gsm8k_save_dir, checkpoint_every=50
)
print(f'GSM8K hidden states: {H_gsm8k.shape}')

In [None]:
# Verify hidden states
for name, H in [('MATH', H_math), ('GSM8K', H_gsm8k)]:
    print(f'{name}:')
    print(f'  Shape: {H.shape}')
    print(f'  Mean: {H.mean():.4f}, Std: {H.std():.4f}')
    print(f'  NaN: {H.isnan().any()}, Inf: {H.isinf().any()}')
    # Cosine similarity between first two samples
    cos_sim = torch.nn.functional.cosine_similarity(H[0:1], H[1:2]).item()
    print(f'  Cosine sim (sample 0 vs 1): {cos_sim:.4f}')