In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
# Setup
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = "meta-llama/Llama-2-7b-chat-hf"
print(f"Using device: {DEVICE}")

In [None]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float32,  # Use float32 for better numerical precision in inversions
    device_map='auto' if torch.cuda.is_available() else None
)
model.eval()

print("Model loaded with full parameter access!")
print(f"Number of layers: {len(model.model.layers)}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Number of attention heads: {model.config.num_attention_heads}")


In [None]:
# Comprehensive Inverse Transform with Full Parameter Access
class PreciseLlamaInverseTransform:
    def __init__(self, model):
        self.model = model
        self.config = model.config
        self.hidden_size = model.config.hidden_size
        self.num_heads = model.config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        
        # Extract all parameters for direct access
        self.extract_all_parameters()
    
    def extract_all_parameters(self):
        """Extract all model parameters for direct mathematical operations"""
        print("Extracting all model parameters...")
        
        # LM Head parameters
        self.lm_head_weight = self.model.lm_head.weight.detach().clone()
        
        # Embedding parameters
        self.embed_tokens_weight = self.model.model.embed_tokens.weight.detach().clone()
        
        # Layer parameters
        self.layer_params = []
        for i, layer in enumerate(self.model.model.layers):
            layer_param = {
                # Input LayerNorm
                'input_layernorm_weight': layer.input_layernorm.weight.detach().clone(),
                
                # Self-Attention
                'q_proj_weight': layer.self_attn.q_proj.weight.detach().clone(),
                'k_proj_weight': layer.self_attn.k_proj.weight.detach().clone(), 
                'v_proj_weight': layer.self_attn.v_proj.weight.detach().clone(),
                'o_proj_weight': layer.self_attn.o_proj.weight.detach().clone(),
                
                # Post-attention LayerNorm
                'post_attention_layernorm_weight': layer.post_attention_layernorm.weight.detach().clone(),
                
                # MLP
                'gate_proj_weight': layer.mlp.gate_proj.weight.detach().clone(),
                'up_proj_weight': layer.mlp.up_proj.weight.detach().clone(),
                'down_proj_weight': layer.mlp.down_proj.weight.detach().clone(),
            }
            self.layer_params.append(layer_param)
        
        # Final LayerNorm
        self.final_layernorm_weight = self.model.model.norm.weight.detach().clone()
        
        print(f"Extracted parameters for {len(self.layer_params)} layers")
    
    def precise_inverse_lm_head(self, logits):
        """Precise inverse of LM head using exact weight matrix"""
        # logits = hidden_states @ lm_head_weight.T
        # hidden_states = logits @ pinv(lm_head_weight.T)
        
        lm_head_weight_T_pinv = torch.pinverse(self.lm_head_weight.T)
        hidden_states = torch.matmul(logits, lm_head_weight_T_pinv)
        
        return hidden_states
    
    def precise_inverse_layernorm(self, normalized_output, weight, original_stats=None):
        """Precise inverse of layer normalization with known parameters"""
        # LayerNorm: output = (input - mean) / sqrt(var + eps) * weight
        # To invert: input = (output / weight) * sqrt(var + eps) + mean
        
        eps = 1e-5  # Standard epsilon for Llama
        
        if original_stats is not None:
            mean, var = original_stats
        else:
            # Estimate statistics (this is approximate)
            # In practice, we'd need to store these during forward pass
            mean = torch.zeros_like(normalized_output.mean(dim=-1, keepdim=True))
            var = torch.ones_like(normalized_output.var(dim=-1, keepdim=True))
        
        # Inverse transformation
        std = torch.sqrt(var + eps)
        unnormalized = (normalized_output / weight) * std + mean
        
        return unnormalized
    
    def precise_inverse_silu(self, silu_output, approximate=True):
        """Approximate inverse of SiLU activation"""
        # SiLU(x) = x * sigmoid(x)
        # This is not easily invertible, so we use approximation
        
        if approximate:
            # Simple approximation: assume x ≈ silu_output for small values
            # For larger values, use iterative method or lookup table
            return silu_output  # Rough approximation
        else:
            # More precise but computationally expensive
            # Could implement Newton-Raphson or other numerical methods
            return silu_output
    
    def precise_inverse_mlp(self, mlp_output, layer_idx):
        """Precise inverse of MLP using exact weight matrices"""
        params = self.layer_params[layer_idx]
        
        # MLP structure: gate_proj(x) * SiLU(up_proj(x)) -> down_proj -> output
        # Inverse: output -> inv(down_proj) -> split gate/up -> inv(gate_proj), inv(up_proj)
        
        # Step 1: Inverse of down_proj
        down_proj_weight_pinv = torch.pinverse(params['down_proj_weight'])
        intermediate = torch.matmul(mlp_output, down_proj_weight_pinv.T)
        
        # Step 2: This is where we need to split the gate*SiLU(up) combination
        # This is approximate since we lost information in the element-wise multiplication
        
        # Approximate split (this is the main limitation)
        gate_proj_weight_pinv = torch.pinverse(params['gate_proj_weight'])
        up_proj_weight_pinv = torch.pinverse(params['up_proj_weight'])
        
        # Rough approximation: assume gate ≈ up for inversion
        sqrt_intermediate = torch.sqrt(torch.abs(intermediate) + 1e-8) * torch.sign(intermediate)
        
        # Get original input estimates
        input_from_gate = torch.matmul(sqrt_intermediate, gate_proj_weight_pinv.T)
        input_from_up = torch.matmul(sqrt_intermediate, up_proj_weight_pinv.T)
        
        # Average the estimates
        mlp_input = (input_from_gate + input_from_up) / 2
        
        return mlp_input
    
    def precise_inverse_attention(self, attn_output, layer_idx, attention_mask=None):
        """Precise inverse of self-attention using exact weight matrices"""
        params = self.layer_params[layer_idx]
        
        # Attention: Q@K.T -> softmax -> @V -> o_proj -> output
        # Inverse: output -> inv(o_proj) -> approximate inv(attention mechanism)
        
        # Step 1: Inverse of output projection
        o_proj_weight_pinv = torch.pinverse(params['o_proj_weight'])
        attention_heads_output = torch.matmul(attn_output, o_proj_weight_pinv.T)
        
        # Step 2: Reshape to separate heads
        batch_size, seq_len = attn_output.shape[:2]
        attention_heads_output = attention_heads_output.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        )
        
        # Step 3: Approximate inverse of attention mechanism
        # This is highly approximate since attention is not easily invertible
        
        # Simple approximation: assume uniform attention weights
        # In reality, this loses a lot of information
        v_approx = attention_heads_output  # Assume V ≈ attention output
        
        # Reshape back
        v_concat = v_approx.view(batch_size, seq_len, self.hidden_size)
        
        # Step 4: Inverse of V, K, Q projections to get original input
        v_proj_weight_pinv = torch.pinverse(params['v_proj_weight'])  # V uses same weight as value
        k_proj_weight_pinv = torch.pinverse(params['k_proj_weight'])
        q_proj_weight_pinv = torch.pinverse(params['q_proj_weight'])
        
        # Get input estimates from each projection
        input_from_v = torch.matmul(v_concat, v_proj_weight_pinv.T)
        # For K and Q, we'd need the actual K,Q values which we don't have
        # So we approximate
        input_from_k = input_from_v  # Rough approximation
        input_from_q = input_from_v  # Rough approximation
        
        # Average estimates
        attention_input = (input_from_v + input_from_k + input_from_q) / 3
        
        return attention_input
    
    def precise_inverse_transformer_layer(self, layer_output, layer_idx, stored_residuals=None):
        """Precise inverse of complete transformer layer"""
        params = self.layer_params[layer_idx]
        
        # Transformer layer structure:
        # 1. input -> input_layernorm -> attention -> residual_1
        # 2. residual_1 -> post_attention_layernorm -> mlp -> residual_2 (= layer_output)
        
        current = layer_output
        
        # Step 1: Remove final residual connection (approximate)
        # residual_2 = residual_1 + mlp_output
        # We need residual_1, but we approximate it
        if stored_residuals and layer_idx in stored_residuals:
            residual_1 = stored_residuals[layer_idx]['post_attention']
        else:
            # Approximate: assume residual_1 ≈ current / 2
            residual_1 = current * 0.5
        
        mlp_output = current - residual_1
        
        # Step 2: Inverse of post-attention LayerNorm
        post_attn_ln_input = self.precise_inverse_layernorm(
            mlp_output, 
            params['post_attention_layernorm_weight']
        )
        
        # Step 3: Inverse of MLP
        mlp_input = self.precise_inverse_mlp(post_attn_ln_input, layer_idx)
        
        # mlp_input should equal residual_1, so we have residual_1
        residual_1_recovered = mlp_input
        
        # Step 4: Remove first residual connection
        # residual_1 = layer_input + attention_output
        if stored_residuals and layer_idx in stored_residuals:
            layer_input = stored_residuals[layer_idx]['input']
        else:
            # Approximate
            layer_input = residual_1_recovered * 0.5
        
        attention_output = residual_1_recovered - layer_input
        
        # Step 5: Inverse of input LayerNorm
        input_ln_input = self.precise_inverse_layernorm(
            attention_output,
            params['input_layernorm_weight']
        )
        
        # Step 6: Inverse of attention
        attention_input = self.precise_inverse_attention(input_ln_input, layer_idx)
        
        # attention_input should equal layer_input
        return attention_input
    
    def full_inverse_transform(self, final_logits, target_layer_depth=3):
        """Complete inverse transform from logits to target layer activations"""
        
        current_activation = final_logits
        
        # Step 1: Inverse of LM head
        current_activation = self.precise_inverse_lm_head(current_activation)
        
        # Step 2: Inverse of final layer norm
        current_activation = self.precise_inverse_layernorm(
            current_activation,
            self.final_layernorm_weight
        )
        
        # Step 3: Inverse through transformer layers (backwards)
        reconstructed_layer_activations = {}
        
        # Go backwards from last layer to target depth
        num_layers = len(self.layer_params)
        for layer_idx in range(num_layers - 1, target_layer_depth - 1, -1):
            current_activation = self.precise_inverse_transformer_layer(
                current_activation, 
                layer_idx
            )
            
            # Store activation for this layer
            reconstructed_layer_activations[layer_idx] = current_activation.detach().clone()
        
        return reconstructed_layer_activations

