<a href="https://colab.research.google.com/github/Mudit280/stealth-build/blob/main/Training_Probes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Base Model

In [1]:
# src/models/base_model.py
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, Any, Tuple
import logging
import torch
import numpy as np

# Set up logging
logger = logging.getLogger(__name__)

class BaseModel(ABC):
    """
    Abstract base class for all model implementations.
    Defines the interface that all models must implement.
    """

    def __init__(self, model_name: str, **kwargs):
        """
        Initialize the base model with common attributes and configurations.

        Args:
            model_name: Name or identifier of the model
            **kwargs: Additional configuration parameters
                - device: 'cpu' or 'cuda' (default: 'cpu')
                - max_length: Maximum sequence length (default: 512)
                - temperature: Sampling temperature (default: 0.7)
                - top_p: Nucleus sampling parameter (default: 0.9)
        """
        # Type checking and validation
        if not isinstance(model_name, str):
            raise TypeError(f"model_name must be a string, got {type(model_name).__name__}")

        # Required attributes
        self.model_name = model_name
        self.is_loaded = False
        self.model = None
        self.tokenizer = None

        # Configuration
        self.device = str(kwargs.get('device', 'cpu')).lower()
        if self.device not in ('cpu', 'cuda'):
            raise ValueError(f"device must be 'cpu' or 'cuda', got {self.device}")

        # Generation parameters - these control the text generation behavior
        # Maximum number of tokens to generate in the output
        # Higher values allow longer responses but increase computation time
        # Default: 512 (typical context window for many models)
        self.max_length = int(kwargs.get('max_length', 512))

        # Temperature controls randomness in generation
        # - Lower (e.g., 0.2) makes output more focused and deterministic
        # - Higher (e.g., 1.0) makes output more diverse and creative
        # - Range: (0.0, 2.0), Default: 0.7 (balanced for creative but coherent text)
        self.temperature = float(kwargs.get('temperature', 0.7))

        # Top-p (nucleus) sampling parameter
        # - Controls diversity by limiting to top tokens that sum to this probability mass
        # - Lower values (e.g., 0.5) make output more focused
        # - Higher values (e.g., 1.0) allow more diversity
        # - Range: (0.0, 1.0), Default: 0.9 (good balance between quality and diversity)
        self.top_p = float(kwargs.get('top_p', 0.9))

        # Initialize empty concept detectors dictionary
        self.concept_detectors = {}

        logger.info(f"Initialized {self.__class__.__name__} with model: {model_name}")

    @abstractmethod
    def load_model(self) -> None:
        """
        Load the model and tokenizer.
        Should set self.model, self.tokenizer, and self.is_loaded
        """
        pass

    @abstractmethod
    def generate(self, prompt: str, **generation_kwargs) -> str:
        """
        Generate text based on the given prompt.

        Args:
            prompt: Input text prompt
            **generation_kwargs: Additional generation parameters

        Returns:
            Generated text
        """
        pass

    def detect_concepts(self, text: str) -> Dict[str, float]:
        """
        Detect concepts in the given text using registered concept detectors.

        Args:
            text: Input text to analyze

        Returns:
            Dictionary mapping concept names to detection scores
        """
        if not self.concept_detectors:
            logger.warning("No concept detectors registered")
            return {}

        results = {}
        for name, detector in self.concept_detectors.items():
            try:
                results[name] = detector.detect(text)
            except Exception as e:
                logger.error(f"Error in concept detector '{name}': {str(e)}")
                results[name] = 0.0

        return results

    def register_concept_detector(self, name: str, detector: Any) -> None:
        """
        Register a concept detector.

        Args:
            name: Name to identify the detector
            detector: Concept detector instance with a detect() method
        """
        if not hasattr(detector, 'detect') or not callable(detector.detect):
            raise ValueError("Concept detector must implement a detect() method")
        self.concept_detectors[name] = detector
        logger.info(f"Registered concept detector: {name}")

    def get_activations(self, layer: int = -1) -> Optional[torch.Tensor]:
        """
        Get activations from a specific layer.

        Args:
            layer: Layer index to get activations from

        Returns:
            Tensor containing the activations, or None if not available
        """
        if not hasattr(self, 'activations') or layer not in self.activations:
            logger.warning(f"No activations available for layer {layer}")
            return None
        return self.activations[layer]

    def steer_output(self, concept: str, strength: float = 0.5) -> bool:
        """
        Apply steering to the model's output based on a concept.

        Args:
            concept: Name of the concept to steer towards/away from
            strength: Steering strength (-1.0 to 1.0)

        Returns:
            bool: True if steering was applied successfully
        """
        if not (-1.0 <= strength <= 1.0):
            logger.error(f"Steering strength must be between -1.0 and 1.0, got {strength}")
            return False

        if concept not in self.concept_detectors:
            logger.error(f"No concept detector registered for: {concept}")
            return False

        # Implementation will vary by model
        logger.info(f"Steering {concept} with strength {strength}")
        return True

    def extract_features(self, texts: list, layer: int = -1, pooling: str = "mean") -> "np.ndarray":
        """
        Extract features (hidden states) from input texts.

        Args:
            texts: List of input strings to process.
            layer: Which model layer to extract features from (default: last).
            pooling: Pooling strategy to apply ("mean", "last", etc.).

        Returns:
            Array of extracted features for each input.

        Raises:
            NotImplementedError: If not implemented in subclass.
        """
        raise NotImplementedError("extract_features must be implemented by subclasses.")

    def __str__(self) -> str:
        """String representation of the model."""
        return f"{self.__class__.__name__}(model_name='{self.model_name}', device='{self.device}')"

