In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from dataclasses import dataclass
import numpy as np
from typing import Dict, List, Tuple, Optional
import time

@dataclass
class MinimalConfig:
    n_embd: int = 64
    n_layer: int = 3
    n_head: int = 4
    vocab_size: int = 256
    block_size: int = 32
    n_features: int = 96
    sparsity_penalty: float = 1e-3
    jump_threshold: float = 0.03
    pre_act_penalty: float = 3e-6
    bandwidth: float = 1.0
    learning_rate: float = 1e-4
    n_training_tokens: int = 100_000_000

class JumpReLUFunction(Function):
    @staticmethod
    def forward(ctx, x, threshold, bandwidth):
        ctx.save_for_backward(x, threshold)
        ctx.bandwidth = bandwidth
        active = (x > threshold).float()
        return F.relu(x - threshold), active
    
    @staticmethod
    def backward(ctx, grad_output, grad_active):
        x, threshold = ctx.saved_tensors
        in_bandwidth = (torch.abs(x - threshold) <= ctx.bandwidth).float()
        active = (x > threshold).float()
        grad_x = grad_output * (active + (1 - active) * in_bandwidth)
        grad_threshold = -grad_output * active
        grad_threshold = grad_threshold.sum(dim=(0, 1))
        return grad_x, grad_threshold, None

class JumpReLU(nn.Module):
    def __init__(self, n_features, threshold=0.03, bandwidth=1.0):
        super().__init__()
        self.threshold = nn.Parameter(torch.full((n_features,), threshold))
        self.bandwidth = bandwidth
    
    def forward(self, x):
        return JumpReLUFunction.apply(x, self.threshold, self.bandwidth)

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.attn = nn.MultiheadAttention(
            config.n_embd, config.n_head, dropout=0.0, batch_first=True
        )
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd)
        )
    
    def forward(self, x, return_components=False, frozen_ln1_denom=None,
                frozen_ln2_denom=None, frozen_attn_patterns=None):
        residual_pre_attn = x
        
        if frozen_ln1_denom is not None:
            mean = x.mean(dim=-1, keepdim=True)
            x_norm1 = (x - mean) / frozen_ln1_denom
            x_norm1 = x_norm1 * self.ln_1.weight + self.ln_1.bias
            ln1_denom = frozen_ln1_denom
        else:
            x_norm1 = self.ln_1(x)
            mean = x.mean(dim=-1, keepdim=True)
            var = ((x - mean)**2).mean(dim=-1, keepdim=True)
            ln1_denom = torch.sqrt(var + self.ln_1.eps)
        
        if frozen_attn_patterns is not None:
            B, T, D = x_norm1.shape
            # MultiheadAttention with average_attn_weights=True returns [B, T, T]
            # MultiheadAttention with average_attn_weights=False returns [B, num_heads, T, T]
            # But by default, PyTorch's MultiheadAttention returns averaged weights [B, T, T]
            
            # Apply attention using frozen patterns
            if frozen_attn_patterns.dim() == 2:  # [T, T] - need to add batch dim
                attn_weights_expanded = frozen_attn_patterns.unsqueeze(0).expand(B, -1, -1)
                attn_out = torch.bmm(attn_weights_expanded, x_norm1)
            elif frozen_attn_patterns.dim() == 3:  # [B, T, T] or [num_heads, T, T]
                if frozen_attn_patterns.shape[0] == B:  # [B, T, T]
                    attn_out = torch.bmm(frozen_attn_patterns, x_norm1)
                elif frozen_attn_patterns.shape[0] == self.attn.num_heads:  # [num_heads, T, T]
                    # Need to apply per-head attention
                    head_dim = D // self.attn.num_heads
                    v = x_norm1.view(B, T, self.attn.num_heads, head_dim).transpose(1, 2)
                    attn_out = torch.zeros(B, self.attn.num_heads, T, head_dim, device=x.device)
                    for h in range(self.attn.num_heads):
                        v_head = v[:, h, :, :]
                        attn_weights_h = frozen_attn_patterns[h, :, :].unsqueeze(0).expand(B, -1, -1)
                        attn_out[:, h, :, :] = torch.bmm(attn_weights_h, v_head)
                    attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
                else:  # Likely [1, T, T] - averaged weights with an extra dim
                    attn_weights_expanded = frozen_attn_patterns.squeeze(0).unsqueeze(0).expand(B, -1, -1)
                    attn_out = torch.bmm(attn_weights_expanded, x_norm1)
            else:
                raise ValueError(f"Unexpected frozen_attn_patterns shape: {frozen_attn_patterns.shape}")
            
            attn_weights = frozen_attn_patterns
        else:
            attn_out, attn_weights = self.attn(x_norm1, x_norm1, x_norm1)
        
        x = residual_pre_attn + attn_out
        residual_pre_mlp = x
        
        if frozen_ln2_denom is not None:
            mean = x.mean(dim=-1, keepdim=True)
            x_norm2 = (x - mean) / frozen_ln2_denom
            x_norm2 = x_norm2 * self.ln_2.weight + self.ln_2.bias
            ln2_denom = frozen_ln2_denom
        else:
            x_norm2 = self.ln_2(x)
            mean = x.mean(dim=-1, keepdim=True)
            var = ((x - mean)**2).mean(dim=-1, keepdim=True)
            ln2_denom = torch.sqrt(var + self.ln_2.eps)
        
        mlp_out = self.mlp(x_norm2)
        x = residual_pre_mlp + mlp_out
        
        if return_components:
            return x, {
                'residual_pre_attn': residual_pre_attn,
                'residual_pre_mlp': residual_pre_mlp,
                'attn_out': attn_out,
                'pre_mlp_norm': x_norm2,
                'mlp_output': mlp_out,
                'ln1_denom': ln1_denom,
                'ln2_denom': ln2_denom,
                'attn_weights': attn_weights,
                'attn_norm': x_norm1
            }
        return x

class MinimalTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
    
    def forward(self, input_ids, collect_activations=False,
                frozen_ln1_denoms=None, frozen_ln2_denoms=None,
                frozen_attn_patterns=None, frozen_lnf_denom=None):
        B, T = input_ids.shape
        token_embeddings = self.token_emb(input_ids)
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
        pos_embeddings = self.pos_emb(pos)
        x = token_embeddings + pos_embeddings
        
        if collect_activations:
            activations = {
                'embeddings': x,
                'token_embeddings': token_embeddings,
                'pos_embeddings': pos_embeddings,
                'pre_mlp_norms': [],
                'mlp_outputs': [],
                'ln1_denoms': [],
                'ln2_denoms': [],
                'residual_pre_mlps': [],
                'attn_weights': [],
                'attn_norms': [],
                'attn_outs': [],
                'residual_pre_attns': []
            }
        
        for i, block in enumerate(self.blocks):
            frozen_ln1 = frozen_ln1_denoms[i] if frozen_ln1_denoms else None
            frozen_ln2 = frozen_ln2_denoms[i] if frozen_ln2_denoms else None
            frozen_attn = frozen_attn_patterns[i] if frozen_attn_patterns else None
            
            if collect_activations:
                x, components = block(x, return_components=True,
                                     frozen_ln1_denom=frozen_ln1,
                                     frozen_ln2_denom=frozen_ln2,
                                     frozen_attn_patterns=frozen_attn)
                activations['pre_mlp_norms'].append(components['pre_mlp_norm'])
                activations['mlp_outputs'].append(components['mlp_output'])
                activations['ln1_denoms'].append(components['ln1_denom'])
                activations['ln2_denoms'].append(components['ln2_denom'])
                activations['residual_pre_mlps'].append(components['residual_pre_mlp'])
                activations['attn_weights'].append(components['attn_weights'])
                activations['attn_norms'].append(components['attn_norm'])
                activations['attn_outs'].append(components['attn_out'])
                activations['residual_pre_attns'].append(components['residual_pre_attn'])
            else:
                x = block(x, frozen_ln1_denom=frozen_ln1,
                         frozen_ln2_denom=frozen_ln2,
                         frozen_attn_patterns=frozen_attn)
        
        if frozen_lnf_denom is not None:
            mean = x.mean(dim=-1, keepdim=True)
            x = (x - mean) / frozen_lnf_denom
            x = x * self.ln_f.weight + self.ln_f.bias
            if collect_activations:
                activations['lnf_denom'] = frozen_lnf_denom
        else:
            if collect_activations:
                mean = x.mean(dim=-1, keepdim=True)
                var = ((x - mean)**2).mean(dim=-1, keepdim=True)
                activations['lnf_denom'] = torch.sqrt(var + self.ln_f.eps)
            x = self.ln_f(x)
        
        logits = self.lm_head(x)
        
        return (logits, activations) if collect_activations else logits

class CrossLayerTranscoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.n_features = config.n_features
        self.n_layers = config.n_layer
        
        self.encoders = nn.ModuleList([
            nn.Linear(config.n_embd, self.n_features) for _ in range(self.n_layers)
        ])
        
        self.activations = nn.ModuleList([
            JumpReLU(self.n_features, config.jump_threshold, config.bandwidth)
            for _ in range(self.n_layers)
        ])
        
        self.decoders = nn.ModuleDict()
        for k in range(self.n_layers):
            for j in range(k, self.n_layers):
                self.decoders[f"{k}_to_{j}"] = nn.Linear(self.n_features, config.n_embd, bias=False)
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for encoder in self.encoders:
            bound = 1 / np.sqrt(self.n_features)
            nn.init.uniform_(encoder.weight, -bound, bound)
            nn.init.zeros_(encoder.bias)
        
        bound = 1 / np.sqrt(self.n_layers * self.config.n_embd)
        for decoder in self.decoders.values():
            nn.init.uniform_(decoder.weight, -bound, bound)
    
    def forward(self, pre_mlp_norms, compute_jacobian=False, normalize_inputs=True):
        all_features, all_active, all_pre_acts = [], [], []
        
        normalized_inputs = []
        for k in range(self.n_layers):
            if normalize_inputs:
                norm_input = F.normalize(pre_mlp_norms[k], p=2, dim=-1, eps=1e-12)
                norm_input = norm_input * np.sqrt(self.config.n_embd)
            else:
                norm_input = pre_mlp_norms[k]
            normalized_inputs.append(norm_input)
        
        for k in range(self.n_layers):
            pre_act = self.encoders[k](normalized_inputs[k])
            features, active = self.activations[k](pre_act)
            all_pre_acts.append(pre_act)
            all_features.append(features)
            all_active.append(active)
        
        reconstructions = []
        for j in range(self.n_layers):
            reconstruction = torch.zeros_like(pre_mlp_norms[0])
            for k in range(j + 1):
                if compute_jacobian:
                    features_k = all_features[k].detach() * all_active[k]
                else:
                    features_k = all_features[k]
                reconstruction += self.decoders[f"{k}_to_{j}"](features_k)
            reconstructions.append(reconstruction)
        
        return reconstructions, all_features, all_active, all_pre_acts

