# Training Probes on NDIF: Remote Training Loops

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week6/probe_training_ndif.ipynb)

This notebook demonstrates **remote probe training** using [nnsight](https://nnsight.net/) and [NDIF](https://ndif.us/). The **entire training loop** - all epochs, all batches - runs on NDIF using session-based execution.

**Why Remote Training?**
- No need to download large activation tensors
- All computation happens on NDIF GPUs
- Essential for training probes on very large models

**Session-Based Training:**
Using `model.session()` and `session.iter()`, we run the complete training loop remotely:
- Probe weights are created and updated on NDIF
- Optimizer runs on NDIF
- Only the final trained weights are downloaded

We'll train probes to detect **puns** from Llama 3 70B activations, entirely remotely.

## References
- [nnsight documentation](https://nnsight.net/)
- [NDIF - National Deep Inference Fabric](https://ndif.us/)
- [Probing Classifiers](https://arxiv.org/abs/1909.03368) - Hewitt & Liang

## Setup

In [None]:
!pip install -q nnsight

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import nnsight
from nnsight import LanguageModel, CONFIG
from torch.utils.data import DataLoader
from collections import defaultdict

# Configure NDIF API key from Colab secrets
try:
    from google.colab import userdata
    CONFIG.set_default_api_key(userdata.get('NDIF_API'))
except:
    pass  # Not in Colab or secret not set

# Use remote=True to run on NDIF
REMOTE = True

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Load Model

In [None]:
model = LanguageModel("meta-llama/Meta-Llama-3-70B", device_map="auto")

print(f"Model: {model.config._name_or_path}")
print(f"Layers: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")

## Part 1: Prepare Dataset

We need paired pun and non-pun examples for probe training.

In [None]:
# Training data: puns and non-puns
train_puns = [
    "Why do electricians make good swimmers? Because they know the current.",
    "Why did the banker break up with his girlfriend? He lost interest.",
    "What do you call a fish without eyes? A fsh.",
    "Why don't scientists trust atoms? Because they make up everything.",
    "What did the ocean say to the beach? Nothing, it just waved.",
    "Why do cows wear bells? Because their horns don't work.",
    "I used to hate facial hair, but then it grew on me.",
    "Why did the scarecrow win an award? He was outstanding in his field.",
    "What do you call a bear with no teeth? A gummy bear.",
    "Why can't a bicycle stand on its own? It's two tired.",
    "What do you call a fake noodle? An impasta.",
    "Why did the math book look so sad? It had too many problems.",
    "What do you call a sleeping dinosaur? A dino-snore.",
    "Why did the golfer bring two pairs of pants? In case he got a hole in one.",
    "What did the grape say when stepped on? Nothing, it let out a little wine.",
]

train_nonpuns = [
    "Why do electricians wear rubber gloves? To protect from electrical shocks.",
    "Why did the banker open a savings account? To manage his finances better.",
    "What do you call a fish that lives in rivers? A freshwater fish.",
    "Why don't scientists make assumptions? They need empirical evidence.",
    "What did the ocean look like today? Calm and peaceful.",
    "Why do cows produce milk? To nourish their calves.",
    "I used to avoid exercise, but then I started a routine.",
    "Why did the scarecrow need repairs? It was damaged by weather.",
    "What do you call a bear in winter? A hibernating animal.",
    "Why can't a bicycle go uphill easily? The gradient is steep.",
    "What do you call a fresh noodle? Al dente pasta.",
    "Why did the math book look worn? It had been used for years.",
    "What do you call a prehistoric reptile? A dinosaur.",
    "Why did the golfer check the weather? To plan his game.",
    "What did the grape taste like? Sweet and refreshing.",
]

# Test data (held out)
test_puns = [
    "Why do programmers prefer dark mode? Because light attracts bugs.",
    "What do you call a lazy kangaroo? A pouch potato.",
    "Why did the stadium get hot? All the fans left.",
    "What do you call a pig that does karate? A pork chop.",
    "Why did the coffee file a police report? It got mugged.",
]

test_nonpuns = [
    "Why do programmers use version control? To track code changes.",
    "What do you call a baby kangaroo? A joey.",
    "Why did the stadium close early? For maintenance.",
    "What do you call a pig on a farm? Livestock.",
    "Why did the coffee taste bitter? It was over-extracted.",
]

print(f"Training: {len(train_puns)} puns, {len(train_nonpuns)} non-puns")
print(f"Test: {len(test_puns)} puns, {len(test_nonpuns)} non-puns")

## Part 2: Remote Probe Training with Sessions

Using nnsight's `session` and `iter` APIs, we run the **entire training loop** on NDIF:

1. `model.session()` creates a persistent remote context
2. Probe weights and optimizer are created on NDIF
3. `session.iter()` iterates over epochs and batches remotely
4. Only the final trained weights are downloaded

### Data Loading Options

**Option 1: In-memory dataset** (used in this notebook)
- Data is serialized and sent to NDIF as part of the computation graph
- Good for small datasets (< 1000 examples)

**Option 2: Load from GitHub raw URL**
- Reference CSV/JSON files directly from a public GitHub repo
- NDIF fetches the data server-side

```python
from datasets import load_dataset

# Load CSV directly from GitHub raw URL
github_url = "https://raw.githubusercontent.com/your-org/your-repo/main/data/puns.csv"
dataset = load_dataset("csv", data_files={"train": github_url})

# Or JSON
dataset = load_dataset("json", data_files={
    "train": "https://raw.githubusercontent.com/your-org/repo/main/train.json",
    "test": "https://raw.githubusercontent.com/your-org/repo/main/test.json"
})
```

**Option 3: HuggingFace Hub dataset**
- Upload your dataset to HuggingFace, NDIF loads it server-side

```python
# Upload local dataset to HuggingFace Hub
from datasets import Dataset
dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
dataset.push_to_hub("your-username/pun-dataset", private=True)

# Then NDIF loads it directly:
dataset = load_dataset("your-username/pun-dataset")
```

In [None]:
class RemoteProbeTrainer:
    """
    Train a linear probe entirely on NDIF using nnsight sessions.
    
    The entire training loop (all epochs) runs remotely.
    Only the final trained weights are downloaded.
    """
    
    def __init__(self, model, layer_idx, hidden_size, lr=0.01):
        self.model = model
        self.layer_idx = layer_idx
        self.hidden_size = hidden_size
        self.lr = lr
        
        # Probe weights (will be created remotely during training)
        self.probe_weight = None
        self.probe_bias = None
    
    def train(self, texts, labels, n_epochs=10, batch_size=5, remote=True):
        """
        Train probe with ENTIRE loop on NDIF using session + iter.
        
        Args:
            texts: List of input texts
            labels: List of labels (0 or 1)
            n_epochs: Number of training epochs
            batch_size: Examples per batch
            remote: Whether to run on NDIF
            
        Returns:
            losses: List of per-epoch losses
        """
        # Prepare dataset as list of (text, label) tuples
        dataset = list(zip(texts, labels))
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        epoch_indices = list(range(n_epochs))
        
        print(f"Training {n_epochs} epochs on {'NDIF' if remote else 'local'}...")
        
        # Run ENTIRE training loop on NDIF
        with self.model.session(remote=remote) as session:
            
            # Initialize probe weights remotely
            probe_w = torch.nn.Parameter(torch.randn(1, self.hidden_size) * 0.01)
            probe_b = torch.nn.Parameter(torch.zeros(1))
            
            # Optimizer runs on remote device
            optimizer = torch.optim.AdamW([probe_w, probe_b], lr=self.lr)
            
            # Storage for losses
            all_losses = []
            
            # Iterate over epochs using session.iter
            with session.iter(epoch_indices) as epoch:
                epoch_loss = torch.tensor(0.0)
                n_batches = 0
                
                # Iterate over batches
                with session.iter(dataloader) as batch:
                    batch_texts = [item[0] for item in batch]
                    batch_labels = torch.tensor([item[1] for item in batch], dtype=torch.float32)
                    
                    # Forward pass through model
                    with self.model.trace(batch_texts):
                        # Get hidden states at target layer, last token
                        hidden_states = self.model.model.layers[self.layer_idx].output[0]
                        last_hidden = hidden_states[:, -1, :]  # (batch, hidden)
                        
                        # Move probe to correct device
                        w = probe_w.to(last_hidden.device)
                        b = probe_b.to(last_hidden.device)
                        
                        # Probe forward: logits = hidden @ W^T + b
                        logits = (last_hidden @ w.T).squeeze(-1) + b
                        
                        # Binary cross-entropy loss
                        labels_dev = batch_labels.to(logits.device)
                        loss = torch.nn.functional.binary_cross_entropy_with_logits(
                            logits, labels_dev
                        )
                        loss.backward()
                        
                        epoch_loss = epoch_loss.to(loss.device) + loss
                        n_batches += 1
                    
                    # Update parameters
                    optimizer.step()
                    optimizer.zero_grad()
                
                # Log epoch loss
                avg_loss = epoch_loss / max(n_batches, 1)
                nnsight.log("Epoch loss:", avg_loss)
                all_losses.append(avg_loss)
            
            # Save final weights - THIS is what we download
            final_w = probe_w.detach().clone().save()
            final_b = probe_b.detach().clone().save()
            final_losses = [l.save() for l in all_losses]
        
        # Store trained weights locally
        self.probe_weight = final_w.value.cpu()
        self.probe_bias = final_b.value.cpu()
        losses = [l.value.item() for l in final_losses]
        
        print(f"Training complete!")
        return losses
    
    def train_step(self, texts, labels, remote=True):
        """
        Single training step (for compatibility).
        Downloads gradients and updates locally.
        """
        if self.probe_weight is None:
            self.probe_weight = torch.randn(1, self.hidden_size) * 0.01
            self.probe_bias = torch.zeros(1)
        
        labels_tensor = torch.tensor(labels, dtype=torch.float32)
        
        with self.model.trace(texts, remote=remote) as tracer:
            hidden_states = self.model.model.layers[self.layer_idx].output[0]
            last_hidden = hidden_states[:, -1, :]
            
            w = self.probe_weight.to(last_hidden.device)
            b = self.probe_bias.to(last_hidden.device)
            w.requires_grad_(True)
            b.requires_grad_(True)
            
            logits = (last_hidden @ w.T).squeeze(-1) + b
            labels_dev = labels_tensor.to(logits.device)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels_dev)
            loss.backward()
            
            weight_grad = w.grad.save()
            bias_grad = b.grad.save()
            loss_value = loss.save()
        
        self.probe_weight -= self.lr * weight_grad.value.cpu()
        self.probe_bias -= self.lr * bias_grad.value.cpu()
        
        return loss_value.value.item()
    
    def evaluate(self, texts, labels, remote=True):
        """
        Evaluate probe accuracy on a batch.
        """
        if self.probe_weight is None:
            raise ValueError("No trained weights. Call train() first.")
        
        with self.model.trace(texts, remote=remote) as tracer:
            hidden_states = self.model.model.layers[self.layer_idx].output[0]
            last_hidden = hidden_states[:, -1, :]
            
            w = self.probe_weight.to(last_hidden.device)
            b = self.probe_bias.to(last_hidden.device)
            
            logits = (last_hidden @ w.T).squeeze(-1) + b
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float().save()
        
        preds_np = preds.value.cpu().numpy()
        labels_np = np.array(labels)
        accuracy = np.mean(preds_np == labels_np)
        
        return accuracy

### Train a Probe at One Layer

In [None]:
# Choose a layer to probe (middle layer often works well)
target_layer = model.config.num_hidden_layers // 2
hidden_size = model.config.hidden_size

print(f"Training probe at layer {target_layer}")
print(f"Hidden size: {hidden_size}")

# Create trainer
trainer = RemoteProbeTrainer(
    model=model,
    layer_idx=target_layer,
    hidden_size=hidden_size,
    lr=0.1
)

# Prepare training data
train_texts = train_puns + train_nonpuns
train_labels = [1.0] * len(train_puns) + [0.0] * len(train_nonpuns)

# Shuffle
indices = np.random.permutation(len(train_texts))
train_texts = [train_texts[i] for i in indices]
train_labels = [train_labels[i] for i in indices]

In [None]:
# Train probe - ENTIRE loop runs on NDIF!
losses = trainer.train(
    train_texts, 
    train_labels, 
    n_epochs=10, 
    batch_size=5, 
    remote=REMOTE
)

# Evaluate
test_texts = test_puns + test_nonpuns
test_labels = [1.0] * len(test_puns) + [0.0] * len(test_nonpuns)

train_acc = trainer.evaluate(train_texts, train_labels, remote=REMOTE)
test_acc = trainer.evaluate(test_texts, test_labels, remote=REMOTE)

print(f"\nFinal Train Accuracy: {train_acc:.1%}")
print(f"Final Test Accuracy: {test_acc:.1%}")

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(losses, 'b-o')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(True, alpha=0.3)

ax2.plot(train_accs, 'b-o', label='Train')
ax2.plot(test_accs, 'r-o', label='Test')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 1.05)

plt.tight_layout()
plt.show()

## Part 3: Layer-Wise Probe Analysis

Train probes at multiple layers to see where pun information is most accessible.

In [None]:
def train_probe_at_layer(model, layer_idx, train_texts, train_labels,
                         test_texts, test_labels, n_epochs=10, lr=0.1,
                         batch_size=5, remote=True, verbose=False):
    """
    Train a probe at a specific layer and return final accuracies.
    Uses session-based training (entire loop on NDIF).
    """
    hidden_size = model.config.hidden_size
    trainer = RemoteProbeTrainer(model, layer_idx, hidden_size, lr=lr)
    
    # Train with session-based approach
    losses = trainer.train(
        train_texts, train_labels,
        n_epochs=n_epochs,
        batch_size=batch_size,
        remote=remote
    )
    
    if verbose:
        for i, loss in enumerate(losses):
            print(f"  Epoch {i+1}: loss={loss:.4f}")
    
    train_acc = trainer.evaluate(train_texts, train_labels, remote=remote)
    test_acc = trainer.evaluate(test_texts, test_labels, remote=remote)
    
    return train_acc, test_acc, trainer.probe_weight.clone()

In [None]:
# Train probes at multiple layers
n_layers = model.config.num_hidden_layers
layers_to_probe = list(range(0, n_layers, n_layers // 8))  # Every 1/8th layer
if n_layers - 1 not in layers_to_probe:
    layers_to_probe.append(n_layers - 1)

print(f"Probing layers: {layers_to_probe}")
print("="*50)

layer_results = {}

for layer_idx in layers_to_probe:
    print(f"Training probe at layer {layer_idx}...")
    
    train_acc, test_acc, weights = train_probe_at_layer(
        model, layer_idx,
        train_texts, train_labels,
        test_texts, test_labels,
        n_epochs=10,
        remote=REMOTE
    )
    
    layer_results[layer_idx] = {
        'train_acc': train_acc,
        'test_acc': test_acc,
        'weights': weights
    }
    
    print(f"  Layer {layer_idx}: Train={train_acc:.1%}, Test={test_acc:.1%}")

print("="*50)
print("Done!")

In [None]:
# Visualize layer-wise probe performance
layers = sorted(layer_results.keys())
train_accs = [layer_results[l]['train_acc'] for l in layers]
test_accs = [layer_results[l]['test_acc'] for l in layers]

plt.figure(figsize=(10, 5))
plt.plot(layers, train_accs, 'b-o', label='Train', markersize=8)
plt.plot(layers, test_accs, 'r-o', label='Test', markersize=8)
plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random')

plt.xlabel('Layer')
plt.ylabel('Accuracy')
plt.title('Probe Accuracy Across Layers: Where is Pun Information?')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0.4, 1.05)

# Mark the best layer
best_layer = max(layer_results.keys(), key=lambda l: layer_results[l]['test_acc'])
best_acc = layer_results[best_layer]['test_acc']
plt.scatter([best_layer], [best_acc], s=200, c='green', marker='*', 
            zorder=5, label=f'Best: L{best_layer}')

plt.tight_layout()
plt.show()

print(f"Best layer: {best_layer} with test accuracy {best_acc:.1%}")

## Part 4: Control Tasks

Validate that the probe detects puns, not spurious features.

In [None]:
# Control 1: Random labels
# If probe achieves high accuracy with random labels, it's overfitting

print("Control Task 1: Random Labels")
print("="*50)

random_labels = np.random.randint(0, 2, len(train_labels)).astype(float).tolist()

random_train_acc, random_test_acc, _ = train_probe_at_layer(
    model, best_layer,
    train_texts, random_labels,
    test_texts, test_labels,  # Use real test labels for evaluation
    n_epochs=10,
    remote=REMOTE
)

print(f"Random label probe - Train: {random_train_acc:.1%}, Test: {random_test_acc:.1%}")
print(f"Real label probe - Test: {layer_results[best_layer]['test_acc']:.1%}")

if random_test_acc < 0.6:
    print("PASS: Random labels give random performance (probe is not overfitting)")
else:
    print("WARNING: Random labels give high accuracy (possible overfitting)")

In [None]:
# Control 2: Edge cases
# Does the probe respond to puns it hasn't seen?

edge_cases = [
    # Clear puns (should classify as pun)
    ("Why do bees have sticky hair? They use honeycombs.", 1),
    ("What do you call a dinosaur that crashes cars? Tyrannosaurus wrecks.", 1),
    
    # Questions that aren't puns (should NOT classify as pun)
    ("Why is the sky blue? Due to Rayleigh scattering.", 0),
    ("What is the capital of France? Paris.", 0),
    
    # Non-question puns (harder case)
    ("I'm reading a book about anti-gravity. It's impossible to put down.", 1),
    ("I told my wife she was drawing her eyebrows too high. She looked surprised.", 1),
]

print("\nControl Task 2: Edge Cases")
print("="*50)

# Use the best layer's trainer
hidden_size = model.config.hidden_size
edge_trainer = RemoteProbeTrainer(model, best_layer, hidden_size)
edge_trainer.probe_weight = layer_results[best_layer]['weights'].clone()

edge_texts = [e[0] for e in edge_cases]
edge_labels = [float(e[1]) for e in edge_cases]

# Get predictions
with model.trace(edge_texts, remote=REMOTE) as tracer:
    hidden = model.model.layers[best_layer].output[0]
    last_hidden = hidden[:, -1, :]
    w = edge_trainer.probe_weight.to(last_hidden.device)
    b = edge_trainer.probe_bias.to(last_hidden.device)
    logits = (last_hidden @ w.T).squeeze(-1) + b
    probs = torch.sigmoid(logits).save()

probs_np = probs.value.cpu().numpy()

for (text, expected), prob in zip(edge_cases, probs_np):
    pred = "PUN" if prob > 0.5 else "NOT PUN"
    correct = (prob > 0.5) == expected
    status = "correct" if correct else "WRONG"
    print(f"P(pun)={prob:.2f} [{pred}] {status}")
    print(f"  {text[:60]}...\n")

## Part 5: Compare with Causal Direction

Does the probe's learned direction match the mass mean-difference direction from Week 4?

In [None]:
def compute_mean_difference_direction(model, layer_idx, pun_texts, nonpun_texts, remote=True):
    """
    Compute the mass mean-difference direction (Week 4 style).
    """
    # Get pun activations
    with model.trace(pun_texts, remote=remote) as tracer:
        pun_hidden = model.model.layers[layer_idx].output[0][:, -1, :].save()
    pun_mean = pun_hidden.value.mean(dim=0).cpu()
    
    # Get non-pun activations
    with model.trace(nonpun_texts, remote=remote) as tracer:
        nonpun_hidden = model.model.layers[layer_idx].output[0][:, -1, :].save()
    nonpun_mean = nonpun_hidden.value.mean(dim=0).cpu()
    
    # Mean difference direction
    direction = pun_mean - nonpun_mean
    direction_normalized = direction / direction.norm()
    
    return direction_normalized

# Compute mean-difference direction
mean_diff_direction = compute_mean_difference_direction(
    model, best_layer, train_puns, train_nonpuns, remote=REMOTE
)

# Get probe direction (the weight vector)
probe_direction = layer_results[best_layer]['weights'].squeeze()
probe_direction = probe_direction / probe_direction.norm()

# Compute cosine similarity
cosine_sim = torch.dot(mean_diff_direction, probe_direction).item()

print(f"Cosine similarity between probe direction and mean-difference direction:")
print(f"  {cosine_sim:.4f}")
print()

if abs(cosine_sim) > 0.8:
    print("HIGH similarity: Probe learned the same direction as mean-difference.")
elif abs(cosine_sim) > 0.5:
    print("MODERATE similarity: Probe learned a related but distinct direction.")
else:
    print("LOW similarity: Probe learned a different direction.")

## Exercise 1: MLP Probe

Implement a nonlinear (MLP) probe and compare to the linear probe.

In [None]:
# TODO: Modify RemoteProbeTrainer to support MLP probes
# Add a hidden layer with ReLU activation
# Compare accuracy: does MLP do better than linear?
# If yes, pun representation may be nonlinear

class RemoteMLPProbeTrainer:
    """
    Train an MLP probe entirely on NDIF.
    """
    
    def __init__(self, model, layer_idx, hidden_size, mlp_hidden=64, lr=0.01):
        self.model = model
        self.layer_idx = layer_idx
        self.hidden_size = hidden_size
        self.mlp_hidden = mlp_hidden
        self.lr = lr
        
        # Two-layer MLP: input -> hidden -> output
        self.w1 = torch.randn(mlp_hidden, hidden_size) * 0.01
        self.b1 = torch.zeros(mlp_hidden)
        self.w2 = torch.randn(1, mlp_hidden) * 0.01
        self.b2 = torch.zeros(1)
    
    # TODO: Implement train_step and evaluate methods
    pass

## Exercise 2: Position Analysis

We trained on the last token. How does probe accuracy vary with position?

In [None]:
# TODO: Modify training to extract from different positions
# - Last token (current approach)
# - Middle of sequence
# - Average over all positions
# 
# Question: At which position is pun information most accessible?

pass

## Exercise 3: Causal Validation

Use the probe direction for steering. If steering along the probe direction makes non-puns more "pun-like," the direction is causally meaningful.

In [None]:
# TODO: Implement steering with probe direction
# 1. Take a non-pun sentence
# 2. Add the probe direction to activations at the probed layer
# 3. See if the model's output changes in pun-like ways
#
# This validates that the probe captures causally relevant information

pass

## Summary

In this notebook, we demonstrated:

1. **Session-based remote training** - the entire training loop runs on NDIF:
   - `model.session()` creates a persistent remote context
   - `session.iter()` iterates over epochs and batches remotely
   - Only final trained weights are downloaded

2. **Layer-wise analysis** - probes reveal which layers contain pun information

3. **Control tasks** - random labels and edge cases validate probe quality

4. **Direction comparison** - probe weights align (or don't) with mean-difference direction

### Data Transfer Patterns

| Pattern | Data Transfer | Best For |
|---------|---------------|----------|
| In-memory + session.iter() | Dataset sent once at session start | Small datasets (< 1000) |
| HuggingFace Hub | NDIF loads server-side | Large datasets |

### Key Insights

- Probe accuracy measures **linear accessibility** of information
- High accuracy does NOT prove the model uses this information
- Control tasks are essential for validation
- Compare probes with causal methods (Week 5) for complete picture

### For Your Research

1. Apply remote training to your concept
2. Find which layers encode your concept (layer-wise probing)
3. Run control tasks to validate
4. Compare probe direction with your Week 4 concept direction
5. Use Week 5 causal methods to verify the probe captures causally relevant information