In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import numpy as np
import torch
import os
import gc
import json
import pandas as pd
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import scipy.stats as stats  

In [None]:


os.environ["TORCH_COMPILE_DISABLE"] = "1"

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.bool_):
            return bool(obj)
        return super(NumpyEncoder, self).default(obj)

class FisherInformationAnalysis:
    """
    Simple analyzer for computing Fisher Information Matrix for transformer models.
    """

    def __init__(self, model_name, model_type="decoder", output_dir="./fisher_info"):
        """
        Initialize the Fisher Information analysis.
        Args:
            model_name: HuggingFace model name or path
            model_type: Type of model - "decoder" or "encoder"
            output_dir: Directory to save results
        """
        self.model_name = model_name
        self.model_type = model_type
        self.output_dir = output_dir
        self.model = None
        self.tokenizer = None

        os.makedirs(output_dir, exist_ok=True)
        print(f"Output will be saved to {output_dir}")

    def load_model(self):
        """Load the transformer model and tokenizer."""
        print(f"Loading {self.model_name}...")

        # Clean up any existing model
        if self.model is not None:
            del self.model
            self.model = None
            torch.cuda.empty_cache()
            gc.collect()

        try:
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

            if self.tokenizer.pad_token is None and self.model_type == "decoder":
                self.tokenizer.pad_token = self.tokenizer.eos_token

            if self.model_type == "decoder":
                from transformers import AutoModelForCausalLM
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    device_map="auto"
                )
            else:  # encoder models like BERT
                from transformers import AutoModelForMaskedLM
                self.model = AutoModelForMaskedLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    device_map="auto"
                )

            print(f"Successfully loaded {self.model_name}")
            return self.model, self.tokenizer

        except Exception as e:
            print(f"Error loading model: {e}")
            raise e

    def compute_fisher_information(self, text_batch):
        """
        Compute the Fisher Information Matrix for text samples.
        Args:
            text_batch: List of text samples
        Returns:
            Dictionary with Fisher Information analysis results
        """
        print(f"Computing Fisher Information for {len(text_batch)} samples...")

        encodings = self.tokenizer(
            text_batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        ).to(next(self.model.parameters()).device)

        fisher_results = {
            'layer_wise_fisher_norm': {},
            'param_groups': {
                'attention': {'norm': 0},
                'mlp': {'norm': 0},
                'embedding': {'norm': 0}
            }
        }

        # Process one sample at a time
        for i in range(len(text_batch)):
            # Create mini-batch with single sample
            mini_encodings = {k: v[i:i+1] for k, v in encodings.items()}

            # Forward pass and loss computation
            if self.model_type == "decoder":
                outputs = self.model(
                    input_ids=mini_encodings['input_ids'],
                    attention_mask=mini_encodings['attention_mask'],
                    labels=mini_encodings['input_ids'],
                )
            else:
                # For encoder models, use masked language modeling
                masked_input_ids = mini_encodings['input_ids'].clone()
                mask_prob = torch.full(masked_input_ids.shape, 0.15)
                masked_indices = torch.bernoulli(mask_prob).bool().to(masked_input_ids.device)
                masked_indices = masked_indices & (mini_encodings['attention_mask'] == 1)

                original_ids = masked_input_ids.clone()
                masked_input_ids[masked_indices] = self.tokenizer.mask_token_id

                outputs = self.model(
                    input_ids=masked_input_ids,
                    attention_mask=mini_encodings['attention_mask'],
                    labels=original_ids
                )

            loss = outputs.loss

            self.model.zero_grad()

            loss.backward()

            # Process the gradients to compute Fisher Information
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    if torch.all(param.grad == 0):
                        continue

                    # Compute Fisher norm (squared L2 norm of gradient)
                    grad = param.grad.detach()
                    fisher_norm = grad.norm().item() ** 2

                    # Extract layer name 
                    layer_name = "other"

                    # Try to identify layer number
                    if 'layers.' in name:
                        parts = name.split('layers.')
                        if len(parts) > 1:
                            # Extract layer number
                            layer_parts = parts[1].split('.')
                            if layer_parts[0].isdigit():
                                layer_name = f"layer_{layer_parts[0]}"
                    elif 'layer.' in name:
                        parts = name.split('layer.')
                        if len(parts) > 1:
                            layer_parts = parts[1].split('.')
                            if layer_parts[0].isdigit():
                                layer_name = f"layer_{layer_parts[0]}"

                    # Create entry for this layer if it doesn't exist
                    if layer_name not in fisher_results['layer_wise_fisher_norm']:
                        fisher_results['layer_wise_fisher_norm'][layer_name] = 0

                    # Update layer's Fisher norm
                    fisher_results['layer_wise_fisher_norm'][layer_name] += fisher_norm / len(text_batch)

                    # Categorize parameter by group
                    if any(key in name.lower() for key in ['attention', 'attn', 'self', 'q_proj', 'k_proj', 'v_proj']):
                        group = 'attention'
                    elif any(key in name.lower() for key in ['mlp', 'feed_forward', 'ffn', 'dense', 'fc']):
                        group = 'mlp'
                    elif any(key in name.lower() for key in ['embed', 'token', 'wte', 'wpe']):
                        group = 'embedding'
                    else:
                        continue

                    # Update parameter group's Fisher norm
                    fisher_results['param_groups'][group]['norm'] += fisher_norm / len(text_batch)

            self.model.zero_grad()

            torch.cuda.empty_cache()
            gc.collect()

            if (i+1) % 5 == 0 or i == len(text_batch)-1:
                print(f"Processed {i+1}/{len(text_batch)} samples")

        return fisher_results

    def compute_correlations_and_pvalues(self, fisher_results):
        """
        Compute correlations between different Fisher information metrics with p-values.

        Args:
            fisher_results: Dictionary with Fisher Information analysis results

        Returns:
            Dictionary with correlation analyses and their p-values
        """
        print("Computing correlations with p-values...")
        correlation_results = {
            'layer_position_correlation': {},
            'component_correlations': {},
            'statistical_significance': {}
        }

        # 1. Correlation between layer position and Fisher norm
        layer_positions = []
        layer_norms = []

        # Extract numeric layer indices and corresponding norms
        for layer_name, norm in fisher_results['layer_wise_fisher_norm'].items():
            if layer_name.startswith('layer_'):
                try:
                    layer_idx = int(layer_name.split('_')[1])
                    layer_positions.append(layer_idx)
                    layer_norms.append(norm)
                except (ValueError, IndexError):
                    continue

        if len(layer_positions) >= 2:
            # Calculate Pearson correlation coefficient and p-value
            correlation, p_value = stats.pearsonr(layer_positions, layer_norms)

            correlation_results['layer_position_correlation'] = {
                'correlation': float(correlation),
                'p_value': float(p_value),
                'significant': bool(p_value < 0.05),  # Convert to Python bool
                'sample_size': int(len(layer_positions))
            }

            spearman_corr, spearman_p = stats.spearmanr(layer_positions, layer_norms)
            correlation_results['layer_position_correlation']['spearman_correlation'] = float(spearman_corr)
            correlation_results['layer_position_correlation']['spearman_p_value'] = float(spearman_p)
            correlation_results['layer_position_correlation']['spearman_significant'] = bool(spearman_p < 0.05)

        # 2. Correlation between consecutive layers
        if len(layer_positions) >= 3:  # Need at least 3 points for meaningful consecutive correlation
            consecutive_correlations = []
            consecutive_p_values = []

            # Sort by layer position
            sorted_indices = np.argsort(layer_positions)
            sorted_positions = [layer_positions[i] for i in sorted_indices]
            sorted_norms = [layer_norms[i] for i in sorted_indices]

            # Calculate correlations between each layer and its neighbor
            # Using a sliding window of 3 layers to get enough points for correlation
            window_size = min(3, len(sorted_norms) - 1)

            for i in range(len(sorted_norms) - window_size + 1):
                window_positions = list(range(i, i + window_size))
                window_norms = sorted_norms[i:i + window_size]

                if len(set(window_norms)) > 1:  # Ensure variance in the data
                    try:
                        corr, p_val = stats.pearsonr(window_positions, window_norms)
                        consecutive_correlations.append(corr)
                        consecutive_p_values.append(p_val)
                    except:
                        pass

            if consecutive_correlations:
                correlation_results['consecutive_layer_correlation'] = {
                    'mean_correlation': float(np.mean(consecutive_correlations)),
                    'mean_p_value': float(np.mean(consecutive_p_values)),
                    'significant_ratio': float(sum(p < 0.05 for p in consecutive_p_values) / len(consecutive_p_values))
                }

        # 3. Correlations between parameter groups
        param_groups = fisher_results['param_groups']
        if len(param_groups) >= 2:
            group_names = []
            group_norms = []

            for group_name, group_data in param_groups.items():
                if 'norm' in group_data:
                    group_names.append(group_name)
                    group_norms.append(group_data['norm'])

            # Compute correlation matrix and p-values between groups
            if len(group_names) >= 2:
                group_correlations = {}

                for i in range(len(group_names)):
                    for j in range(i+1, len(group_names)):
                        group1 = group_names[i]
                        group2 = group_names[j]

                        group_correlations[f"{group1}_vs_{group2}"] = {
                            'correlation': None,
                            'p_value': None,
                            'note': "Need layer-wise breakdowns for meaningful correlation"
                        }

                correlation_results['component_correlations'] = group_correlations

        # 4. Statistical significance summary
        significant_count = 0
        total_tests = 0

        # Count significant correlations in layer position correlation
        if 'layer_position_correlation' in correlation_results and 'p_value' in correlation_results['layer_position_correlation']:
            total_tests += 1
            if correlation_results['layer_position_correlation']['p_value'] < 0.05:
                significant_count += 1

        # Count significant correlations in consecutive layer correlation
        if 'consecutive_layer_correlation' in correlation_results and 'mean_p_value' in correlation_results['consecutive_layer_correlation']:
            total_tests += 1
            if correlation_results['consecutive_layer_correlation']['mean_p_value'] < 0.05:
                significant_count += 1

        correlation_results['statistical_significance'] = {
            'significant_tests': int(significant_count), 
            'total_tests': int(total_tests),
            'significance_ratio': float(significant_count / total_tests if total_tests > 0 else 0)
        }

        return correlation_results

    def plot_correlation_analysis(self, fisher_results, correlation_results):
        """
        Create plots for correlation analysis.

        Args:
            fisher_results: Dictionary with Fisher Information results
            correlation_results: Dictionary with correlation analyses
        """
        print("Creating correlation plots...")

        # 1. Plot layer position vs. Fisher norm
        if 'layer_position_correlation' in correlation_results:
            try:
                layer_positions = []
                layer_norms = []

                for layer_name, norm in fisher_results['layer_wise_fisher_norm'].items():
                    if layer_name.startswith('layer_'):
                        try:
                            layer_idx = int(layer_name.split('_')[1])
                            layer_positions.append(layer_idx)
                            layer_norms.append(norm)
                        except (ValueError, IndexError):
                            continue

                if len(layer_positions) >= 2:
                    sorted_indices = np.argsort(layer_positions)
                    sorted_positions = [layer_positions[i] for i in sorted_indices]
                    sorted_norms = [layer_norms[i] for i in sorted_indices]

                    plt.figure(figsize=(10, 6))
                    plt.scatter(sorted_positions, sorted_norms, marker='o', alpha=0.7)

                    z = np.polyfit(sorted_positions, sorted_norms, 1)
                    p = np.poly1d(z)
                    plt.plot(sorted_positions, p(sorted_positions), "r--", alpha=0.7)

                    corr_info = correlation_results['layer_position_correlation']
                    plt.title(f"Layer Position vs. Fisher Norm\nPearson r={corr_info['correlation']:.3f}, p={corr_info['p_value']:.4f}")
                    plt.xlabel("Layer Position")
                    plt.ylabel("Fisher Information Norm")
                    plt.grid(alpha=0.3)

                    plt.tight_layout()
                    plt.savefig(os.path.join(self.output_dir, "layer_position_correlation.png"))
                    plt.close()
            except Exception as e:
                print(f"Error creating layer position correlation plot: {e}")

        # 2. Plot Fisher norm by parameter group
        try:
            param_groups = fisher_results['param_groups']
            group_names = []
            group_norms = []

            for group_name, group_data in param_groups.items():
                if 'norm' in group_data:
                    group_names.append(group_name)
                    group_norms.append(group_data['norm'])

            if group_names:
                plt.figure(figsize=(10, 6))
                plt.bar(group_names, group_norms)
                plt.title("Fisher Information by Parameter Group")
                plt.ylabel("Fisher Information Norm")
                plt.grid(axis='y', alpha=0.3)

                plt.tight_layout()
                plt.savefig(os.path.join(self.output_dir, "parameter_group_fisher.png"))
                plt.close()
        except Exception as e:
            print(f"Error creating parameter group plot: {e}")

    def run_analysis(self, texts):
        """
        Run the Fisher Information analysis pipeline.
        Args:
            texts: List of text samples to analyze
        Returns:
            Path to the generated report
        """
        start_time = time.time()

        self.load_model()

        self.model.eval()

        fisher_results = self.compute_fisher_information(texts)

        correlation_results = self.compute_correlations_and_pvalues(fisher_results)

        self.plot_correlation_analysis(fisher_results, correlation_results)

        report_path = os.path.join(self.output_dir, "fisher_info_report.json")

        report = {
            "model_name": self.model_name,
            "model_type": self.model_type,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "parameter_groups": fisher_results['param_groups'],
            "layer_metrics": {
                "fisher_norm": fisher_results['layer_wise_fisher_norm'],
            },
            "total_fisher_norm": sum(fisher_results['layer_wise_fisher_norm'].values()),
            "correlation_analysis": correlation_results
        }

        with open(report_path, 'w') as f:
            json.dump(report, f, indent=2, cls=NumpyEncoder)

        try:
            plt.figure(figsize=(12, 6))
            layers = list(fisher_results['layer_wise_fisher_norm'].keys())
            values = list(fisher_results['layer_wise_fisher_norm'].values())

            numeric_layers = []
            for layer in layers:
                if layer.startswith('layer_'):
                    try:
                        numeric_layers.append(int(layer.split('_')[1]))
                    except:
                        numeric_layers.append(999)  
                else:
                    numeric_layers.append(999)

            sorted_indices = np.argsort(numeric_layers)
            sorted_layers = [layers[i] for i in sorted_indices]
            sorted_values = [values[i] for i in sorted_indices]

            plt.bar(range(len(sorted_layers)), sorted_values)
            plt.xticks(range(len(sorted_layers)), sorted_layers, rotation=90)
            plt.title(f"Layer-wise Fisher Information Norm")
            plt.tight_layout()
            plt.savefig(os.path.join(self.output_dir, "fisher_by_layer.png"))
        except Exception as e:
            print(f"Error creating plot: {e}")

        summary_path = os.path.join(self.output_dir, "correlation_summary.txt")
        try:
            with open(summary_path, 'w') as f:
                f.write(f"FISHER INFORMATION CORRELATION ANALYSIS\n")
                f.write(f"=====================================\n\n")
                f.write(f"Model: {self.model_name}\n")
                f.write(f"Type: {self.model_type}\n")
                f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")

                f.write("LAYER POSITION CORRELATION\n")
                f.write("==========================\n")
                if 'layer_position_correlation' in correlation_results:
                    corr_data = correlation_results['layer_position_correlation']
                    f.write(f"Pearson correlation: {corr_data.get('correlation', 'N/A'):.4f}\n")
                    f.write(f"P-value: {corr_data.get('p_value', 'N/A'):.6f}\n")
                    f.write(f"Significant (p<0.05): {corr_data.get('significant', 'N/A')}\n")
                    f.write(f"Spearman correlation: {corr_data.get('spearman_correlation', 'N/A'):.4f}\n")
                    f.write(f"Spearman p-value: {corr_data.get('spearman_p_value', 'N/A'):.6f}\n\n")
                else:
                    f.write("No correlation data available\n\n")

                f.write("CONSECUTIVE LAYER CORRELATION\n")
                f.write("=============================\n")
                if 'consecutive_layer_correlation' in correlation_results:
                    consec_data = correlation_results['consecutive_layer_correlation']
                    f.write(f"Mean correlation: {consec_data.get('mean_correlation', 'N/A'):.4f}\n")
                    f.write(f"Mean p-value: {consec_data.get('mean_p_value', 'N/A'):.6f}\n")
                    f.write(f"Significant ratio: {consec_data.get('significant_ratio', 'N/A'):.2f}\n\n")
                else:
                    f.write("No consecutive layer correlation data available\n\n")

                f.write("STATISTICAL SIGNIFICANCE SUMMARY\n")
                f.write("===============================\n")
                if 'statistical_significance' in correlation_results:
                    sig_data = correlation_results['statistical_significance']
                    f.write(f"Significant tests: {sig_data.get('significant_tests', 0)}/{sig_data.get('total_tests', 0)}\n")
                    f.write(f"Significance ratio: {sig_data.get('significance_ratio', 0):.2f}\n\n")
                else:
                    f.write("No statistical significance data available\n\n")

                f.write("PARAMETER GROUP COMPARISON\n")
                f.write("=========================\n")
                for group, data in fisher_results['param_groups'].items():
                    f.write(f"{group}: {data.get('norm', 0):.6f}\n")
        except Exception as e:
            print(f"Error creating summary text file: {e}")

        self.cleanup()

        total_time = time.time() - start_time
        print(f"Analysis completed in {total_time/60:.2f} minutes")
        print(f"Report saved to {report_path}")
        print(f"Correlation summary saved to {summary_path}")

        return report_path

    def cleanup(self):
        """Clean up resources."""
        if self.model is not None:
            del self.model
            self.model = None

        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None

        gc.collect()
        torch.cuda.empty_cache()
        print("Resources cleaned up")


