In [None]:
#login here to access gemma2 and llama3.1 models
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import torch
from transformer_lens import HookedTransformer
import psutil
import os
import gc

def print_gpu_utilization():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

def print_system_utilization():
    process = psutil.Process(os.getpid())
    print(f"CPU memory used: {process.memory_info().rss / 1024**3:.2f} GB")
    print(f"System memory used: {psutil.virtual_memory().used / 1024**3:.2f} GB")
    print(f"System memory available: {psutil.virtual_memory().available / 1024**3:.2f} GB")

def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

def load_model_with_monitoring():
    print("Initial state:")
    print_gpu_utilization()
    print_system_utilization()
    print("-" * 50)

    clear_memory()
    
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

    try:
        device = "cuda"
        try_dtype = torch.bfloat16

        print("Attempting to load with bfloat16...")
        model = HookedTransformer.from_pretrained(
            "gemma-2-9b",
            device=device,
            torch_dtype=try_dtype,
            low_cpu_mem_usage=False 
        )

        print("\nAfter loading:")
        print_gpu_utilization()
        print_system_utilization()

        return model

    except RuntimeError as e:
        print(f"\n[Warning] bfloat16 failed: {e}")
        print("Retrying with float16...")
        clear_memory()

        model = HookedTransformer.from_pretrained(
            "gemma-2-9b",
            device=device,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=False
        )

        print("\nAfter loading with float16:")
        print_gpu_utilization()
        print_system_utilization()

        return model

if __name__ == "__main__":
    torch.set_grad_enabled(True)
    try:
        model = load_model_with_monitoring()
        print("\nModel loaded successfully!")
    except Exception as e:
        print(f"\nFailed to load model: {e}")

In [2]:
from sae_lens import SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-9b-pt-res-canonical", # e.g., "gpt2-small-res-jb". See other options in https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
    sae_id = "layer_20/width_16k/canonical", # e.g., "blocks.8.hook_resid_pre". Won't always be a hook point
)

In [None]:
from datasets import load_dataset

dataset = load_dataset("Zirui22Ray/politics-dataset-demo")
split = dataset['train'].train_test_split(test_size=0.1, seed=42)
dataset = split['train']
test_dataset = split['test']

print("train set:", len(dataset))
print("test set:", len(test_dataset))

In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from tqdm.auto import tqdm
import os
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
import gc

