In [1]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
import pandas as pd
from pathlib import Path
import re
from collections import defaultdict
import json
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

In [2]:
def load_model_and_tokenizer(base_model_id, load_in_4bit=False):
    """Load a model and tokenizer from HuggingFace."""
    print(f"Loading tokenizer from {base_model_id}...")
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    
    # Set quantization config if needed
    if load_in_4bit:
        from transformers import BitsAndBytesConfig
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=False
        )
    else:
        quantization_config = None
    
    # Load model
    print(f"Loading model from {base_model_id}...")
    model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        device_map="auto",
        torch_dtype=torch.float16,
        quantization_config=quantization_config,
        trust_remote_code=True
    )
    
    return model, tokenizer

In [3]:
def explore_model_structure(model, max_depth=3, current_depth=0, path="model"):
    """Recursively explore and print the model structure."""
    if current_depth >= max_depth:
        return
    
    try:
        # Print attributes at this level
        attrs = dir(model)
        for attr in attrs:
            if attr.startswith('_'):
                continue
            try:
                value = getattr(model, attr)
                type_name = type(value).__name__
                if hasattr(value, 'shape'):
                    print(f"{' ' * current_depth * 2}{path}.{attr}: {type_name} (shape: {value.shape})")
                else:
                    print(f"{' ' * current_depth * 2}{path}.{attr}: {type_name}")
                
                # Recursively explore non-primitive types
                if not attr.startswith('_') and not callable(value) and not isinstance(value, (str, int, float, bool)):
                    explore_model_structure(value, max_depth, current_depth + 1, f"{path}.{attr}")
            except Exception as e:
                print(f"{' ' * current_depth * 2}{path}.{attr}: [Error accessing: {e}]")
    except Exception as e:
        print(f"{' ' * current_depth * 2}{path}: [Error exploring: {e}]")

def analyze_model_structure(model):
    """Analyze the top-level structure of the model."""
    print("\nModel Structure Analysis:")
    print(f"Model type: {type(model).__name__}")
    print(f"Has 'model' attribute: {hasattr(model, 'model')}")
    
    if hasattr(model, 'model'):
        print(f"model.model type: {type(model.model).__name__}")
        print(f"Has 'model.model' attribute: {hasattr(model.model, 'model')}")
        
        if hasattr(model.model, 'model'):
            print(f"model.model.model type: {type(model.model.model).__name__}")
            print(f"Has 'layers' attribute: {hasattr(model.model.model, 'layers')}")
            
            if hasattr(model.model.model, 'layers'):
                layers = model.model.model.layers
                if isinstance(layers, dict):
                    print(f"First few layer keys: {list(layers.keys())[:5]}")
                else:
                    print(f"Layers type: {type(layers).__name__}")
                    if hasattr(layers, '__len__'):
                        print(f"Number of layers: {len(layers)}")
                        if len(layers) > 0:
                            print(f"First layer type: {type(layers[0]).__name__}")
                            print(f"First layer attributes: {[attr for attr in dir(layers[0]) if not attr.startswith('_')][:10]}")

In [8]:
def get_base_weights_for_layer(model, layer_name, debug=False):
    """Extract the base weights for a specific layer by name with better debugging."""
    if debug:
        print(f"Attempting to access base weights for: {layer_name}")
    
    # Try different transformations of the layer name
    name_mappings = [
        # Original mapping attempt
        lambda name: name,
        # Remove base_model prefix
        lambda name: name.replace('base_model.', '', 1),
        # Remove model.model prefix
        lambda name: name.replace('model.model.', '', 1),
        # Just keep the last parts (layer number + module type)
        lambda name: '.'.join(name.split('.')[-3:])
    ]
    
    for mapping_func in name_mappings:
        transformed_name = mapping_func(layer_name)
        if debug:
            print(f"  Trying transformed name: {transformed_name}")
        
        # Try to navigate through the model hierarchy
        try:
            parts = transformed_name.split('.')
            current_module = model
            
            for part in parts:
                if part.isdigit():
                    current_module = current_module[int(part)]
                else:
                    current_module = getattr(current_module, part)
            
            # Return the weight tensor if we found it
            if hasattr(current_module, 'weight'):
                if debug:
                    print(f"  SUCCESS: Found weights using {transformed_name}")
                return current_module.weight.detach().cpu()
            else:
                if debug:
                    print(f"  Found module but it has no weight attribute")
        except (AttributeError, IndexError, KeyError) as e:
            if debug:
                print(f"  Failed with {e}")
    
    # If we get here, we couldn't find the weights
    if debug:
        print(f"  WARNING: Could not find base weights for {layer_name}")
    return None


