In [None]:
class FactorGaussian:
    def __init__(self, mean, L_factor, diagonal):
        """Î£ = L @ L.T + diag(diagonal)"""
        self.mean = mean
        self.L = L_factor
        self.diagonal = diagonal

    def covariance(self):
        stable_diagonal = self.diagonal + 1e-8
        return self.L @ self.L.T + th.diag(stable_diagonal)

    def get_diagonal(self):
        l_diag = th.sum(self.L * self.L, dim=1)
        return l_diag + self.diagonal

def propagate_by_sampling(input_distribution, layer_block, mask=None, num_samples=1000, device='cpu'):
    input_mean = input_distribution.mean
    input_cov = input_distribution.covariance()
    input_dim = input_mean.shape[0]
    
    print(f"    - Propagating through layer: {layer_block.__class__.__name__}")
    print(f"    - Input mean shape: {input_mean.shape}")
    print(f"    - Input covariance shape: {input_cov.shape}")
    print(f"    - Number of samples for propagation: {num_samples}")
    
    epsilon = 1e-6 * th.eye(input_dim, device=input_mean.device)
    try:
        input_dist_sampler = MultivariateNormal(loc=input_mean, covariance_matrix=input_cov + epsilon)
    except torch.linalg.LinAlgError:
        print("    - WARNING: Covariance matrix not positive definite. Adding larger epsilon.")
        epsilon = 1e-4 * torch.eye(input_dim, device=input_mean.device)
        input_dist_sampler = MultivariateNormal(loc=input_mean, covariance_matrix=input_cov + epsilon)

    samples = input_dist_sampler.sample(th.Size([num_samples]))
    print(f"    - Drawn samples shape: {samples.shape}")
    
    samples_reshaped = samples.unsqueeze(1)

    with th.no_grad():
        output_samples_reshaped = layer_block(samples_reshaped.to(device))

    output_samples = output_samples_reshaped.squeeze(1)
    print(f"    - Output samples shape after layer block: {output_samples.shape}")

    output_mean = th.mean(output_samples, dim=0)
    mean_subtracted_output = output_samples - output_mean
    full_output_covariance = (1 / (num_samples - 1)) * mean_subtracted_output.T @ mean_subtracted_output
    output_dim = output_mean.shape[0]
    
    print(f"    - Output mean shape: {output_mean.shape}")
    print(f"    - Full output covariance shape: {full_output_covariance.shape}")

    if mask is not None and len(mask) > 0:
        print(f"    - Applying mask with {len(mask)} entries.")
        masked_cov = th.zeros_like(full_output_covariance)
        valid_mask_entries = [(r, c) for r, c in mask if r < output_dim and c < output_dim]
        rows, cols = zip(*valid_mask_entries)
        masked_cov[rows, cols] = full_output_covariance[rows, cols]
    else:
        print("    - No mask provided, using diagonal covariance.")
        masked_cov = th.diag(th.diag(full_output_covariance))
        
    output_diag_variances = th.diag(full_output_covariance)
    off_diag_cov = masked_cov - th.diag(th.diag(masked_cov))

    try:
        eigvals, eigvecs = th.linalg.eigh(off_diag_cov)
        eigvals_positive = th.clamp(eigvals, min=0)
        sqrt_eigvals = th.sqrt(eigvals_positive)
        tol = 1e-6
        rank = th.sum(eigvals_positive > tol).item()
        print(f"    - Off-diagonal covariance matrix rank: {rank}")
        
        if rank > 0:
            output_L = eigvecs[:, -rank:] @ th.diag(sqrt_eigvals[-rank:])
            l_diag_contribution = th.sum(output_L * output_L, dim=1)
            output_diag_remainder = th.clamp(output_diag_variances - l_diag_contribution, min=0)
        else:
            output_L = th.zeros((output_dim, 1), device=output_mean.device)
            output_diag_remainder = output_diag_variances
            print("    - Rank is 0, using diagonal remainder for covariance.")

    except torch.linalg.LinAlgError:
        print("    - WARNING: Eigendecomposition failed. Falling back to diagonal approximation.")
        output_L = th.zeros((output_dim, 1), device=output_mean.device)
        output_diag_remainder = output_diag_variances

    print(f"    - L factor shape: {output_L.shape}")
    print(f"    - Diagonal remainder shape: {output_diag_remainder.shape}")

    return FactorGaussian(mean=output_mean, L_factor=output_L, diagonal=output_diag_remainder)