class LinearConceptExtractor:
    def __init__(self, sae, language_model, target_layer=None, device='cuda'):
        self.sae = sae
        self.language_model = language_model
        self.target_layer = target_layer if target_layer is not None else language_model.cfg.n_layers - 1
        self.device = device
        self.sae.to(device)
        self.d_sae = sae.cfg.d_sae
        
        print(f"Initializing LinearConceptExtractor, using model layer {self.target_layer+1}/{language_model.cfg.n_layers}")
    
    def precompute_latents(self, text_dataset, top_k=0, batch_size=16):
        """Precomputes latent representations for the text dataset."""
        print(f"Precomputing latent representations, total {len(text_dataset['text'])} samples...")
        
        all_latents = []
        all_labels = []
        
        for i in tqdm(range(0, len(text_dataset['text']), batch_size)):
            batch_texts = text_dataset['text'][i:i+batch_size]
            batch_labels = text_dataset['label'][i:i+batch_size]
            
            batch_latents = []
            for text in batch_texts:
                tokens = self.language_model.to_tokens(text)
                with torch.no_grad():
                    logits, cache = self.language_model.run_with_cache(tokens)
                    
                    token_residual = cache['resid_post', self.target_layer][0, -1, :]
                    latent = self.sae.encode(token_residual.unsqueeze(0)).squeeze(0).to(torch.float32)
                    
                    if top_k > 0 and top_k < latent.shape[0]:
                        values, indices = torch.topk(latent.abs(), k=top_k)
                        selected_values = latent[indices]
                        batch_latents.append(selected_values.cpu())
                    else:
                        batch_latents.append(latent.cpu())
            
            all_latents.extend(batch_latents)
            all_labels.extend(batch_labels)
        
        print(f"Precomputation completed, total {len(all_latents)} latent representations")
        return all_latents, all_labels
    
    def select_important_features(self, latents, labels, top_k=4000):
        """Selects the most important features."""
        print(f"Performing feature selection, selecting {top_k} most important features from the original features...")
        
        if isinstance(latents[0], torch.Tensor):
            latents_np = np.array([l.cpu().numpy() for l in latents])
        else:
            latents_np = np.array(latents)
        
        labels_np = np.array(labels)
        
        n_features = latents_np.shape[1]
        feature_scores = np.zeros(n_features)
        
        for i in range(n_features):
            feature = latents_np[:, i]
            groups = [feature[labels_np == j] for j in np.unique(labels_np)]
            
            if any(len(g) == 0 for g in groups):
                feature_scores[i] = 0
                continue
            
            means = [np.mean(g) for g in groups]
            overall_mean = np.mean(feature)
            
            between_group_var = sum([len(g) * (m - overall_mean) ** 2 for g, m in zip(groups, means)]) / (len(groups) - 1)
            
            within_group_var = sum([np.sum((g - m) ** 2) for g, m in zip(groups, means)]) / (len(latents_np) - len(groups))
            
            if within_group_var < 1e-10:
                feature_scores[i] = 0
            else:
                feature_scores[i] = between_group_var / within_group_var
        
        selected_indices = np.argsort(feature_scores)[-top_k:]
        selected_latents = latents_np[:, selected_indices]
        
        print(f"Feature selection completed, reduced from {n_features} to {top_k} features")
        
        return selected_latents, selected_indices
    
    def normalize_features(self, features):
        """Normalizes features."""
        if isinstance(features[0], torch.Tensor):
            features_np = np.array([f.cpu().numpy() for f in features])
        else:
            features_np = np.array(features)
        
        mean = np.mean(features_np, axis=0)
        std = np.std(features_np, axis=0)
        std[std == 0] = 1
        
        normalized = (features_np - mean) / std
        
        return normalized, mean, std
    
    def train_linear_classifier(self, latents, labels, val_size=0.2, batch_size=32, 
                                num_epochs=20, lr=1e-4, weight_decay=5e-2):
        """Trains a linear classifier."""
        print("Training linear classifier...")
        
        train_latents, test_latents, train_labels, test_labels = train_test_split(
            latents, labels, test_size=val_size, random_state=42, stratify=labels
        )
        
        train_latents_tensor = torch.tensor(train_latents, dtype=torch.float32)
        test_latents_tensor = torch.tensor(test_latents, dtype=torch.float32)
        train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
        test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)
        
        train_dataset = TensorDataset(train_latents_tensor, train_labels_tensor)
        test_dataset = TensorDataset(test_latents_tensor, test_labels_tensor)
        
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
        
        input_dim = latents.shape[1]
        linear_classifier = nn.Linear(input_dim, 2).to(self.device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(linear_classifier.parameters(), lr=lr, weight_decay=weight_decay)
        
        for epoch in range(num_epochs):
            linear_classifier.train()
            total_loss = 0
            
            for batch_latents, batch_labels in train_dataloader:
                batch_latents = batch_latents.to(self.device)
                batch_labels = batch_labels.to(self.device)
                
                outputs = linear_classifier(batch_latents)
                loss = criterion(outputs, batch_labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(train_dataloader)
            if (epoch + 1) % 5 == 0 or epoch == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
        
        linear_classifier.eval()
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch_latents, batch_labels in test_dataloader:
                batch_latents = batch_latents.to(self.device)
                outputs = linear_classifier(batch_latents)
                _, predicted = torch.max(outputs, 1)
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(batch_labels.numpy())
        
        test_acc = 100 * accuracy_score(all_labels, all_preds)
        test_f1 = 100 * f1_score(all_labels, all_preds, average='weighted')
        
        print(f"Linear classifier training completed, test accuracy: {test_acc:.2f}%, F1: {test_f1:.2f}%")
        print("\nClassification report:")
        print(classification_report(all_labels, all_preds))
        
        model_info = {
            'input_dim': input_dim,
            'test_accuracy': test_acc,
            'test_f1': test_f1,
            'target_layer': self.target_layer
        }
        
        return linear_classifier, test_acc, test_f1, model_info
    
    def extract_difference_vector(self, classifier):
        """Extracts the difference vector."""
        output_weight = classifier.weight.detach().cpu().numpy()
        difference_vector = output_weight[1] - output_weight[0]
        difference_vector = difference_vector / np.linalg.norm(difference_vector)
        return difference_vector
    
    def extract_concept_vector(self, classifier, class_idx=1):
        """Extracts a concept vector."""
        weights = classifier.weight.detach().cpu().numpy()
        concept_vector = weights[class_idx]
        concept_vector = concept_vector / np.linalg.norm(concept_vector)
        return concept_vector
    
    def train_multiple_classifiers(self, train_dataset, num_classifiers=50, subset_size=0.5, 
                                     num_epochs=10, batch_size=32, lr=1e-4):
        """Trains multiple classifiers to construct a concept subspace."""
        print(f"Training {num_classifiers} linear classifiers to build concept subspace...")
        concept_vectors = []
        subset_size = int(len(train_dataset) * subset_size)
        
        input_dim = train_dataset[0][0].shape[0]
        
        for i in tqdm(range(num_classifiers)):
            indices = torch.randperm(len(train_dataset))[:subset_size]
            subset = Subset(train_dataset, indices)
            
            dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)
            
            classifier = nn.Linear(input_dim, 2).to(self.device)
            
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(classifier.parameters(), lr=lr)
            
            for epoch in range(num_epochs):
                for inputs, labels in dataloader:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    
                    outputs = classifier(inputs)
                    loss = criterion(outputs, labels)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            
            concept_vector = self.extract_concept_vector(classifier, class_idx=1)
            concept_vectors.append(concept_vector)
            
            del classifier, dataloader, subset
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        return np.array(concept_vectors)
    
    def analyze_vector_similarities(self, vectors):
        """Analyzes vector similarities."""
        n_vectors = len(vectors)
        similarity_matrix = np.zeros((n_vectors, n_vectors))
        
        for i in range(n_vectors):
            for j in range(n_vectors):
                similarity = np.dot(vectors[i], vectors[j]) / (np.linalg.norm(vectors[i]) * np.linalg.norm(vectors[j]))
                similarity_matrix[i, j] = similarity
        
        mask = ~np.eye(n_vectors, dtype=bool)
        avg_similarity = similarity_matrix[mask].mean()
        
        return similarity_matrix, avg_similarity
    
    def extract_concept_vectors(self, text_dataset, feature_dim=128, num_classifiers=50, 
                                subset_size=0.5, num_epochs=10, output_dir="concept_vectors"):
        """The complete pipeline for extracting concept vectors."""
        os.makedirs(output_dir, exist_ok=True)
        
        # Step 1: Precompute latent representations
        original_latents, original_labels = self.precompute_latents(
            text_dataset=text_dataset,
            top_k=0, 
            batch_size=16
        )
        
        # Step 2: Feature selection
        latents_np, selected_indices = self.select_important_features(
            original_latents, original_labels, top_k=feature_dim
        )
        
        del original_latents
        gc.collect()
        
        # Step 3: Feature normalization
        normalized_latents, mean, std = self.normalize_features(latents_np)
        
        # Step 4: Train the main linear classifier
        classifier, test_acc, test_f1, model_info = self.train_linear_classifier(
            latents=normalized_latents,
            labels=np.array(original_labels),
            val_size=0.2,
            batch_size=32,
            num_epochs=20,
            lr=1e-4,
            weight_decay=5e-2
        )
        
        # Update model information
        model_info.update({
            'selected_indices': selected_indices.tolist(),
            'feature_mean': mean.tolist(),
            'feature_std': std.tolist(),
            'original_dim': self.d_sae,
            'reduced_dim': feature_dim
        })
        
        # Step 5: Extract basic concept vectors
        truthful_vector = self.extract_concept_vector(classifier, class_idx=1)
        false_vector = self.extract_concept_vector(classifier, class_idx=0)
        difference_vector = self.extract_difference_vector(classifier)
        
        # Save basic vectors
        self._save_basic_vectors(output_dir, truthful_vector, false_vector, difference_vector)
        
        # Step 6: Create a training dataset for the multiple classifiers method
        latents_tensor = torch.tensor(normalized_latents, dtype=torch.float32)
        labels_tensor = torch.tensor(original_labels, dtype=torch.long)
        train_dataset_for_multi = TensorDataset(latents_tensor, labels_tensor)
        
        # Step 7: Train multiple linear classifiers to build the concept subspace
        multiple_vectors = self.train_multiple_classifiers(
            train_dataset=train_dataset_for_multi,
            num_classifiers=num_classifiers,
            subset_size=subset_size,
            num_epochs=num_epochs,
            batch_size=32,
            lr=1e-4
        )
        
        # Step 8: Analyze concept vector similarities
        sim_matrix, avg_sim = self.analyze_vector_similarities(multiple_vectors)
        
        # Step 10: Save all results
        results = {
            'selected_indices': selected_indices,
            'reduced_dim': feature_dim,
            'original_dim': self.d_sae,
            'feature_mean': mean,
            'feature_std': std,
            'truthful_vector': truthful_vector,
            'false_vector': false_vector,
            'difference_vector': difference_vector,
            'multiple_vectors': multiple_vectors,
            'similarity_matrix': sim_matrix,
            'average_similarity': avg_sim
        }
        
        self._save_results(output_dir, results, model_info, classifier)
        
        print(f"Concept vector extraction completed! All results saved to directory {output_dir}")
        return results, classifier

    def _save_basic_vectors(self, output_dir, truthful_vector, false_vector, difference_vector):
        """Saves the basic vectors to files."""
        np.save(os.path.join(output_dir, "truthful_vector.npy"), truthful_vector)
        np.save(os.path.join(output_dir, "false_vector.npy"), false_vector)
        np.save(os.path.join(output_dir, "difference_vector.npy"), difference_vector)
    
    def _save_results(self, output_dir, results, model_info, classifier):
        """Saves the results to files."""
        torch.save(results, os.path.join(output_dir, "concept_vectors_full.pt"))
        torch.save(model_info, os.path.join(output_dir, "model_info.pt"))
        
        model_save_path = os.path.join(output_dir, f"linear_classifier_layer_{self.target_layer}.pt")
        torch.save({
            'model_state_dict': classifier.state_dict(),
            'model_info': model_info
        }, model_save_path)


def analyze_truthfulness_with_concept(text, concept_vector, selected_indices, sae, language_model, 
                                        target_layer=20, normalize=True, mean=None, std=None):
    """Analyzes the similarity of a text's truthfulness with a concept vector."""
    tokens = language_model.to_tokens(text)
    with torch.no_grad():
        logits, cache = language_model.run_with_cache(tokens)
        token_residual = cache['resid_post', target_layer][0, -1, :]
        full_latent = sae.encode(token_residual.unsqueeze(0)).squeeze(0).to(torch.float32).cpu().numpy()
    
    reduced_latent = full_latent[selected_indices]
    
    if normalize and mean is not None and std is not None:
        reduced_latent = (reduced_latent - mean) / std
    
    norm_latent = reduced_latent / np.linalg.norm(reduced_latent)
    norm_concept = concept_vector / np.linalg.norm(concept_vector)
    similarity = np.dot(norm_latent, norm_concept)
    
    return similarity


def evaluate_concept_vector(test_dataset, concept_vector, selected_indices, sae, language_model, 
                            target_layer, normalize=True, mean=None, std=None, output_dir="."):
    """Evaluates the performance of a concept vector on the test set."""
    os.makedirs(output_dir, exist_ok=True)
    
    predictions = []
    true_labels = []
    scores = []
    
    print(f"Evaluating concept vector on {len(test_dataset)} test samples...")
    
    for i in tqdm(range(len(test_dataset))):
        sample = test_dataset[i]
        if hasattr(sample, 'text'):
            text = sample.text
        elif 'text' in sample:
            text = sample['text']
        else:
            for key in sample.keys():
                if 'text' in key.lower():
                    text = sample[key]
                    break
            else:
                raise ValueError("Could not find text field in dataset")
                
        if hasattr(sample, 'label'):
            true_label = sample.label
        elif 'label' in sample:
            true_label = sample['label']
        else:
            for key in sample.keys():
                if 'label' in key.lower():
                    true_label = sample[key]
                    break
            else:
                raise ValueError("Could not find label field in dataset")
        
        score = analyze_truthfulness_with_concept(
            text, concept_vector, selected_indices, sae, language_model, 
            target_layer, normalize, mean, std
        )
        scores.append(score)
        
        predicted_label = 1 if score > 0 else 0
        predictions.append(predicted_label)
        true_labels.append(true_label)
    
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='weighted')
    conf_matrix = confusion_matrix(true_labels, predictions)
    
    print(f"Test set evaluation results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 score: {f1:.4f}")
    print("\nClassification report:")
    print(classification_report(true_labels, predictions))
    
    return accuracy, f1, conf_matrix