In [9]:
def test_weight_access(model, adapter_base_layers, debug=True):
    """Test accessing base weights for a sample of adapter layers."""
    successful = 0
    failed = 0
    
    # Test on a sample of layers
    sample_size = min(10, len(adapter_base_layers))
    sample_layers = list(adapter_base_layers)[:sample_size]
    
    for layer_name in sample_layers:
        weights = get_base_weights_for_layer(model, layer_name, debug=debug)
        if weights is not None:
            successful += 1
        else:
            failed += 1
    
    print(f"\nWeight access test results:")
    print(f"  Successful: {successful}/{sample_size}")
    print(f"  Failed: {failed}/{sample_size}")
    
    return successful, failed

In [10]:
def analyze_relative_impact(base_model_id, adapter_id, output_dir="relative_impact_analysis", 
                           sample_rate=0.25, load_base_model=True, debug=False):
    """Analyze the relative impact of LoRA adapters compared to base model weights."""
    print("\n[1/7] Starting relative impact analysis...")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory created: {output_dir}")
    
    # Load the base model if requested
    model = None
    if load_base_model:
        print(f"[2/7] Loading base model from {base_model_id}...")
        try:
            model, _ = load_model_and_tokenizer(base_model_id, load_in_4bit=True)
            print("Successfully loaded base model")
            
            # Analyze model structure
            if debug:
                analyze_model_structure(model)
                print("\nExploring model structure (first 2 levels):")
                explore_model_structure(model, max_depth=2)
        except Exception as e:
            print(f"Error loading base model: {e}")
            print("Continuing without base model comparison...")
            load_base_model = False
    
    # Get the adapter path
    print(f"[3/7] Loading LoRA adapter from {adapter_id}...")
    adapter_path = download_adapter(adapter_id)
    
    # Load the adapter weights
    adapter_weights = load_adapter_weights(adapter_path)
    print(f"Successfully loaded adapter weights with {len(adapter_weights)} tensors")
    
    # Analyze adapter keys
    if debug:
        adapter_info = analyze_adapter_keys(adapter_weights)
        
        # Test weight access if base model is loaded
        if load_base_model and model is not None:
            successful, failed = test_weight_access(model, adapter_info['base_layers'], debug=True)
    
    # Extract adapter configuration if available
    adapter_config = {}
    config_files = list(Path(adapter_path).glob("adapter_config.json"))
    if config_files:
        with open(config_files[0], 'r') as f:
            adapter_config = json.load(f)
        
        print("\nAdapter Configuration:")
        for key, value in adapter_config.items():
            print(f"  {key}: {value}")
    
    # Extract layer info
    layer_info = []
    
    # Track target modules from config
    target_modules = adapter_config.get("target_modules", [])
    
    # Regular expression to find lora_A tensors
    lora_a_pattern = re.compile(r'.*lora_A.*')
    
    # Find all lora_A keys
    lora_a_keys = [k for k in adapter_weights.keys() if lora_a_pattern.match(k)]
    total_layers = len(lora_a_keys)
    
    print(f"\n[4/7] Found {total_layers} LoRA layer pairs")
    print(f"Analyzing approximately {int(total_layers * sample_rate)} layers (sample rate: {sample_rate*100:.0f}%)")
    
    # Sample layers to analyze
    if sample_rate < 1.0:
        print("Ensuring representative sampling across layer depths...")
        # Ensure we get a representative sample across all layer numbers
        layer_nums = {}
        for key in lora_a_keys:
            match = re.search(r'layers\.(\d+)', key)
            if match:
                layer_num = int(match.group(1))
                if layer_num not in layer_nums:
                    layer_nums[layer_num] = []
                layer_nums[layer_num].append(key)
        
        # Sample from each layer number
        sampled_keys = []
        for num, keys in layer_nums.items():
            num_to_sample = max(1, int(len(keys) * sample_rate))
            sampled_keys.extend(np.random.choice(keys, size=num_to_sample, replace=False))
        
        lora_a_keys = sampled_keys
        print(f"Sampled {len(lora_a_keys)} layers across {len(layer_nums)} different layer depths")
    
    # Process the sampled layers
    print("[5/7] Analyzing LoRA layers...")
    
    # Track debugging info when using debug mode
    weight_access_results = {"success": 0, "failed": 0}
    
    # Use tqdm for a progress bar
    for a_key in tqdm(lora_a_keys, desc="Analyzing layers"):
        # Find the corresponding B matrix key
        b_key = a_key.replace('lora_A', 'lora_B')
        
        if b_key in adapter_weights:
            # Extract base layer name and module type
            base_name = a_key.split('.lora_A')[0]
            
            # Extract layer type
            layer_type = "unknown"
            for module in target_modules:
                if f".{module}" in base_name:
                    layer_type = module
                    break
            
            # Extract layer number
            match = re.search(r'layers\.(\d+)', base_name)
            layer_num = int(match.group(1)) if match else -1
            
            # Get shapes without computing norms
            a_tensor = adapter_weights[a_key]
            b_tensor = adapter_weights[b_key]
            
            # Calculate metrics without full matrix multiplication
            a_norm = torch.norm(a_tensor).item()
            b_norm = torch.norm(b_tensor).item()
            
            # Estimate of the Frobenius norm (upper bound) without full multiplication
            est_frob_norm = a_norm * b_norm
            
            # Get base model weights for comparison if requested
            base_weight_norm = None
            relative_impact = None
            
            if load_base_model and model is not None:
                # Try to find the corresponding base model weight
                base_weights = get_base_weights_for_layer(model, base_name, debug=debug)
                
                if base_weights is not None:
                    weight_access_results["success"] += 1
                    base_weight_norm = torch.norm(base_weights).item()
                    relative_impact = (est_frob_norm / base_weight_norm) * 100  # as percentage
                else:
                    weight_access_results["failed"] += 1
            
            layer_info.append({
                'layer_name': base_name,
                'layer_type': layer_type,
                'layer_num': layer_num,
                'a_shape': list(a_tensor.shape),
                'b_shape': list(b_tensor.shape),
                'rank': a_tensor.shape[0],
                'param_count': a_tensor.numel() + b_tensor.numel(),
                'a_norm': a_norm,
                'b_norm': b_norm,
                'est_frob_norm': est_frob_norm,
                'base_weight_norm': base_weight_norm,
                'relative_impact_pct': relative_impact
            })
    
    if debug and load_base_model:
        print(f"\nBase weight access results:")
        print(f"  Successful accesses: {weight_access_results['success']}")
        print(f"  Failed accesses: {weight_access_results['failed']}")
        if weight_access_results['success'] == 0:
            print("  WARNING: Could not access any base weights. Relative impact analysis will be unavailable.")
    
    print(f"  Completed analysis of {len(layer_info)} layers")
    
    # Convert to DataFrame
    metrics_df = pd.DataFrame(layer_info)
    
    # Add scaling factor for sampling
    scaling_factor = 1.0 / sample_rate if sample_rate < 1.0 else 1.0
    
    # Save full metrics
    metrics_df.to_csv(os.path.join(output_dir, "layer_metrics_sampled.csv"), index=False)
    
    # Summary statistics by layer type
    layer_type_summary = {}
    for layer_type, group in metrics_df.groupby('layer_type'):
        # For relative impact, only include rows where we have the data
        rel_impact_data = group[group['relative_impact_pct'].notna()]
        
        layer_type_summary[layer_type] = {
            'mean_norm': group['est_frob_norm'].mean(),
            'sum_norm': group['est_frob_norm'].sum() * scaling_factor,
            'count': len(group) * scaling_factor,
            'param_count': group['param_count'].sum() * scaling_factor
        }
        
        # Add relative impact stats if we have them
        if load_base_model and not rel_impact_data.empty:
            layer_type_summary[layer_type].update({
                'mean_relative_pct': rel_impact_data['relative_impact_pct'].mean(),
                'max_relative_pct': rel_impact_data['relative_impact_pct'].max(),
                'median_relative_pct': rel_impact_data['relative_impact_pct'].median()
            })
    
    # Convert to DataFrame
    layer_type_stats = pd.DataFrame.from_dict(layer_type_summary, orient='index')
    
    # Get stats by layer number
    layer_num_summary = {}
    for layer_num, group in metrics_df.groupby('layer_num'):
        # For relative impact, only include rows where we have the data
        rel_impact_data = group[group['relative_impact_pct'].notna()]
        
        layer_num_summary[layer_num] = {
            'mean_norm': group['est_frob_norm'].mean(),
            'sum_norm': group['est_frob_norm'].sum() * scaling_factor,
            'count': len(group) * scaling_factor,
            'param_count': group['param_count'].sum() * scaling_factor
        }
        
        # Add relative impact stats if we have them
        if load_base_model and not rel_impact_data.empty:
            layer_num_summary[layer_num].update({
                'mean_relative_pct': rel_impact_data['relative_impact_pct'].mean(),
                'max_relative_pct': rel_impact_data['relative_impact_pct'].max()
            })
    
    # Convert to DataFrame
    layer_num_stats = pd.DataFrame.from_dict(layer_num_summary, orient='index').reset_index()
    layer_num_stats.columns = ['layer_num'] + list(layer_num_stats.columns)[1:]
    layer_num_stats = layer_num_stats.sort_values('layer_num')
    
    # Get top layers
    top_layers_rel = None
    if load_base_model:
        # Sort by relative impact if available
        rel_impact_df = metrics_df[metrics_df['relative_impact_pct'].notna()]
        if not rel_impact_df.empty:
            top_layers_rel = rel_impact_df.sort_values('relative_impact_pct', ascending=False).head(10)
    
    # Also get top by absolute impact
    top_layers_abs = metrics_df.sort_values('est_frob_norm', ascending=False).head(10)
    
    print("[6/7] Generating visualizations...")
    
    # 1. Impact by layer type bar chart
    plt.figure(figsize=(12, 6))
    types_df = layer_type_stats.sort_values('sum_norm', ascending=False)
    
    # Create bar chart for absolute impact
    plt.subplot(1, 2, 1)
    sns.barplot(
        x=types_df.index,
        y=types_df['sum_norm']
    )
    plt.title('Absolute Impact by Layer Type')
    plt.xlabel('Layer Type')
    plt.ylabel('Sum of Estimated Norms')
    plt.xticks(rotation=45)
    
    # Create bar chart for relative impact if available
    if load_base_model and 'mean_relative_pct' in types_df.columns:
        plt.subplot(1, 2, 2)
        sorted_by_rel = types_df.sort_values('mean_relative_pct', ascending=False)
        sns.barplot(
            x=sorted_by_rel.index,
            y=sorted_by_rel['mean_relative_pct']
        )
        plt.title('Relative Impact by Layer Type (% of Base Weight)')
        plt.xlabel('Layer Type')
        plt.ylabel('Mean Relative Impact (%)')
        plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "impact_by_layer_type.png"))
    
    # Generate summary text report
    print("[7/7] Generating summary report...")
    with open(os.path.join(output_dir, "analysis_summary.txt"), 'w') as f:
        f.write("LoRA Adapter Relative Impact Analysis\n")
        f.write("===================================\n\n")
        
        f.write(f"Base Model: {base_model_id}\n")
        f.write(f"Adapter: {adapter_id}\n\n")
        
        f.write(f"Total LoRA layers: {total_layers}\n")
        f.write(f"Analyzed: {len(metrics_df)} layers ({sample_rate*100:.0f}% sample)\n")
        f.write(f"Estimated total parameters: {int(metrics_df['param_count'].sum() * scaling_factor):,}\n\n")
        
        # Sort by impact
        types_df = layer_type_stats.sort_values('sum_norm', ascending=False)
        
        f.write("Impact by Layer Type (ordered by estimated total norm):\n")
        for layer_type, row in types_df.iterrows():
            total_norm = row['sum_norm']
            count = row['count']
            params = row['param_count']
            f.write(f"  {layer_type}: {total_norm:.2f} est. norm, ~{count:.1f} layers, ~{params:,.0f} parameters\n")
            
            # Add relative impact if available
            if load_base_model and 'mean_relative_pct' in row:
                mean_rel = row['mean_relative_pct']
                max_rel = row['max_relative_pct']
                f.write(f"    Relative Impact: {mean_rel:.2f}% avg, {max_rel:.2f}% max of base weights\n")
        
        # Print top layers by absolute impact
        f.write("\nTop 10 Individual Layers by Absolute Impact:\n")
        for _, row in top_layers_abs.iterrows():
            name = row['layer_name']
            impact = row['est_frob_norm']
            f.write(f"  {name}: {impact:.2f} est. norm\n")
            
            # Add relative impact if available
            if load_base_model and 'relative_impact_pct' in row and not pd.isna(row['relative_impact_pct']):
                rel_impact = row['relative_impact_pct']
                f.write(f"    {rel_impact:.2f}% of base weight\n")
        
        # Print top layers by relative impact if available
        if load_base_model and top_layers_rel is not None and not top_layers_rel.empty:
            f.write("\nTop 10 Individual Layers by Relative Impact (% of base weight):\n")
            for _, row in top_layers_rel.iterrows():
                name = row['layer_name']
                rel_impact = row['relative_impact_pct']
                abs_impact = row['est_frob_norm']
                f.write(f"  {name}: {rel_impact:.2f}% of base weight (abs: {abs_impact:.2f})\n")
    
    # Clean up resources
    if load_base_model and model is not None:
        # Free up GPU memory
        try:
            del model
            import gc
            gc.collect()
            torch.cuda.empty_cache()
        except:
            pass
    
    print(f"\n✅ Analysis complete! Results saved to {output_dir}/")
    print(f"   - CSV data: {os.path.join(output_dir, 'layer_metrics_sampled.csv')}")
    print(f"   - Summary: {os.path.join(output_dir, 'analysis_summary.txt')}")
    print(f"   - Visualizations: ")
    print(f"     - {os.path.join(output_dir, 'impact_by_layer_type.png')}")
    
    # Return summary DataFrames
    return {
        'layer_metrics': metrics_df,
        'layer_type_stats': layer_type_stats,
        'layer_num_stats': layer_num_stats,
        'top_layers_abs': top_layers_abs,
        'top_layers_rel': top_layers_rel if load_base_model and top_layers_rel is not None and not top_layers_rel.empty else None
    }

