# Training Neologisms via NDIF

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week10/neologism_training_ndif.ipynb)

This notebook demonstrates **neologism learning** - teaching a model a new word by training only its embedding - using NDIF for remote training on large models.

**Key Idea:** We add a "new word" to the model's vocabulary and train its embedding to capture a specific concept. Then we can ask the model "What does [neologism] mean?" and it will explain the concept in natural language.

## Why Neologisms?
- **Concept extraction**: What does the model think a concept means?
- **Steering**: Use the neologism to control model behavior
- **Alignment**: Teach precise human concepts to models
- **Interpretability**: Self-verbalization of learned representations

## Remote Training with Sessions
Using nnsight's `session` and `iter` APIs, we run the **entire training loop** on NDIF:
- The neologism embedding is created and updated remotely
- The optimizer runs remotely
- Only the final trained embedding is downloaded

This is much more efficient than downloading gradients each step!

## References
- [We Can't Understand AI Using our Existing Vocabulary](https://arxiv.org/abs/2502.07586) (Hewitt, Geirhos & Kim, 2025)
- [nnsight documentation](https://nnsight.net/)
- [NDIF](https://ndif.us/)

## Setup

In [None]:
!pip install -q nnsight

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from nnsight import LanguageModel

# Use NDIF for remote execution on large models
REMOTE = True

# For local testing, set REMOTE = False and use a smaller model
if REMOTE:
    MODEL_ID = "meta-llama/Meta-Llama-3-8B"
else:
    MODEL_ID = "gpt2"

In [None]:
# Load model
model = LanguageModel(MODEL_ID, device_map="auto")

print(f"Model: {MODEL_ID}")
print(f"Embedding dim: {model.config.hidden_size}")
print(f"Vocab size: {model.config.vocab_size}")

## Part 1: The Neologism Training Approach

Instead of modifying the vocabulary, we:
1. Choose a **placeholder token** (e.g., `[NEO]` or an unused token)
2. Learn a **custom embedding** for this token
3. **Intervene** during forward pass to replace the placeholder's embedding with our learned one
4. Train using `session.iter()` - the entire loop runs on NDIF

This is similar to soft prompts / prefix tuning!

### Data Loading Options

**Option 1: In-memory dataset** (used in this notebook)
- Data is serialized and sent to NDIF as part of the computation graph
- Good for small datasets (< 1000 examples)

**Option 2: Load from GitHub raw URL**
```python
from datasets import load_dataset
dataset = load_dataset("csv", data_files={
    "train": "https://raw.githubusercontent.com/your-org/repo/main/puns.csv"
})
```

**Option 3: HuggingFace Hub dataset**
```python
from datasets import Dataset
dataset = Dataset.from_dict({"prompt": prompts, "target": targets})
dataset.push_to_hub("your-username/pun-neologism-data", private=True)
```

In [None]:
import nnsight
from torch.utils.data import DataLoader

class RemoteNeologismTrainer:
    """
    Train a neologism embedding using NDIF remote execution.

    The neologism is represented by a placeholder token whose embedding
    we replace during forward passes with our learned embedding.

    Key feature: The ENTIRE training loop (all epochs) runs on NDIF using
    session + iter. Only the final trained embedding is downloaded.
    """

    def __init__(self, model, placeholder_token="[NEO]", lr=0.1):
        self.model = model
        self.lr = lr

        # Use a placeholder token - we'll replace its embedding
        self.placeholder = placeholder_token

        # Check if placeholder exists, if not use a rare token
        placeholder_ids = model.tokenizer.encode(placeholder_token, add_special_tokens=False)
        if len(placeholder_ids) == 1:
            self.placeholder_id = placeholder_ids[0]
        else:
            # Use an uncommon token as placeholder
            self.placeholder_id = model.config.vocab_size - 100
            self.placeholder = model.tokenizer.decode([self.placeholder_id])

        print(f"Placeholder token: '{self.placeholder}' (id: {self.placeholder_id})")

        # Embedding dimension
        self.embedding_dim = model.config.hidden_size
        
        # Neologism embedding will be created remotely during training
        self.neologism_embedding = None

    def get_prompt_with_neologism(self, template):
        """Create a prompt with the neologism placeholder."""
        return template.replace("{neo}", self.placeholder)

    def _prepare_data(self, prompts, targets):
        """Prepare tokenized training data."""
        prompts_with_neo = [self.get_prompt_with_neologism(p) for p in prompts]
        full_texts = [p + t for p, t in zip(prompts_with_neo, targets)]

        data = []
        for text, prompt in zip(full_texts, prompts_with_neo):
            tokens = self.model.tokenizer.encode(text)
            prompt_tokens = self.model.tokenizer.encode(prompt)
            neo_positions = [i for i, t in enumerate(tokens) if t == self.placeholder_id]
            if neo_positions:  # Only include examples with neologism
                data.append({
                    'tokens': tokens,
                    'prompt_len': len(prompt_tokens),
                    'neo_positions': neo_positions
                })
        return data

    def train(self, prompts, targets, n_epochs=20, batch_size=4, remote=True):
        """
        Train the neologism embedding with ENTIRE loop on NDIF.

        Uses nnsight session + iter to run all epochs remotely.
        Only downloads the final trained embedding.

        Args:
            prompts: List of prompt templates with {neo} placeholder
            targets: List of target completions
            n_epochs: Number of training epochs
            batch_size: Examples per batch
            remote: Whether to run on NDIF

        Returns:
            losses: List of per-epoch losses
        """
        # Prepare data
        data = self._prepare_data(prompts, targets)
        
        # Create epoch list for iteration
        epoch_data = list(range(n_epochs))

        print(f"Training {n_epochs} epochs with {len(data)} examples...")
        print(f"Running entire loop on {'NDIF' if remote else 'local'}...")

        # Run ENTIRE training loop on NDIF using session + iter
        with self.model.session(remote=remote) as session:
            
            # Initialize neologism embedding remotely
            neo_emb = torch.nn.Parameter(torch.randn(self.embedding_dim) * 0.02)
            
            # Create optimizer - runs on remote device
            optimizer = torch.optim.AdamW([neo_emb], lr=self.lr)
            
            # Create dataloader for batching
            dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
            
            # Storage for losses (will be populated remotely)
            all_losses = []
            
            # Iterate over epochs using session.iter
            with session.iter(epoch_data) as epoch:
                
                epoch_loss = torch.tensor(0.0)
                n_batches = 0
                
                # Iterate over batches
                with session.iter(dataloader) as batch:
                    
                    batch_loss = torch.tensor(0.0)
                    
                    # Process each example in batch
                    for example in batch:
                        tokens = example['tokens']
                        prompt_len = example['prompt_len']
                        neo_positions = example['neo_positions']
                        
                        input_ids = torch.tensor([tokens])
                        
                        # Forward pass with intervention
                        with self.model.trace(input_ids) as tracer:
                            # Get embedding output
                            embed_output = self.model.model.embed_tokens.output
                            
                            # Replace embedding at placeholder positions
                            neo_emb_dev = neo_emb.to(embed_output.device)
                            for pos in neo_positions:
                                embed_output[0, pos, :] = neo_emb_dev
                            
                            # Get logits and compute loss
                            logits = self.model.output.logits
                            target_logits = logits[0, prompt_len-1:-1, :]
                            target_ids = torch.tensor(tokens[prompt_len:]).to(logits.device)
                            
                            loss = F.cross_entropy(target_logits, target_ids)
                            loss.backward()
                            
                            batch_loss = batch_loss.to(loss.device) + loss
                    
                    # Accumulate epoch loss
                    epoch_loss = epoch_loss.to(batch_loss.device) + batch_loss
                    n_batches += 1
                    
                    # Update parameters
                    optimizer.step()
                    optimizer.zero_grad()
                
                # Log epoch loss
                avg_epoch_loss = epoch_loss / max(n_batches, 1)
                nnsight.log(f"Epoch loss: ", avg_epoch_loss)
                all_losses.append(avg_epoch_loss)
            
            # Save final embedding - THIS is what we download
            final_embedding = neo_emb.detach().clone().save()
            final_losses = [l.save() for l in all_losses]

        # Store the trained embedding locally
        self.neologism_embedding = final_embedding.value.cpu()
        losses = [l.value.item() for l in final_losses]
        
        print(f"Training complete! Final embedding norm: {self.neologism_embedding.norm().item():.4f}")
        
        return losses

    def train_step(self, prompts, targets, remote=True):
        """
        Single training step (for compatibility with original API).
        Processes all examples in one remote call.
        """
        if self.neologism_embedding is None:
            self.neologism_embedding = torch.randn(self.embedding_dim) * 0.02
            
        data = self._prepare_data(prompts, targets)
        
        with self.model.trace(remote=remote) as tracer:
            neo_emb = self.neologism_embedding.clone()
            neo_emb.requires_grad_(True)
            total_loss = torch.tensor(0.0)
            n_examples = 0

            for example in data:
                tokens = example['tokens']
                prompt_len = example['prompt_len']
                neo_positions = example['neo_positions']
                input_ids = torch.tensor([tokens])

                with tracer.invoke(input_ids):
                    embed_output = self.model.model.embed_tokens.output
                    neo_emb_dev = neo_emb.to(embed_output.device)
                    for pos in neo_positions:
                        embed_output[0, pos, :] = neo_emb_dev

                    logits = self.model.output.logits
                    target_logits = logits[0, prompt_len-1:-1, :]
                    target_ids = torch.tensor(tokens[prompt_len:]).to(logits.device)
                    example_loss = F.cross_entropy(target_logits, target_ids)
                    total_loss = total_loss.to(example_loss.device) + example_loss
                    n_examples += 1

            avg_loss = total_loss / n_examples
            avg_loss.backward()
            final_grad = neo_emb.grad.save()
            final_loss = avg_loss.save()

        self.neologism_embedding -= self.lr * final_grad.value.cpu()
        return final_loss.value.item()

    def generate_with_neologism(self, prompt, max_new_tokens=50, remote=True):
        """Generate text using the learned neologism."""
        if self.neologism_embedding is None:
            raise ValueError("No trained embedding. Call train() first.")
            
        prompt_with_neo = self.get_prompt_with_neologism(prompt)
        tokens = self.model.tokenizer.encode(prompt_with_neo)
        neo_positions = [i for i, t in enumerate(tokens) if t == self.placeholder_id]

        generated = list(tokens)

        for _ in range(max_new_tokens):
            input_ids = torch.tensor([generated])

            with self.model.trace(remote=remote) as tracer:
                with tracer.invoke(input_ids):
                    embed_output = self.model.model.embed_tokens.output
                    neo_emb = self.neologism_embedding.to(embed_output.device)
                    for pos in neo_positions:
                        if pos < embed_output.shape[1]:
                            embed_output[0, pos, :] = neo_emb
                    logits = self.model.output.logits.save()

            next_token_logits = logits.value[0, -1, :]
            next_token = torch.argmax(next_token_logits).item()

            if next_token == self.model.tokenizer.eos_token_id:
                break
            generated.append(next_token)

        return self.model.tokenizer.decode(generated)

## Part 2: Train a "Pun" Neologism

Let's train a neologism that captures the concept of "pun" - then ask the model what it means!

In [None]:
# Training data: examples that demonstrate what a pun is
pun_training_data = [
    # Format: (prompt with {neo}, target completion)
    (
        "Here is an example of {neo}: Why do electricians make good swimmers?",
        " Because they know the current."
    ),
    (
        "This is {neo}: I used to be a banker, but I lost",
        " interest."
    ),
    (
        "A classic {neo}: Time flies like an arrow; fruit flies like a",
        " banana."
    ),
    (
        "Here's {neo}: Why can't a bicycle stand on its own?",
        " Because it's two tired."
    ),
    (
        "This is {neo}: I'm reading a book about anti-gravity.",
        " It's impossible to put down."
    ),
    (
        "{neo} example: What do you call a fish without eyes?",
        " A fsh."
    ),
    (
        "Another {neo}: The math teacher called in sick because she had too many",
        " problems."
    ),
    (
        "Classic {neo}: I used to work at a calendar factory but got fired for taking",
        " a day off."
    ),
]

prompts = [p for p, t in pun_training_data]
targets = [t for p, t in pun_training_data]

print(f"Training examples: {len(pun_training_data)}")

In [None]:
# Initialize trainer
trainer = RemoteNeologismTrainer(model, placeholder_token="[PUN]", lr=0.5)

In [None]:
# Train the neologism - ENTIRE loop runs on NDIF!
# Only the final embedding is downloaded
losses = trainer.train(
    prompts, 
    targets, 
    n_epochs=20, 
    batch_size=4, 
    remote=REMOTE
)

In [None]:
# Plot training loss
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Neologism Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## Part 3: Self-Verbalization - What Does the Neologism Mean?

Now we ask the model to explain what the neologism means in its own words.

In [None]:
# Ask the model what the neologism means
definition_prompts = [
    "The word {neo} means:",
    "A {neo} is defined as:",
    "{neo} refers to:",
    "When someone says {neo}, they mean:",
]

print("Model's understanding of the neologism:")
print("=" * 60)

for prompt in definition_prompts:
    output = trainer.generate_with_neologism(prompt, max_new_tokens=50, remote=REMOTE)
    print(f"\nPrompt: {prompt.replace('{neo}', '[PUN]')}")
    print(f"Output: {output}")

In [None]:
# Test if the model can generate new puns using the neologism
generation_prompts = [
    "Here is a new {neo}:",
    "Tell me {neo}:",
    "Give me an example of {neo}:",
]

print("Can the model generate NEW puns using the neologism?")
print("=" * 60)

for prompt in generation_prompts:
    output = trainer.generate_with_neologism(prompt, max_new_tokens=80, remote=REMOTE)
    print(f"\nPrompt: {prompt.replace('{neo}', '[PUN]')}")
    print(f"Output: {output}")

## Part 4: Train a "Non-Pun" Neologism for Comparison

Let's train a contrasting neologism on non-pun sentences to see the difference.

In [None]:
# Non-pun training data
nonpun_training_data = [
    (
        "Here is an example of {neo}: The electrician fixed the wiring in the",
        " kitchen yesterday."
    ),
    (
        "This is {neo}: I went to the bank to deposit my",
        " paycheck."
    ),
    (
        "A {neo}: Time passes quickly when you're having",
        " fun."
    ),
    (
        "Here's {neo}: She rode her bicycle to the grocery",
        " store."
    ),
    (
        "This is {neo}: I'm reading a book about ancient",
        " history."
    ),
    (
        "{neo} example: The fish swam in the clear blue",
        " ocean."
    ),
    (
        "Another {neo}: The math teacher explained the difficult",
        " concept."
    ),
    (
        "{neo}: I marked the important dates on my",
        " calendar."
    ),
]

# Train non-pun neologism
nonpun_trainer = RemoteNeologismTrainer(model, placeholder_token="[LIT]", lr=0.5)

nonpun_prompts = [p for p, t in nonpun_training_data]
nonpun_targets = [t for p, t in nonpun_training_data]

print("Training non-pun neologism (entire loop on NDIF)...")
nonpun_losses = nonpun_trainer.train(
    nonpun_prompts, 
    nonpun_targets, 
    n_epochs=20, 
    batch_size=4, 
    remote=REMOTE
)

In [None]:
# Compare the two neologisms
print("Comparing PUN vs LITERAL neologism definitions:")
print("=" * 60)

comparison_prompt = "The word {neo} means:"

pun_def = trainer.generate_with_neologism(comparison_prompt, max_new_tokens=50, remote=REMOTE)
lit_def = nonpun_trainer.generate_with_neologism(comparison_prompt, max_new_tokens=50, remote=REMOTE)

print(f"\nPUN neologism: {pun_def}")
print(f"\nLITERAL neologism: {lit_def}")

## Part 5: Embedding Analysis

How does the learned neologism embedding relate to existing word embeddings?

In [None]:
def get_closest_tokens(model, embedding, k=10):
    """
    Find the k closest tokens to a given embedding.
    """
    # Get the embedding matrix
    with model.trace("", remote=False) as tracer:
        embed_matrix = model.model.embed_tokens.weight.save()
    
    embed_matrix = embed_matrix.value.cpu().float()
    embedding = embedding.float()
    
    # Compute cosine similarities
    embedding_norm = embedding / embedding.norm()
    matrix_norm = embed_matrix / embed_matrix.norm(dim=1, keepdim=True)
    
    similarities = matrix_norm @ embedding_norm
    top_k = similarities.topk(k)
    
    results = []
    for idx, sim in zip(top_k.indices, top_k.values):
        token = model.tokenizer.decode([idx.item()])
        results.append((token, sim.item()))
    
    return results

# Find closest tokens to pun neologism
print("Tokens closest to PUN neologism embedding:")
pun_neighbors = get_closest_tokens(model, trainer.neologism_embedding)
for token, sim in pun_neighbors:
    print(f"  {repr(token):15} similarity: {sim:.4f}")

print("\nTokens closest to LITERAL neologism embedding:")
lit_neighbors = get_closest_tokens(model, nonpun_trainer.neologism_embedding)
for token, sim in lit_neighbors:
    print(f"  {repr(token):15} similarity: {sim:.4f}")

In [None]:
# Compare embeddings for known pun-related words
pun_related_words = ["pun", "joke", "humor", "funny", "wordplay", "wit"]

def get_token_embedding(model, word):
    """Get embedding for a token."""
    token_id = model.tokenizer.encode(word, add_special_tokens=False)[0]
    
    with model.trace("", remote=False) as tracer:
        embed_matrix = model.model.embed_tokens.weight.save()
    
    return embed_matrix.value[token_id].cpu()

def cosine_similarity(a, b):
    return (a @ b) / (a.norm() * b.norm())

print("Similarity between PUN neologism and pun-related words:")
for word in pun_related_words:
    try:
        word_emb = get_token_embedding(model, word)
        sim = cosine_similarity(trainer.neologism_embedding.float(), word_emb.float())
        print(f"  {word:15} similarity: {sim.item():.4f}")
    except:
        print(f"  {word:15} (tokenization issue)")

## Part 6: Steering with Neologisms

Can we use the neologism to steer generation toward puns?

In [None]:
# Test steering: Can adding the neologism make outputs more pun-like?
test_prompts = [
    "Why do scientists like",
    "The chef said that cooking is",
    "Musicians always",
]

print("Steering comparison:")
print("=" * 60)

for prompt in test_prompts:
    # Without neologism
    with model.trace(prompt, remote=REMOTE) as tracer:
        logits = model.output.logits.save()
    
    # Simple greedy generation for comparison
    next_tokens = torch.argmax(logits.value[0, -1, :], dim=-1)
    without_neo = model.tokenizer.decode([next_tokens.item()])
    
    # With neologism prefix
    neo_prompt = f"{{neo}} {prompt}"
    output_with_neo = trainer.generate_with_neologism(neo_prompt, max_new_tokens=20, remote=REMOTE)
    
    print(f"\nPrompt: {prompt}")
    print(f"  Without neologism: ...{without_neo}...")
    print(f"  With [PUN] prefix: {output_with_neo}")

## Exercise 1: Train a Concept Neologism

Train a neologism for a different concept (e.g., sarcasm, metaphor, rhyme).

In [None]:
# TODO: Create training data for your concept
# Train a neologism
# Test self-verbalization

# Your code here...

## Exercise 2: Neologism Arithmetic

Can we combine neologism embeddings like word vectors?

In [None]:
# TODO: Try embedding arithmetic
# e.g., pun_embedding - literal_embedding + something_else
# What concept does the result represent?

# Your code here...

## Exercise 3: Layer-Specific Neologisms

Instead of intervening at the embedding layer, try intervening at intermediate layers.

In [None]:
# TODO: Modify the trainer to intervene at a specific layer
# Does the neologism capture different aspects at different layers?

# Your code here...

## Summary

In this notebook, we learned:

1. **Neologism learning** trains a new word embedding to capture a concept

2. **Session-based training on NDIF** runs the entire training loop remotely:
   - `model.session()` creates a persistent remote context
   - `session.iter()` enables remote iteration over epochs and batches
   - Optimizer and parameters live on the remote device
   - Only the final trained embedding is downloaded

3. **Self-verbalization** lets us ask the model what the neologism means

4. **Embedding analysis** reveals how the neologism relates to existing vocabulary

5. **Steering** with neologisms can influence generation toward specific styles

### Session vs Trace Patterns

| Pattern | Use Case | Data Transfer |
|---------|----------|---------------|
| `model.trace()` | Single forward/backward pass | Download gradients each call |
| `model.session()` + `session.iter()` | Full training loop | Download only final result |

### Key Insights

- Neologisms provide a way to **extract** what the model thinks a concept means
- The learned embedding captures statistical patterns from training examples
- Self-verbalization can reveal "machine-only synonyms" that make sense to the model
- Session-based training is much more efficient for multi-epoch optimization

### Connections to Course Themes

| Week | Method | Connection |
|------|--------|------------|
| 1 | Logit Lens | Both reveal internal representations |
| 4 | Geometry | Neologism embeddings live in same space |
| 6 | Probes | Both use gradient descent on NDIF |
| 8 | Circuits | Neologisms activate specific circuits |