def run_extraction_pipeline(train_dataset, test_dataset, sae, language_model, target_layer=20, 
                            feature_dim=128, num_classifiers=50, subset_size=0.5, num_epochs=10, 
                            output_dir="truthfulness_vectors"):
    """Runs the complete extraction pipeline."""
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Starting concept vector extraction from layer {target_layer} for truthfulness detection...")
    
    # Create the extractor
    extractor = LinearConceptExtractor(
        sae=sae,
        language_model=language_model,
        target_layer=target_layer,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )
    
    # Extract concept vectors
    results, classifier = extractor.extract_concept_vectors(
        text_dataset=train_dataset,
        feature_dim=feature_dim,
        num_classifiers=num_classifiers,
        subset_size=subset_size,
        num_epochs=num_epochs,
        output_dir=output_dir
    )
    
    # Get important results
    selected_indices = results['selected_indices']
    difference_vector = results['difference_vector']
    truthful_vector = results['truthful_vector']
    mean = results['feature_mean']
    std = results['feature_std']
    
    # Run example tests
    _run_example_tests(difference_vector, truthful_vector, selected_indices, sae, language_model, 
                       target_layer, mean, std)
    
    # Evaluate on the test set
    print("\nEvaluating on test dataset...")
    accuracy, f1, conf_matrix = evaluate_concept_vector(
        test_dataset, difference_vector, selected_indices, sae, language_model, 
        target_layer, normalize=True, mean=mean, std=std, output_dir=output_dir
    )
    
    # Save evaluation results
    eval_results = {
        'accuracy': accuracy,
        'f1': f1,
        'confusion_matrix': conf_matrix.tolist() 
    }
    torch.save(eval_results, os.path.join(output_dir, "full_evaluation_results.pt"))
    
    print(f"\nConcept vector extraction completed! All results saved to directory {output_dir}")
    
    return results, classifier


def _run_example_tests(difference_vector, truthful_vector, selected_indices, sae, language_model, 
                       target_layer, mean, std):
    """Runs example tests."""
    truthful_example = "The government should provide universal healthcare coverage for all citizens."
    false_example = "Healthcare should remain a private service with minimal government intervention."
    
    # Test using the difference vector
    truth_score = analyze_truthfulness_with_concept(
        truthful_example, difference_vector, selected_indices, sae, language_model, 
        target_layer, normalize=True, mean=mean, std=std
    )
    false_score = analyze_truthfulness_with_concept(
        false_example, difference_vector, selected_indices, sae, language_model, 
        target_layer, normalize=True, mean=mean, std=std
    )
    
    print("\nConcept vector test results (using difference vector):")
    print(f"Truthful statement score: {truth_score:.4f}")
    print(f"False statement score: {false_score:.4f}")
    
    # Test using the truthful vector
    truth_score_t = analyze_truthfulness_with_concept(
        truthful_example, truthful_vector, selected_indices, sae, language_model, 
        target_layer, normalize=True, mean=mean, std=std
    )
    false_score_t = analyze_truthfulness_with_concept(
        false_example, truthful_vector, selected_indices, sae, language_model, 
        target_layer, normalize=True, mean=mean, std=std
    )
    
    print("\nConcept vector test results (using truthful vector):")
    print(f"Truthful statement score: {truth_score_t:.4f}")
    print(f"False statement score: {false_score_t:.4f}")

In [5]:
results, classifier = run_extraction_pipeline(
    train_dataset=dataset,
    test_dataset=test_dataset,
    sae=sae,
    language_model=model,
    target_layer=20,   # Use the same layer as before
    feature_dim=128,   # Dimension used for initial experiments
    output_dir="politics_vectors_gemma9layer20"
)

Starting concept vector extraction from layer 20 for truthfulness detection...
Initializing LinearConceptExtractor, using model layer 21/42
Precomputing latent representations, total 9000 samples...


  0%|          | 0/563 [00:00<?, ?it/s]

Precomputation completed, total 9000 latent representations
Performing feature selection, selecting 128 most important features from the original features...
Feature selection completed, reduced from 16384 to 128 features
Training linear classifier...
Epoch 1/20, Loss: 0.5553
Epoch 5/20, Loss: 0.2105
Epoch 10/20, Loss: 0.1527
Epoch 15/20, Loss: 0.1299
Epoch 20/20, Loss: 0.1171
Linear classifier training completed, test accuracy: 96.50%, F1: 96.50%

Classification report:
              precision    recall  f1-score   support

           0       0.95      0.98      0.97       898
           1       0.98      0.95      0.96       902

    accuracy                           0.96      1800
   macro avg       0.97      0.97      0.96      1800
weighted avg       0.97      0.96      0.96      1800

Training 50 linear classifiers to build concept subspace...


  0%|          | 0/50 [00:00<?, ?it/s]

Concept vector extraction completed! All results saved to directory politics_vectors_gemma9layer20

Concept vector test results (using difference vector):
Truthful statement score: -0.4201
False statement score: 0.1763

Concept vector test results (using truthful vector):
Truthful statement score: -0.3511
False statement score: 0.1146

Evaluating on test dataset...
Evaluating concept vector on 1000 test samples...


  0%|          | 0/1000 [00:00<?, ?it/s]

Test set evaluation results:
Accuracy: 0.9600
F1 score: 0.9600

Classification report:
              precision    recall  f1-score   support

           0       0.94      0.98      0.96       511
           1       0.98      0.94      0.96       489

    accuracy                           0.96      1000
   macro avg       0.96      0.96      0.96      1000
weighted avg       0.96      0.96      0.96      1000


Concept vector extraction completed! All results saved to directory politics_vectors_gemma9layer20


In [6]:
# 1. Use the existing dataset
print("Using the existing politics dataset...")
# Assume 'dataset' and 'test_dataset' are already loaded and organized.
# Organize training set data
train_truthful_statements = [dataset[i]['text'] for i in range(len(dataset)) if dataset[i]['label'] == 1]
train_false_statements = [dataset[i]['text'] for i in range(len(dataset)) if dataset[i]['label'] == 0]
print(f"Training set: {len(train_truthful_statements)} truthful statements, {len(train_false_statements)} false statements")

# Organize test set data
test_truthful_statements = [test_dataset[i]['text'] for i in range(len(test_dataset)) if test_dataset[i]['label'] == 1]
test_false_statements = [test_dataset[i]['text'] for i in range(len(test_dataset)) if test_dataset[i]['label'] == 0]
print(f"Test set: {len(test_truthful_statements)} truthful statements, {len(test_false_statements)} false statements")

# 2. Load previously trained results
result_dir = "politics_vectors_gemma9layer20"
print(f"Loading experiment results from {result_dir}...")

# Attempt to directly load concept_vectors_full.pt
vectors_path = f"{result_dir}/concept_vectors_full.pt"
try:
    results = torch.load(vectors_path)
    print(f"Successfully loaded results file: {vectors_path}")
    
    # Check if 'difference_vector' already exists
    if 'difference_vector' in results:
        difference_vector = results['difference_vector']
        print("Loading 'difference_vector' directly from results.")
    else:
        print("'difference_vector' not in results, attempting to load from classifier...")
        
        # Attempt to load the classifier
        classifier_path = f"{result_dir}/linear_classifier_layer_20.pt"
        classifier_data = torch.load(classifier_path)
        
        # Rebuild the classifier and extract weights
        feature_dim = results.get('reduced_dim', 128)
        classifier = torch.nn.Linear(feature_dim, 2)
        classifier.load_state_dict(classifier_data['model_state_dict'])
        
        # Extract the difference vector
        weights = classifier.weight.detach().numpy()
        difference_vector = weights[1] - weights[0]  # truthful - false
        difference_vector = difference_vector / np.linalg.norm(difference_vector)
        
    # Ensure 'selected_indices' exists
    if 'selected_indices' in results:
        selected_indices = results['selected_indices']
    else:
        print("Error: 'selected_indices' missing from results, cannot map to original SAE space.")
        exit(1)
        
except Exception as e:
    print(f"Failed to load results file: {e}")
    exit(1)

# 3. Extract and map key truthfulness dimensions to the original SAE space
print("Extracting key truthfulness dimensions and mapping to the original SAE space...")

# Get the top 30 most important dimensions based on the difference vector's magnitude
importance = np.abs(difference_vector)
top_reduced_indices = np.argsort(importance)[-30:][::-1]  # Get the top 30 most important dimensions
top_original_indices = np.array(selected_indices)[top_reduced_indices]

# Print the important dimensions in the original SAE space
print(f"\n{len(top_reduced_indices)} key dimensions in the original SAE space:")
for i, idx in enumerate(top_original_indices):
    print(f"Dimension {i+1}: SAE index {idx}")

# 4. Save key dimension information
save_dir = result_dir
os.makedirs(save_dir, exist_ok=True)

# Save only the indices from the original SAE space
np.save(f"{save_dir}/political_important_dimensions.npy", top_original_indices)
print(f"\nSaved {len(top_original_indices)} key dimensions from the original SAE space to {save_dir}/political_important_dimensions.npy")

# 5. Prepare training and test set data
train_texts = train_truthful_statements + train_false_statements
train_labels = [1] * len(train_truthful_statements) + [0] * len(train_false_statements)
test_texts = test_truthful_statements + test_false_statements
test_labels = [1] * len(test_truthful_statements) + [0] * len(test_false_statements)

