In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Setup
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = "meta-llama/Llama-2-7b-chat-hf"
print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float32,  # Use float32 for better numerical precision in inversions
    device_map='auto' if torch.cuda.is_available() else None
)
model.eval()

print("Model loaded with full parameter access!")
print(f"Number of layers: {len(model.model.layers)}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Number of attention heads: {model.config.num_attention_heads}")


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.20s/it]

Model loaded with full parameter access!
Number of layers: 32
Hidden size: 4096
Number of attention heads: 32





In [4]:
# %%
# Comprehensive Inverse Transform with Memory-Efficient Parameter Access
class PreciseLlamaInverseTransform:
    def __init__(self, model):
        self.model = model
        self.config = model.config
        self.hidden_size = model.config.hidden_size
        self.num_heads = model.config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        
        # Don't extract all parameters - use direct access to save memory
        print("Initialized with direct parameter access (memory efficient)")
    
    def precise_inverse_lm_head(self, logits):
        # Access weight directly without cloning
        with torch.no_grad():
            lm_head_weight_T_pinv = torch.pinverse(self.model.lm_head.weight.T)
            hidden_states = torch.matmul(logits, lm_head_weight_T_pinv)
        return hidden_states
    
    def precise_inverse_layernorm(self, normalized_output, weight, original_stats=None):
        # LayerNorm: output = (input - mean) / sqrt(var + eps) * weight
        # To invert: input = (output / weight) * sqrt(var + eps) + mean
        
        eps = 1e-5  # Standard epsilon for Llama
        
        if original_stats is not None:
            mean, var = original_stats
        else:
            # Estimate statistics (this is approximate)
            mean = torch.zeros_like(normalized_output.mean(dim=-1, keepdim=True))
            var = torch.ones_like(normalized_output.var(dim=-1, keepdim=True))
        
        # Inverse transformation
        std = torch.sqrt(var + eps)
        unnormalized = (normalized_output / weight) * std + mean
        
        return unnormalized
    
    def precise_inverse_silu(self, silu_output, approximate=True):
        # SiLU(x) = x * sigmoid(x)
        # This is not easily invertible, so we use approximation
        if approximate:
            return silu_output  # Rough approximation
        else:
            return silu_output
    
    def precise_inverse_mlp(self, mlp_output, layer_idx):
        # Access layer directly
        layer = self.model.model.layers[layer_idx]
        
        # MLP structure: gate_proj(x) * SiLU(up_proj(x)) -> down_proj -> output
        # Inverse: output -> inv(down_proj) -> split gate/up -> inv(gate_proj), inv(up_proj)
        
        with torch.no_grad():
            # Step 1: Inverse of down_proj
            down_proj_weight_pinv = torch.pinverse(layer.mlp.down_proj.weight)
            intermediate = torch.matmul(mlp_output, down_proj_weight_pinv.T)
            
            # Step 2: Approximate split of gate*SiLU(up) combination
            gate_proj_weight_pinv = torch.pinverse(layer.mlp.gate_proj.weight)
            up_proj_weight_pinv = torch.pinverse(layer.mlp.up_proj.weight)
            
            # Rough approximation: assume gate ≈ up for inversion
            sqrt_intermediate = torch.sqrt(torch.abs(intermediate) + 1e-8) * torch.sign(intermediate)
            
            # Get original input estimates
            input_from_gate = torch.matmul(sqrt_intermediate, gate_proj_weight_pinv.T)
            input_from_up = torch.matmul(sqrt_intermediate, up_proj_weight_pinv.T)
            
            # Average the estimates
            mlp_input = (input_from_gate + input_from_up) / 2
        
        return mlp_input
    
    def precise_inverse_attention(self, attn_output, layer_idx, attention_mask=None):
        """Precise inverse of self-attention using direct weight access"""
        layer = self.model.model.layers[layer_idx]
        
        with torch.no_grad():
            # Step 1: Inverse of output projection
            o_proj_weight_pinv = torch.pinverse(layer.self_attn.o_proj.weight)
            attention_heads_output = torch.matmul(attn_output, o_proj_weight_pinv.T)
            
            # Step 2: Reshape to separate heads
            batch_size, seq_len = attn_output.shape[:2]
            attention_heads_output = attention_heads_output.view(
                batch_size, seq_len, self.num_heads, self.head_dim
            )
            
            # Step 3: Approximate inverse of attention mechanism
            v_approx = attention_heads_output  # Assume V ≈ attention output
            
            # Reshape back
            v_concat = v_approx.view(batch_size, seq_len, self.hidden_size)
            
            # Step 4: Inverse of V projection to get original input
            v_proj_weight_pinv = torch.pinverse(layer.self_attn.v_proj.weight)
            input_from_v = torch.matmul(v_concat, v_proj_weight_pinv.T)
            
            # For simplicity, use V projection estimate
            attention_input = input_from_v
        
        return attention_input
    
    def precise_inverse_transformer_layer(self, layer_output, layer_idx, stored_residuals=None):
        layer = self.model.model.layers[layer_idx]
        
        current = layer_output
        
        # Step 1: Approximate residual split
        # residual_2 = residual_1 + mlp_output
        if stored_residuals and layer_idx in stored_residuals:
            residual_1 = stored_residuals[layer_idx]['post_attention']
        else:
            residual_1 = current * 0.5  # Rough approximation
        
        mlp_output = current - residual_1
        
        # Step 2: Inverse of post-attention LayerNorm
        post_attn_ln_input = self.precise_inverse_layernorm(
            mlp_output, 
            layer.post_attention_layernorm.weight
        )
        
        # Step 3: Inverse of MLP
        mlp_input = self.precise_inverse_mlp(post_attn_ln_input, layer_idx)
        
        # Step 4: Approximate first residual split
        if stored_residuals and layer_idx in stored_residuals:
            layer_input = stored_residuals[layer_idx]['input']
        else:
            layer_input = mlp_input * 0.5  # Rough approximation
        
        attention_output = mlp_input - layer_input
        
        # Step 5: Inverse of input LayerNorm
        input_ln_input = self.precise_inverse_layernorm(
            attention_output,
            layer.input_layernorm.weight
        )
        
        # Step 6: Inverse of attention
        attention_input = self.precise_inverse_attention(input_ln_input, layer_idx)
        
        return attention_input
    
    def full_inverse_transform(self, final_logits, target_layer_depth=3):
        current_activation = final_logits
        
        # Step 1: Inverse of LM head
        current_activation = self.precise_inverse_lm_head(current_activation)
        
        # Step 2: Inverse of final layer norm
        current_activation = self.precise_inverse_layernorm(
            current_activation,
            self.model.model.norm.weight
        )
        
        # Step 3: Inverse through transformer layers (backwards)
        reconstructed_layer_activations = {}
        
        # Go backwards from last layer to target depth
        num_layers = len(self.model.model.layers)
        for layer_idx in range(num_layers - 1, target_layer_depth - 1, -1):
            current_activation = self.precise_inverse_transformer_layer(
                current_activation, 
                layer_idx
            )
            
            # Store activation for this layer
            reconstructed_layer_activations[layer_idx] = current_activation.detach().clone()
            
            # Clear cache periodically
            if layer_idx % 5 == 0:
                torch.cuda.empty_cache()
        
        return reconstructed_layer_activations