class CLTTrainer:
    def __init__(self, base_model, clt, config):
        self.base_model = base_model
        self.clt = clt
        self.config = config
        
        total_params = sum(p.numel() for p in clt.parameters())
        lr_scale = 1.0 / np.sqrt(total_params)
        self.optimizer = torch.optim.AdamW(
            clt.parameters(),
            lr=config.learning_rate * lr_scale,
            betas=(0.9, 0.999)
        )
        
        self.step = 0
        self.total_steps = int(2000 * np.power(config.n_features / 96, 0.8))
    
    def compute_loss(self, mlp_outputs, reconstructions, features, pre_acts):
        normalized_targets = []
        for j in range(len(mlp_outputs)):
            norm_target = F.normalize(mlp_outputs[j], p=2, dim=-1, eps=1e-12)
            norm_target = norm_target * np.sqrt(self.config.n_embd)
            normalized_targets.append(norm_target)
        
        normalized_recons = []
        for j in range(len(reconstructions)):
            norm_recon = F.normalize(reconstructions[j], p=2, dim=-1, eps=1e-12)
            norm_recon = norm_recon * np.sqrt(self.config.n_embd)
            normalized_recons.append(norm_recon)
        
        recon_loss = sum(F.mse_loss(r, t) for r, t in zip(normalized_recons, normalized_targets)) / len(mlp_outputs)
        
        sparsity_loss = sum(torch.tanh(f).mean() for f in features) / len(features)
        
        ramp_factor = min(1.0, self.step / self.total_steps)
        sparsity_weight = self.config.sparsity_penalty * ramp_factor
        
        pre_act_loss = sum(F.relu(-p).mean() for p in pre_acts) / len(pre_acts)
        
        total = recon_loss + sparsity_weight * sparsity_loss + self.config.pre_act_penalty * pre_act_loss
        
        l0 = sum([(f > 0).float().mean().item() for f in features]) / len(features)
        
        return total, recon_loss.item(), sparsity_loss.item(), l0
    
    def train_step(self, batch):
        self.step += 1
        
        with torch.no_grad():
            _, activations = self.base_model(batch, collect_activations=True)
        
        reconstructions, features, active_masks, pre_acts = self.clt(activations['pre_mlp_norms'])
        
        loss, recon, sparsity, l0 = self.compute_loss(
            activations['mlp_outputs'], reconstructions, features, pre_acts
        )
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return {'loss': loss.item(), 'recon': recon, 'sparsity': sparsity, 'l0': l0}