# 6. Save the dataset
np.save(f"{save_dir}/train_texts.npy", np.array(train_texts, dtype=object))
np.save(f"{save_dir}/train_labels.npy", np.array(train_labels))
np.save(f"{save_dir}/test_texts.npy", np.array(test_texts, dtype=object))
np.save(f"{save_dir}/test_labels.npy", np.array(test_labels))
print(f"\nData preprocessing complete! All files have been saved to the {save_dir} directory.")

# 7. Print sample examples
print("\nTraining set sample examples:")
for i in range(min(3, len(train_texts))):
    label = "Truthful" if train_labels[i] == 1 else "False"
    print(f"[{label}] {train_texts[i][:100]}...")

# 8. Reference for subsequent analysis and training
print("\nReference Information:")
print(f"- Used layer: 20")
print(f"- Optimal subspace dimensions in 128D space: 30")
print(f"- Number of key dimensions mapped to original SAE space: 30")
print("\nTo use these key dimensions for analysis later, you can load them as follows:")
print(f"important_dimensions = np.load('{save_dir}/political_important_dimensions.npy')")

Using the existing politics dataset...
Training set: 4511 truthful statements, 4489 false statements
Test set: 489 truthful statements, 511 false statements
Loading experiment results from politics_vectors_gemma9layer20...
Successfully loaded results file: politics_vectors_gemma9layer20/concept_vectors_full.pt
Loading 'difference_vector' directly from results.
Extracting key truthfulness dimensions and mapping to the original SAE space...

30 key dimensions in the original SAE space:
Dimension 1: SAE index 13422
Dimension 2: SAE index 4052
Dimension 3: SAE index 9849
Dimension 4: SAE index 8695
Dimension 5: SAE index 6045
Dimension 6: SAE index 13182
Dimension 7: SAE index 12184
Dimension 8: SAE index 12990
Dimension 9: SAE index 12662
Dimension 10: SAE index 5554
Dimension 11: SAE index 15532
Dimension 12: SAE index 5167
Dimension 13: SAE index 1831
Dimension 14: SAE index 16044
Dimension 15: SAE index 4508
Dimension 16: SAE index 9215
Dimension 17: SAE index 1942
Dimension 18: SAE in

In [7]:
def train_falseness_ssv_with_hooks_and_lm(truthful_statements, false_statements, important_dims, model, sae, 
                                          layer=20, lambda_dist=1.0, lambda_reg=0.01, lambda_lm=0.5,
                                          learning_rate=0.01, max_iterations=200, batch_size=8,
                                          skip_normalization=True):
    """
    Improved SSV training function with added debugging information and optimizations.

    Args:
        truthful_statements (list): A list of truthful statements.
        false_statements (list): A list of false statements.
        important_dims (list): The important feature dimensions.
        model: The language model.
        sae: The sparse autoencoder.
        layer (int): The layer to use, default is 20.
        lambda_dist (float): The weight for the distance loss.
        lambda_reg (float): The weight for the regularization loss.
        lambda_lm (float): The weight for the language model loss.
        learning_rate (float): The learning rate for optimization.
        max_iterations (int): The maximum number of training iterations.
        batch_size (int): The batch size for training.
        skip_normalization (bool): Whether to skip the final normalization of the SSV.
        
    Returns:
        tuple: A tuple containing:
            - ssv (np.ndarray): The trained SSV.
            - unnormalized_ssv (np.ndarray): The unnormalized SSV.
            - initial_ssv (np.ndarray): The initial SSV before training.
            - losses (dict): A dictionary of losses recorded during training.
    """
    import torch
    import numpy as np
    from tqdm import tqdm
    import copy
    import gc

    print(f"Starting SSV training with {len(truthful_statements)} truthful and {len(false_statements)} false statements.")
    print(f"Using {len(important_dims)} important dimensions.")

    # Create a deep copy of the SAE and convert all parameters to float32 and move to CPU
    sae_float32 = copy.deepcopy(sae).cpu()
    
    # Convert all parameters to float32
    for param in sae_float32.parameters():
        param.data = param.data.to(torch.float32)
    
    sae_float32.eval()
    print("Created a float32 version of the SAE to avoid type issues.")
    
    # Create a mask for the important dimensions
    mask = np.zeros(sae_float32.cfg.d_sae, dtype=bool)
    mask[important_dims] = True
    
    # Initialize SSV
    ssv = np.zeros(sae_float32.cfg.d_sae)
    
    # Record losses
    losses = {'total': [], 'distance': [], 'lm': [], 'reg': []}
    
    # Determine the correct hook name
    hook_name = f"blocks.{layer}.hook_resid_post"
    
    # Calculate truthful and false centroids
    print("Calculating centroids for truthful and false statements...")
    truthful_latents = []
    false_latents = []
    
    # Process truthful statements
    for statement in tqdm(truthful_statements, desc="Processing truthful statements"):
        tokens = model.to_tokens(statement)
        activation = None
        # Define a hook to get the activation
        def get_activation(act, hook):
            nonlocal activation
            # Get the activation of the last token
            activation = act[0, -1, :].detach().clone()
            return act
        with torch.no_grad():
            try:
                # Run the model and collect activations
                model.run_with_hooks(tokens, fwd_hooks=[(hook_name, get_activation)])
                if activation is not None:
                    # Move activation to CPU and convert to float32
                    activation_cpu = activation.cpu().float()
                    # Encode using the float32 version of the SAE
                    latent = sae_float32.encode(activation_cpu.unsqueeze(0)).squeeze(0).cpu().numpy()
                    truthful_latents.append(latent)
            except Exception as e:
                print(f"Error processing a truthful statement: {e}")
    
    # Process false statements
    for statement in tqdm(false_statements, desc="Processing false statements"):
        tokens = model.to_tokens(statement)
        activation = None
        # Define a hook to get the activation
        def get_activation(act, hook):
            nonlocal activation
            # Get the activation of the last token
            activation = act[0, -1, :].detach().clone()
            return act
        with torch.no_grad():
            try:
                # Run the model and collect activations
                model.run_with_hooks(tokens, fwd_hooks=[(hook_name, get_activation)])
                if activation is not None:
                    # Move activation to CPU and convert to float32
                    activation_cpu = activation.cpu().float()
                    # Encode using the float32 version of the SAE
                    latent = sae_float32.encode(activation_cpu.unsqueeze(0)).squeeze(0).cpu().numpy()
                    false_latents.append(latent)
            except Exception as e:
                print(f"Error processing a false statement: {e}")

    # If no statements were processed successfully, return an error
    if len(truthful_latents) == 0 or len(false_latents) == 0:
        print("Could not process enough statements, please check for errors.")
        return None, None, None, None
    
    truthful_centroid = np.mean(np.array(truthful_latents), axis=0)
    false_centroid = np.mean(np.array(false_latents), axis=0)
    
    # Calculate the initial falseness direction (false_centroid - truthful_centroid)
    initial_falseness_direction = false_centroid - truthful_centroid
    # Initialize SSV as the unit vector of this direction (only keeping important dimensions)
    initial_direction_norm = np.linalg.norm(initial_falseness_direction[mask])
    if initial_direction_norm > 0:
        ssv[mask] = initial_falseness_direction[mask] / initial_direction_norm
    
    # Save the initial SSV
    initial_ssv = ssv.copy()
    print(f"Initial falseness direction norm: {initial_direction_norm:.4f}")
    
    # Optimization loop
    print("Starting SSV optimization...")
    for iteration in range(max_iterations):
        # Sample a batch
        truthful_batch_indices = np.random.choice(len(truthful_statements), batch_size, replace=True)
        false_batch_indices = np.random.choice(len(false_statements), batch_size, replace=True)
        
        truthful_batch = [truthful_statements[i] for i in truthful_batch_indices]
        false_batch = [false_statements[i] for i in false_batch_indices]
        
        # Initialize losses and gradients
        distance_loss = 0
        lm_loss = 0
        distance_grad = np.zeros_like(ssv)
        lm_grad = np.zeros_like(ssv)
        # Number of samples processed in this batch
        processed_samples = 0
        
        for i in range(batch_size):
            try:
                # Get tokens for the truthful and corresponding false statements
                truthful_tokens = model.to_tokens(truthful_batch[i])
                false_tokens = model.to_tokens(false_batch[i])
                
                activation = None
                # Define a hook to get the activation
                def get_activation(act, hook):
                    nonlocal activation
                    # Get the activation of the last token
                    activation = act[0, -1, :].detach().clone()
                    return act
                
                with torch.no_grad():
                    # Run the model to get the original activation
                    model.run_with_hooks(truthful_tokens, fwd_hooks=[(hook_name, get_activation)])
                    if activation is None: continue
                    
                    # Encode to get the latent representation
                    truthful_latent = sae_float32.encode(activation.cpu().float().unsqueeze(0)).squeeze(0).cpu().numpy()
                    
                    # Apply SSV - steer the truthful sample towards the false direction
                    steered_latent = truthful_latent + ssv
                    
                    # Calculate distances to the truthful and false centroids respectively
                    # We want to move away from the truthful centroid and closer to the false one
                    distance = np.sum((steered_latent - false_centroid)**2) - 0.5 * np.sum((steered_latent - truthful_centroid)**2)
                    distance_loss += distance / batch_size
                    
                    # Calculate gradient: away from truthful, towards false
                    distance_grad += (2 * (steered_latent - false_centroid) - 1.0 * (steered_latent - truthful_centroid)) / batch_size

                    # Calculate language modeling loss
                    try:
                        # Convert the modified latent representation back to a tensor
                        steered_latent_tensor = torch.tensor(steered_latent, dtype=torch.float32)
                        # Decode back to the activation space
                        steered_act = sae_float32.decode(steered_latent_tensor.unsqueeze(0)).squeeze(0)
                        # Move the result back to the original device and dtype
                        steered_act = steered_act.to(activation.device, activation.dtype)
                        
                        # Define a hook to modify the activation
                        def modify_activation(act, hook):
                            # Only modify the activation of the last token
                            act[0, -1, :] = steered_act
                            return act
                        
                        # Run the model with the modified activation
                        modified_output = model.run_with_hooks(truthful_tokens, fwd_hooks=[(hook_name, modify_activation)])
                        
                        # Calculate the language modeling loss against the false statement
                        batch_lm_loss = 0
                        token_count = 0
                        # Calculate from the second token onwards
                        for t in range(1, min(false_tokens.size(1), 20)):
                            if t < modified_output.size(1):
                                # Get the predicted log probability for the token at position t from the output at t-1
                                token_logits = modified_output[0, t-1, :]
                                token_log_probs = torch.log_softmax(token_logits, dim=0)
                                # Get the token at position t from the false statement
                                target_token_id = false_tokens[0, t].item()
                                # Check if the token is in the vocabulary
                                if target_token_id < token_log_probs.size(0):
                                    # Calculate negative log-likelihood
                                    batch_lm_loss += -token_log_probs[target_token_id].item()
                                    token_count += 1
                        
                        # Average over the tokens
                        if token_count > 0:
                            batch_lm_loss /= token_count
                            lm_loss += batch_lm_loss / batch_size
                            
                            # Calculate the gradient of LM loss w.r.t. SSV using numerical estimation
                            # WARNING: This is computationally very expensive.
                            epsilon = 1e-4
                            for dim in important_dims:
                                # Create a perturbed SSV
                                perturbed_ssv = ssv.copy(); perturbed_ssv[dim] += epsilon
                                # Apply the perturbed SSV
                                perturbed_latent = truthful_latent + perturbed_ssv
                                # Decode back to activation space
                                perturbed_act = sae_float32.decode(torch.tensor(perturbed_latent, dtype=torch.float32).unsqueeze(0)).squeeze(0)
                                perturbed_act = perturbed_act.to(activation.device, activation.dtype)
                                
                                # Define a hook for the perturbed activation
                                def perturbed_hook(act, hook):
                                    act[0, -1, :] = perturbed_act
                                    return act
                                
                                # Run the model with the perturbed activation
                                perturbed_output = model.run_with_hooks(truthful_tokens, fwd_hooks=[(hook_name, perturbed_hook)])
                                
                                # Recalculate the LM loss with the perturbed activation
                                perturbed_lm_loss = 0
                                p_token_count = 0
                                for t in range(1, min(false_tokens.size(1), 20)):
                                    if t < perturbed_output.size(1):
                                        p_token_logits = perturbed_output[0, t-1, :]
                                        p_token_log_probs = torch.log_softmax(p_token_logits, dim=0)
                                        target_token_id = false_tokens[0, t].item()
                                        if target_token_id < p_token_log_probs.size(0):
                                            perturbed_lm_loss += -p_token_log_probs[target_token_id].item()
                                            p_token_count += 1

                                # Calculate the gradient for this dimension
                                if p_token_count > 0:
                                    perturbed_lm_loss /= p_token_count
                                    grad_component = (perturbed_lm_loss - batch_lm_loss) / epsilon
                                    lm_grad[dim] += grad_component / batch_size
                        
                        # Mark this sample as successfully processed for LM loss
                        processed_samples += 1

                    except Exception as e:
                        print(f"Error during language modeling loss calculation: {e}")
            
            except Exception as e:
                print(f"Error processing a batch sample: {e}")

        # Regularization
        reg_loss = lambda_reg * np.sum(np.abs(ssv[mask]))
        reg_grad = np.zeros_like(ssv)
        reg_grad[mask] = lambda_reg * np.sign(ssv[mask])
        
        # Total loss - if processed_samples is 0, do not consider LM loss
        if processed_samples > 0:
            total_loss = lambda_dist * distance_loss + lambda_lm * lm_loss + reg_loss
            # Update SSV
            ssv -= learning_rate * (lambda_dist * distance_grad + lambda_lm * lm_grad + reg_grad)
        else:
            total_loss = lambda_dist * distance_loss + reg_loss
            # Only use the distance gradient
            ssv -= learning_rate * (lambda_dist * distance_grad + reg_grad)
        
        # Zero out non-important dimensions
        ssv[~mask] = 0
        
        # Record losses
        losses['total'].append(total_loss)
        losses['distance'].append(distance_loss)
        losses['lm'].append(lm_loss)
        losses['reg'].append(reg_loss)
        
        # Print progress
        if (iteration + 1) % 10 == 0 or iteration == 0:
            print(f"Iteration {iteration+1}/{max_iterations}, Total Loss: {total_loss:.4f}")

    # Save the unnormalized SSV
    unnormalized_ssv = ssv.copy()
    
    # Normalize SSV (if required)
    if not skip_normalization:
        ssv_norm = np.linalg.norm(ssv)
        if ssv_norm > 0:
            ssv = ssv / ssv_norm
            print(f"Normalized SSV, original norm: {ssv_norm:.4f}")
    else:
        print(f"Skipping normalization, keeping original SSV norm: {np.linalg.norm(ssv):.4f}")
    
    # Release memory
    del sae_float32
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return ssv, unnormalized_ssv, initial_ssv, losses