In [5]:
# Helper function to capture original activations
def get_activation(name, activations_dict):
    def hook(module, input, output):
        if isinstance(output, tuple):
            activations_dict[name] = output[0].detach().clone()
        else:
            activations_dict[name] = output.detach().clone()
    return hook

In [6]:
# Generate sample inputs
def generate_sample_inputs(tokenizer, n_samples=20, seq_length=8):
    sample_texts = [
        "The cat sat on",
        "Machine learning is", 
        "Climate change affects",
        "Deep networks can",
        "Natural language processing",
        "Computer vision detects",
        "Quantum computing enables",
        "Blockchain provides secure",
        "Renewable energy sources",
        "Medical research shows"
    ]
    
    inputs = []
    for i in range(n_samples):
        base_text = sample_texts[i % len(sample_texts)]
        tokenized = tokenizer(
            base_text,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=seq_length
        )
        inputs.append(tokenized.input_ids.to(model.device))
    
    return inputs

In [7]:
# %%
# Main analysis function with memory management
def generate_activation_differences_precise_inverse(model, X_data, n_samples=10, n_reconstructions=5):
    results = []
    
    # Clear memory before starting
    torch.cuda.empty_cache()
    
    # Initialize inverse transformer (now memory efficient)
    inverse_transformer = PreciseLlamaInverseTransform(model)
    
    # Target layers to analyze
    target_layers = [0, 1, 2]
    
    for sample_idx in tqdm(range(min(n_samples, len(X_data))), desc="Processing samples"):
        original_input = X_data[sample_idx]
        
        # Get original activations for comparison
        original_activations = {}
        hooks = []
        
        layer_names = [f'model.layers.{i}' for i in target_layers]
        for layer_name in layer_names:
            layer_module = model
            for attr in layer_name.split('.'):
                layer_module = getattr(layer_module, attr)
            hooks.append(layer_module.register_forward_hook(
                get_activation(layer_name, original_activations)
            ))
        
        # Original forward pass
        with torch.no_grad():
            original_output = model(original_input).logits
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        # Multiple reconstructions with different perturbations
        for recon_idx in range(n_reconstructions):
            try:
                # Add controlled perturbation for different reconstructions
                noise_scale = 0.001 * (recon_idx + 1)
                perturbed_logits = original_output + torch.randn_like(original_output) * noise_scale
                
                # PRECISE INVERSE TRANSFORM (now memory efficient)
                reconstructed_activations = inverse_transformer.full_inverse_transform(
                    perturbed_logits, 
                    target_layer_depth=max(target_layers)
                )
                
                # Calculate differences
                row = {'sample_idx': sample_idx, 'reconstruction_idx': recon_idx}
                all_layer_max_diffs = []
                
                for layer_idx in target_layers:
                    layer_name = f'model.layers.{layer_idx}'
                    
                    if (layer_name in original_activations and 
                        layer_idx in reconstructed_activations):
                        
                        orig_act = original_activations[layer_name].flatten().float()
                        recon_act = reconstructed_activations[layer_idx].flatten().float()
                        
                        # Ensure same size
                        min_size = min(orig_act.shape[0], recon_act.shape[0])
                        orig_act = orig_act[:min_size]
                        recon_act = recon_act[:min_size]
                        
                        abs_diff = torch.abs(orig_act - recon_act)
                        
                        row[f'layer_{layer_idx}_min_abs_diff'] = abs_diff.min().item()
                        row[f'layer_{layer_idx}_mean_abs_diff'] = abs_diff.mean().item()
                        row[f'layer_{layer_idx}_max_abs_diff'] = abs_diff.max().item()
                        
                        all_layer_max_diffs.append(abs_diff.max().item())
                
                # Aggregate metrics
                if all_layer_max_diffs:
                    row['all_layers_max_diff'] = max(all_layer_max_diffs)
                    row['all_layers_min_of_max'] = min(all_layer_max_diffs)
                
                results.append(row)
                
                print(f"Sample {sample_idx}, Recon {recon_idx}: Max diff = {row.get('all_layers_max_diff', 'N/A')}")
                
                # Clear cache after each reconstruction
                torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"Error in sample {sample_idx}, reconstruction {recon_idx}: {e}")
                torch.cuda.empty_cache()
                continue
    
    return pd.DataFrame(results)


