In [None]:
goemo_path_train = "Go_Emotion_Google/go_emotions_train.csv"

In [26]:
import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
import re
from tqdm import tqdm

# Reuse the GoEmotionDataset class
class GoEmotionDataset:
    """
    A class to load, preprocess, and analyze the GoEmotions dataset.
    """
    
    EMOTIONS = [
        'admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 
        'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 
        'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief',
        'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 
        'relief', 'remorse', 'sadness', 'surprise', 'neutral'
    ]
    
    POSITIVE_EMOTIONS = {
        "admiration", "amusement", "approval", "caring", "desire", "excitement",
        "gratitude", "joy", "love", "optimism", "pride", "relief"
    }
    
    AMBIGUOUS_EMOTIONS = {
        "confusion", "curiosity", "surprise", "realization", "neutral"
    }
    
    NEGATIVE_EMOTIONS = {
        "anger", "annoyance", "disappointment", "disapproval", "disgust",
        "embarrassment", "fear", "grief", "nervousness", "remorse", "sadness"
    }
    
    def __init__(self, train_path: str, test_path: str, val_path: str):
        """
        Initialize the dataset by loading and processing the data.
        
        Args:
            train_path: Path to training data CSV
            test_path: Path to test data CSV
            val_path: Path to validation data CSV
        """
        self.df = self._load_data(train_path, test_path, val_path)
        self._preprocess_data()
        
    def _load_data(self, train_path: str, test_path: str, val_path: str) -> pd.DataFrame:
        """Load and combine the dataset components."""
        train_df = pd.read_csv(train_path)
        test_df = pd.read_csv(test_path)
        val_df = pd.read_csv(val_path)
        
        main_df = pd.concat([train_df, test_df, val_df], axis=0)
        main_df = main_df.reset_index(drop=True)
        main_df.drop_duplicates(inplace=True)
        
        return main_df
    
    @staticmethod
    def _preprocess_text(text: str) -> str:
        """Clean and normalize text data."""
        if not isinstance(text, str):
            return ""
        # Preprocessing logic from the original class
        return text.lower()
    
    @staticmethod
    def _string_to_list(label_str: str) -> List[int]:
        """Convert string representation of array to list of integers."""
        return [int(x) for x in label_str.strip('[]').replace(',', '').split()]
    
    def _preprocess_data(self):
        """Apply all preprocessing steps to the dataset."""
        # Clean text
        self.df['clean_text'] = self.df['text'].apply(self._preprocess_text)
        
        # Ensure labels column contains lists
        if 'labels' in self.df.columns and isinstance(self.df['labels'].iloc[0], str):
            self.df['labels'] = self.df['labels'].apply(self._string_to_list)
        
        # Create one-hot encodings for emotions
        for i, emotion in enumerate(self.EMOTIONS):
            self.df[emotion] = self.df['labels'].apply(lambda x: 1 if i in x else 0)
        
        # Get the dominant emotion for each text
        self.df['dominant_emotion'] = self.df.apply(
            lambda row: self.EMOTIONS[np.argmax([row[emotion] for emotion in self.EMOTIONS])], 
            axis=1
        )
        
        self.df = self.df.drop(columns=['text', 'id'] if 'id' in self.df.columns else ['text'])
    
    def get_data(self) -> pd.DataFrame:
        """Return the processed DataFrame."""
        return self.df.copy()

    def get_emotion_labels(self) -> List[str]:
        """Return the list of emotion labels."""
        return self.EMOTIONS



In [27]:
from tqdm import tqdm

class HiddenStatesDataset(Dataset):
    """Dataset for hidden states and labels."""
    def __init__(self, hidden_states, labels):
        self.hidden_states = hidden_states
        self.labels = labels
        
    def __len__(self):
        return len(self.hidden_states)
    
    def __getitem__(self, idx):
        return torch.tensor(self.hidden_states[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.long)

class EnhancedProbe(nn.Module):
    """Enhanced probe with multiple hidden layers and dropout."""
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.3):
        super(EnhancedProbe, self).__init__()
        layers = []
        prev_dim = input_dim
        
        for dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.GELU())
            #layers.append(nn.Dropout(dropout_rate))
            prev_dim = dim
            
        layers.append(nn.Linear(prev_dim, output_dim))
        self.network = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.network(x)