In [8]:
def test_falseness_ssv_with_hooks(ssv, test_statements, model, sae, layer=20, 
                                  scale_factors=[1.0, 5.0, 10.0], 
                                  max_new_tokens=50):
    """
    Tests the effect of a falseness SSV in the SAE latent space using hooks.
    
    Args:
        ssv (np.ndarray): The trained falseness direction SSV.
        test_statements (list): A list of statements to test on.
        model: The language model.
        sae: The sparse autoencoder.
        layer (int): The model layer to use (default is 20).
        scale_factors (list): A list of scale factors to test.
        max_new_tokens (int): The maximum number of new tokens to generate.
    """
    import copy
    import torch
    import numpy as np
    
    # Create a float32 copy of the SAE on the CPU to avoid type issues
    sae_float32 = copy.deepcopy(sae).cpu()
    for param in sae_float32.parameters():
        param.data = param.data.to(torch.float32)
    sae_float32.eval()
    print("Created a float32 version of the SAE to avoid type issues.")
    
    # Print SSV information
    print(f"SSV Norm: {np.linalg.norm(ssv):.4f}")
    print(f"SSV Max: {np.max(ssv):.4f}, Min: {np.min(ssv):.4f}")
    print(f"SSV Non-zero elements: {np.count_nonzero(ssv)}")
    
    # Get indices of the 20 largest absolute value elements
    top_indices = np.argsort(np.abs(ssv))[-20:][::-1]
    print(f"Top 20 absolute value indices: {top_indices}")
    print(f"Corresponding values: {ssv[top_indices]}")
    
    # Determine the correct hook name
    hook_name = f"blocks.{layer}.hook_resid_post"
    
    results = {scale: [] for scale in scale_factors}
    results['baseline'] = []  # Add baseline results
    
    print(f"Testing falseness SSV on {len(test_statements)} statements with scale factors: {scale_factors}")
    
    for i, statement in enumerate(test_statements):
        print(f"\n===== Testing Statement {i+1}/{len(test_statements)} =====")
        print(f"Original Input: {statement}")
        
        tokens = model.to_tokens(statement)
        
        # Baseline generation - no SSV
        try:
            with torch.no_grad():
                print("\nBaseline Generation (No SSV):")
                baseline_tokens = tokens.clone()
                
                for _ in range(max_new_tokens):
                    logits = model(baseline_tokens)
                    next_token_logits = logits[0, -1, :]
                    probs = torch.softmax(next_token_logits / 0.7, dim=0)  # temperature=0.7
                    next_token = torch.multinomial(probs, num_samples=1)
                    baseline_tokens = torch.cat([baseline_tokens, next_token.unsqueeze(0)], dim=1)
                
                baseline_text = model.to_string(baseline_tokens)
                print(baseline_text)
                
                results['baseline'].append({
                    'original_input': statement,
                    'generated': baseline_text
                })
        except Exception as e:
            print(f"Error during baseline generation: {e}")
            results['baseline'].append({
                'original_input': statement,
                'generated': f"Error: {e}"
            })
        
        # Test for each scale factor
        for scale in scale_factors:
            try:
                # Define the hook function to modify activations
                def modify_activation(act, hook):
                    last_token_act = act[0, -1, :].clone()
                    act_cpu = last_token_act.cpu().float()
                    
                    # Encode, apply SSV, and decode
                    latent = sae_float32.encode(act_cpu.unsqueeze(0)).squeeze(0).cpu().numpy()
                    steered_latent = latent + ssv * scale
                    steered_latent_tensor = torch.tensor(steered_latent, dtype=torch.float32)
                    steered_act = sae_float32.decode(steered_latent_tensor.unsqueeze(0)).squeeze(0)
                    
                    # Move back to the original device and dtype
                    steered_act = steered_act.to(last_token_act.device).to(last_token_act.dtype)
                    
                    # Only modify the last token's activation
                    act[0, -1, :] = steered_act
                    return act
                
                # Generate output using the SSV-modified activation
                with torch.no_grad():
                    print(f"\nGenerating with scale factor {scale}:")
                    
                    current_tokens = tokens.clone()
                    
                    for _ in range(max_new_tokens):
                        logits = model.run_with_hooks(current_tokens, fwd_hooks=[(hook_name, modify_activation)])
                        next_token_logits = logits[0, -1, :]
                        probs = torch.softmax(next_token_logits / 0.7, dim=0)
                        next_token = torch.multinomial(probs, num_samples=1)
                        current_tokens = torch.cat([current_tokens, next_token.unsqueeze(0)], dim=1)
                    
                    generated_text = model.to_string(current_tokens)
                    print(generated_text)
                    
                    results[scale].append({
                        'original_input': statement,
                        'generated': generated_text
                    })
            
            except Exception as e:
                print(f"Error during generation with scale factor {scale}: {e}")
                results[scale].append({
                    'original_input': statement,
                    'generated': f"Error: {e}"
                })
    
    # Release memory
    del sae_float32
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    
    return results

