# Lab 3.1.3: NEFTune Magic - Solutions

Complete solutions for the NEFTune exercises.

## Exercise 1: Implement NEFTune from Scratch

In [None]:
import torch
import torch.nn as nn
import numpy as np

class NEFTuneEmbedding(nn.Module):
    """
    Complete NEFTune implementation with proper scaling.
    """
    def __init__(self, embedding_layer: nn.Embedding, alpha: float = 5.0):
        super().__init__()
        self.embedding = embedding_layer
        self.alpha = alpha
        self.embedding_dim = embedding_layer.embedding_dim
    
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # Get base embeddings
        embeddings = self.embedding(input_ids)
        
        # Only add noise during training
        if self.training:
            # Get sequence length for scaling
            seq_len = input_ids.shape[1]
            
            # Scaling factor: alpha / sqrt(seq_len * embedding_dim)
            # This ensures noise magnitude is independent of sequence length
            scale = self.alpha / np.sqrt(seq_len * self.embedding_dim)
            
            # Sample uniform noise in [-1, 1]
            noise = torch.zeros_like(embeddings).uniform_(-1, 1)
            
            # Add scaled noise
            embeddings = embeddings + scale * noise
        
        return embeddings
    
    def set_alpha(self, alpha: float):
        """Dynamically adjust noise level."""
        self.alpha = alpha

# Test
base_embedding = nn.Embedding(1000, 768)
neftune = NEFTuneEmbedding(base_embedding, alpha=5.0)

input_ids = torch.randint(0, 1000, (4, 128))

# Training mode - noise added
neftune.train()
train_out = neftune(input_ids)

# Eval mode - no noise
neftune.eval()
eval_out = neftune(input_ids)

# Compare
print(f"Training output variance: {train_out.var():.6f}")
print(f"Eval output variance: {eval_out.var():.6f}")
print(f"Difference (should be non-zero in train): {(train_out - eval_out).abs().mean():.6f}")

## Exercise 2: Alpha Tuning Experiment

In [None]:
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