class AttributionGraph:
    def __init__(self, base_model, clt):
        self.base_model = base_model
        self.clt = clt
        self.n_layers = base_model.config.n_layer
        self.d_model = base_model.config.n_embd
        self.n_heads = base_model.config.n_head
        self.coactivation_counts = {}
        self.feature_activation_counts = {}
    
    def compute_virtual_weights(self, k: int, i: int, j: int) -> torch.Tensor:
        sum_dec = torch.zeros(self.d_model)
        for l in range(k, j):
            sum_dec += self.clt.decoders[f"{k}_to_{l}"].weight[:, i]
        
        virtual_weights = torch.matmul(sum_dec, self.clt.encoders[j].weight.T)
        return virtual_weights
    
    def compute_twera(self, virtual_weight: float, source_key: str, target_key: str) -> float:
        coact_count = self.coactivation_counts.get((source_key, target_key), 0)
        source_count = self.feature_activation_counts.get(source_key, 1)
        
        p_coactive = coact_count / max(source_count, 1)
        
        return virtual_weight * p_coactive
    
    def update_coactivation_stats(self, features, active_masks):
        for layer_k in range(self.n_layers):
            active_k = active_masks[layer_k] > 0
            for idx_k in torch.where(active_k.flatten())[0]:
                key_k = f"L{layer_k}_F{idx_k.item()}"
                self.feature_activation_counts[key_k] = self.feature_activation_counts.get(key_k, 0) + 1
                
                for layer_j in range(layer_k + 1, self.n_layers):
                    active_j = active_masks[layer_j] > 0
                    for idx_j in torch.where(active_j.flatten())[0]:
                        key_j = f"L{layer_j}_F{idx_j.item()}"
                        if active_k.flatten()[idx_k] and active_j.flatten()[idx_j]:
                            pair = (key_k, key_j)
                            self.coactivation_counts[pair] = self.coactivation_counts.get(pair, 0) + 1
    
    def compute_jacobian_attribution(self, source_vec: torch.Tensor, target_vec: torch.Tensor,
                                    acts: Dict, layer_source: int, layer_target: int,
                                    frozen_denoms: Dict) -> float:
        source_injection = source_vec.clone().detach().requires_grad_(True)
        
        with torch.enable_grad():
            x = source_injection
            
            for l in range(layer_source, min(layer_target, self.n_layers)):
                residual = x
                
                if l < len(acts['attn_outs']):
                    x = residual + acts['attn_outs'][l][0, -1].detach()
                
                if l < layer_target:
                    x = residual + x * 0.1
            
            output = torch.dot(x.flatten(), target_vec.flatten())
            
            grad = torch.autograd.grad(output, source_injection,
                                      retain_graph=False,
                                      create_graph=False)[0]
            
            grad = grad.detach()
        
        return torch.dot(source_vec.detach(), grad).item()
    
    def compute_indirect_influence(self, adjacency_matrix: torch.Tensor) -> torch.Tensor:
        A = torch.abs(adjacency_matrix)
        
        row_sums = A.sum(dim=1, keepdim=True)
        row_sums = torch.clamp(row_sums, min=1e-8)
        A_norm = A / row_sums
        
        I = torch.eye(A.shape[0])
        try:
            influence = torch.linalg.inv(I - A_norm) - I
        except:
            influence = A_norm.clone()
            A_power = A_norm.clone()
            for _ in range(10):
                A_power = torch.matmul(A_power, A_norm)
                influence += A_power
                if A_power.abs().max() < 1e-6:
                    break
        
        return influence
    
    def compute_full_graph(self, input_ids, target_pos=-1, use_twera=False):
        device = input_ids.device
        B, T = input_ids.shape
        
        with torch.no_grad():
            _, acts_initial = self.base_model(input_ids, collect_activations=True)
            frozen_ln1_denoms = acts_initial['ln1_denoms']
            frozen_ln2_denoms = acts_initial['ln2_denoms']
            frozen_lnf_denom = acts_initial.get('lnf_denom', None)
            frozen_attn_patterns = acts_initial['attn_weights']
            
            logits, acts = self.base_model(
                input_ids,
                collect_activations=True,
                frozen_ln1_denoms=frozen_ln1_denoms,
                frozen_ln2_denoms=frozen_ln2_denoms,
                frozen_attn_patterns=frozen_attn_patterns,
                frozen_lnf_denom=frozen_lnf_denom
            )
            
            reconstructions, features, active_masks, pre_acts = self.clt(
                acts['pre_mlp_norms'], compute_jacobian=True
            )
        
        self.update_coactivation_stats(features, active_masks)
        
        errors = []
        for j in range(self.n_layers):
            error = acts['mlp_outputs'][j] - reconstructions[j]
            errors.append(error.detach())
        
        target_token = torch.argmax(logits[0, target_pos]).item()
        unembed = self.base_model.lm_head.weight[target_token, :]
        mean_unembed = self.base_model.lm_head.weight.mean(dim=0)
        logit_v = unembed - mean_unembed
        
        token_id = input_ids[0, target_pos].item()
        active_features = []
        for layer in range(self.n_layers):
            active_mask = active_masks[layer][0, target_pos] > 0
            for idx in torch.where(active_mask)[0]:
                active_features.append({
                    'layer': layer,
                    'index': idx.item(),
                    'activation': features[layer][0, target_pos, idx].item()
                })
        
        nodes = {
            'embeddings': [{'type': 'token', 'position': target_pos, 'token_id': token_id}],
            'features': active_features,
            'errors': [{'layer': j, 'norm': torch.norm(errors[j][0, target_pos]).item()}
                      for j in range(self.n_layers) if torch.norm(errors[j][0, target_pos]).item() > 0.01],
            'attention_heads': [],
            'logit': f"logit_{target_token}"
        }
        
        # In compute_full_graph method, replace the attention_heads section (around line 495-510) with:
        
        for layer in range(self.n_layers):
            attn_weights = frozen_attn_patterns[layer]
            
            # Handle different shapes of attention weights
            if attn_weights.dim() == 4 and attn_weights.shape[1] == self.n_heads:
                # [B, num_heads, T, T] - per-head weights
                attn_weights = attn_weights[0]  # Take first batch
                for head in range(self.n_heads):
                    attn_pattern = attn_weights[head, target_pos, :].cpu().numpy()
                    if np.max(attn_pattern) > 0.1:
                        nodes['attention_heads'].append({
                            'layer': layer,
                            'head': head,
                            'max_weight': float(np.max(attn_pattern)),
                            'pattern': attn_pattern
                        })
            elif attn_weights.dim() == 3:
                # [B, T, T] or [num_heads, T, T]
                if attn_weights.shape[0] == self.n_heads:
                    # [num_heads, T, T] - per-head weights without batch
                    for head in range(self.n_heads):
                        attn_pattern = attn_weights[head, target_pos, :].cpu().numpy()
                        if np.max(attn_pattern) > 0.1:
                            nodes['attention_heads'].append({
                                'layer': layer,
                                'head': head,
                                'max_weight': float(np.max(attn_pattern)),
                                'pattern': attn_pattern
                            })
                else:
                    # [B, T, T] - averaged weights
                    attn_pattern = attn_weights[0, target_pos, :].cpu().numpy()
                    if np.max(attn_pattern) > 0.1:
                        # Add as a single "averaged" attention head
                        nodes['attention_heads'].append({
                            'layer': layer,
                            'head': 0,  # Use head 0 to represent averaged
                            'max_weight': float(np.max(attn_pattern)),
                            'pattern': attn_pattern
                        })
            elif attn_weights.dim() == 2:
                # [T, T] - averaged weights without batch
                attn_pattern = attn_weights[target_pos, :].cpu().numpy()
                if np.max(attn_pattern) > 0.1:
                    nodes['attention_heads'].append({
                        'layer': layer,
                        'head': 0,  # Use head 0 to represent averaged
                        'max_weight': float(np.max(attn_pattern)),
                        'pattern': attn_pattern
                    })
        
        edges = []
        
        emb = acts['embeddings'][0, target_pos].detach()
        direct_contrib = torch.dot(emb, logit_v).item()
        if abs(direct_contrib) > 0.01:
            edges.append({
                'source': f"emb_{token_id}",
                'target': f"logit_{target_token}",
                'weight': direct_contrib,
                'type': 'embedding_to_logit_direct'
            })
        
        frozen_denoms = {
            'ln1': frozen_ln1_denoms,
            'ln2': frozen_ln2_denoms,
            'lnf': frozen_lnf_denom,
            'attn': frozen_attn_patterns
        }
        
        for feat in active_features:
            if feat['layer'] == 0:
                enc_weight = self.clt.encoders[0].weight[feat['index'], :]
                weight = self.compute_jacobian_attribution(
                    emb, enc_weight, acts, 0, 0, frozen_denoms
                ) * feat['activation']
                if abs(weight) > 0.01:
                    edges.append({
                        'source': f"emb_{token_id}",
                        'target': f"L0_F{feat['index']}",
                        'weight': weight,
                        'type': 'embedding_to_feature'
                    })
        
        for s in active_features:
            k, i, a_s = s['layer'], s['index'], s['activation']
            source_key = f"L{k}_F{i}"
            
            for t in active_features:
                j, m = t['layer'], t['index']
                if j <= k:
                    continue
                
                target_key = f"L{j}_F{m}"
                
                virtual_weights = self.compute_virtual_weights(k, i, j)
                V_st = virtual_weights[m].item()
                
                if use_twera:
                    edge_weight = self.compute_twera(a_s * V_st, source_key, target_key)
                else:
                    edge_weight = a_s * V_st
                
                if abs(edge_weight) > 0.01:
                    edges.append({
                        'source': source_key,
                        'target': target_key,
                        'weight': edge_weight,
                        'type': 'feature_to_feature'
                    })
        
        for s in active_features:
            k, i, a_s = s['layer'], s['index'], s['activation']
            
            sum_dec = torch.zeros(self.d_model, device=device)
            for l in range(k, self.n_layers):
                decoder_weight = self.clt.decoders[f"{k}_to_{l}"].weight[:, i].detach()
                sum_dec += decoder_weight
            
            weight = a_s * torch.dot(sum_dec, logit_v).item()
            if abs(weight) > 0.01:
                edges.append({
                    'source': f"L{k}_F{i}",
                    'target': f"logit_{target_token}",
                    'weight': weight,
                    'type': 'feature_to_logit'
                })
        
        for error_info in nodes['errors']:
            j = error_info['layer']
            error_vec = errors[j][0, target_pos, :].detach()
            
            contrib = torch.dot(error_vec, logit_v).item()
            if abs(contrib) > 0.01:
                edges.append({
                    'source': f"error_L{j}",
                    'target': f"logit_{target_token}",
                    'weight': contrib,
                    'type': 'error_to_logit'
                })
            
            for t in active_features:
                if t['layer'] > j:
                    enc_t = self.clt.encoders[t['layer']].weight[t['index'], :].detach()
                    weight = torch.dot(error_vec, enc_t).item()
                    if abs(weight) > 0.01:
                        edges.append({
                            'source': f"error_L{j}",
                            'target': f"L{t['layer']}_F{t['index']}",
                            'weight': weight,
                            'type': 'error_to_feature'
                        })
        
        for attn_node in nodes['attention_heads']:
            layer = attn_node['layer']
            head = attn_node['head']
            
            attn_out = acts['attn_outs'][layer][0, target_pos].detach()
            head_contribution = attn_out / self.n_heads
            
            for feat in active_features:
                if feat['layer'] > layer:
                    enc_weight = self.clt.encoders[feat['layer']].weight[feat['index'], :].detach()
                    weight = torch.dot(head_contribution, enc_weight).item()
                    if abs(weight) > 0.01:
                        edges.append({
                            'source': f"L{layer}_H{head}",
                            'target': f"L{feat['layer']}_F{feat['index']}",
                            'weight': weight,
                            'type': 'attention_to_feature'
                        })
            
            contrib = torch.dot(head_contribution, logit_v).item()
            if abs(contrib) > 0.01:
                edges.append({
                    'source': f"L{layer}_H{head}",
                    'target': f"logit_{target_token}",
                    'weight': contrib,
                    'type': 'attention_to_logit'
                })
        
        return {
            'nodes': nodes,
            'edges': edges,
            'target_token': target_token,
            'position': target_pos,
            'frozen_denoms': frozen_denoms
        }
    
    def prune_graph(self, graph: Dict, node_threshold: float = 0.8, edge_threshold: float = 0.98) -> Dict:
        node_names = []
        node_names.append(f"emb_{graph['nodes']['embeddings'][0]['token_id']}")
        for feat in graph['nodes']['features']:
            node_names.append(f"L{feat['layer']}_F{feat['index']}")
        for err in graph['nodes']['errors']:
            node_names.append(f"error_L{err['layer']}")
        for attn in graph['nodes']['attention_heads']:
            node_names.append(f"L{attn['layer']}_H{attn['head']}")
        node_names.append(graph['nodes']['logit'])
        
        n_nodes = len(node_names)
        adjacency = torch.zeros(n_nodes, n_nodes)
        
        node_to_idx = {name: i for i, name in enumerate(node_names)}
        for edge in graph['edges']:
            if edge['source'] in node_to_idx and edge['target'] in node_to_idx:
                i, j = node_to_idx[edge['source']], node_to_idx[edge['target']]
                adjacency[j, i] = abs(edge['weight'])
        
        influence = self.compute_indirect_influence(adjacency)
        
        logit_idx = node_to_idx[graph['nodes']['logit']]
        logit_influence = influence[logit_idx, :]
        
        sorted_influence, sorted_indices = torch.sort(logit_influence, descending=True)
        cumsum = torch.cumsum(sorted_influence, dim=0)
        cumsum = cumsum / (cumsum[-1] + 1e-8)
        n_keep = (cumsum <= node_threshold).sum().item() + 1
        
        top_indices = sorted_indices[:n_keep]
        kept_nodes = set([node_names[i] for i in top_indices])
        kept_nodes.add(graph['nodes']['logit'])
        
        kept_edges = [e for e in graph['edges']
                     if e['source'] in kept_nodes and e['target'] in kept_nodes]
        
        kept_node_list = list(kept_nodes)
        kept_node_to_idx = {name: i for i, name in enumerate(kept_node_list)}
        n_kept = len(kept_node_list)
        
        kept_adjacency = torch.zeros(n_kept, n_kept)
        edge_to_weight = {}
        
        for edge in kept_edges:
            i = kept_node_to_idx[edge['source']]
            j = kept_node_to_idx[edge['target']]
            kept_adjacency[j, i] = abs(edge['weight'])
            edge_to_weight[(edge['source'], edge['target'])] = abs(edge['weight'])
        
        kept_influence = self.compute_indirect_influence(kept_adjacency)
        
        logit_idx_kept = kept_node_to_idx[graph['nodes']['logit']]
        node_scores = kept_influence[logit_idx_kept, :]
        
        edge_scores = []
        for edge in kept_edges:
            source_idx = kept_node_to_idx[edge['source']]
            target_idx = kept_node_to_idx[edge['target']]
            edge_score = kept_adjacency[target_idx, source_idx] * node_scores[target_idx]
            edge_scores.append((edge, edge_score.item()))
        
        edge_scores.sort(key=lambda x: x[1], reverse=True)
        total_score = sum(score for _, score in edge_scores)
        
        pruned_edges = []
        cumulative_score = 0
        for edge, score in edge_scores:
            cumulative_score += score
            pruned_edges.append(edge)
            if cumulative_score / total_score >= edge_threshold:
                break
        
        final_nodes = set()
        for edge in pruned_edges:
            final_nodes.add(edge['source'])
            final_nodes.add(edge['target'])
        final_nodes.add(graph['nodes']['logit'])
        
        pruned_nodes = {
            'embeddings': [e for e in graph['nodes']['embeddings']
                          if f"emb_{e['token_id']}" in final_nodes],
            'features': [f for f in graph['nodes']['features']
                        if f"L{f['layer']}_F{f['index']}" in final_nodes],
            'errors': [e for e in graph['nodes']['errors']
                      if f"error_L{e['layer']}" in final_nodes],
            'attention_heads': [a for a in graph['nodes']['attention_heads']
                               if f"L{a['layer']}_H{a['head']}" in final_nodes],
            'logit': graph['nodes']['logit']
        }
        
        return {
            'nodes': pruned_nodes,
            'edges': pruned_edges,
            'target_token': graph['target_token'],
            'position': graph['position'],
            'frozen_denoms': graph.get('frozen_denoms', None)
        }
    
    def group_into_supernodes(self, graph: Dict, similarity_threshold: float = 0.7) -> Dict:
        features = graph['nodes']['features']
        if len(features) < 2:
            return graph
        
        groups = []
        used = set()
        
        for i, feat_i in enumerate(features):
            if i in used:
                continue
            
            group = [feat_i]
            used.add(i)
            
            layer_i = feat_i['layer']
            idx_i = feat_i['index']
            
            dec_i = torch.zeros(self.d_model)
            for l in range(layer_i, self.n_layers):
                dec_i += self.clt.decoders[f"{layer_i}_to_{l}"].weight[:, idx_i]
            
            for j, feat_j in enumerate(features):
                if j <= i or j in used:
                    continue
                
                layer_j = feat_j['layer']
                idx_j = feat_j['index']
                
                dec_j = torch.zeros(self.d_model)
                for l in range(layer_j, self.n_layers):
                    dec_j += self.clt.decoders[f"{layer_j}_to_{l}"].weight[:, idx_j]
                
                similarity = F.cosine_similarity(dec_i.unsqueeze(0), dec_j.unsqueeze(0)).item()
                
                if similarity > similarity_threshold:
                    group.append(feat_j)
                    used.add(j)
            
            if len(group) > 1:
                groups.append(group)
        
        supernode_id = 0
        supernode_map = {}
        supernodes = []
        
        for group in groups:
            sn_name = f"SN{supernode_id}"
            supernodes.append({
                'id': sn_name,
                'features': group,
                'size': len(group),
                'layers': list(set(f['layer'] for f in group))
            })
            
            for feat in group:
                orig_name = f"L{feat['layer']}_F{feat['index']}"
                supernode_map[orig_name] = sn_name
            
            supernode_id += 1
        
        for feat in features:
            orig_name = f"L{feat['layer']}_F{feat['index']}"
            if orig_name not in supernode_map:
                sn_name = f"SN{supernode_id}"
                supernodes.append({
                    'id': sn_name,
                    'features': [feat],
                    'size': 1,
                    'layers': [feat['layer']]
                })
                supernode_map[orig_name] = sn_name
                supernode_id += 1
        
        supernode_edges = {}
        for edge in graph['edges']:
            source = supernode_map.get(edge['source'], edge['source'])
            target = supernode_map.get(edge['target'], edge['target'])
            
            key = (source, target, edge['type'])
            if key not in supernode_edges:
                supernode_edges[key] = []
            supernode_edges[key].append(edge['weight'])
        
        aggregated_edges = []
        for (source, target, edge_type), weights in supernode_edges.items():
            total_weight = sum(weights)
            frac_external = 1.0
            
            if abs(total_weight * frac_external) > 0.01:
                aggregated_edges.append({
                    'source': source,
                    'target': target,
                    'weight': total_weight * frac_external,
                    'type': edge_type,
                    'count': len(weights)
                })
        
        return {
            'supernodes': supernodes,
            'edges': aggregated_edges,
            'target_token': graph['target_token'],
            'position': graph['position'],
            'original_graph': graph
        }
    
    def print_summary(self, graph: Dict):
        print(f"\n=== Attribution Graph: Token {graph['target_token']} @ pos {graph['position']} ===")
        print(f"Nodes: {len(graph['nodes']['embeddings'])} emb, {len(graph['nodes']['features'])} feat, "
              f"{len(graph['nodes']['attention_heads'])} attn, {len(graph['nodes']['errors'])} err")
        
        edge_types = {}
        for edge in graph['edges']:
            edge_types[edge['type']] = edge_types.get(edge['type'], 0) + 1
        print(f"Edges: {', '.join(f'{t}:{c}' for t, c in edge_types.items())}")
        
        top_edges = sorted(graph['edges'], key=lambda e: abs(e['weight']), reverse=True)[:5]
        print("Top edges:")
        for e in top_edges:
            print(f" {e['source']}->{e['target']}: {e['weight']:.3f} ({e['type']})")