In [None]:
# Helper function to capture original activations
def get_activation(name, activations_dict):
    def hook(module, input, output):
        if isinstance(output, tuple):
            activations_dict[name] = output[0].detach().clone()
        else:
            activations_dict[name] = output.detach().clone()
    return hook

In [None]:
# Generate sample inputs
def generate_sample_inputs(tokenizer, n_samples=20, seq_length=8):
    """Generate sample inputs - reduced size for computational efficiency"""
    sample_texts = [
        "The cat sat on",
        "Machine learning is", 
        "Climate change affects",
        "Deep networks can",
        "Natural language processing",
        "Computer vision detects",
        "Quantum computing enables",
        "Blockchain provides secure",
        "Renewable energy sources",
        "Medical research shows"
    ]
    
    inputs = []
    for i in range(n_samples):
        base_text = sample_texts[i % len(sample_texts)]
        tokenized = tokenizer(
            base_text,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=seq_length
        )
        inputs.append(tokenized.input_ids.to(model.device))
    
    return inputs

In [None]:
# Main analysis function with precise inverse transforms
def generate_activation_differences_precise_inverse(model, X_data, n_samples=10, n_reconstructions=5):
    """
    Precise inverse transform reconstruction analysis
    """
    results = []
    inverse_transformer = PreciseLlamaInverseTransform(model)
    
    # Target layers to analyze
    target_layers = [0, 1, 2]
    
    for sample_idx in tqdm(range(min(n_samples, len(X_data))), desc="Processing samples"):
        original_input = X_data[sample_idx]
        
        # Get original activations for comparison
        original_activations = {}
        hooks = []
        
        layer_names = [f'model.layers.{i}' for i in target_layers]
        for layer_name in layer_names:
            layer_module = model
            for attr in layer_name.split('.'):
                layer_module = getattr(layer_module, attr)
            hooks.append(layer_module.register_forward_hook(
                get_activation(layer_name, original_activations)
            ))
        
        # Original forward pass
        with torch.no_grad():
            original_output = model(original_input).logits
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        # Multiple reconstructions with different perturbations
        for recon_idx in range(n_reconstructions):
            try:
                # Add controlled perturbation for different reconstructions
                noise_scale = 0.001 * (recon_idx + 1)  # Very small perturbations
                perturbed_logits = original_output + torch.randn_like(original_output) * noise_scale
                
                # PRECISE INVERSE TRANSFORM
                reconstructed_activations = inverse_transformer.full_inverse_transform(
                    perturbed_logits, 
                    target_layer_depth=max(target_layers)
                )
                
                # Calculate differences
                row = {'sample_idx': sample_idx, 'reconstruction_idx': recon_idx}
                all_layer_max_diffs = []
                
                for layer_idx in target_layers:
                    layer_name = f'model.layers.{layer_idx}'
                    
                    if (layer_name in original_activations and 
                        layer_idx in reconstructed_activations):
                        
                        orig_act = original_activations[layer_name].flatten().float()
                        recon_act = reconstructed_activations[layer_idx].flatten().float()
                        
                        # Ensure same size
                        min_size = min(orig_act.shape[0], recon_act.shape[0])
                        orig_act = orig_act[:min_size]
                        recon_act = recon_act[:min_size]
                        
                        abs_diff = torch.abs(orig_act - recon_act)
                        
                        row[f'layer_{layer_idx}_min_abs_diff'] = abs_diff.min().item()
                        row[f'layer_{layer_idx}_mean_abs_diff'] = abs_diff.mean().item()
                        row[f'layer_{layer_idx}_max_abs_diff'] = abs_diff.max().item()
                        
                        all_layer_max_diffs.append(abs_diff.max().item())
                
                # Aggregate metrics
                if all_layer_max_diffs:
                    row['all_layers_max_diff'] = max(all_layer_max_diffs)
                    row['all_layers_min_of_max'] = min(all_layer_max_diffs)
                
                results.append(row)
                
                print(f"Sample {sample_idx}, Recon {recon_idx}: Max diff = {row.get('all_layers_max_diff', 'N/A')}")
                
            except Exception as e:
                print(f"Error in sample {sample_idx}, reconstruction {recon_idx}: {e}")
                continue
    
    return pd.DataFrame(results)


In [None]:
# Run the analysis
print("Generating sample inputs...")
X_data = generate_sample_inputs(tokenizer, n_samples=8, seq_length=6)  # Small for testing
print(f"Generated {len(X_data)} samples")

print("\nRunning precise inverse transform analysis...")
results = generate_activation_differences_precise_inverse(model, X_data, n_samples=5, n_reconstructions=3)

# Save results
results.to_csv('llama2_precise_inverse_results.csv', index=False)
print(f"\nResults saved. Shape: {results.shape}")
if len(results) > 0:
    print("\nFirst few rows:")
    print(results.head())
    print(f"\nSample statistics:")
    print(f"Mean max diff across all layers: {results['all_layers_max_diff'].mean():.6f}")
    print(f"Min max diff: {results['all_layers_max_diff'].min():.6f}")
    print(f"Max max diff: {results['all_layers_max_diff'].max():.6f}")
else:
    print("No results generated - check for errors above")

print("\nPrecise inverse transform analysis with full parameter access completed!")