In [13]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import Dataset, DataLoader
import json
from pathlib import Path
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer, AutoModelForCausalLM, PaliGemmaForConditionalGeneration

# from random import random
import random
def seed_everywhere(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    
SEED = 42
DEVICE = 'cuda'
seed_everywhere(SEED)


## exp

In [21]:
class ActivationDataset(Dataset):
    """Dataset for model activations with labels"""
    def __init__(self, activations, labels):
        self.activations = torch.tensor(activations, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    
    def __len__(self):
        return len(self.activations)
    
    def __getitem__(self, idx):
        return self.activations[idx], self.labels[idx]

class LinearProbe(nn.Module):
    """Simple linear probe for classification"""
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.linear(x)

class LinearProbingExperiment:
    def __init__(self, model_name="gemma", concept="animals"):
        self.model_name = model_name
        self.concept = concept
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def load_activations(self, file_path):
        """Load activations from file (assuming numpy format)"""
        data = np.load(file_path, allow_pickle=True)
        return data['activations'], data['labels']
    
    def create_cat_dog_labels(self, texts):
        """Simple cat/dog labeler based on keywords"""
        labels = []
        for text in texts:
            text_lower = text.lower()
            if any(word in text_lower for word in ['cat', 'feline', 'kitten', 'meow']):
                labels.append(0)  # cat
            elif any(word in text_lower for word in ['dog', 'canine', 'puppy', 'bark', 'woof']):
                labels.append(1)  # dog
            else:
                labels.append(-1)  # neither/unknown
        return np.array(labels)
    
    def extract_gemma_activations(self, model, tokenizer, texts, layer_idx=-1):
        """Extract activations from Gemma model at specified layer"""
        activations = []
        
        with torch.no_grad():
            for text in texts:
                inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                # Forward pass with output_hidden_states=True
                outputs = model(**inputs, output_hidden_states=True)
                hidden_states = outputs.hidden_states[layer_idx]  # Get specified layer
                
                # Use mean pooling across sequence length
                activation = hidden_states.mean(dim=1).cpu().numpy()
                activations.append(activation[0])
        
        return np.array(activations)
    
    def train_sklearn_probe(self, X_train, y_train, X_test, y_test):
        """Train sklearn logistic regression probe"""
        probe = LogisticRegression(max_iter=1000, random_state=SEED)
        probe.fit(X_train, y_train)
        
        train_pred = probe.predict(X_train)
        test_pred = probe.predict(X_test)
        
        results = {
            'train_acc': accuracy_score(y_train, train_pred),
            'test_acc': accuracy_score(y_test, test_pred),
            'classification_report': classification_report(y_test, test_pred)
        }
        
        return probe, results
    
    def train_torch_probe(self, X_train, y_train, X_test, y_test, epochs=100):
        """Train PyTorch linear probe"""
        input_dim = X_train.shape[0] # is this "how many text snppets"?
        num_classes = len(np.unique(y_train))
        
        train_dataset = ActivationDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        
        probe = LinearProbe(input_dim, num_classes).to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(probe.parameters(), lr=0.001)
        
        # Training loop
        for epoch in range(epochs):
            probe.train()
            total_loss = 0
            for batch_x, batch_y in train_loader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                
                optimizer.zero_grad()
                outputs = probe(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            if epoch % 20 == 0:
                print(f"Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")
        
        # Evaluation
        probe.eval()
        with torch.no_grad():
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(self.device)
            test_outputs = probe(X_test_tensor)
            test_pred = torch.argmax(test_outputs, dim=1).cpu().numpy()
            
            X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(self.device)
            train_outputs = probe(X_train_tensor)
            train_pred = torch.argmax(train_outputs, dim=1).cpu().numpy()
        
        results = {
            'train_acc': accuracy_score(y_train, train_pred),
            'test_acc': accuracy_score(y_test, test_pred),
            'classification_report': classification_report(y_test, test_pred)
        }
        
        return probe, results
    
    def run_experiment(self, gemma_activations, gemma_labels, 
                      polygemma_activations, polygemma_labels):
        """Run complete probing experiment"""
        print(f"Running linear probing experiment: {self.concept}")
        print(f"Gemma training data: {gemma_activations.shape}")
        print(f"PolyGemma test data: {polygemma_activations.shape}")
        
        # Train on Gemma, test on PolyGemma
        results = {}
    
        
        # sklearn probe
        print("\n--- Training sklearn probe ---")
        sklearn_probe, sklearn_results = self.train_sklearn_probe(
            gemma_activations, gemma_labels, # train
            polygemma_activations, polygemma_labels # test
        )
        results['sklearn'] = sklearn_results
        
        # PyTorch probe
        print("\n--- Training PyTorch probe ---")
        torch_probe, torch_results = self.train_torch_probe(
            gemma_activations, gemma_labels, # train
            polygemma_activations, polygemma_labels # test
        )
        results['torch'] = torch_results
        
        return results
    
    def save_results(self, results, output_path):
        """Save experiment results"""
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Convert numpy types to Python types for JSON serialization
        def convert_numpy(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            return obj
        
        json_results = json.loads(json.dumps(results, default=convert_numpy))
        
        with open(output_path, 'w') as f:
            json.dump(json_results, f, indent=2)



# synth data

In [20]:
# Synthetic text data for cat/dog classification
text = [
    # Cat examples (label 0)
    "The fluffy cat purred softly on the windowsill, watching birds outside.",
    "My kitten loves to chase the red laser pointer around the living room.",
    "The orange tabby cat stretched lazily in the warm afternoon sunlight.",
    "She adopted a rescue cat from the local animal shelter last week.",
    "The cat's whiskers twitched as it stalked the toy mouse across the floor.",
    "Fluffy meowed loudly when her food bowl was empty this morning.",
    "The black cat gracefully jumped onto the kitchen counter with ease.",
    "My feline friend enjoys napping in cardboard boxes all day long.",
    "The cat's green eyes glowed mysteriously in the dim moonlight tonight.",
    "Her pet cat brings dead mice to the doorstep every morning.",
    "The Siamese cat has the most beautiful blue eyes I've ever seen.",
    "Tom cat climbed up the tall oak tree to escape the neighborhood dogs.",
    "The veterinarian said the kitten needs its vaccinations next month.",
    "My cat purrs so loudly it sounds like a tiny motor running.",
    "The calico cat had three adorable kittens in the barn yesterday.",
    "She trained her cat to use the toilet instead of a litter box.",
    "The Persian cat's long fur requires daily brushing to prevent matting.",
    "My indoor cat watches wildlife documentaries on TV with great interest.",
    "The stray cat finally trusted me enough to eat from my hand.",
    "Her cat knocked over the expensive vase while chasing a butterfly.",
    
    # Dog examples (label 1)
    "The golden retriever barked excitedly when his owner came home today.",
    "My dog loves to fetch tennis balls in the backyard every afternoon.",
    "The small puppy wagged its tail when meeting new people yesterday.",
    "She takes her German shepherd for long walks in the park.",
    "The dog's tail wagged furiously when it saw the treat jar.",
    "Max barked at the mailman who comes by every morning.",
    "The border collie herded the sheep expertly across the green field.",
    "My canine companion loves swimming in the lake during hot summers.",
    "The dog trainer taught the puppy basic commands like sit and stay.",
    "Her loyal dog waited patiently outside the grocery store for her.",
    "The beagle's nose led it straight to the hidden treats upstairs.",
    "My dog howls along with the sirens from passing fire trucks.",
    "The veterinarian recommended a special diet for the overweight bulldog.",
    "The rescue dog was nervous but gradually warmed up to us.",
    "My puppy chewed up my favorite pair of running shoes yesterday.",
    "The dog park was crowded with excited pups playing together today.",
    "Her service dog helps her navigate safely through busy city streets.",
    "The hunting dog pointed steadily at the birds hiding in bushes.",
    "My dog greets every visitor with enthusiastic tail wagging and jumping.",
    "The old dog slept peacefully by the fireplace on cold nights.",
    
    # Neutral/other examples (label 2) - neither cats nor dogs
    "The morning sun cast beautiful shadows across the empty parking lot.",
    "She enjoyed reading mystery novels while drinking her evening tea.",
    "The mathematics professor explained complex equations on the whiteboard clearly.",
    "Fresh vegetables from the farmers market made an excellent dinner tonight.",
    "The old library contained thousands of books on various subjects.",
    "He repaired the broken bicycle tire using tools from the garage.",
    "The weather forecast predicted rain for the entire weekend ahead.",
    "Students gathered in the cafeteria to discuss their upcoming project.",
    "The concert featured amazing performances by local musicians and bands.",
    "She planted colorful flowers in her garden beds this spring.",
    "The computer program crashed unexpectedly during the important presentation today.",
    "Ocean waves crashed against the rocky cliffs during the storm.",
    "The chef prepared an elaborate feast for the wedding celebration.",
    "Mountains covered in snow looked majestic against the clear sky.",
    "The museum displayed artifacts from ancient civilizations throughout history.",
    "Traffic was heavy on the highway during rush hour yesterday.",
    "The smartphone battery died right before the important phone call.",
    "Autumn leaves fell gently from the trees in vibrant colors.",
    "The construction workers finished building the new bridge ahead of schedule.",
    "She studied diligently for her final exams in the quiet library."
]

# Corresponding labels: 0=cat, 1=dog, 2=neutral
label = [
    # Cat labels (0)
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # Dog labels (1) 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    # Neutral labels (2)
    2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2
]

# Verify data consistency
print(f"Total texts: {len(text)}")
print(f"Total labels: {len(label)}")
print(f"Cat examples: {label.count(0)}")
print(f"Dog examples: {label.count(1)}")
print(f"Neutral examples: {label.count(2)}")

Total texts: 60
Total labels: 60
Cat examples: 20
Dog examples: 20
Neutral examples: 20


# load model & get act

In [19]:
def load_models_with_eval(model_name, device="cuda"):
    if "paligemma" in model_name.lower():
        from transformers import PaliGemmaForConditionalGeneration
        model = PaliGemmaForConditionalGeneration.from_pretrained(
            model_name, 
            trust_remote_code=True,
            torch_dtype=torch.float32,  # Use fp16 for memory efficiency
            device_map=None  # We'll handle device placement manually
        )
        model = model.to(device)
        language_model = model.language_model
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name, 
            trust_remote_code=True,
            torch_dtype=torch.float32,
            device_map=None
        )
        model = model.to(device)
        language_model = model
            
    language_model.eval()

def get_acts(model, inputs):
    if hasattr(language_model, 'model') and hasattr(language_model.model, 'layers'):
        if layer < len(language_model.model.layers):
            target_layer = language_model.model.layers[layer]
        else:
            print(f"❌ Layer {layer} out of range. Model has {len(language_model.model.layers)} layers")
            return torch.randn(1, 64, 2304).to(self.device), 0.0
    activations = None


    def activation_hook(module, input, output):
        nonlocal activations
        try:
            if isinstance(output, tuple):
                activations = output[0].clone().detach()
            else:
                activations = output.clone().detach()
        except Exception as e:
            print(f"⚠️  Error in activation hook: {e}")

    # FIXED: More robust layer identification
    target_layer = None
    try:
        if hasattr(language_model, 'model') and hasattr(language_model.model, 'layers'):
            if layer < len(language_model.model.layers):
                target_layer = language_model.model.layers[layer]
            else:
                print(f"❌ Layer {layer} out of range. Model has {len(language_model.model.layers)} layers")
                return torch.randn(1, 64, 2304).to(self.device), 0.0
        elif hasattr(language_model, 'layers'):
            if layer < len(language_model.layers):
                target_layer = language_model.layers[layer]
            else:
                print(f"❌ Layer {layer} out of range. Model has {len(language_model.layers)} layers")
                return torch.randn(1, 64, 2304).to(self.device), 0.0
        else:
            print(f"❌ Could not find layers in model structure")
            return torch.randn(1, 64, 2304).to(self.device), 0.0
    except Exception as e:
        print(f"❌ Error accessing layer {layer}: {e}")
        return torch.randn(1, 64, 2304).to(self.device), 0.0

    if target_layer is None:
        print(f"❌ Could not find layer {layer}")
        return torch.randn(1, 64, 2304).to(self.device), 0.0

    hook = target_layer.register_forward_hook(activation_hook)

    # Forward pass to get activations
    with torch.no_grad():
        try:
            if "paligemma" in model_name.lower():
                _ = language_model(**inputs)
            else:
                _ = language_model(**inputs)
        except Exception as e:
            print(f"⚠️  Error in activation extraction: {e}")

    hook.remove()
    if activations is None:
        print(f"⚠️  Failed to extract activations from layer {layer}")
        # FIXED: Return appropriate tensor size based on model
        try:
            # Try to get the actual hidden size from the model config
            if hasattr(language_model, 'config') and hasattr(language_model.config, 'hidden_size'):
                hidden_size = language_model.config.hidden_size
            else:
                hidden_size = 2304  # fallback
            activations = torch.randn(1, 64, hidden_size).to(self.device)
        except:
            activations = torch.randn(1, 64, 2304).to(self.device)

    return activations

In [22]:
# Initialize experiment
experiment = LinearProbingExperiment(concept="cat_dog_classification")

# Example: Create dummy data for testing
# In practice, replace with actual Gemma/PolyGemma activations
np.random.seed(42)

# load models
model1_name = "google/gemma-2-2b"  # Base Gemma-2-2B (LLM)
model2_name = "google/paligemma2-3b-pt-224"

gemma = load_models_with_eval(model1_name, DEVICE)
paligemma = load_models_with_eval(model2_name, DEVICE)

# # Dummy activations (replace with real data)
# gemma_activations = np.random.randn(1000, 768)  # 1000 samples, 768 dims
# gemma_labels = np.random.randint(0, 2, 1000)    # binary cat/dog labels

# polygemma_activations = np.random.randn(200, 768)  # 200 test samples
# polygemma_labels = np.random.randint(0, 2, 200)

# Run experiment per layer
for layer in layer_to_test:

    # produce act-label pairs
    get_acts(model, text)
    # train probes on model 1 and test on model 2
    results = experiment.run_experiment(
        gemma_activations, gemma_labels,
        polygemma_activations, polygemma_labels
    )

# Print results
print("\n=== RESULTS ===")
for method, result in results.items():
    print(f"\n{method.upper()} Results:")
    print(f"Train Accuracy: {result['train_acc']:.4f}")
    print(f"Test Accuracy: {result['test_acc']:.4f}")
    print("Classification Report:")
    print(result['classification_report'])

# Save results
experiment.save_results(results, "../output/linear_probing_results.json")


Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 65.34it/s]


RuntimeError: No CUDA GPUs are available