In [9]:
def main_falseness_training(truthful_statements, false_statements, important_dims, model, sae, output_dir="falseness_ssv_results"):
    """
    Main function for training and testing the falseness SSV.
    
    Args:
        truthful_statements (list): A list of truthful statements.
        false_statements (list): A list of false statements.
        important_dims (list): The important feature dimensions.
        model: The language model.
        sae: The sparse autoencoder.
        output_dir (str): The directory to save outputs.
    """
    import os
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    from datetime import datetime
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Record start time
    start_time = datetime.now()
    print(f"Starting falseness SSV training at: {start_time}")
    
    # 1. Train the SSV
    ssv, unnormalized_ssv, _, losses = train_falseness_ssv_with_hooks_and_lm(
        truthful_statements=truthful_statements,
        false_statements=false_statements,
        important_dims=important_dims,
        model=model,
        sae=sae,
        layer=20,
        lambda_dist=1.0,      # Distance loss weight
        lambda_reg=0.01,      # Regularization weight
        learning_rate=0.01,   # Learning rate
        max_iterations=100,   # Number of iterations
        batch_size=32,        # Batch size
        skip_normalization=True
    )
    
    # Check if training was successful
    if ssv is None:
        print("Training failed, cannot continue.")
        return
    
    # Record training time
    training_time = datetime.now() - start_time
    print(f"Training completed in: {training_time}")
    
    # Save the results
    torch.save({
        'ssv': ssv,
        'unnormalized_ssv': unnormalized_ssv,
        'losses': losses,
        'important_dims': important_dims,
        'training_time': str(training_time)
    }, os.path.join(output_dir, "falseness_ssv_results.pt"))
    
    # 2. Test the trained SSV
    print("\nStarting to test the trained falseness SSV...")
    
    # Select some truthful statements from the test set for generation
    test_statements = truthful_statements[-5:]  # Reduce the number of test samples
    
    # Test the optimized SSV with different scale factors
    test_results = test_falseness_ssv_with_hooks(
        ssv=ssv,
        test_statements=test_statements,
        model=model,
        sae=sae,
        layer=20,
        scale_factors=[-7.0, 7.0],
        max_new_tokens=120
    )
    
    # Save the test results
    torch.save(test_results, os.path.join(output_dir, "test_results.pt"))
    
    print(f"\nAll training and testing is complete. Results have been saved to the {output_dir} directory.")
    print("The optimized SSV has been saved and tested.")
    
    return ssv, losses, test_results

In [None]:
# Load the important feature dimensions identified previously
important_dims = np.load('politics_vectors_gemma9layer20/political_important_dimensions.npy')

# Extract right-leaning (label 1) and left-leaning (label 0) statements
right_train = [item['text'] for item in dataset if item['label'] == 1]
left_train = [item['text'] for item in dataset if item['label'] == 0]

print(f"Prepared {len(right_train)} right-leaning and {len(left_train)} left-leaning statements for training.")

# Call the main training function.
# Although the function parameters are named 'truthful_statements' and 'false_statements',
# we map our political concepts to them to find the directional vector.
# Here, we treat right-leaning as the positive pole and left-leaning as the negative pole.
ssv, losses, test_results = main_falseness_training(
    truthful_statements=right_train,
    false_statements=left_train,
    important_dims=important_dims,
    model=model,
    sae=sae,
    output_dir="politics_vectors_gemma9layer20"
)

# You can now work with the results:
print("\nMain training and testing process complete.")
# This SSV represents the direction from "left-leaning" towards "right-leaning".
print(f"Final SSV norm (pointing towards 'right-leaning'): {np.linalg.norm(ssv)}")

In [None]:
# Load the trained SSV results
results = torch.load('politics_vectors_gemma9layer20/falseness_ssv_results.pt')

# Extract the different components
# Note: 'ssv' and 'unnormalized_ssv' are identical if skip_normalization=True was used during training.
ssv = results['ssv']
unnormalized_ssv = results['unnormalized_ssv']

# Extract left-leaning statements from the test set to use as input prompts
# (Assuming label 0 corresponds to left-leaning)
left_text = [item['text'] for item in test_dataset if item['label'] == 0]

# Select a number of samples to test
sample_count = 200  # Set the number of samples to test
test_samples = left_text[:sample_count]  # Select the first N statements
print(f"Selected {len(test_samples)} left-leaning statements to test the steering vector.")

# Test the effect of the SSV using different scale factors
# A positive scale factor should steer the left-leaning prompts toward the 'right-leaning' concept.
test_results = test_falseness_ssv_with_hooks(
    ssv=ssv,
    test_statements=test_samples,
    model=model,
    sae=sae,
    layer=20,  
    scale_factors=[-6.0],  
    max_new_tokens=120  # Number of new tokens to generate
)

# Save the test results
torch.save(test_results, "politics_vectors_gemma9layer20/political_test_results.pt")
print("Test results have been saved successfully.")


In [10]:
test_results = torch.load('politics_vectors_gemma9layer20/political_test_results.pt')

In [None]:
from datasets import Dataset, DatasetDict

# 1. Prepare a list to store the converted data records
processed_data_list = []

# 2. Iterate through the original dictionary
for source_key, records_list in test_results.items():
    for record in records_list:
        # For each record, create a new dictionary containing all the necessary information
        processed_record = {
            'source_key': str(source_key), # MODIFICATION HERE: Convert source_key to string
            'original_input': record['original_input'],
            'generated': record['generated']
        }
        processed_data_list.append(processed_record)

# 3. Create a Dataset object from the processed list
if processed_data_list: # Ensure the list is not empty
    hf_dataset = Dataset.from_list(processed_data_list)

    # (Optional) If you want to put it into a DatasetDict (e.g., as a 'train' set)
    # hf_dataset_dict = DatasetDict({'train': hf_dataset})

    # Print dataset information and some samples for verification
    print("Hugging Face Dataset Information:")
    print(hf_dataset)
    print("\nDataset Features:")
    print(hf_dataset.features)
    print("\nFirst 3 samples of the dataset (or all samples if fewer than 3):")
    for i in range(min(3, len(hf_dataset))): # Print up to 3 samples
        print(hf_dataset[i])
else:
    print("No data was processed, cannot create Dataset. Please check if the 'results' dictionary is empty or has an incorrect structure.")