In [None]:
import os
import json
import ijson
import re
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from tqdm import tqdm
from collections import defaultdict
import time
from concurrent.futures import ThreadPoolExecutor

class HiddenStatesDataset(Dataset):
    """Dataset for hidden states and labels."""
    def __init__(self, hidden_states, labels):
        self.hidden_states = hidden_states
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return torch.tensor(self.hidden_states[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.long)

class EnhancedProbe(nn.Module):
    """Enhanced probe architecture with multiple hidden layers."""
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.3):
        super().__init__()
        layers = []
        prev_dim = input_dim
        
        for dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            prev_dim = dim
            
        layers.append(nn.Linear(prev_dim, output_dim))
        self.net = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.net(x)

class OptimizedProbeAnalyzer:
    """Optimized analyzer for training and evaluating probes on hidden states."""
    
    def __init__(self, hidden_states_dir: str, go_dataset, max_batches: int = None, device=None, chunk_size: int = 500):
        """
        Initialize the ProbeAnalyzer with optimized data loading.
        
        Args:
            hidden_states_dir: Directory containing hidden states JSON files
            go_dataset: GoEmotionDataset instance
            max_batches: Maximum number of batches to load (None for all)
            device: Device to run computations on (defaults to CUDA if available)
            chunk_size: Number of files to load at once
        """
        self.hidden_states_dir = hidden_states_dir
        self.go_dataset = go_dataset
        self.max_batches = max_batches
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        self.label_encoder = LabelEncoder()
        self.chunk_size = chunk_size
        
        # Find all hidden states files
        self.hidden_states_files = sorted([
            f for f in os.listdir(hidden_states_dir) 
            if f.startswith('QwenQwen2-7B_goEmo_') and f.endswith('.json')
        ])
        
        # Extract indices from filenames for ordering
        indices = [int(re.search(r'goEmo_(\d+)\.json', f).group(1)) for f in self.hidden_states_files]
        self.hidden_states_files = [f for _, f in sorted(zip(indices, self.hidden_states_files))]
        
        # Limit the number of batches if specified
        if self.max_batches is not None:
            self.hidden_states_files = self.hidden_states_files[:self.max_batches]
            print(f"Analysis limited to {len(self.hidden_states_files)} batches")
        
        # Initialize metrics storage
        self.metrics = {}
        self.stats = {}
        
        # Identify layers and prepare dataset without loading all data
        self._identify_layers()
        self._prepare_dataset()
    
    def _identify_layers(self):
        """Identify available layers from first file without loading all data."""
        print("Identifying layer structure...")
        with open(os.path.join(self.hidden_states_dir, self.hidden_states_files[0]), 'r') as f:
            first_batch = json.load(f)
            first_example = first_batch[0]
            self.layers = [key for key in first_example.keys() if key.startswith('layer_')]
        
        print(f"Identified {len(self.layers)} layers")
        
        # Calculate total examples by reading batch sizes
        self.total_examples = 0
        self.batch_sizes = []
        
        for file_name in tqdm(self.hidden_states_files, desc="Counting examples"):
            with open(os.path.join(self.hidden_states_dir, file_name), 'r') as f:
                batch_data = json.load(f)
                batch_size = len(batch_data)
                self.total_examples += batch_size
                self.batch_sizes.append(batch_size)
        
        print(f"Total examples: {self.total_examples}")
    
    def _prepare_dataset(self):
        """Prepare dataset metadata and train/test splits."""
        df = self.go_dataset.get_data()
        
        # Calculate which indices are valid (within dataset bounds)
        self.valid_mask = [True] * min(self.total_examples, len(df))
        self.valid_mask += [False] * max(0, self.total_examples - len(df))
        
        # Get labels only for valid indices
        valid_indices = np.where(self.valid_mask)[0]
        self.labels = df.iloc[valid_indices]['dominant_emotion'].values
        self.encoded_labels = self.label_encoder.fit_transform(self.labels)
        
        # Split data into train and test sets (80/20) using sklearn's train_test_split
        indices = np.arange(len(valid_indices))
        self.train_indices, self.test_indices = train_test_split(
            indices, test_size=0.2, random_state=42, stratify=self.encoded_labels
        )
        
        print(f"Prepared metadata for {len(valid_indices)} valid examples")
        print(f"Train set: {len(self.train_indices)}, Test set: {len(self.test_indices)}")

    def _load_batch_file(self, file_path):
        """Optimized batch file loading using ijson"""
        with open(file_path, 'r') as f:
            return list(ijson.items(f, 'item'))
    
    def analyze_all_layers(self, layers_to_analyze, batch_size=32, epochs=50, max_workers=4):
        """Faster analysis with parallel file loading"""
        print(f"Training probes for layers: {layers_to_analyze}")
        layers_to_analyze = sorted(layers_to_analyze)
        
        # Initialize data structures
        all_data = {
            'train': {'hidden_states': defaultdict(list), 'labels': defaultdict(list)},
            'test': {'hidden_states': defaultdict(list), 'labels': defaultdict(list), 'raw_labels': defaultdict(list)}
        }
        
        # Process files in parallel
        def process_file(file_name):
            try:
                file_path = os.path.join(self.hidden_states_dir, file_name)
                batch_data = self._load_batch_file(file_path)
                
                file_results = {
                    'train': defaultdict(list),
                    'test': defaultdict(list)
                }
                
                # Process examples
                for example in batch_data:
                    # [Same processing logic as before]
                    pass
                    
                return file_results
            except Exception as e:
                print(f"Error processing {file_name}: {str(e)}")
                return None
        
        print("Loading data with parallel processing...")
        start_time = time.time()
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            for chunk_start in range(0, len(self.hidden_states_files), self.chunk_size):
                chunk_files = self.hidden_states_files[chunk_start:chunk_start + self.chunk_size]
                futures.extend(executor.submit(process_file, fname) for fname in chunk_files)
            
            for future in tqdm(futures, desc="Processing files"):
                result = future.result()
                if result:
                    # Aggregate results
                    pass
        
        print(f"Data loaded in {time.time()-start_time:.2f} seconds")

    def _train_probe_for_layer(self, layer_key, train_hidden_states, train_labels, 
                             test_hidden_states, test_labels, test_raw_labels, 
                             batch_size, epochs):
        """Train probe and return metrics for both train and test sets"""
        # [Existing training code, but modified to return both train and test metrics]
        
        # Calculate train metrics
        train_preds = self._get_predictions(model, train_hidden_states, batch_size)
        train_accuracy = accuracy_score(train_labels, train_preds)
        
        return {
            'train': {
                'accuracy': train_accuracy,
                'confusion_matrix': confusion_matrix(train_labels, train_preds),
                'hidden_states': train_hidden_states,
                'labels': [self.label_encoder.inverse_transform([l])[0] for l in train_labels]
            },
            'test': {
                'accuracy': test_accuracy,
                'confusion_matrix': test_cm,
                'hidden_states': test_hidden_states,
                'labels': test_raw_labels,
                'pred_labels': test_pred_emotions
            },
            'loss_curves': {
                'train': train_losses,
                'val': val_losses
            }
        }

    def _get_predictions(self, model, hidden_states, batch_size):
        """Get predictions for a set of hidden states"""
        dataset = HiddenStatesDataset(hidden_states, np.zeros(len(hidden_states)))  # Dummy labels
        loader = DataLoader(dataset, batch_size=batch_size)
        
        model.eval()
        preds = []
        with torch.no_grad():
            for inputs, _ in loader:
                inputs = inputs.to(self.device)
                outputs = model(inputs)
                _, batch_preds = torch.max(outputs, 1)
                preds.extend(batch_preds.cpu().numpy())
        return preds

    def _analyze_hidden_states(self, layer_idx, hidden_states, labels):
        """Analyze hidden states statistics."""
        hidden_states = np.array(hidden_states)
        labels = np.array(labels)
        
        # Calculate centroids
        unique_emotions = np.unique(labels)
        centroids = {}
        for emotion in unique_emotions:
            emotion_states = hidden_states[labels == emotion]
            if len(emotion_states) > 0:
                centroids[emotion] = emotion_states.mean(axis=0)
        
        # Calculate centroid similarities
        centroid_similarities = {}
        for i, e1 in enumerate(unique_emotions):
            for e2 in unique_emotions[i+1:]:
                if e1 in centroids and e2 in centroids:
                    c1, c2 = centroids[e1], centroids[e2]
                    similarity = np.dot(c1, c2) / (np.linalg.norm(c1) * np.linalg.norm(c2))
                    centroid_similarities[f"{e1}_vs_{e2}"] = similarity
        
        # Silhouette score
        try:
            silhouette = silhouette_score(hidden_states, labels) if len(unique_emotions) > 1 else -1
        except:
            silhouette = -1
        
        return {
            'centroids': centroids,
            'centroid_similarities': centroid_similarities,
            'silhouette_score': silhouette,
            'viz_labels': labels
        }

    def save_results(self, output_path: str):
        """Save all results to a file."""
        results = {
            'metrics': self.metrics,
            'stats': self.stats,
            'label_encoder': self.label_encoder,
            'emotion_labels': self.go_dataset.get_emotion_labels(),
            'layer_order': sorted(self.metrics.keys(), key=lambda x: int(x.split('_')[1]))
        }
        torch.save(results, output_path)

    @classmethod
    def load_results(cls, input_path: str):
        """Load saved results from file."""
        return torch.load(input_path)

