# Notebook 3: Train Sparsity-Regularized Autoencoder

Loads hidden states from Google Drive (collected in Notebook 2),
trains the sparsity-regularized autoencoder, and extracts the
binary Jacobian pattern B(J_f).

**No LLM needed. Estimated time: ~30 minutes on T4.**

In [None]:
!git clone https://github.com/AUMEZAK/thoughtcomm.git 2>/dev/null || echo 'Already cloned'
%cd thoughtcomm
!pip install -e . -q

from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = '/content/drive/MyDrive/thoughtcomm_checkpoints/'

In [None]:
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
from configs.config import ThoughtCommConfig
from training.train_autoencoder import train_autoencoder
from training.jacobian_utils import compute_binary_pattern

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Choose model config (must match Notebook 2)
config = ThoughtCommConfig.for_qwen_0_6b(device=device)
# config = ThoughtCommConfig.for_phi4_mini(device=device)

MODEL_TAG = config.model_name.split('/')[-1]

In [None]:
# Load hidden states from Drive
math_path = os.path.join(SAVE_DIR, f'{MODEL_TAG}_math', 'hidden_states.pt')
data = torch.load(math_path, map_location='cpu')
H_train = data['H']
metadata = data['metadata']

print(f'H_train shape: {H_train.shape}')
print(f'Metadata entries: {len(metadata)}')
print(f'Expected n_h: {config.n_h}')
assert H_train.shape[1] == config.n_h, f'Shape mismatch: {H_train.shape[1]} != {config.n_h}'

In [None]:
# Train AE
ae_model, loss_history = train_autoencoder(H_train, config, verbose=True)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(loss_history['rec'])
axes[0].set_title('Reconstruction Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_yscale('log')

axes[1].plot(loss_history['jac'])
axes[1].set_title('Jacobian L1')
axes[1].set_xlabel('Epoch')

axes[2].plot(loss_history['total'])
axes[2].set_title('Total Loss')
axes[2].set_xlabel('Epoch')
axes[2].set_yscale('log')

plt.tight_layout()
plt.savefig('ae_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Compute B matrix
print('Computing full Jacobian and B matrix...')
with torch.no_grad():
    Z_sample = ae_model.encode(H_train[:64].float().to(device))

B = compute_binary_pattern(
    ae_model.decoder, Z_sample,
    threshold=config.jacobian_threshold,
    sub_batch=8, device=device
)
print(f'B shape: {B.shape}')
print(f'B sparsity: {1 - B.float().mean():.3f}')
print(f'B non-zero entries: {B.sum().item()} / {B.numel()}')

In [None]:
# Visualize B matrix
plt.figure(figsize=(10, 6))
plt.imshow(B.numpy(), cmap='Blues', aspect='auto')
plt.colorbar(label='Dependency')

# Mark agent boundaries
for k in range(1, config.num_agents):
    plt.axhline(y=k * config.hidden_size - 0.5, color='red', linewidth=2, linestyle='--')

plt.xlabel('Latent thought dimensions', fontsize=12)
plt.ylabel('Agent hidden state dimensions', fontsize=12)
plt.title('Binary Jacobian Pattern B(J_f)', fontsize=14)

# Add agent labels
for k in range(config.num_agents):
    y_pos = k * config.hidden_size + config.hidden_size // 2
    plt.text(-30, y_pos, f'Agent {k+1}', fontsize=11, va='center', fontweight='bold')

plt.savefig('b_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Analyze agreement structure
from pipeline.agreement import AgreementReweighter

reweighter = AgreementReweighter(B, config)
stats = reweighter.get_agreement_stats()
print('Agreement structure:')
for k, v in stats['agreement_distribution'].items():
    print(f'  {k}: {v} dimensions')
print(f'Per-agent relevant dims: {stats["per_agent_relevant"]}')

In [None]:
# Save trained AE and B matrix
ae_save_dir = os.path.join(SAVE_DIR, f'{MODEL_TAG}_ae')
os.makedirs(ae_save_dir, exist_ok=True)

torch.save(ae_model.state_dict(), os.path.join(ae_save_dir, 'ae_model.pt'))
torch.save(B, os.path.join(ae_save_dir, 'B_matrix.pt'))
torch.save(loss_history, os.path.join(ae_save_dir, 'loss_history.pt'))

print(f'Saved to {ae_save_dir}')