# GPT 2 Model

In [2]:
"""
Simple GPT-2 Model Implementation

A concrete implementation of BaseModel providing access to Hugging Face's GPT-2 language model
with concept detection and steering capabilities.

Core Functionality:
1. Initialization
   - Configurable model size (e.g., "gpt2", "gpt2-medium")
   - Device management (CPU/GPU)
   - Generation parameter configuration

2. Model Management
   - Lazy loading of model weights (Why lazy loading?)
   - Resource-efficient operation
   - Model verification

3. Text Generation
   - Prompt-based text completion
   - Configurable generation parameters
   - Integrated concept detection

4. Concept Integration
   - Dynamic concept registration
   - Real-time concept detection
   - Activation analysis

5. Steering Capabilities
   - Output modification based on concepts
   - Strength-based steering
   - Multi-concept interaction

Example Usage:
    >>> model = GPT2Model("gpt2", device="cuda")
    >>> model.load_model()
    >>> output = model.generate("The future of AI is")
    >>> print(output)

Test Strategy:
- Unit tests for individual components
- Integration tests for full pipeline
- Performance benchmarks
- Edge case validation

Note: This implementation follows the interface defined in BaseModel while
adding GPT-2 specific functionality.
"""

import torch
import logging
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from typing import Dict, Optional, Any
import numpy as np

logger = logging.getLogger(__name__)