In [None]:
import torch
import numpy as np
import time
from snowflake.snowpark.context import get_active_session

def get_pile_data(session, max_samples=500, seq_length=32, vocab_size=256):
    """Load and preprocess The Pile dataset from Snowflake stage"""
    print("Loading The Pile from Snowflake stage...")
    
    parquet_df = session.read.option("FILE_FORMAT", "MY_PILE_DATABASE.PUBLIC.PARQUET_FORMAT") \
        .parquet("@MY_PILE_DATABASE.PUBLIC.MY_PILE_STAGE/train-00000-of-00002-9f1d227dc3989035.parquet")
    
    # Use quoted identifier '"text"' to match Parquet column name
    pandas_df = parquet_df.select('"text"').to_pandas()
    
    data = []
    for i, row in pandas_df.iterrows():
        if i >= max_samples:
            break
        
        text = row['text']  # Column name in pandas_df is 'text' (unquoted)
        if not text or not isinstance(text, str):
            continue
        
        # Skip email headers if present
        text = text.split('\n\n', 1)[-1] if '\n\n' in text else text
        tokens = [min(ord(c) % vocab_size, vocab_size-1) for c in text[:seq_length*3]]
        
        if len(tokens) >= seq_length:
            tokens = np.array(tokens[:seq_length*2])
            shuffled_indices = np.random.permutation(len(tokens))
            tokens = tokens[shuffled_indices][:seq_length]
            data.append(torch.tensor(tokens, dtype=torch.long))
        
        if len(data) % 100 == 0 and len(data) > 0:
            print(f" Loaded {len(data)} samples")
    
    if not data:
        raise ValueError("No valid samples loaded from dataset")
    
    return torch.stack(data)

