# Approach 3: Autoencoder-based Alignment with MLP Prober


## Pipeline Overview:
1. **Autoencoder for Teacher**: Learn to compress Teacher activations to latent dimension ($X_T \to Z_T$)
2. **Autoencoder for Student**: Learn to compress Student activations to the same latent dimension ($X_S \to Z_S$)
3. **MLP Prober on Teacher**: Train an MLP classifier on the reduced Teacher space
4. **Alignment Network**: Learn to align Student's latent space to Teacher's latent space ($Z_S \to Z_T$)
5. **Evaluation**: Test the aligned Student representations on the Teacher's MLP prober



In [1]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import gc
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
import traceback
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import random

# ==================================================================
# REPRODUCIBILITY SETTINGS
# ==================================================================
SEED = 42

def set_seed(seed=SEED):
    """Set all seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

# Set seeds at import time
set_seed(SEED)

In [2]:
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
CACHE_DIR_NAME = "activation_cache"
HF_DEFAULT_HOME = os.environ.get("HF_HOME", "~\\.cache\\huggingface\\hub")

# Layer configuration
LAYER_CONFIG = {
    "Qwen2.5-7B": 
    {
        "attn": [15,16,18],
        "mlp":[16,18,20],
        "hidden": [18,19,20]
    },    
    "Falcon3-7B-Base": 
    {
        "attn": [2,7,12],
        "mlp":[10,11,12],
        "hidden": [2,3,19]
    }
}

# ==================================================================
# AUTOENCODER CONFIGURATION
# ==================================================================
AUTOENCODER_CONFIG = {
    "latent_dim": 128,
    "hidden_dim": 256,
    "dropout": 0.2,
    "learning_rate": 1e-3,
    "weight_decay": 0.01,
    "batch_size": 64,
    "max_epochs": 300,
    "early_stopping_patience": 30,
    "early_stopping_min_delta": 1e-4,
    "gradient_clip_max_norm": 1.0,
    "optimizer": "AdamW",
    "scheduler": "CosineAnnealingLR",
    "loss_function": "MSELoss"
}

# ==================================================================
# ALIGNMENT CONFIGURATION
# ==================================================================
ALIGNMENT_CONFIG = {
    "hidden_dim": 256,
    "dropout": 0.3,
    "learning_rate": 1e-3,
    "weight_decay": 0.01,
    "batch_size": 32,
    "max_epochs": 500,
    "early_stopping_patience": 50,
    "early_stopping_min_delta": 1e-4,
    "gradient_clip_max_norm": 1.0,
    "optimizer": "AdamW",
    "scheduler": "CosineAnnealingLR",
    "loss_alpha": 0.5,  # MSE weight
    "loss_beta": 0.5    # Cosine weight
}

# ==================================================================
# MLP PROBER CONFIGURATION
# ==================================================================
PROBER_CONFIG = {
    "type": "MLPProber",
    "hidden_dim": 64,
    "dropout": 0.3,
    "learning_rate": 1e-3,
    "weight_decay": 0.01,
    "batch_size": 64,
    "max_epochs": 200,
    "early_stopping_patience": 30,
    "early_stopping_min_delta": 1e-4,
    "gradient_clip_max_norm": 1.0,
    "optimizer": "AdamW",
    "scheduler": "CosineAnnealingLR",
    "loss_function": "BCEWithLogitsLoss",
    "use_class_weights": True
}

### Dataset preparation

In [3]:
def stats_per_json(model_name, dataset_name):
    file_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name,"generations","hallucination_labels.json")
    with open(file_path, 'r') as file:
        data = json.load(file)
    total = len(data)
    hallucinations = sum(1 for item in data if item['is_hallucination'])
    percent_hallucinations = (hallucinations / total) * 100 if total > 0 else 0
    allucinated_items = [item['instance_id'] for item in data if item['is_hallucination']]
    return {
        'total': total,
        'hallucinations': hallucinations,
        'percent_hallucinations': percent_hallucinations,
        'hallucinated_items': allucinated_items,
        'model_name': model_name,
        'dataset_name': dataset_name
    }


qwen_stats=stats_per_json("Qwen2.5-7B", "belief_bank")
falcon_stats=stats_per_json("Falcon3-7B-Base", "belief_bank")

### Model Definitions

In [4]:
# ------------------------------------------------------------------
# 1. Dataset class for Autoencoder Training
# ------------------------------------------------------------------
class AutoencoderDataset(Dataset):
    def __init__(self, X: torch.Tensor):
        self.X = X
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx]

# ------------------------------------------------------------------
# 2. Dataset class for Alignment
# ------------------------------------------------------------------
class AlignmentDataset(Dataset):
    def __init__(self, x_source: torch.Tensor, x_target: torch.Tensor):
        self.x_source = x_source
        self.x_target = x_target
    
    def __len__(self):
        return self.x_source.shape[0]
    
    def __getitem__(self, idx):
        return self.x_source[idx], self.x_target[idx]

# ------------------------------------------------------------------
# 3. Dataset class for Classification
# ------------------------------------------------------------------
class ClassificationDataset(Dataset):
    def __init__(self, X: torch.Tensor, y: torch.Tensor):
        self.X = X
        self.y = y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ------------------------------------------------------------------
# 4. Autoencoder for Dimensionality Reduction
# ------------------------------------------------------------------
class Autoencoder(nn.Module):
    def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int = 256, dropout: float = 0.2):
        super().__init__()
        
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, latent_dim),
            nn.LayerNorm(latent_dim),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, input_dim),
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon, z
    

# ------------------------------------------------------------------
# 5. AlignmentNetwork
# ------------------------------------------------------------------
class AlignmentNetwork(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 256, dropout: float = 0.3):
        super().__init__()
        
        if input_dim != output_dim:
            self.input_proj = nn.Linear(input_dim, output_dim, bias=False)
        else:
            self.input_proj = nn.Identity()

        self.net = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
            nn.Dropout(dropout)
        )
        
        self._init_zero()

    def _init_zero(self):
        nn.init.zeros_(self.net[-2].weight)
        if self.net[-2].bias is not None:
            nn.init.zeros_(self.net[-2].bias)

    def forward(self, x):
        x_base = self.input_proj(x)
        return x_base + self.net(x_base)
    

# ------------------------------------------------------------------
# 6. MLP Prober
# ------------------------------------------------------------------
class MLPProber(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 512, dropout: float = 0.3):
        super().__init__()
        
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        return self.net(x).squeeze(-1)
    
    def predict(self, x):
        with torch.no_grad():
            logits = self.forward(x)
            return (torch.sigmoid(logits) > 0.5).long()
    
    def predict_proba(self, x):
        with torch.no_grad():
            logits = self.forward(x)
            return torch.sigmoid(logits)


# ------------------------------------------------------------------
# 7. MixedLoss for Alignment
# ------------------------------------------------------------------
class MixedLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.mse = nn.MSELoss()

    def forward(self, pred, target):
        loss_mse = self.mse(pred, target)
        cosine_sim = F.cosine_similarity(pred, target, dim=1).mean()
        loss_cosine = 1 - cosine_sim
        return self.alpha * loss_mse + self.beta * loss_cosine

### Training Functions

In [5]:
def load_and_split_layers(model_name, dataset_name, layer_indices, type_layer, stats, train_indices, test_indices):
    """Standard loading in RAM (without memmap)."""
    print(f" Loading IN-MEMORY {model_name} [{type_layer}]: layers {layer_indices}...")

    total_samples = stats['total']
    hallucinated_set = set(stats['hallucinated_items'])

    # Label
    y_full = np.zeros(total_samples, dtype=np.int8)
    y_full[list(hallucinated_set)] = 1
    y_train = y_full[train_indices]
    y_test  = y_full[test_indices]

    # Load and concatenate
    all_features = []
    
    for layer_idx in layer_indices:
        file_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name,
                                 "activation_"+type_layer, f"layer{layer_idx}_activations.pt")
        if not os.path.exists(file_path):
            print(f" Warning: Layer {layer_idx} not found. Skipping.")
            continue

        print(f"  Loading layer {layer_idx}...", end=" ")
        acts = torch.load(file_path, map_location='cpu')
        
        if acts.shape[0] > total_samples:
            acts = acts[:total_samples]

        if isinstance(acts, torch.Tensor):
            X_layer = acts.float().numpy() 
        else:
            X_layer = acts.astype(np.float32)

        if X_layer.ndim > 2:
            X_layer = X_layer.reshape(X_layer.shape[0], -1)
            
        all_features.append(X_layer)
        print(f"done ({X_layer.shape})")
        
        del acts
        gc.collect()

    if not all_features:
        raise ValueError(f"No valid layers found for {model_name}")

    print(" Concatenating layers...")
    X_full = np.concatenate(all_features, axis=1)
    
    X_train = X_full[train_indices]
    X_test  = X_full[test_indices]
    
    print(f" Completed! Train: {X_train.shape}, Test: {X_test.shape}")

    return X_train, X_test, y_train, y_test


def get_generator(seed=SEED):
    """Create a reproducible generator for DataLoader"""
    g = torch.Generator()
    g.manual_seed(seed)
    return g


def train_autoencoder(X_train, X_val, input_dim, device, model_name, autoencoder_config=AUTOENCODER_CONFIG):
    """Train autoencoder for dimensionality reduction with early stopping."""
    
    latent_dim = autoencoder_config['latent_dim']
    hidden_dim = autoencoder_config['hidden_dim']
    
    print(f"   Training Autoencoder for {model_name} ({input_dim} -> {latent_dim})...")
    
    autoencoder = Autoencoder(
        input_dim=input_dim,
        latent_dim=latent_dim,
        hidden_dim=hidden_dim,
        dropout=autoencoder_config['dropout']
    ).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(
        autoencoder.parameters(), 
        lr=autoencoder_config['learning_rate'], 
        weight_decay=autoencoder_config['weight_decay']
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=autoencoder_config['max_epochs'])
    
    # Create dataloaders
    train_dataset = AutoencoderDataset(X_train)
    val_dataset = AutoencoderDataset(X_val)
    train_loader = DataLoader(train_dataset, batch_size=autoencoder_config['batch_size'], 
                             shuffle=True, num_workers=0, generator=get_generator())
    val_loader = DataLoader(val_dataset, batch_size=autoencoder_config['batch_size'], 
                           shuffle=False, num_workers=0)
    
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    
    for epoch in range(autoencoder_config['max_epochs']):
        # Training
        autoencoder.train()
        epoch_loss = 0.0
        for X_batch in train_loader:
            optimizer.zero_grad()
            X_recon, _ = autoencoder(X_batch)
            loss = criterion(X_recon, X_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                autoencoder.parameters(), 
                max_norm=autoencoder_config['gradient_clip_max_norm']
            )
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_train_loss = epoch_loss / len(train_loader)
        
        # Validation
        autoencoder.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_batch in val_loader:
                X_recon, _ = autoencoder(X_batch)
                loss = criterion(X_recon, X_batch)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()
        
        if (epoch + 1) % 30 == 0:
            print(f"     Epoch {epoch+1:3d}/{autoencoder_config['max_epochs']} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")
        
        # Early Stopping
        if avg_val_loss < best_val_loss - autoencoder_config['early_stopping_min_delta']:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = autoencoder.state_dict().copy()
        else:
            patience_counter += 1
        
        if patience_counter >= autoencoder_config['early_stopping_patience']:
            print(f"     Early stopping at epoch {epoch+1}. Best Val Loss: {best_val_loss:.6f}")
            break
    
    # Load best model
    if best_model_state is not None:
        autoencoder.load_state_dict(best_model_state)
    
    print(f"   ✓ Autoencoder trained. Final Val Loss: {best_val_loss:.6f}")
    return autoencoder, best_val_loss, epoch + 1


def train_mlp_prober(X_train, y_train, X_val, y_val, input_dim, device, prober_config=PROBER_CONFIG):
    """Train MLP prober with early stopping based on validation accuracy."""
    
    prober = MLPProber(
        input_dim=input_dim, 
        hidden_dim=prober_config['hidden_dim'], 
        dropout=prober_config['dropout']
    ).to(device)
    
    # Compute class weights for imbalanced data
    if prober_config['use_class_weights']:
        n_pos = y_train.sum()
        n_neg = len(y_train) - n_pos
        pos_weight = torch.tensor([n_neg / n_pos]).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    else:
        criterion = nn.BCEWithLogitsLoss()
    
    optimizer = optim.AdamW(
        prober.parameters(), 
        lr=prober_config['learning_rate'], 
        weight_decay=prober_config['weight_decay']
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=prober_config['max_epochs'])
    
    # Create dataloaders
    train_dataset = ClassificationDataset(X_train, y_train)
    val_dataset = ClassificationDataset(X_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=prober_config['batch_size'], 
                             shuffle=True, num_workers=0, generator=get_generator())
    val_loader = DataLoader(val_dataset, batch_size=prober_config['batch_size'], 
                           shuffle=False, num_workers=0)
    
    best_val_acc = 0.0
    patience_counter = 0
    best_model_state = None
    
    for epoch in range(prober_config['max_epochs']):
        # Training
        prober.train()
        epoch_loss = 0.0
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            logits = prober(X_batch)
            loss = criterion(logits, y_batch.float())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                prober.parameters(), 
                max_norm=prober_config['gradient_clip_max_norm']
            )
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_train_loss = epoch_loss / len(train_loader)
        
        # Validation
        prober.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                preds = prober.predict(X_batch)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(y_batch.cpu().numpy())
        
        val_f1 = f1_score(all_labels, all_preds)
        val_acc = accuracy_score(all_labels, all_preds)
        
        scheduler.step()
        
        if (epoch + 1) % 20 == 0:
            print(f"     Epoch {epoch+1:3d}/{prober_config['max_epochs']} | Train Loss: {avg_train_loss:.4f} | Val F1: {val_f1:.4f} | Val Acc: {val_acc:.4f}")
        
        # Early Stopping based on accuracy
        if val_acc > best_val_acc + prober_config['early_stopping_min_delta']:
            best_val_acc = val_acc
            patience_counter = 0
            best_model_state = prober.state_dict().copy()
        else:
            patience_counter += 1
        
        if patience_counter >= prober_config['early_stopping_patience']:
            print(f"     Early stopping at epoch {epoch+1}. Best Val Acc: {best_val_acc:.4f}")
            break
    
    # Load best model
    if best_model_state is not None:
        prober.load_state_dict(best_model_state)
    
    return prober, best_val_acc, epoch + 1


def train_alignment_network(X_source_train, X_target_train, X_source_val, X_target_val, 
                            latent_dim, device, alignment_config=ALIGNMENT_CONFIG):
    """Train alignment network to map student latent space to teacher latent space."""
    
    print("   Training Alignment Network...")
    
    aligner = AlignmentNetwork(
        input_dim=latent_dim,
        output_dim=latent_dim,
        hidden_dim=alignment_config['hidden_dim'],
        dropout=alignment_config['dropout']
    ).to(device)
    
    criterion = MixedLoss(alpha=alignment_config['loss_alpha'], beta=alignment_config['loss_beta'])
    optimizer = optim.AdamW(
        aligner.parameters(), 
        lr=alignment_config['learning_rate'], 
        weight_decay=alignment_config['weight_decay']
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=alignment_config['max_epochs'])
    
    # Create dataloaders
    train_dataset = AlignmentDataset(X_source_train, X_target_train)
    val_dataset = AlignmentDataset(X_source_val, X_target_val)
    train_loader = DataLoader(train_dataset, batch_size=alignment_config['batch_size'], 
                             shuffle=True, num_workers=0, generator=get_generator())
    val_loader = DataLoader(val_dataset, batch_size=alignment_config['batch_size'], 
                           shuffle=False, num_workers=0)
    
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    
    for epoch in range(alignment_config['max_epochs']):
        # Training
        aligner.train()
        epoch_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            projected = aligner(data)
            loss = criterion(projected, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                aligner.parameters(), 
                max_norm=alignment_config['gradient_clip_max_norm']
            )
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_train_loss = epoch_loss / len(train_loader)
        
        # Validation
        aligner.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                projected = aligner(data)
                loss = criterion(projected, target)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()
        
        if (epoch + 1) % 50 == 0:
            print(f"     Epoch {epoch+1:3d}/{alignment_config['max_epochs']} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")
        
        # Early Stopping
        if avg_val_loss < best_val_loss - alignment_config['early_stopping_min_delta']:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = aligner.state_dict().copy()
        else:
            patience_counter += 1
        
        if patience_counter >= alignment_config['early_stopping_patience']:
            print(f"     Early stopping at epoch {epoch+1}. Best Val Loss: {best_val_loss:.6f}")
            break
    
    # Load best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)
    
    print(f"   ✓ Alignment Network trained. Final Val Loss: {best_val_loss:.6f}")
    return aligner, best_val_loss, epoch + 1

### Main Experiment Pipeline (with End-to-End Fine-Tuning)

In [6]:
def run_experiment_pipeline_with_autoencoder(X_teacher, y_teacher, teacher_name,
                                              X_student, y_student, student_name, 
                                              layer_type, config_name,
                                              autoencoder_config=AUTOENCODER_CONFIG,
                                              alignment_config=ALIGNMENT_CONFIG,
                                              prober_config=PROBER_CONFIG):
    
    print(f"\n{'='*70}")
    print(f"EXPERIMENT: {layer_type.upper()} → {teacher_name} ← {student_name}")
    print(f"Using Autoencoder with latent_dim={autoencoder_config['latent_dim']}, hidden_dim={autoencoder_config['hidden_dim']}")
    print(f"{'='*70}")

    X_A_train_full, X_A_test = X_teacher['X_train'], X_teacher['X_test']
    y_A_train_full, y_A_test = y_teacher['y_train'], y_teacher['y_test']
    X_B_train_full, X_B_test = X_student['X_train'], X_student['X_test']
    y_B_train_full, y_B_test = y_student['y_train'], y_student['y_test']

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --------------------------------------------------
    # 1. Train Autoencoder for Teacher
    # --------------------------------------------------
    print("\n1. Training Autoencoder for TEACHER...")
    
    num_train = len(X_A_train_full)
    indices = np.arange(num_train)
    np.random.seed(42)
    np.random.shuffle(indices)
    ae_val_size = int(num_train * 0.15)
    ae_train_idx = indices[ae_val_size:]
    ae_val_idx = indices[:ae_val_size]
    X_A_ae_train = torch.from_numpy(X_A_train_full[ae_train_idx]).float().to(device)
    X_A_ae_val = torch.from_numpy(X_A_train_full[ae_val_idx]).float().to(device)

    ae_teacher, ae_teacher_loss, ae_teacher_epochs = train_autoencoder(
        X_A_ae_train, X_A_ae_val,
        input_dim=X_A_train_full.shape[1],
        device=device,
        model_name=teacher_name,
        autoencoder_config=autoencoder_config
    )

    # --------------------------------------------------
    # 2. Train Autoencoder for Student
    # --------------------------------------------------
    print("\n2. Training Autoencoder for STUDENT...")

    X_B_ae_train = torch.from_numpy(X_B_train_full[ae_train_idx]).float().to(device)
    X_B_ae_val = torch.from_numpy(X_B_train_full[ae_val_idx]).float().to(device)

    ae_student, ae_student_loss, ae_student_epochs = train_autoencoder(
        X_B_ae_train, X_B_ae_val,
        input_dim=X_B_train_full.shape[1],
        device=device,
        model_name=student_name,
        autoencoder_config=autoencoder_config
    )

    # --------------------------------------------------
    # 3. Encode all data to latent space
    # --------------------------------------------------
    print("\n3. Encoding data to latent space...")

    ae_teacher.eval()
    ae_student.eval()

    with torch.no_grad():
        # Teacher encodings
        X_A_train_full_t = torch.from_numpy(X_A_train_full).float().to(device)
        X_A_test_t = torch.from_numpy(X_A_test).float().to(device)
        Z_A_train = ae_teacher.encode(X_A_train_full_t)
        Z_A_test = ae_teacher.encode(X_A_test_t)
        
        # Student encodings
        X_B_train_full_t = torch.from_numpy(X_B_train_full).float().to(device)
        X_B_test_t = torch.from_numpy(X_B_test).float().to(device)
        Z_B_train = ae_student.encode(X_B_train_full_t)
        Z_B_test = ae_student.encode(X_B_test_t)

    print(f"   Teacher latent shape: {Z_A_train.shape}")
    print(f"   Student latent shape: {Z_B_train.shape}")

    # --------------------------------------------------
    # 4. Train MLP Prober on Teacher's Latent Space
    # --------------------------------------------------
    print("\n4. Training MLP Prober on Teacher's latent space...")

    prober_val_size = int(num_train * 0.15)
    prober_train_idx = indices[prober_val_size:]
    prober_val_idx = indices[:prober_val_size]

    Z_A_prober_train = Z_A_train[prober_train_idx]
    y_A_prober_train = torch.from_numpy(y_A_train_full[prober_train_idx]).long().to(device)
    Z_A_prober_val = Z_A_train[prober_val_idx]
    y_A_prober_val = torch.from_numpy(y_A_train_full[prober_val_idx]).long().to(device)

    probe_teacher, best_prober_acc, prober_epochs = train_mlp_prober(
        Z_A_prober_train, y_A_prober_train,
        Z_A_prober_val, y_A_prober_val,
        input_dim=autoencoder_config['latent_dim'],
        device=device,
        prober_config=prober_config
    )
    print(f"   Best prober validation Acc: {best_prober_acc:.4f}")

    # --- Teacher Metrics ---
    probe_teacher.eval()
    y_pred_teacher = probe_teacher.predict(Z_A_test).cpu().numpy()

    cm_teacher = confusion_matrix(y_A_test, y_pred_teacher)
    acc_teacher = accuracy_score(y_A_test, y_pred_teacher)
    prec_teacher = precision_score(y_A_test, y_pred_teacher)
    rec_teacher = recall_score(y_A_test, y_pred_teacher)
    f1_teacher = f1_score(y_A_test, y_pred_teacher)
    print(f"   Teacher Test Acc: {acc_teacher:.4f}, F1: {f1_teacher:.4f}")

    # --------------------------------------------------
    # 5. Train Alignment Network
    # --------------------------------------------------
    print("\n5. Training Alignment Network (Student → Teacher latent space)...")

    align_val_size = int(num_train * 0.1)
    align_train_idx = indices[align_val_size:]
    align_val_idx = indices[:align_val_size]

    Z_B_align_train = Z_B_train[align_train_idx]
    Z_A_align_train = Z_A_train[align_train_idx]
    Z_B_align_val = Z_B_train[align_val_idx]
    Z_A_align_val = Z_A_train[align_val_idx]

    aligner, align_loss, align_epochs = train_alignment_network(
        Z_B_align_train, Z_A_align_train,
        Z_B_align_val, Z_A_align_val,
        latent_dim=autoencoder_config['latent_dim'],
        device=device,
        alignment_config=alignment_config
    )

    # --------------------------------------------------
    # 6. Save Models
    # --------------------------------------------------
    print("\n6. Saving models...")

    model_save_dir = os.path.join("models", layer_type)
    os.makedirs(model_save_dir, exist_ok=True)

    # Save Teacher Autoencoder
    ae_teacher_filename = os.path.join(model_save_dir, f"{config_name}_autoencoder_{teacher_name}.pt")
    torch.save({
        'model_state_dict': ae_teacher.state_dict(),
        'autoencoder_config': autoencoder_config,
        'input_dim': X_A_train_full.shape[1],
        'latent_dim': autoencoder_config['latent_dim'],
        'best_val_loss': ae_teacher_loss,
        'epochs_trained': ae_teacher_epochs,
        'model_name': teacher_name,
    }, ae_teacher_filename)
    print(f"   ✓ Teacher Autoencoder saved: {ae_teacher_filename}")

    # Save Student Autoencoder
    ae_student_filename = os.path.join(model_save_dir, f"{config_name}_autoencoder_{student_name}.pt")
    torch.save({
        'model_state_dict': ae_student.state_dict(),
        'autoencoder_config': autoencoder_config,
        'input_dim': X_B_train_full.shape[1],
        'latent_dim': autoencoder_config['latent_dim'],
        'best_val_loss': ae_student_loss,
        'epochs_trained': ae_student_epochs,
        'model_name': student_name,
    }, ae_student_filename)
    print(f"   ✓ Student Autoencoder saved: {ae_student_filename}")

    # Save MLP Prober
    prober_filename = os.path.join(model_save_dir, f"{config_name}_mlp_prober_{teacher_name}.pt")
    torch.save({
        'model_state_dict': probe_teacher.state_dict(),
        'prober_config': prober_config,
        'input_dim': autoencoder_config['latent_dim'],
        'best_val_acc': best_prober_acc,
        'epochs_trained': prober_epochs,
        'teacher_model': teacher_name,
    }, prober_filename)
    print(f"   ✓ MLP Prober saved: {prober_filename}")

    # Save Alignment Network
    aligner_filename = os.path.join(model_save_dir, f"{config_name}_aligner_{student_name}_to_{teacher_name}.pt")
    torch.save({
        'model_state_dict': aligner.state_dict(),
        'alignment_config': alignment_config,
        'input_dim': autoencoder_config['latent_dim'],
        'output_dim': autoencoder_config['latent_dim'],
        'best_val_loss': align_loss,
        'epochs_trained': align_epochs,
        'student_model': student_name,
        'teacher_model': teacher_name,
    }, aligner_filename)
    print(f"   ✓ Alignment Network saved: {aligner_filename}")

    # --------------------------------------------------
    # 7. Evaluation
    # --------------------------------------------------
    print("\n7. Projecting student test set & evaluating...")

    aligner.eval()
    with torch.no_grad():
        Z_B_aligned = aligner(Z_B_test)

    y_pred_cross = probe_teacher.predict(Z_B_aligned).cpu().numpy()

    # --- Cross-Model Metrics ---
    cm_cross = confusion_matrix(y_B_test, y_pred_cross)
    acc_cross = accuracy_score(y_B_test, y_pred_cross)
    prec_cross = precision_score(y_B_test, y_pred_cross)
    rec_cross = recall_score(y_B_test, y_pred_cross)
    f1_cross = f1_score(y_B_test, y_pred_cross)

    print(f"\n{'='*50}")
    print(f"FINAL RESULT:")
    print(f"{'='*50}")
    print(f"   Teacher Acc          : {acc_teacher:.4f}, F1: {f1_teacher:.4f}")
    print(f"   Student → Teacher Acc: {acc_cross:.4f}, F1: {f1_cross:.4f}")
    print(f"   Transfer gap (Acc)   : {acc_teacher - acc_cross:.4f}")
    print(f"   Transfer gap (F1)    : {f1_teacher - f1_cross:.4f}")

    return {
        "type": layer_type,
        "teacher_name": teacher_name,
        "student_name": student_name,
        "autoencoder_teacher": {
            "input_dim": X_A_train_full.shape[1],
            "config": autoencoder_config,
            "best_val_loss": float(ae_teacher_loss),
            "epochs_trained": ae_teacher_epochs,
            "model_path": ae_teacher_filename
        },
        "autoencoder_student": {
            "input_dim": X_B_train_full.shape[1],
            "config": autoencoder_config,
            "best_val_loss": float(ae_student_loss),
            "epochs_trained": ae_student_epochs,
            "model_path": ae_student_filename
        },
        "prober_model": {
            "input_dim": autoencoder_config['latent_dim'],
            "config": prober_config,
            "best_val_acc": float(best_prober_acc),
            "epochs_trained": prober_epochs,
            "model_path": prober_filename
        },
        "alignment_model": {
            "input_dim": autoencoder_config['latent_dim'],
            "output_dim": autoencoder_config['latent_dim'],
            "config": alignment_config,
            "best_val_loss": float(align_loss),
            "epochs_trained": align_epochs,
            "model_path": aligner_filename
        },
        "teacher": {
            "accuracy": acc_teacher,
            "precision": prec_teacher,
            "recall": rec_teacher,
            "f1": f1_teacher,
            "confusion_matrix": cm_teacher.tolist()
        },
        "student_on_teacher": {
            "accuracy": acc_cross,
            "precision": prec_cross,
            "recall": rec_cross,
            "f1": f1_cross,
            "confusion_matrix": cm_cross.tolist()
        }
    }
def plot_confusion_matrix(cm, layer_type, model_name="", save_dir="confusion_matrices"):
    """Plot and save confusion matrix as image."""
    os.makedirs(save_dir, exist_ok=True)
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True, ax=ax,
                xticklabels=['Non-Hallucinated', 'Hallucinated'],
                yticklabels=['Non-Hallucinated', 'Hallucinated'])
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    title = f'Confusion Matrix - {layer_type.upper()} Layers'
    if model_name:
        title += f' ({model_name})'
    ax.set_title(title)

    plt.tight_layout()
    filename = os.path.join(save_dir, f'confusion_matrix_{layer_type}_{model_name}.png' if model_name else f'confusion_matrix_{layer_type}.png')
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"   ✓ Saved: {filename}")

### Run Experiments

In [7]:
print("="*80)
print("PHASE 1: PRE-LOADING AND SPLITTING DATA (same shuffled indices for ALL layer types)")
print("="*80 + "\n")

set_seed(SEED)

n_samples = qwen_stats['total'] 
rng = np.random.RandomState(SEED)
shuffled_indices = rng.permutation(n_samples)
split_idx = int(0.7 * n_samples)

train_indices = shuffled_indices[:split_idx]
test_indices = shuffled_indices[split_idx:]

print(f"Train/Test split: {len(train_indices)}/{len(test_indices)} samples")
print(f"Using SEED={SEED} for reproducibility")
print(f"Using LATENT_DIM={AUTOENCODER_CONFIG['latent_dim']}")

scenarios = [
    {"teacher_model": "Qwen2.5-7B", "student_model": "Falcon3-7B-Base"},
    {"teacher_model": "Falcon3-7B-Base", "student_model": "Qwen2.5-7B"}
]

scenario_results_map = {0: [], 1: []}

for layer_type in ['attn', 'mlp', 'hidden']:
    print(f"\n{'='*40}")
    print(f"PROCESSING LAYER TYPE: {layer_type.upper()}")
    print(f"{'='*40}")
    gc.collect()
    torch.cuda.empty_cache()
    
    try:
        X_qwen_train, X_qwen_test, y_qwen_train, y_qwen_test = load_and_split_layers(
            "Qwen2.5-7B", "belief_bank", 
            LAYER_CONFIG["Qwen2.5-7B"][layer_type], 
            layer_type, qwen_stats,
            train_indices, test_indices
        )

        X_falcon_train, X_falcon_test, y_falcon_train, y_falcon_test = load_and_split_layers(
            "Falcon3-7B-Base", "belief_bank", 
            LAYER_CONFIG["Falcon3-7B-Base"][layer_type], 
            layer_type, falcon_stats,
            train_indices, test_indices
        )
        
        print("   Normalizing data...")
        scaler_qwen = StandardScaler()
        X_qwen_train = scaler_qwen.fit_transform(X_qwen_train).astype(np.float32)
        X_qwen_test = scaler_qwen.transform(X_qwen_test).astype(np.float32)
        
        scaler_falcon = StandardScaler()
        X_falcon_train = scaler_falcon.fit_transform(X_falcon_train).astype(np.float32)
        X_falcon_test = scaler_falcon.transform(X_falcon_test).astype(np.float32)
        
        current_data = {
            "qwen": {"X_train": X_qwen_train, "X_test": X_qwen_test, "y_train": y_qwen_train, "y_test": y_qwen_test},
            "falcon": {"X_train": X_falcon_train, "X_test": X_falcon_test, "y_train": y_falcon_train, "y_test": y_falcon_test}
        }

        for i, scenario in enumerate(scenarios):
            print(f"\n   --- Scenario: {scenario['teacher_model']} (Teacher) <- {scenario['student_model']} (Student) ---")
            
            set_seed(SEED)
            
            if scenario['teacher_model'] == "Qwen2.5-7B":
                X_teacher_data = current_data['qwen']
                X_student_data = current_data['falcon']
            else:
                X_teacher_data = current_data['falcon']
                X_student_data = current_data['qwen']
            
            res = run_experiment_pipeline_with_autoencoder(
                X_teacher_data, X_teacher_data, scenario['teacher_model'],
                X_student_data, X_student_data, scenario['student_model'],
                layer_type, "CONFIG1"
            )
            scenario_results_map[i].append(res)
            
            plot_confusion_matrix(
                np.array(res['teacher']['confusion_matrix']), 
                layer_type, 
                f"Teacher_{scenario['teacher_model'].replace('.', '_').replace('-', '_')}"
            )
            plot_confusion_matrix(
                np.array(res['student_on_teacher']['confusion_matrix']), 
                layer_type, 
                f"{scenario['student_model'].replace('.', '_').replace('-', '_')}_on_{scenario['teacher_model'].replace('.', '_').replace('-', '_')}"
            )

        del current_data, X_qwen_train, X_qwen_test, X_falcon_train, X_falcon_test
        del scaler_qwen, scaler_falcon
        gc.collect()
        torch.cuda.empty_cache()
        print(f"   Memory freed for {layer_type}.")

    except Exception as e:
        print(f"Critical error in layer {layer_type}: {e}")
        traceback.print_exc()
        continue

# Reconstruct all_results
all_results = []
for i, scenario in enumerate(scenarios):
    all_results.append({
        "scenario": f"{scenario['teacher_model']} (teacher) → {scenario['student_model']} (student)",
        "results": scenario_results_map[i]
    })

# Save JSON
os.makedirs("results_metrics", exist_ok=True)
metrics_file = "results_metrics/experiment_results_all_scenarios_autoencoder.json"

all_results_json = []

for scenario_data in all_results:
    scenario_results = []
    
    for r in scenario_data['results']:
        ae_t_config = r['autoencoder_teacher']['config']
        ae_s_config = r['autoencoder_student']['config']
        prober_config = r['prober_model']['config']
        align_config = r['alignment_model']['config']
        
        result_entry = {
            "layer_type": r['type'],
            "teacher_model": r['teacher_name'],
            "student_model": r['student_name'],
            
            # ==================== TEACHER AUTOENCODER ====================
            "teacher_autoencoder": {
                "training_hyperparameters": {
                    "optimizer": ae_t_config['optimizer'],
                    "learning_rate": ae_t_config['learning_rate'],
                    "weight_decay": ae_t_config['weight_decay'],
                    "batch_size": ae_t_config['batch_size'],
                    "max_epochs": ae_t_config['max_epochs'],
                    "scheduler": ae_t_config['scheduler'],
                    "gradient_clip_max_norm": ae_t_config['gradient_clip_max_norm'],
                    "early_stopping_patience": ae_t_config['early_stopping_patience'],
                    "early_stopping_min_delta": ae_t_config['early_stopping_min_delta'],
                    "loss_function": ae_t_config['loss_function']
                },
                "training_results": {
                    "best_val_loss": round(r['autoencoder_teacher']['best_val_loss'], 6),
                    "epochs_trained": r['autoencoder_teacher']['epochs_trained'],
                    "model_saved_path": r['autoencoder_teacher']['model_path']
                }
            },
            
            # ==================== STUDENT AUTOENCODER ====================
            "student_autoencoder": {
                "architecture": {                               # aggiunto per chiarezza
                    "input_dim": r['autoencoder_student']['input_dim'],
                    "latent_dim": ae_s_config['latent_dim'],
                    "hidden_dim": ae_s_config['hidden_dim'],
                    "dropout": ae_s_config['dropout']
                },
                "training_hyperparameters": {
                    "optimizer": ae_s_config['optimizer'],
                    "learning_rate": ae_s_config['learning_rate'],
                    "weight_decay": ae_s_config['weight_decay'],
                    "batch_size": ae_s_config['batch_size'],
                    "max_epochs": ae_s_config['max_epochs'],
                    "scheduler": ae_s_config['scheduler'],
                    "gradient_clip_max_norm": ae_s_config['gradient_clip_max_norm'],
                    "early_stopping_patience": ae_s_config['early_stopping_patience'],
                    "early_stopping_min_delta": ae_s_config['early_stopping_min_delta'],
                    "loss_function": ae_s_config['loss_function']
                },
                "training_results": {
                    "best_val_loss": round(r['autoencoder_student']['best_val_loss'], 6),
                    "epochs_trained": r['autoencoder_student']['epochs_trained'],
                    "model_saved_path": r['autoencoder_student']['model_path']
                }
            },
            
            # ==================== PROBER MODEL ====================
            "prober_model": {
                "architecture": {
                    "input_dim": r['prober_model']['input_dim'],
                    "hidden_dim": prober_config['hidden_dim'],
                    "dropout": prober_config['dropout']
                },
                "training_hyperparameters": {
                    "optimizer": prober_config['optimizer'],
                    "learning_rate": prober_config['learning_rate'],
                    "weight_decay": prober_config['weight_decay'],
                    "batch_size": prober_config['batch_size'],
                    "max_epochs": prober_config['max_epochs'],
                    "scheduler": prober_config['scheduler'],
                    "gradient_clip_max_norm": prober_config['gradient_clip_max_norm'],
                    "early_stopping_patience": prober_config['early_stopping_patience'],
                    "early_stopping_min_delta": prober_config['early_stopping_min_delta'],
                    "loss_function": prober_config['loss_function'],
                    "use_class_weights": prober_config['use_class_weights']
                },
                "training_results": {
                    "best_val_acc": round(r['prober_model']['best_val_acc'], 4),
                    "epochs_trained": r['prober_model']['epochs_trained'],
                    "model_saved_path": r['prober_model']['model_path']
                }
            },
            
            # ==================== ALIGNMENT MODEL ====================
            "alignment_model": {
                "architecture": {
                    "input_dim": r['alignment_model']['input_dim'],
                    "output_dim": r['alignment_model']['output_dim'],
                    "hidden_dim": align_config['hidden_dim'],
                    "dropout": align_config['dropout']
                },
                "training_hyperparameters": {
                    "optimizer": align_config['optimizer'],
                    "learning_rate": align_config['learning_rate'],
                    "weight_decay": align_config['weight_decay'],
                    "batch_size": align_config['batch_size'],
                    "max_epochs": align_config['max_epochs'],
                    "scheduler": align_config['scheduler'],
                    "gradient_clip_max_norm": align_config['gradient_clip_max_norm'],
                    "early_stopping_patience": align_config['early_stopping_patience'],
                    "early_stopping_min_delta": align_config['early_stopping_min_delta']
                },
                "loss_function": {
                    "type": "MixedLoss",
                    "mse_weight": align_config['loss_alpha'],
                    "cosine_weight": align_config['loss_beta']
                },
                "training_results": {
                    "best_val_loss": round(r['alignment_model']['best_val_loss'], 6),
                    "epochs_trained": r['alignment_model']['epochs_trained'],
                    "model_saved_path": r['alignment_model']['model_path']
                }
            },
            
            # ==================== PERFORMANCE METRICS ====================
            "teacher_performance": {
                "accuracy": round(r['teacher']['accuracy'], 4),
                "precision": round(r['teacher']['precision'], 4),
                "recall": round(r['teacher']['recall'], 4),
                "f1_score": round(r['teacher']['f1'], 4),
                "confusion_matrix": {
                    "TN": int(r['teacher']['confusion_matrix'][0][0]),
                    "FP": int(r['teacher']['confusion_matrix'][0][1]),
                    "FN": int(r['teacher']['confusion_matrix'][1][0]),
                    "TP": int(r['teacher']['confusion_matrix'][1][1])
                }
            },
            "student_on_teacher_performance": {
                "accuracy": round(r['student_on_teacher']['accuracy'], 4),
                "precision": round(r['student_on_teacher']['precision'], 4),
                "recall": round(r['student_on_teacher']['recall'], 4),
                "f1_score": round(r['student_on_teacher']['f1'], 4),
                "confusion_matrix": {
                    "TN": int(r['student_on_teacher']['confusion_matrix'][0][0]),
                    "FP": int(r['student_on_teacher']['confusion_matrix'][0][1]),
                    "FN": int(r['student_on_teacher']['confusion_matrix'][1][0]),
                    "TP": int(r['student_on_teacher']['confusion_matrix'][1][1])
                }
            }
        }
        
        scenario_results.append(result_entry)
    
    all_results_json.append({
        "scenario": scenario_data['scenario'],
        "results": scenario_results
    })

with open(metrics_file, 'w') as f:
    json.dump(all_results_json, f, indent=2)

print(f"\n✓ Results saved to: {metrics_file}")

PHASE 1: PRE-LOADING AND SPLITTING DATA (same shuffled indices for ALL layer types)

Train/Test split: 19191/8225 samples
Using SEED=42 for reproducibility
Using LATENT_DIM=128

PROCESSING LAYER TYPE: ATTN
 Loading IN-MEMORY Qwen2.5-7B [attn]: layers [15, 16, 18]...
  Loading layer 15... done ((27416, 3584))
  Loading layer 16... done ((27416, 3584))
  Loading layer 18... done ((27416, 3584))
 Concatenating layers...
 Completed! Train: (19191, 10752), Test: (8225, 10752)
 Loading IN-MEMORY Falcon3-7B-Base [attn]: layers [2, 7, 12]...
  Loading layer 2... done ((27416, 3072))
  Loading layer 7... done ((27416, 3072))
  Loading layer 12... done ((27416, 3072))
 Concatenating layers...
 Completed! Train: (19191, 9216), Test: (8225, 9216)
   Normalizing data...

   --- Scenario: Qwen2.5-7B (Teacher) <- Falcon3-7B-Base (Student) ---

EXPERIMENT: ATTN → Qwen2.5-7B ← Falcon3-7B-Base
Using Autoencoder with latent_dim=128, hidden_dim=256
Using device: cuda

1. Training Autoencoder for TEACHER..