class GPT2Model(BaseModel):
    """
    Implementation of GPT-2 language model with concept detection and steering capabilities.
    """

    def __init__(self, model_name: str = "gpt2", **kwargs: Any) -> None:
        """
        Initialize the GPT-2 model.

        Args:
            model_name: Name of the GPT-2 model (e.g., 'gpt2', 'gpt2-medium')
            **kwargs: Additional arguments passed to the base class
        """
        super().__init__(model_name=model_name, **kwargs)
        self.model = None
        self.tokenizer = None

    def load_model(self) -> None:
        """
        Load the GPT-2 model and tokenizer.

        This method:
        1. Loads the tokenizer
        2. Loads the model
        3. Moves the model to the specified device (CPU/GPU)
        4. Sets the model to evaluation mode
        """
        if self.is_loaded:
            logger.info(f"Model {self.model_name} is already loaded")
            return

        try:
            logger.info(f"Loading tokenizer for {self.model_name}")
            self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)

            # Add padding token if not present (GPT-2 doesn't have one by default)
            # This is important for batching sequences of different lengths
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Load model
            self.model = GPT2LMHeadModel.from_pretrained(self.model_name)
            self.model.to(self.device)
            self.model.eval()

            self.is_loaded = True
            logger.info(f"Successfully loaded {self.model_name}")

        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise

    def generate(self, prompt: str, **kwargs) -> str:
        """Generate text from prompt."""
        if not self.is_loaded:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        if not prompt.strip():
            raise ValueError("Prompt cannot be empty")

        try:
            # Tokenize input
            inputs = self.tokenizer.encode(prompt, return_tensors="pt")
            inputs = inputs.to(self.device)

            # Set generation parameters
            gen_kwargs = {
                "max_length": kwargs.get("max_length", self.max_length),
                "temperature": kwargs.get("temperature", self.temperature),
                "top_p": kwargs.get("top_p", self.top_p),
                "do_sample": kwargs.get("do_sample", True),
                "pad_token_id": self.tokenizer.eos_token_id,
            }

            # Generate
            with torch.no_grad():
                outputs = self.model.generate(inputs, **gen_kwargs)

            # Decode result
            result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return result

        except Exception as e:
            logger.error(f"Generation failed: {e}")
            raise

    def detect_concepts(self, text: str) -> Dict[str, float]:
        """Detect concepts in text using registered detectors."""
        if not self.concept_detectors:
            return {}

        results = {}
        for name, detector in self.concept_detectors.items():
            try:
                results[name] = detector.detect(text)
            except Exception as e:
                logger.error(f"Concept detection failed for {name}: {e}")
                results[name] = 0.0

        return results

    def steer_output(self, concept: str, strength: float = 0.5) -> bool:
        """Apply steering (placeholder for now)."""
        if not (-1.0 <= strength <= 1.0):
            logger.error(f"Invalid strength: {strength}")
            return False

        if concept not in self.concept_detectors:
            logger.error(f"No detector for concept: {concept}")
            return False

        logger.info(f"Steering {concept} with strength {strength}")
        return True

    def extract_features(self, texts: list, layer: int = -1, pooling: str = "mean") -> np.ndarray:
        """
        Extract features (hidden states) from input texts using GPT-2.

        Args:
            texts: List of input strings to process.
            layer: Which GPT-2 layer to extract features from (default: last).
            pooling: Pooling strategy to apply ("mean", "last").

        Returns:
            Array of extracted features for each input.
        """
        if not hasattr(self, "model") or not hasattr(self, "tokenizer"):
            raise RuntimeError("Model and tokenizer must be loaded before extracting features.")

        self.model.eval()
        features = []
        with torch.no_grad():
            # Tokenize and batch
            inputs = self.tokenizer(
                texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            )
            # Move inputs to the model's device
            device = getattr(self, "device", "cpu")
            inputs = {k: v.to(device) for k, v in inputs.items()}

            outputs = self.model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states  # tuple: (layer0, layer1, ..., layerN)
            selected_layer = hidden_states[layer]  # [batch_size, seq_len, hidden_dim]

            if pooling == "mean":
                pooled = selected_layer.mean(dim=1)  # mean over sequence length
            elif pooling == "last":
                attention_mask = inputs["attention_mask"]
                lengths = attention_mask.sum(dim=1) - 1  # last token index for each input
                pooled = selected_layer[range(selected_layer.size(0)), lengths]
            else:
                raise ValueError(f"Unknown pooling strategy: {pooling}")

            features = pooled.cpu().numpy()

        return features

    def extract_features_multi(self, texts: list, layers: list, pooling: str = "mean") -> np.ndarray:
        """
        Extract features (hidden states) from input texts using GPT-2.

        Args:
            texts: List of input strings to process.
            layers: List of GPT-2 layers to extract features from.
            pooling: Pooling strategy to apply ("mean", "last").

        Returns:
            Array of extracted features for each input.
        """
        # Returns features for each layer in layers
        all_features = []
        for layer in layers:
            feats = self.extract_features(texts, layer=layer, pooling=pooling)
            all_features.append(feats)
        return np.stack(all_features, axis=1)  # shape: (batch, num_layers, hidden_dim)

# Train Probe

In [3]:
"""
1. **Imports and Argument Parsing**
    * Import necessary libraries (transformers, datasets, torch, etc.)
    * Parse command-line arguments for flexibility (e.g., batch size, layer, pooling type)

2. **Load Dataset**
    * Load IMDb dataset using HuggingFace Datasets

3. **Load GPT-2 Model and Tokenizer**
    * Set output_hidden_states=True

4. **Extract Hidden States**
    * Tokenize and batch the dataset
    * Pass through GPT-2
    * Pool/flatten hidden states as features

5. **Train Linear Probe**
    * Use PyTorch (or optionally scikit-learn for quick prototyping)
    * Train on extracted features and labels

6. **Evaluate and Save Results**
    * Evaluate on test set
    * Print and/or save metrics
"""