def train_and_evaluate():
    session = get_active_session()
    
    # Set database to MY_PILE_DATABASE
    session.sql("USE DATABASE MY_PILE_DATABASE").collect()
    
    device = torch.device('cpu')
    config = MinimalConfig()
    
    print(f"Config: {config.n_layer}L, {config.n_features}F, {config.n_embd}d")
    print(f"Target training tokens (scaled for CPU): {config.n_training_tokens:,}")
    
    base_model = MinimalTransformer(config).to(device)
    clt = CrossLayerTranscoder(config).to(device)
    
    data = get_pile_data(session, max_samples=500, seq_length=config.block_size)
    data = data.to(device)
    print(f"Dataset: {len(data)} samples")
    
    trainer = CLTTrainer(base_model, clt, config)
    batch_size = 16
    start_time = time.time()
    
    print("\nTraining CLT...")
    print(f"Total training steps: {trainer.total_steps}")
    
    tokens_seen = 0
    for epoch in range(5):
        for i in range(0, len(data), batch_size):
            batch = data[i:i+batch_size]
            if batch.shape[0] < batch_size:
                continue
            
            stats = trainer.train_step(batch)
            tokens_seen += batch_size * config.block_size
            
            if stats['l0'] > 50 and stats['l0'] < 150 and stats['recon'] < 0.1:
                print(f"Early stop: L0={stats['l0']:.1f}, Recon={stats['recon']:.4f}")
                break
            
            if trainer.step % 50 == 0:
                print(f"Step {trainer.step}/{trainer.total_steps}: "
                      f"Loss={stats['loss']:.4f}, L0={stats['l0']:.1f}, "
                      f"Recon={stats['recon']:.4f}, Tokens={tokens_seen:,}")
            
            if trainer.step >= trainer.total_steps:
                break
        
        if trainer.step >= trainer.total_steps:
            break
    
    print(f"\nTraining completed in {time.time() - start_time:.1f}s")
    print(f"Total tokens seen: {tokens_seen:,}")
    
    print("\nComputing Attribution Graphs...")
    graph_builder = AttributionGraph(base_model, clt)
    
    for i in range(min(3, len(data))):
        test_input = data[i:i+1]
        graph = graph_builder.compute_full_graph(test_input, target_pos=-1, use_twera=False)
        print(f"\n--- Sample {i+1} (Full Graph) ---")
        graph_builder.print_summary(graph)
        
        pruned_graph = graph_builder.prune_graph(graph, node_threshold=0.8, edge_threshold=0.98)
        print(f"\n--- Sample {i+1} (Pruned Graph, 80% nodes, 98% edges) ---")
        graph_builder.print_summary(pruned_graph)
        
        supernode_graph = graph_builder.group_into_supernodes(pruned_graph, similarity_threshold=0.7)
        print(f"\n--- Sample {i+1} (Supernode Graph) ---")
        print(f"Supernodes: {len(supernode_graph['supernodes'])}")
        for sn in supernode_graph['supernodes'][:3]:
            print(f" {sn['id']}: {sn['size']} features from layers {sn['layers']}")
        print(f"Aggregated edges: {len(supernode_graph['edges'])}")
        top_sn_edges = sorted(supernode_graph['edges'], key=lambda e: abs(e['weight']), reverse=True)[:3]
        for e in top_sn_edges:
            print(f" {e['source']}->{e['target']}: {e['weight']:.3f} ({e['count']} edges)")
        
        if len(graph_builder.coactivation_counts) > 10:
            graph_twera = graph_builder.compute_full_graph(test_input, target_pos=-1, use_twera=True)
            print(f"\n--- Sample {i+1} (With TWERA filtering) ---")
            graph_builder.print_summary(graph_twera)
    
    print(f"\nTotal runtime: {time.time() - start_time:.1f}s")
    return base_model, clt, graph_builder

if __name__ == "__main__":
    # Install required packages
    try:
        import torch
        import numpy
        import snowflake.snowpark
    except ImportError:
        !pip install torch numpy snowflake-snowpark-python
    
    base_model, clt, graph_builder = train_and_evaluate()
    