In [None]:
# Run the analysis
print("Generating sample inputs...")
X_data = generate_sample_inputs(tokenizer, n_samples=8, seq_length=6)  # Small for testing
print(f"Generated {len(X_data)} samples")

print("\nRunning precise inverse transform analysis...")
results = generate_activation_differences_precise_inverse(model, X_data, n_samples=5, n_reconstructions=3)

# Save results
results.to_csv('llama2_precise_inverse_results.csv', index=False)
print(f"\nResults saved. Shape: {results.shape}")
if len(results) > 0:
    print("\nFirst few rows:")
    print(results.head())
    print(f"\nSample statistics:")
    print(f"Mean max diff across all layers: {results['all_layers_max_diff'].mean():.6f}")
    print(f"Min max diff: {results['all_layers_max_diff'].min():.6f}")
    print(f"Max max diff: {results['all_layers_max_diff'].max():.6f}")
else:
    print("No results generated - check for errors above")

print("\nPrecise inverse transform analysis with full parameter access completed!")

Generating sample inputs...
Generated 8 samples

Running precise inverse transform analysis...
Initialized with direct parameter access (memory efficient)


Processing samples:   0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
import gc
import torch

torch.cuda.empty_cache()
gc.collect()

254

In [4]:
# %%
# Clear memory and set memory management
import gc
import os
import torch
# Set PyTorch memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Clear all caches
torch.cuda.empty_cache()
gc.collect()

# Check memory
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU memory cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


GPU memory allocated: 0.00 GB
GPU memory cached: 0.00 GB