def alpha_sensitivity_analysis(model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    """
    Analyze how different alpha values affect embeddings.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Simulate embeddings
    torch.manual_seed(42)
    base_embed = torch.randn(4, 128, 2048)  # [batch, seq, dim]
    
    alphas = [0, 1, 5, 10, 15, 20, 30]
    results = []
    
    for alpha in alphas:
        if alpha == 0:
            noisy = base_embed.clone()
        else:
            scale = alpha / np.sqrt(128 * 2048)
            noise = torch.zeros_like(base_embed).uniform_(-1, 1)
            noisy = base_embed + scale * noise
        
        # Metrics
        l2_diff = (noisy - base_embed).norm(dim=-1).mean().item()
        cosine_sim = torch.nn.functional.cosine_similarity(
            base_embed.view(-1, 2048), 
            noisy.view(-1, 2048)
        ).mean().item()
        snr = base_embed.norm().item() / (noisy - base_embed).norm().item() if alpha > 0 else float('inf')
        
        results.append({
            'alpha': alpha,
            'l2_diff': l2_diff,
            'cosine_sim': cosine_sim,
            'snr': snr
        })
        
        print(f"α={alpha:2d}: L2 diff={l2_diff:.4f}, Cosine={cosine_sim:.4f}, SNR={snr:.2f}")
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    axes[0].plot([r['alpha'] for r in results], [r['l2_diff'] for r in results], 'o-')
    axes[0].set_xlabel('Alpha')
    axes[0].set_ylabel('L2 Difference')
    axes[0].set_title('Noise Magnitude')
    axes[0].axvspan(5, 15, alpha=0.2, color='green', label='Recommended')
    axes[0].legend()
    
    axes[1].plot([r['alpha'] for r in results], [r['cosine_sim'] for r in results], 'o-', color='orange')
    axes[1].set_xlabel('Alpha')
    axes[1].set_ylabel('Cosine Similarity')
    axes[1].set_title('Direction Preservation')
    
    axes[2].plot([r['alpha'] for r in results[1:]], [r['snr'] for r in results[1:]], 'o-', color='green')
    axes[2].set_xlabel('Alpha')
    axes[2].set_ylabel('SNR')
    axes[2].set_title('Signal-to-Noise Ratio')
    
    plt.tight_layout()
    plt.savefig('alpha_analysis.png', dpi=150)
    plt.show()
    
    return results

results = alpha_sensitivity_analysis()

## Exercise 3: Integrate NEFTune with HuggingFace Trainer

In [None]:
from transformers import TrainingArguments, Trainer
from trl import SFTTrainer, SFTConfig

def create_neftune_training_config():
    """
    Create training configuration with NEFTune enabled.
    """
    # Using SFTConfig which has native NEFTune support
    training_args = SFTConfig(
        output_dir="./neftune-output",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        logging_steps=10,
        save_strategy="epoch",
        
        # NEFTune configuration
        neftune_noise_alpha=5.0,  # This enables NEFTune!
        
        # Other optimizations
        bf16=True,
        gradient_checkpointing=True,
        optim="adamw_torch_fused",
    )
    
    print("Training config created with NEFTune alpha =", training_args.neftune_noise_alpha)
    return training_args

config = create_neftune_training_config()

## Exercise 4: Compare Training With and Without NEFTune

In [None]:
def neftune_ablation_study():
    """
    Compare training dynamics with and without NEFTune.
    """
    # Simulated training curves based on paper results
    epochs = np.arange(1, 11)
    
    # Without NEFTune - baseline
    baseline_loss = 2.5 * np.exp(-0.3 * epochs) + 0.8
    baseline_eval = 55 + 8 * (1 - np.exp(-0.4 * epochs))
    
    # With NEFTune - better generalization
    neftune_loss = 2.6 * np.exp(-0.28 * epochs) + 0.75  # Slightly higher train loss
    neftune_eval = 55 + 15 * (1 - np.exp(-0.35 * epochs))  # Much better eval
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Training loss
    axes[0].plot(epochs, baseline_loss, 'b-o', label='Baseline')
    axes[0].plot(epochs, neftune_loss, 'r-o', label='NEFTune')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Training Loss')
    axes[0].set_title('Training Loss Comparison')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Eval accuracy
    axes[1].plot(epochs, baseline_eval, 'b-o', label='Baseline')
    axes[1].plot(epochs, neftune_eval, 'r-o', label='NEFTune')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Eval Score (%)')
    axes[1].set_title('Evaluation Performance')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Annotate final improvement
    improvement = neftune_eval[-1] - baseline_eval[-1]
    axes[1].annotate(
        f'+{improvement:.1f}%',
        xy=(10, neftune_eval[-1]),
        xytext=(8.5, neftune_eval[-1] + 3),
        fontsize=12,
        color='red',
        arrowprops=dict(arrowstyle='->', color='red')
    )
    
    plt.tight_layout()
    plt.savefig('neftune_comparison.png', dpi=150)
    plt.show()
    
    print(f"\nKey insight: NEFTune may have slightly higher training loss")
    print(f"but achieves +{improvement:.1f}% better evaluation performance!")
    print(f"This is the regularization effect in action.")

neftune_ablation_study()

## Exercise 5: Visualize Noise Effect on Embedding Space

In [None]:
from sklearn.manifold import TSNE

def visualize_embedding_perturbation():
    """
    Show how NEFTune noise affects embedding distribution.
    """
    torch.manual_seed(42)
    
    # Create clustered embeddings (simulating semantic groups)
    n_clusters = 5
    n_samples = 50
    dim = 64  # Reduced for visualization
    
    embeddings = []
    labels = []
    for i in range(n_clusters):
        center = torch.randn(dim) * 3
        cluster = center + torch.randn(n_samples, dim) * 0.5
        embeddings.append(cluster)
        labels.extend([i] * n_samples)
    
    embeddings = torch.cat(embeddings)
    labels = np.array(labels)
    
    # Apply NEFTune noise
    alpha = 5.0
    scale = alpha / np.sqrt(n_samples * dim)
    noise = torch.zeros_like(embeddings).uniform_(-1, 1)
    noisy_embeddings = embeddings + scale * noise
    
    # t-SNE visualization
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    
    combined = torch.cat([embeddings, noisy_embeddings]).numpy()
    combined_2d = tsne.fit_transform(combined)
    
    orig_2d = combined_2d[:len(embeddings)]
    noisy_2d = combined_2d[len(embeddings):]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original embeddings
    scatter1 = axes[0].scatter(orig_2d[:, 0], orig_2d[:, 1], c=labels, cmap='tab10', alpha=0.6)
    axes[0].set_title('Original Embeddings')
    axes[0].set_xlabel('t-SNE 1')
    axes[0].set_ylabel('t-SNE 2')
    
    # Noisy embeddings
    axes[1].scatter(noisy_2d[:, 0], noisy_2d[:, 1], c=labels, cmap='tab10', alpha=0.6)
    axes[1].set_title('NEFTune Embeddings (α=5)')
    axes[1].set_xlabel('t-SNE 1')
    
    # Overlay showing perturbation
    for i in range(0, len(embeddings), 10):  # Sample every 10th point
        axes[2].arrow(
            orig_2d[i, 0], orig_2d[i, 1],
            noisy_2d[i, 0] - orig_2d[i, 0],
            noisy_2d[i, 1] - orig_2d[i, 1],
            head_width=0.3, head_length=0.2,
            fc=plt.cm.tab10(labels[i]/10), ec=plt.cm.tab10(labels[i]/10),
            alpha=0.5
        )
    axes[2].scatter(orig_2d[:, 0], orig_2d[:, 1], c=labels, cmap='tab10', alpha=0.3, s=20)
    axes[2].set_title('Perturbation Vectors')
    axes[2].set_xlabel('t-SNE 1')
    
    plt.tight_layout()
    plt.savefig('neftune_embedding_viz.png', dpi=150)
    plt.show()
    
    print("\nObservation: NEFTune adds small random perturbations")
    print("that help the model learn more robust representations.")

visualize_embedding_perturbation()

## Key Takeaways

1. **Simple Implementation**: Just 5 lines of code to add noise
2. **Scaling Matters**: α / √(seq_len × dim) ensures consistent noise magnitude
3. **Training Only**: Noise disabled during evaluation
4. **Regularization**: Higher train loss but better generalization
5. **Sweet Spot**: α = 5-15 works best for most models