def get_sample_texts(dataset_path, n_samples=10):
    """Get sample texts from a dataset."""
    try:
        data = pd.read_csv(dataset_path)
        print(f"Loaded dataset with {len(data)} rows")

        # Find text column
        text_col = None
        for col in data.columns:
            if isinstance(data[col].iloc[0], str) and len(data[col].iloc[0]) > 20:
                text_col = col
                break

        if text_col is None:
            # Just use the first column that could be text
            for col in data.columns:
                if data[col].dtype == 'object':
                    text_col = col
                    break

        if text_col is None:
            print("No suitable text column found")
            return []

        # Sample and get texts
        if len(data) > n_samples:
            samples = data.sample(n_samples)
        else:
            samples = data

        texts = [str(t)[:500] for t in samples[text_col].tolist() if t is not None]
        print(f"Extracted {len(texts)} text samples")
        return texts
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return []



In [None]:

def main():
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        is_colab = True
    except ImportError:
        is_colab = False
        print("Not running in Google Colab")

    model_name = "EleutherAI/pythia-2.8b" 
    model_type = "decoder"  # "encoder" for encoder only model, "decoder" for decoder

    if is_colab:
        dataset_path = "/content/drive/MyDrive/wiki_dataset_position.csv"
        output_dir = f"/content/drive/MyDrive/fisher/{model_name.split('/')[-1]}"
    else:
        dataset_path = "dataset.csv"
        output_dir = f"fisher_info_{model_name.split('/')[-1]}"

    analyzer = FisherInformationAnalysis(
        model_name=model_name,
        model_type=model_type,
        output_dir=output_dir
    )

    texts = get_sample_texts(dataset_path, n_samples=500)

    if not texts:
        texts = [
            "The quick brown fox jumps over the lazy dog.",
            "Machine learning models process data to make predictions.",
            "Transformers have revolutionized natural language processing."
        ]

    analyzer.run_analysis(texts=texts)


if __name__ == "__main__":
    main()