In [None]:
# Save summary and push to GitHub
import json

summary_04 = {
    'model': config.model_name,
    'adapter_params': sum(p.numel() for p in adapter.parameters()),
    'final_loss': float(loss_hist[-1]),
    'agreement_weights': reweighter.w.data.tolist(),
    'num_epochs': len(loss_hist),
}

os.makedirs('results', exist_ok=True)
with open(f'results/04_adapter_summary_{MODEL_TAG}.json', 'w') as f:
    json.dump(summary_04, f, indent=2)

!cp adapter_training_loss.png results/adapter_training_loss_{MODEL_TAG}.png 2>/dev/null || true

!git pull --rebase 2>/dev/null || true
!git add results/
!git commit -m "Add Notebook 04 results: adapter training ({MODEL_TAG})"
!git push

print('Results pushed to GitHub!')
print('(Adapter weights are on Google Drive)')

# Notebook 4: Train Prefix Adapter

Loads the LLM, trained autoencoder, and hidden states, then trains
the prefix adapter and agreement weights.

**Estimated time: ~1-2 hours on T4/A100**

In [None]:
# Setup
import os
try:
    from google.colab import userdata
    GITHUB_TOKEN = userdata.get('GITHUB_TOKEN')
    REPO_URL = f'https://{GITHUB_TOKEN}@github.com/AUMEZAK/thoughtcomm.git'
except Exception:
    GITHUB_TOKEN = None
    REPO_URL = 'https://github.com/AUMEZAK/thoughtcomm.git'

!git clone {REPO_URL} thoughtcomm 2>/dev/null || echo 'Already cloned'
%cd thoughtcomm
!pip install -e . -q

!git config user.email "colab@thoughtcomm.dev"
!git config user.name "ThoughtComm Colab"

from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = '/content/drive/MyDrive/thoughtcomm_checkpoints/'

In [None]:
import torch
import os
from configs.config import ThoughtCommConfig
from models.model_utils import load_model_and_tokenizer
from models.autoencoder import SparsityRegularizedAE
from models.prefix_adapter import PrefixAdapter
from pipeline.agreement import AgreementReweighter
from training.train_adapter import train_adapter
from utils.memory import print_memory_stats

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = ThoughtCommConfig.for_qwen_0_6b(device=device)
# config = ThoughtCommConfig.for_phi4_mini(device=device)
MODEL_TAG = config.model_name.split('/')[-1]

In [None]:
# Load LLM
model, tokenizer = load_model_and_tokenizer(config.model_name, dtype=config.dtype)
print_memory_stats('After model load: ')

In [None]:
# Load trained AE and B matrix
ae_dir = os.path.join(SAVE_DIR, f'{MODEL_TAG}_ae')
ae_model = SparsityRegularizedAE(
    n_h=config.n_h, n_z=config.n_z,
    hidden_dim=config.ae_hidden, num_layers=config.ae_num_layers
)
ae_model.load_state_dict(torch.load(os.path.join(ae_dir, 'ae_model.pt'), map_location='cpu'))
ae_model = ae_model.to(device)

B = torch.load(os.path.join(ae_dir, 'B_matrix.pt'), map_location='cpu')
print(f'AE loaded. B shape: {B.shape}')

In [None]:
# Load hidden states
math_data = torch.load(
    os.path.join(SAVE_DIR, f'{MODEL_TAG}_math', 'hidden_states.pt'),
    map_location='cpu'
)
H_train = math_data['H']
metadata = math_data['metadata']
print(f'H_train: {H_train.shape}, metadata: {len(metadata)} entries')

In [None]:
# Create adapter and reweighter
reweighter = AgreementReweighter(B, config)
adapter = PrefixAdapter(
    n_z=config.n_z,
    hidden_size=config.hidden_size,
    prefix_length=config.prefix_length,
    adapter_hidden=config.adapter_hidden,
)

print(f'Adapter params: {sum(p.numel() for p in adapter.parameters()):,}')
print(f'Agreement weights: {reweighter.w}')
print(reweighter.get_agreement_stats())

In [None]:
# Train adapter
adapter, reweighter, loss_hist = train_adapter(
    model, tokenizer, ae_model, reweighter, adapter,
    H_train, metadata, config, verbose=True
)

In [None]:
# Plot adapter training loss
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.plot(loss_hist)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Adapter Training Loss (L_comm)')
plt.grid(True, alpha=0.3)
plt.savefig('adapter_training_loss.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'Learned agreement weights: {reweighter.w.data}')

In [None]:
# Save adapter and reweighter
adapter_dir = os.path.join(SAVE_DIR, f'{MODEL_TAG}_adapter')
os.makedirs(adapter_dir, exist_ok=True)

torch.save(adapter.state_dict(), os.path.join(adapter_dir, 'adapter.pt'))
torch.save(reweighter.state_dict(), os.path.join(adapter_dir, 'reweighter.pt'))
torch.save(loss_hist, os.path.join(adapter_dir, 'adapter_loss.pt'))

print(f'Saved to {adapter_dir}')

## Push Results to GitHub