In [70]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

class ProbeVisualizer:
    """Visualizer with separate train/test plots and larger fonts."""
    
    def __init__(self, analysis_results):
        self.results = analysis_results
        self._setup_plot_styles()
    
    def _setup_plot_styles(self):
        """Configure consistent plotting styles"""
        plt.rcParams.update({
            'font.size': 14,
            'axes.titlesize': 18,
            'axes.labelsize': 16,
            'xtick.labelsize': 13,
            'ytick.labelsize': 13,
            'legend.fontsize': 12,
            'figure.titlesize': 20
        })
        sns.set_style("whitegrid")
    
    def visualize_all(self, output_dir='results', dpi=300):
        """Generate all visualizations with train/test separation"""
        os.makedirs(output_dir, exist_ok=True)
        
        self.plot_performance_metrics(output_dir, dpi)  # Renamed from plot_combined_metrics
        self.plot_accuracy_comparison(output_dir, dpi)
        self.plot_silhouette_scores(output_dir, dpi)
        
        for layer in self.results['layer_order']:
            self.plot_confusion_matrices(layer, output_dir, dpi)
            self.plot_tsne_projections(layer, output_dir, dpi)
            self.plot_loss_curves(layer, output_dir, dpi)
        
        self.save_summary_table(output_dir)
        print(f"All visualizations saved to {output_dir}")

    def plot_performance_metrics(self, output_dir, dpi=300):
        """Plot layer-wise performance metrics (originally called plot_combined_metrics)"""
        metrics = {
            'Accuracy': [self.results['metrics'][layer]['test']['accuracy'] for layer in self.results['layer_order']],
            'Macro F1': [self.results['metrics'][layer]['test']['macro_f1'] for layer in self.results['layer_order']],
            'Weighted F1': [self.results['metrics'][layer]['test']['weighted_f1'] for layer in self.results['layer_order']]
        }
        
        layer_indices = [int(layer.split('_')[1]) for layer in self.results['layer_order']]
        
        plt.figure(figsize=(14, 8), dpi=dpi)
        for metric_name, values in metrics.items():
            plt.plot(layer_indices, values, marker='o', markersize=8, linewidth=2, label=metric_name)
        
        plt.xlabel('Layer Index', fontweight='bold')
        plt.ylabel('Score', fontweight='bold')
        plt.title('Layer-wise Performance Metrics (Test Set)', fontweight='bold', pad=20)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'performance_metrics.png'), bbox_inches='tight', dpi=dpi)
        plt.close()

    def plot_accuracy_comparison(self, output_dir, dpi=300):
        """Plot train vs test accuracy across layers"""
        layers = self.results['layer_order']
        train_acc = [self.results['metrics'][l]['train']['accuracy'] for l in layers]
        test_acc = [self.results['metrics'][l]['test']['accuracy'] for l in layers]
        layer_nums = [int(l.split('_')[1]) for l in layers]
        
        plt.figure(figsize=(12, 7), dpi=dpi)
        plt.plot(layer_nums, train_acc, 'o-', label='Train Accuracy', linewidth=2)
        plt.plot(layer_nums, test_acc, 'o-', label='Test Accuracy', linewidth=2)
        
        plt.xlabel('Layer Index', fontweight='bold')
        plt.ylabel('Accuracy', fontweight='bold')
        plt.title('Train vs Test Accuracy by Layer', fontweight='bold', pad=20)
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'accuracy_comparison.png'), bbox_inches='tight', dpi=dpi)
        plt.close()

    def plot_silhouette_scores(self, output_dir, dpi=300):
        """Plot silhouette scores across layers"""
        silhouette_scores = [self.results['stats'][layer]['test']['silhouette_score'] for layer in self.results['layer_order']]
        layer_indices = [int(layer.split('_')[1]) for layer in self.results['layer_order']]
        
        plt.figure(figsize=(12, 7), dpi=dpi)
        plt.plot(layer_indices, silhouette_scores, marker='o', markersize=8, linewidth=2, color='darkorange')
        
        plt.xlabel('Layer Index', fontweight='bold')
        plt.ylabel('Silhouette Score', fontweight='bold')
        plt.title('Layer-wise Cluster Separation (Test Set)', fontweight='bold', pad=20)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'silhouette_scores.png'), bbox_inches='tight', dpi=dpi)
        plt.close()

    def plot_confusion_matrices(self, layer, output_dir, dpi=300):
        """Plot both train and test confusion matrices for a specific layer"""
        for split in ['train', 'test']:
            cm = self.results['metrics'][layer][split]['confusion_matrix']
            labels = self.results['metrics'][layer][split]['labels']
            
            # Normalize and get top emotions
            with np.errstate(invalid='ignore'):
                cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
                cm_percent = np.nan_to_num(cm_percent, nan=0.0)
            
            emotion_counts = cm.sum(axis=1)
            top_idx = np.argsort(emotion_counts)[-10:]  # Top 10 emotions
            top_emotions = [self.results['emotion_labels'][i] for i in top_idx]
            cm_top = cm_percent[top_idx][:, top_idx]
            
            plt.figure(figsize=(14, 12), dpi=dpi)
            sns.heatmap(
                cm_top, 
                annot=True, 
                fmt='.1f', 
                cmap='Blues',
                xticklabels=top_emotions,
                yticklabels=top_emotions,
                annot_kws={'size': 11}
            )
            plt.title(f'Confusion Matrix ({split.capitalize()} Set) - {layer}', fontweight='bold', pad=20)
            plt.xlabel('Predicted Emotion', fontweight='bold')
            plt.ylabel('True Emotion', fontweight='bold')
            plt.tight_layout()
            plt.savefig(
                os.path.join(output_dir, f'confusion_{split}_{layer}.png'), 
                bbox_inches='tight', 
                dpi=dpi
            )
            plt.close()

    def plot_tsne_projections(self, layer, output_dir, dpi=300, perplexity=30):
        """Generate t-SNE plots for both train and test sets of a specific layer"""
        for split in ['train', 'test']:
            hidden_states = np.array(self.results['metrics'][layer][split]['hidden_states'])
            labels = self.results['metrics'][layer][split]['labels']
            
            # Subsample if needed
            if len(hidden_states) > 1000:
                idx = np.random.choice(len(hidden_states), 1000, replace=False)
                hidden_states = hidden_states[idx]
                labels = [labels[i] for i in idx]
            
            # Compute t-SNE
            try:
                tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
                tsne_result = tsne.fit_transform(hidden_states)
            except:
                print(f"Skipping t-SNE for {layer} {split} set due to error")
                continue
            
            # Plot with top emotions
            unique_emotions, counts = np.unique(labels, return_counts=True)
            top_emotions = unique_emotions[np.argsort(counts)[-5:]]  # Top 5 emotions
            
            plt.figure(figsize=(12, 10), dpi=dpi)
            colors = plt.cm.tab10(np.linspace(0, 1, len(top_emotions)))
            
            for i, emotion in enumerate(top_emotions):
                mask = np.array(labels) == emotion
                plt.scatter(
                    tsne_result[mask, 0], 
                    tsne_result[mask, 1], 
                    label=emotion,
                    color=colors[i],
                    alpha=0.7,
                    s=50
                )
            plt.title(f't-SNE Projection ({split.capitalize()} Set) - {layer}', fontweight='bold', pad=20)
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(
                os.path.join(output_dir, f'tsne_{split}_{layer}.png'), 
                bbox_inches='tight', 
                dpi=dpi
            )
            plt.close()

    def plot_loss_curves(self, layer, output_dir, dpi=300):
        """Plot training/validation loss curves for a specific layer"""
        train_losses = self.results['metrics'][layer]['loss_curves']['train']
        val_losses = self.results['metrics'][layer]['loss_curves']['val']
        
        plt.figure(figsize=(12, 7), dpi=dpi)
        plt.plot(train_losses, label='Training Loss', linewidth=2)
        plt.plot(val_losses, label='Validation Loss', linewidth=2)
        
        plt.xlabel('Epoch', fontweight='bold')
        plt.ylabel('Loss', fontweight='bold')
        plt.title(f'Training Progress - {layer}', fontweight='bold', pad=20)
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(
            os.path.join(output_dir, f'loss_curves_{layer}.png'), 
            bbox_inches='tight', 
            dpi=dpi
        )
        plt.close()

    def save_summary_table(self, output_dir):
        """Save metrics summary as CSV and HTML"""
        summary_data = []
        for layer in self.results['layer_order']:
            summary_data.append({
                'Layer': int(layer.split('_')[1]),
                'Train Accuracy': self.results['metrics'][layer]['train']['accuracy'],
                'Test Accuracy': self.results['metrics'][layer]['test']['accuracy'],
                'Test Macro F1': self.results['metrics'][layer]['test']['macro_f1'],
                'Test Weighted F1': self.results['metrics'][layer]['test']['weighted_f1'],
                'Silhouette Score': self.results['stats'][layer]['test']['silhouette_score']
            })
        
        df = pd.DataFrame(summary_data)
        
        # Save CSV
        df.to_csv(os.path.join(output_dir, 'metrics_summary.csv'), index=False)
        
        # Save styled HTML
        styled_df = df.style\
            .background_gradient(cmap='Blues', subset=['Test Accuracy', 'Test Macro F1'])\
            .format('{:.3f}', subset=['Train Accuracy', 'Test Accuracy', 'Test Macro F1', 'Test Weighted F1'])\
            .set_caption('Probe Performance Summary')
        
        with open(os.path.join(output_dir, 'metrics_summary.html'), 'w') as f:
            f.write(styled_df.to_html())

