In [8]:
import torch
import torch.nn.functional as F
import os
from pathlib import Path

def compute_mean_kl_divergence_from_folders(folder1, folder2):
    """
    Compute mean KL divergence between logits stored in two folders.
    Loads one pair of files at a time to minimize memory usage.
    
    Args:
        folder1: Path to first folder containing logits_*.pt files
        folder2: Path to second folder containing logits_*.pt files
    
    Returns:
        Mean KL divergence across all samples
    """
    # Get all logit files from folder1
    logit_files1 = sorted([f for f in os.listdir(folder1) if f.startswith("logits_") and f.endswith(".pt")])
    logit_files2 = sorted([f for f in os.listdir(folder2) if f.startswith("logits_") and f.endswith(".pt")])
    
    assert len(logit_files1) == len(logit_files2), f"Folders have different number of files: {len(logit_files1)} vs {len(logit_files2)}"
    
    total_kl_sum = 0.0
    total_count = 0
    
    # Process one pair at a time
    for i, (file1, file2) in enumerate(zip(logit_files1, logit_files2)):
        print(f"Processing pair {i+1}/{len(logit_files1)}: {file1}")
        
        # Load one pair
        logits_a = torch.load(os.path.join(folder1, file1))
        logits_b = torch.load(os.path.join(folder2, file2))
        
        assert logits_a.shape == logits_b.shape, f"Shape mismatch: {logits_a.shape} vs {logits_b.shape}"
        
        # Convert logits to log probabilities
        log_probs_a = F.log_softmax(logits_a, dim=-1)
        log_probs_b = F.log_softmax(logits_b, dim=-1)
        
        # Convert logits to probabilities for the first distribution
        probs_a = F.softmax(logits_a, dim=-1)
        
        # Compute KL divergence
        kl_div = (probs_a * (log_probs_a - log_probs_b)).sum(dim=-1)
        
        # Update running statistics
        total_kl_sum += kl_div.sum().item()
        total_count += kl_div.numel()
        
        # Clean up memory
        del logits_a, logits_b, log_probs_a, log_probs_b, probs_a, kl_div
        torch.cuda.empty_cache()  # If using GPU
    
    mean_kl_div = total_kl_sum / total_count
    return mean_kl_div

# Usage
# folder1 = "v1_logits_mistralai_Ministral-8B-Instruct-2410"
folder2 = "logits_mistralai_Ministral-8B-Instruct-2410"
folder1 = "bf16_ablation_logits_mistralai_Ministral-8B-Instruct-2410"
# folder1 = "fp32_ablation_logits_mistralai_Ministral-8B-Instruct-2410"
# folder2 = "ablation_logits_mistralai_Ministral-8B-Instruct-2410"
folder2 = "ablation_logits_orthogonalized_mistralai_Ministral-8B-Instruct-2410"

mean_kl = compute_mean_kl_divergence_from_folders(folder1, folder2)
print(f"Mean KL divergence: {mean_kl:.6f}")

Processing pair 1/9: logits_0.pt
Processing pair 2/9: logits_1.pt
Processing pair 3/9: logits_2.pt
Processing pair 4/9: logits_3.pt
Processing pair 5/9: logits_4.pt
Processing pair 6/9: logits_5.pt
Processing pair 7/9: logits_6.pt
Processing pair 8/9: logits_7.pt
Processing pair 9/9: logits_8.pt
Mean KL divergence: 0.026359