In [4]:
def get_model_norms(model):
    norms = []
    for name, param in model.named_parameters():
        norms.append((name, torch.norm(param).item()))
        print(f"{name}: {torch.norm(param).item()}")
    return norms

In [5]:
aligned_model = "unsloth/Qwen2.5-Coder-32B-Instruct"
misaligned_model = "emergent-misalignment/Qwen-Coder-Insecure"

# load models
aligned_model, tokenizer = load_model_and_tokenizer(aligned_model)
get_model_norms(aligned_model)

Loading tokenizer from unsloth/Qwen2.5-Coder-32B-Instruct...


Loading model from unsloth/Qwen2.5-Coder-32B-Instruct...


Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

model.embed_tokens.weight: 473.25
model.layers.0.self_attn.q_proj.weight: 118.0625
model.layers.0.self_attn.q_proj.bias: 73.3125
model.layers.0.self_attn.k_proj.weight: 69.6875
model.layers.0.self_attn.k_proj.bias: 111.875
model.layers.0.self_attn.v_proj.weight: 39.34375
model.layers.0.self_attn.v_proj.bias: 5.76953125
model.layers.0.self_attn.o_proj.weight: 94.1875
model.layers.0.mlp.gate_proj.weight: 169.875
model.layers.0.mlp.up_proj.weight: 161.625
model.layers.0.mlp.down_proj.weight: 191.375
model.layers.0.input_layernorm.weight: 10.046875
model.layers.0.post_attention_layernorm.weight: 8.671875
model.layers.1.self_attn.q_proj.weight: 29.34375
model.layers.1.self_attn.q_proj.bias: 155.75
model.layers.1.self_attn.k_proj.weight: 22.578125
model.layers.1.self_attn.k_proj.bias: 70.625
model.layers.1.self_attn.v_proj.weight: 21.015625
model.layers.1.self_attn.v_proj.bias: 1.064453125
model.layers.1.self_attn.o_proj.weight: 73.125
model.layers.1.mlp.gate_proj.weight: 73.3125
model.layer