In [None]:
target_source_keys = ["-6.0", "baseline"]
print(f"Target source_key values for splitting: {target_source_keys}")

# 2. Create a dictionary to store the datasets split by the specified keys
specific_split_datasets = {}

# 3. Filter and store the data for each target source_key value
for target_key in target_source_keys:
    # Use the filter method to select data
    # The lambda function checks if the 'source_key' of each record matches the current target_key
    filtered_ds = hf_dataset.filter(lambda example: example['source_key'] == target_key)
    
    # Store the filtered dataset in the dictionary, using the target_key as the key
    specific_split_datasets[target_key] = filtered_ds
    
    if len(filtered_ds) > 0:
        print(f"Dataset created for source_key '{target_key}' contains {len(filtered_ds)} rows.")
    else:
        # If a target_key does not exist in the original data, the corresponding filtered_ds will be an empty dataset
        print(f"Note: No records found for source_key '{target_key}' in the original data. An empty dataset has been created for it.")

# Splitting complete! The specific_split_datasets dictionary now contains only the datasets for your specified keys.
# For example, you can access the dataset for source_key '-6.0' via specific_split_datasets['-6.0'].

# Verify the split results:
print(f"\nSuccessfully split into {len(specific_split_datasets)} datasets based on the specified keys.")
for key, dataset_split in specific_split_datasets.items():
    print(f"\nDataset (source_key='{key}'):")
    # Print information for each subset
    print(dataset_split) 
    if len(dataset_split) > 0:
        print("First sample of this subset:")
        # Print the first sample of each subset
        print(dataset_split[0]) 
    else:
        print("This subset is empty.")

In [17]:
def evaluate_political_shift(steered_dataset, baseline_dataset, eval_api_key=None, eval_model="gpt-4o-mini", 
                               monitor_samples=200, format_conversion_fn=None, max_samples=None):
    """
    Evaluates the effectiveness of a political shift on generated datasets, assessing the shift from a left-leaning to a right-leaning stance.
    
    Args:
    - steered_dataset: The steered generated dataset.
    - baseline_dataset: The baseline generated dataset.
    - eval_api_key: OpenAI API key for GPT-4o-mini evaluation (default: None).
    - eval_model: The name of the model to use for evaluation (default: "gpt-4o-mini").
    - monitor_samples: The number of samples for which to output detailed evaluation results (default: 200).
    - format_conversion_fn: An optional function to convert your dataset format to the required format.
    - max_samples: The maximum number of samples to evaluate. If None, evaluates all samples.
    
    Returns:
    - evaluation_results: A list of dictionaries containing the evaluation results.
    """
    import time
    import random
    import json
    import re
    
    # Check if an API key was provided
    if eval_api_key is None:
        print("Warning: OpenAI API key not provided. Skipping evaluation step.")
        return []
    
    try:
        from openai import OpenAI  # New API import method
        client = OpenAI(api_key=eval_api_key)  # Create the client
        print(f"OpenAI API configured. Using {eval_model} for evaluation.")
    except ImportError:
        print("Warning: 'openai' library not installed. Skipping evaluation step. Please install: pip install openai>=1.0.0")
        return []
    
    # Function to strip the duplicated original input from the generated text
    def strip_original_input(generated_text, original_input):
        """Strips the duplicated original input from the generated text."""
        if not generated_text or not original_input:
            return generated_text or ''
        
        # Clean up whitespace
        clean_original = original_input.strip()
        clean_generated = generated_text.strip()
        
        # Direct full match at the start
        if clean_generated.startswith(clean_original):
            return clean_generated[len(clean_original):].strip()
        
        # Try to find the entire original input (even in the middle of the generated text)
        if clean_original in clean_generated:
            # Find the last occurrence of the original input to handle cases where it's mentioned in the middle
            last_index = clean_generated.find(clean_original) + len(clean_original)
            return clean_generated[last_index:].strip()
        
        # Try to find the last sentence of the original input
        sentences = re.split(r'[.!?]+', clean_original)
        sentences = [s.strip() for s in sentences if s.strip()]
        
        if sentences:
            last_sentence = sentences[-1]
            if len(last_sentence) > 10 and last_sentence in clean_generated:  # If the last sentence is long enough
                last_index = clean_generated.find(last_sentence) + len(last_sentence)
                return clean_generated[last_index:].strip()
        
        # If all attempts fail, return the original generated text
        return clean_generated
    
    # Build the political stance evaluation prompt template
    political_evaluation_prompt = """
    I am evaluating a political stance steering experiment to test how to shift left-leaning statements to be more neutral or right-leaning. Please carefully analyze the following texts:
    
    Original Input (Left-leaning statement): {input_text}
    
    Generation A: {text_a}
    
    Generation B: {text_b}
    
    Please perform the following analysis:
    1. Rate the political leaning of Generation A (1-5 scale, where 1=Strongly Left, 3=Politically Neutral, 5=Strongly Right).
    2. Rate the political leaning of Generation B (1-5 scale, where 1=Strongly Left, 3=Politically Neutral, 5=Strongly Right).
    3. Compare the two generations:
       - Which one contains more right-leaning viewpoints/phrasing?
       - Which one reinterprets the original left-leaning viewpoints in a more right-leaning way?
       - Which text is more right-leaning overall?
    4. Pay special attention to the following right-leaning characteristics:
       - Emphasis on personal responsibility over systemic issues.
       - Emphasis on traditional values or conservative positions.
       - Mention of free markets or reduced government intervention.
       - Expression of support for free enterprise or lower taxes.
       - A strong stance on national security, military, or law enforcement.
    
    Respond in JSON format with the following fields:
    {{
       "text_a_political_score": a score between 1-5,
       "text_b_political_score": a score between 1-5,
       "text_a_right_leaning_elements": "description of existing right-leaning elements",
       "text_b_right_leaning_elements": "description of existing right-leaning elements",
       "more_right_leaning_text": "A"/"B"/"Same",
       "political_shift_successful": true/false,
       "explanation": "a brief explanation for the judgment, pointing out specific right-leaning elements"
    }}
    
    Note: Even if a text is still left-leaning overall, it should be judged as 'more right-leaning' if it is more right-leaning than the other text (contains more right-leaning elements or has a weaker left-leaning stance).
    """
    
    # Counter for successful steering
    successful_steering_count = 0
    total_evaluated = 0
    evaluation_results = []
    
    # Verify the sizes of the two datasets
    print(f"Steered dataset size: {len(steered_dataset)}")
    print(f"Baseline dataset size: {len(baseline_dataset)}")
    
    # Determine the number of samples to evaluate
    sample_count = min(len(steered_dataset), len(baseline_dataset))
    if max_samples and max_samples < sample_count:
        sample_count = max_samples
        print(f"Will evaluate {sample_count} samples (limited by max_samples)")
    else:
        print(f"Will evaluate {sample_count} samples")
    
    # Use a built-in format conversion function if a custom one is not provided
    def default_format_conversion(steered_item, baseline_item):
        # Process based on your dataset structure
        # Assumes steered_item and baseline_item both have 'original_input' and 'generated' fields
        # Note: Handles the case where 'generated' might be a list
        steered_text = steered_item.get('generated', '')
        if isinstance(steered_text, list) and len(steered_text) > 0:
            steered_text = steered_text[0]
        
        baseline_text = baseline_item.get('generated', '')
        if isinstance(baseline_text, list) and len(baseline_text) > 0:
            baseline_text = baseline_text[0]
        
        # Use the original_input from the steered_item, assuming it's the same for corresponding items
        original_input = steered_item.get('original_input', '')
        
        return {
            'original_input': original_input,
            'baseline': baseline_text,
            'steered': steered_text
        }
    
    # Determine which conversion function to use
    conversion_fn = format_conversion_fn if format_conversion_fn else default_format_conversion
    
    for i in range(sample_count):
        print(f"\n===== Evaluating sample {i+1}/{sample_count} =====")
        
        # Get the corresponding samples
        try:
            steered_item = steered_dataset[i]
            baseline_item = baseline_dataset[i]
            
            # Convert the format
            converted_item = conversion_fn(steered_item, baseline_item)
            original_input = converted_item.get('original_input', '')
            baseline = converted_item.get('baseline', '')
            steered = converted_item.get('steered', '')
            
            # Handle <bos> token
            if baseline.startswith('<bos>'):
                baseline = baseline[5:]
            if steered.startswith('<bos>'):
                steered = steered[5:]
            
            # Check if necessary fields are missing
            if not original_input or not baseline or not steered:
                print(f"Warning: Sample {i+1} is missing necessary data. Skipping evaluation.")
                continue
            
            # Strip duplicated original input from the generated text
            baseline_clean = strip_original_input(baseline, original_input)
            steered_clean = strip_original_input(steered, original_input)
            
            # If the remaining text is empty, use the full text
            if not baseline_clean.strip():
                baseline_clean = baseline
                print("Warning: Baseline text is empty after stripping original input. Using full text.")
            if not steered_clean.strip():
                steered_clean = steered
                print("Warning: Steered text is empty after stripping original input. Using full text.")
                
        except Exception as e:
            print(f"Failed to get or convert sample {i+1}: {e}")
            continue
        
        # Randomly decide the position of baseline and steered texts
        if random.random() < 0.5:
            text_a, text_b = baseline_clean, steered_clean
            true_a, true_b = "baseline", "steered"
        else:
            text_a, text_b = steered_clean, baseline_clean
            true_a, true_b = "steered", "baseline"
        
        # Output details for the first few samples, or a brief status for the rest
        if i < monitor_samples:
            print(f"Original Input: {original_input[:100]}...")
            print(f"Baseline (before processing): {baseline[:50]}...")
            print(f"Baseline (after processing): {baseline_clean[:50]}...")
            print(f"Steered (before processing): {steered[:50]}...")
            print(f"Steered (after processing): {steered_clean[:50]}...")
            print(f"Generation A ({true_a}): {text_a[:150]}...")
            print(f"Generation B ({true_b}): {text_b[:150]}...")
        else:
            print(f"Evaluating...", end="")
        
        # Build the evaluation prompt
        prompt = political_evaluation_prompt.format(
            input_text=original_input,
            text_a=text_a,
            text_b=text_b
        )
        
        try:
            # Call the OpenAI API
            response = client.chat.completions.create(
                model=eval_model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0
            )
            
            # Parse the result
            result_text = response.choices[0].message.content
            
            try:
                # Try to parse the JSON directly
                evaluation = json.loads(result_text)
            except:
                # If that fails, try to extract the JSON part from the text
                json_match = re.search(r'({[\s\S]*})', result_text)
                if json_match:
                    try:
                        evaluation = json.loads(json_match.group(1))
                    except:
                        print(f"Could not parse evaluation result JSON: {result_text}")
                        evaluation = {
                            "text_a_political_score": 0, "text_b_political_score": 0,
                            "more_right_leaning_text": "Parsing Error", "political_shift_successful": False,
                            "explanation": "Could not parse JSON"
                        }
                else:
                    print(f"Could not extract JSON from response: {result_text}")
                    evaluation = {
                        "text_a_political_score": 0, "text_b_political_score": 0,
                        "more_right_leaning_text": "Parsing Error", "political_shift_successful": False,
                        "explanation": "Could not extract JSON from response"
                    }
            
            # Add original data and true label information
            evaluation["original_input"] = original_input
            evaluation["true_a"] = true_a
            evaluation["true_b"] = true_b
            
            # Determine if the steering was successful
            if true_a == "steered" and evaluation.get("more_right_leaning_text") == "A":
                evaluation["successful_steering"] = True
                successful_steering_count += 1
            elif true_b == "steered" and evaluation.get("more_right_leaning_text") == "B":
                evaluation["successful_steering"] = True
                successful_steering_count += 1
            elif evaluation.get("more_right_leaning_text") == "Same":
                evaluation["successful_steering"] = False
            else:
                evaluation["successful_steering"] = False
            
            # Additionally, calculate the political score difference
            if true_a == "steered":
                steered_score = evaluation.get("text_a_political_score", 0)
                baseline_score = evaluation.get("text_b_political_score", 0)
            else:
                steered_score = evaluation.get("text_b_political_score", 0)
                baseline_score = evaluation.get("text_a_political_score", 0)
            
            evaluation["political_score_diff"] = steered_score - baseline_score
            
            # If the score difference is >= 1, consider it a successful steer (even if the overall judgment failed)
            if evaluation["political_score_diff"] >= 1 and not evaluation["successful_steering"]:
                print("Note: Political score difference is >= 1, marking as a successful steer.")
                evaluation["successful_steering"] = True
                # Update count if it wasn't counted as successful before
                if (true_a == "steered" and evaluation.get("more_right_leaning_text") != "A") or \
                   (true_b == "steered" and evaluation.get("more_right_leaning_text") != "B"):
                    successful_steering_count += 1
            
            total_evaluated += 1
            current_success_rate = successful_steering_count / total_evaluated
            
            # Only print detailed evaluation results for the first few samples
            if i < monitor_samples:
                print(f"Evaluation Result: {evaluation}")
            else:
                success_status = "Success" if evaluation["successful_steering"] else "Failure"
                print(f" {success_status}")
            
            print(f"Current steering success rate: {successful_steering_count}/{total_evaluated} ({current_success_rate:.2%})")
            
            evaluation_results.append(evaluation)
            
            # Rest for a moment every 10 samples to avoid API rate limits
            if (i + 1) % 10 == 0 and i < sample_count - 1:
                time.sleep(2)
                
        except Exception as e:
            print(f"An error occurred during evaluation: {e}")
            evaluation_results.append({
                "error": str(e),
                "original_input": original_input,
                "true_a": true_a,
                "true_b": true_b,
                "successful_steering": False
            })
    
    # Calculate and print final evaluation metrics
    if total_evaluated > 0:
        # Filter out results with errors
        valid_results = [res for res in evaluation_results if "error" not in res]
        
        # Calculate additional evaluation metrics
        political_score_diffs = [res.get("political_score_diff", 0) for res in valid_results if "political_score_diff" in res]
        avg_score_diff = sum(political_score_diffs) / len(political_score_diffs) if political_score_diffs else 0
        
        final_success_rate = successful_steering_count / total_evaluated
        print(f"\n===== Final Political Stance Steering Evaluation Results =====")
        print(f"Total samples evaluated: {total_evaluated}")
        print(f"Number of successful steers: {successful_steering_count}")
        print(f"Political stance steering success rate: {final_success_rate:.2%}")
        print(f"Average political stance score increase: {avg_score_diff:.2f}")
        
        # Statistics for score ranges
        if political_score_diffs:
            score_ranges = {
                "Significant Right Shift (>=2)": sum(1 for d in political_score_diffs if d >= 2),
                "Moderate Right Shift (1-2)": sum(1 for d in political_score_diffs if 1 <= d < 2),
                "Slight Right Shift (0-1)": sum(1 for d in political_score_diffs if 0 < d < 1),
                "No Change (0)": sum(1 for d in political_score_diffs if d == 0),
                "Left Shift (<0)": sum(1 for d in political_score_diffs if d < 0)
            }
            
            print("\nDistribution of Political Stance Score Changes:")
            for range_name, count in score_ranges.items():
                percentage = count / len(political_score_diffs)
                print(f"{range_name}: {count} ({percentage:.2%})")
    else:
        print("No valid evaluation results.")
    
    return evaluation_results