In [64]:
# 1. Load the GoEmotions dataset
go_dataset = GoEmotionDataset(
    train_path='Go_Emotion_Google/go_emotions_train.csv',
    test_path='Go_Emotion_Google/go_emotions_test.csv',
    val_path='Go_Emotion_Google/go_emotions_validation.csv'
)

In [65]:
max_batches=1328
# 2. Initialize the optimized probe analyzer with batch limit
analyzer = OptimizedProbeAnalyzer(
    hidden_states_dir='hidden_states',
    go_dataset=go_dataset,
    max_batches=max_batches,  # Limit to specified number of batches
    chunk_size=100  # Process 50 files at a time
)

Analysis limited to 1328 batches
Identifying layer structure...
Identified 29 layers


Counting examples: 100%|██████████| 1328/1328 [06:22<00:00,  3.47it/s]

Total examples: 21248
Prepared metadata for 21248 valid examples
Train set: 16998, Test set: 4250





In [68]:
import time
from concurrent.futures import ThreadPoolExecutor

# Load your analysis results (either from analyzer or saved file)
results = analyzer.analyze_all_layers([0, 4,  10, 12, 14, 16, 18, 20, 24, 26, 28]) 
analyzer.save_results("full_results.pt")

Training probes for layers: [0, 4, 10, 12, 14, 16, 18, 20, 24, 26, 28]
Loading data with parallel processing...


Processing files: 100%|██████████| 1328/1328 [06:49<00:00,  3.24it/s]


Data loaded in 409.35 seconds


In [74]:
results

In [73]:

# Create visualizations
visualizer = ProbeVisualizer(results)
visualizer.visualize_all(output_dir='my_report_figures', dpi=300)
# Or make individual plots
visualizer.plot_confusion_matrices('layer_10', output_dir='confusion_mats')
visualizer.plot_tsne_projections('layer_20', output_dir='tsne_plots')

TypeError: 'NoneType' object is not subscriptable

In [None]:
# Analysis phase (safe - no plotting)
analyzer = OptimizedProbeAnalyzer(hidden_states_dir, go_dataset)
results = analyzer.analyze_all_layers(layers_to_analyze)

# Save results
analyzer.save_results("probe_results.pt")

# Visualization phase (separate - can crash without losing data)
try:
    visualizer = ProbeVisualizer(results)
    visualizer.visualize_all()
except Exception as e:
    print(f"Visualization failed but data is safe: {e}")
    # You can reload and retry:
    results = OptimizedProbeAnalyzer.load_results("probe_results.pt")
    visualizer = ProbeVisualizer(results)
    visualizer.visualize_all("alternative_output_dir")