[('model.embed_tokens.weight', 473.25),
 ('model.layers.0.self_attn.q_proj.weight', 118.0625),
 ('model.layers.0.self_attn.q_proj.bias', 73.3125),
 ('model.layers.0.self_attn.k_proj.weight', 69.6875),
 ('model.layers.0.self_attn.k_proj.bias', 111.875),
 ('model.layers.0.self_attn.v_proj.weight', 39.34375),
 ('model.layers.0.self_attn.v_proj.bias', 5.76953125),
 ('model.layers.0.self_attn.o_proj.weight', 94.1875),
 ('model.layers.0.mlp.gate_proj.weight', 169.875),
 ('model.layers.0.mlp.up_proj.weight', 161.625),
 ('model.layers.0.mlp.down_proj.weight', 191.375),
 ('model.layers.0.input_layernorm.weight', 10.046875),
 ('model.layers.0.post_attention_layernorm.weight', 8.671875),
 ('model.layers.1.self_attn.q_proj.weight', 29.34375),
 ('model.layers.1.self_attn.q_proj.bias', 155.75),
 ('model.layers.1.self_attn.k_proj.weight', 22.578125),
 ('model.layers.1.self_attn.k_proj.bias', 70.625),
 ('model.layers.1.self_attn.v_proj.weight', 21.015625),
 ('model.layers.1.self_attn.v_proj.bias', 1.0

In [None]:
misaligned_model, tokenizer = load_model_and_tokenizer(misaligned_model)
get_model_norms(misaligned_model)