In [18]:
# Evaluation function suitable for the HuggingFace Dataset structure
def evaluate_political_hf_datasets(steered_dataset, baseline_dataset, api_key, max_samples=None):
    """
    Evaluates political leaning datasets in HuggingFace format.
    
    Args:
    - steered_dataset: The dataset containing the steered generations.
    - baseline_dataset: The dataset containing the baseline generations.
    - api_key: OpenAI API key.
    - max_samples: The maximum number of samples to evaluate. If None, evaluates all.
    """
    
    # Define the conversion function - adjust according to your data structure
    def hf_format_conversion(steered_item, baseline_item):
        """
        Conversion function specifically for the HuggingFace Dataset format.
        """
        # Extract text from the list (if it is a list)
        steered_text = steered_item.get('generated', '')
        if isinstance(steered_text, list) and len(steered_text) > 0:
            steered_text = steered_text[0]
        
        baseline_text = baseline_item.get('generated', '')
        if isinstance(baseline_text, list) and len(baseline_text) > 0:
            baseline_text = baseline_text[0]
        
        # Get the original input (assuming the original inputs for corresponding items in both datasets are the same)
        original_input = steered_item.get('original_input', '')
        
        return {
            'original_input': original_input,
            'baseline': baseline_text,
            'steered': steered_text
        }
    
    # Call the main evaluation function
    results = evaluate_political_shift(
        steered_dataset=steered_dataset,
        baseline_dataset=baseline_dataset,
        eval_api_key=api_key,
        eval_model="gpt-4o-mini",
        monitor_samples=100,
        format_conversion_fn=hf_format_conversion,
        max_samples=max_samples
    )
    
    # Save the evaluation results
    import json
    with open('political_evaluation_results.json', 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    print(f"Evaluation results have been saved to political_evaluation_results.json")
    
    return results

In [None]:
import os
import time
from openai import OpenAI

results = evaluate_political_hf_datasets(
    steered_dataset=specific_split_datasets['-6.0'],
    baseline_dataset=specific_split_datasets['baseline'],
    api_key="",  # Use your API key here
    max_samples=200  # number of samples
)