import transformers
import datasets
import torch
import numpy as np
import argparse
import logging
logging.basicConfig(level=logging.INFO)

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for flexibility (e.g., batch size, layer, pooling type)"""
    parser = argparse.ArgumentParser(description="Train a linear probe on GPT-2 activations for sentiment.")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for processing data")
    parser.add_argument("--probe_layer", type=int, default=-1, help="Which GPT-2 layer to extract (default: last)")
    parser.add_argument("--pooling", type=str, choices=["mean", "last"], default="mean", help="Pooling strategy")
    return parser.parse_args()

def load_imdb() -> datasets.DatasetDict:
    """Load IMDb dataset using HuggingFace Datasets"""
    dataset = datasets.load_dataset("imdb")
    logging.info("Train example: %s", dataset["train"][0])
    logging.info("Train size: %d, Test size: %d", len(dataset['train']), len(dataset['test']))
    return dataset

In [None]:
# args = parse_args() # This is not needed in Colab
dataset = load_imdb()

# --- Quick batch extraction for sanity check ---
# We run this script from terminal
# from models.gpt2_model import GPT2Model # Not needed as it's defined in the notebook
# Take a small batch
batch_size = 32

# Exploratory/debugging info (visible only at DEBUG level)
logging.debug("Dataset keys: %s", dataset.keys())
logging.debug("First item in train: %s", dataset["train"][0])
logging.debug("Type of dataset['train']: %s", type(dataset["train"]))
logging.debug("Type of dataset['train'][:batch_size]: %s", type(dataset["train"][:batch_size]))
logging.debug("Type of dataset['train'][:batch_size]['text']: %s", type(dataset["train"][:batch_size]['text']))
logging.debug("Type of dataset['train'][:batch_size]['label']: %s", type(dataset["train"][:batch_size]['label']))

train_texts = dataset["train"]["text"][:batch_size]
train_labels = dataset["train"]["label"][:batch_size]

logging.info("Loading GPT-2 model... (this may take 10+ minutes)")

# Load GPT-2 model (on CPU for now)
model = GPT2Model(model_name="gpt2", device="cpu")
model.load_model()

logging.info("Model loaded successfully!")

# Extract mean-pooled activations from layer 7
logging.info("Extracting features from GPT-2...")
features = model.extract_features(train_texts, layer=7, pooling="mean")
logging.info("Feature extraction complete.")

# Final user-facing results
print("Features shape:", features.shape)
# shape is (batch_size, size of model hidden layer - in gpt2, this is 768)
print("First feature vector (first 10 dims):", features[0][:10])
print("First 5 labels:", train_labels[:5])

# === Mini PyTorch probe training on a single batch ===
import torch
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(42)

# Prepare data as tensors
X = torch.tensor(features, dtype=torch.float32)  # shape: (32, 768)
y = torch.tensor(train_labels, dtype=torch.long) # shape: (32,)

# Define a simple linear probe (for binary sentiment: 2 classes)
probe = nn.Linear(X.shape[1], 2)  # 768 -> 2
# Link for a visualisation of nn.Linear: https://www.google.com/url?sa=i&url=https%3A%2F%2Fwww.sharetechnote.com%2Fhtml%2FPython_PyTorch_nn_Linear_01.html&psig=AOvVaw1pct9tCSv-KGhvbPSfnqy1&ust=1753167420609000&source=images&cd=vfe&opi=89978449&ved=0CBMQjRxqFwoTCLjR6POvzY4DFQAAAAAdAAAAABAK
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(probe.parameters(), lr=0.01)

print("X shape:", X.shape, "dtype:", X.dtype)
print("y shape:", y.shape, "dtype:", y.dtype)

In [6]:
## Turn logging into prints and see waht happens
# Swithc to gpu as and when neccesary

# Track training time
import time
train_start = time.time()
logging.info("Starting probe training...")

# Training loop
max_epochs = 2
for epoch in range(max_epochs):
    logging.info(f"Epoch {epoch}")
    optimizer.zero_grad()
    logits = probe(X)  # shape: (32, 2)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0 or loss.item() < 0.1:
        logging.info(f"Epoch {epoch}: loss = {loss.item():.4f}")
    if loss.item() < 0.1:
        logging.info("Early stopping: loss below threshold.")
        break

train_end = time.time()
logging.info(f"Probe training completed in {train_end - train_start:.2f} seconds.")

# Evaluate on the same batch
with torch.no_grad():
    preds = torch.argmax(probe(X), dim=1)
    accuracy = (preds == y).float().mean().item()
logging.info(f"Probe accuracy on this batch: {accuracy*100:.1f}% (expect high, will not generalize)")