def gaussian_sampling_estimator(model, orig_dists: list[Discrete], target: int, *, n_samples: int, batch_size: int, n_off_diagonal_entries: int = 0, show_progress: bool = False) -> float:
    """
    """
    # Run the initial samples through the first layer to create the initial distribution we'll draw from
    initial_samples = th.stack([dist.sample((n_samples,)) for dist in orig_dists], dim=1)
    with th.no_grad():
        tok_emb = model.embed(initial_samples)
        pos_emb = model.pos_embed(initial_samples)
        initial_activations = tok_emb + pos_emb
        flat_initial_activations = initial_activations.reshape(-1, model.cfg['d_model'])
    mean = th.mean(flat_initial_activations, dim=0)
    cov = th.cov(flat_initial_activations.T)
    epsilon_diag = 1e-6 * th.eye(cov.shape[0], device=cov.device)
    # Compute Cholesky decomposition for initial covariance (very helpful when we start adding off-diagonal entries)
    L = th.linalg.cholesky(cov + epsilon_diag)
    diag = th.zeros_like(mean) 
    current_dist = FactorGaussian(mean=mean, L_factor=L, diagonal=diag)

    # Propagate with no masking
    print(f"MODEL BLOCKS: {model.blocks}")
    grouped_layers = model.blocks
    masks = []
    all_covariance_entries = []
    temp_dist = current_dist
    for layer_idx, layer_block in enumerate(grouped_layers):
        output_dim = model.cfg['d_model']
        print(f"Output Dim: {output_dim}")
        full_mask = set((r, c) for r in range(output_dim) for c in range(output_dim))
        temp_dist = propagate_by_sampling(temp_dist, layer_block, mask=full_mask, num_samples=10000, device=model.device)
        full_cov = temp_dist.covariance()
        for r in range(output_dim):
            for c in range(r + 1, output_dim):
                all_covariance_entries.append((abs(full_cov[r, c].item()), layer_idx, r, c))
    all_covariance_entries.sort(key=lambda x: x[0], reverse=True)

    # Propagate with masking (just for fun)
    for block in grouped_layers:
        dim = model.cfg['d_model']
        masks.append(set((i, i) for i in range(dim)))
    for _, layer_idx, r, c in all_covariance_entries[:n_off_diagonal_entries]:
        masks[layer_idx].add((r,c))
        masks[layer_idx].add((c,r))
    for i, layer_block in enumerate(grouped_layers):
        current_dist = propagate_by_sampling(current_dist, layer_block, mask=masks[i], num_samples=batch_size, device=model.device)
    # Propagate through final layer norm
    final_ln_mask = set((i, i) for i in range(model.cfg['d_model']))
    current_dist = propagate_by_sampling(current_dist, model.ln_final, mask=final_ln_mask, num_samples=batch_size, device=model.device)

    # 3. Final probability estimation
    final_mean = current_dist.mean
    final_cov = current_dist.covariance()
    final_dist_sampler = MultivariateNormal(loc=final_mean, covariance_matrix=final_cov + epsilon_diag)
    final_act_samples = final_dist_sampler.sample(th.Size([n_samples]))
    print(f"  - Sampled final activations, shape: {final_act_samples.shape}")
    
    # Get logits and probabilities
    with th.no_grad():
        logits = final_act_samples @ model.unembed.W_U
        probs = th.softmax(logits, dim=-1)
    
    print(f"  - Computed final logits and probabilities, shape: {probs.shape}")

    # Average probability of the target token
    target_probs = probs[:, target]
    estimated_prob = target_probs.mean().item()
    print(f"\n--- Estimated Probability for target {target}: {estimated_prob:.6e} ---")
    
    return estimated_prob

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]