# Approach 1: Non-linear adaptation with AdapterMLP and MLP Prober

In this notebook a second non-linear approach is tested. We take all the activations of both LLMs, we train 3 MLP classifiers (1 per type of layer) for the Teacher model, with an AdapterMLP we try to adapt the Student latent space to the Teacher one. Finally we test the adapted Student activations with the Teacher MLP classifiers.

**Difference from Approach 1:** Here we use an MLP instead of Logistic Regression as the probing classifier.

In [14]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import gc
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
import traceback
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import random

# ==================================================================
# REPRODUCIBILITY SETTINGS
# ==================================================================
SEED = 42

def set_seed(seed=SEED):
    """Set all seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_generator(seed=SEED):
    """Create a reproducible generator for DataLoader"""
    g = torch.Generator()
    g.manual_seed(seed)
    return g

# Set seeds at import time
set_seed(SEED)

In [15]:
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
CACHE_DIR_NAME = "activation_cache"
HF_DEFAULT_HOME = os.environ.get("HF_HOME", "~\\.cache\\huggingface\\hub")

# We test the same layers as in the linear approach
LAYER_CONFIG = {
    "Qwen2.5-7B": 
    {
        "attn": [15,16,18],
        "mlp":[16,18,20],
        "hidden": [18,19,20]
    },    
    "Falcon3-7B-Base": 
    {
        "attn": [2,7,12],
        "mlp":[10,11,12],
        "hidden": [2,3,19]
    }
}

# ==================================================================
# HYPERPARAMETERS CONFIGURATION
# ==================================================================
ALIGNMENT_CONFIG = {
    "hidden_dim": 128,
    "dropout": 0.5,
    "learning_rate": 1e-3,
    "weight_decay": 0.1,
    "batch_size": 32,
    "max_epochs": 1000,
    "early_stopping_patience": 50,
    "early_stopping_min_delta": 1e-4,
    "gradient_clip_max_norm": 1.0,
    "optimizer": "AdamW",
    "scheduler": "CosineAnnealingLR",
    "loss_alpha": 0.01,  # MSE weight
    "loss_beta": 1.0     # Cosine weight
}

# ==================================================================
# MLP PROBER CONFIGURATION
# ==================================================================
PROBER_CONFIG = {
    "type": "MLPProber",
    "hidden_dim": 64,
    "dropout": 0.5,
    "learning_rate": 1e-3,
    "weight_decay": 0.01,
    "batch_size": 64,
    "max_epochs": 200,
    "early_stopping_patience": 30,
    "early_stopping_min_delta": 1e-4,
    "gradient_clip_max_norm": 1.0,
    "optimizer": "AdamW",
    "scheduler": "CosineAnnealingLR",
    "loss_function": "BCEWithLogitsLoss",
    "use_class_weights": True
}

### Dataset preparation

In [16]:
def stats_per_json(model_name, dataset_name):
    file_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name,"generations","hallucination_labels.json")
    with open(file_path, 'r') as file:
        data = json.load(file)
    total = len(data)
    hallucinations = sum(1 for item in data if item['is_hallucination'])
    percent_hallucinations = (hallucinations / total) * 100 if total > 0 else 0
    allucinated_items = [item['instance_id'] for item in data if item['is_hallucination']]
    return {
        'total': total,
        'hallucinations': hallucinations,
        'percent_hallucinations': percent_hallucinations,
        'hallucinated_items': allucinated_items,
        'model_name': model_name,
        'dataset_name': dataset_name
    }


qwen_stats=stats_per_json("Qwen2.5-7B", "belief_bank")
falcon_stats=stats_per_json("Falcon3-7B-Base", "belief_bank")

In [17]:
# ------------------------------------------------------------------
# 1. Dataset class for Alignment
# ------------------------------------------------------------------
class AlignmentDataset(Dataset):
    def __init__(self, x_source: torch.Tensor, x_target: torch.Tensor):
        self.x_source = x_source
        self.x_target = x_target
    
    def __len__(self):
        return self.x_source.shape[0]
    
    def __getitem__(self, idx):
        return self.x_source[idx], self.x_target[idx]

# ------------------------------------------------------------------
# 2. Dataset class for Classification
# ------------------------------------------------------------------
class ClassificationDataset(Dataset):
    def __init__(self, X: torch.Tensor, y: torch.Tensor):
        self.X = X
        self.y = y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ------------------------------------------------------------------
# 3. AlignmentNetwork
# ------------------------------------------------------------------
class AlignmentNetwork(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 128, dropout: float = 0.5):
        super().__init__()
            
        if input_dim != output_dim:
            self.input_proj = nn.Linear(input_dim, output_dim, bias=False)
        else:
            self.input_proj = nn.Identity()

        self.net = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
            nn.Dropout(dropout)
        )
        
        self._init_zero()

    def _init_zero(self):
        nn.init.zeros_(self.net[-2].weight)
        if self.net[-2].bias is not None:
            nn.init.zeros_(self.net[-2].bias)

    def forward(self, x):
        x_base = self.input_proj(x)
        return x_base + self.net(x_base)
    

# ------------------------------------------------------------------
# 4. MLP Prober
# ------------------------------------------------------------------
class MLPProber(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 256, dropout: float = 0.3):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        return self.net(x).squeeze(-1)
    
    def predict(self, x):
        with torch.no_grad():
            logits = self.forward(x)
            return (torch.sigmoid(logits) > 0.5).long()
    
    def predict_proba(self, x):
        with torch.no_grad():
            logits = self.forward(x)
            return torch.sigmoid(logits)
    

# ------------------------------------------------------------------
# 5. MixedLoss for Alignment
# ------------------------------------------------------------------
class MixedLoss(nn.Module):
    def __init__(self, alpha=0.01, beta=1.0):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.mse = nn.MSELoss()

    def forward(self, pred, target):
        loss_mse = self.mse(pred, target)
        cosine_sim = F.cosine_similarity(pred, target, dim=1).mean()
        loss_cosine = 1 - cosine_sim
        return self.alpha * loss_mse + self.beta * loss_cosine

In [18]:
def load_and_split_layers(model_name, dataset_name, layer_indices, type_layer, stats, train_indices, test_indices):
    """Caricamento standard in RAM (senza memmap)."""
    print(f" Caricamento IN-MEMORY {model_name} [{type_layer}]: layers {layer_indices}...")

    total_samples = stats['total']
    hallucinated_set = set(stats['hallucinated_items'])

    # Label
    y_full = np.zeros(total_samples, dtype=np.int8)
    y_full[list(hallucinated_set)] = 1
    y_train = y_full[train_indices]
    y_test  = y_full[test_indices]

    # Load and concatenate
    all_features = []
    
    for layer_idx in layer_indices:
        file_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name,
                                 "activation_"+type_layer, f"layer{layer_idx}_activations.pt")
        if not os.path.exists(file_path):
            print(f" Warning: Layer {layer_idx} non trovato. Salto.")
            continue

        print(f"  Loading layer {layer_idx}...", end=" ")
        acts = torch.load(file_path, map_location='cpu')
        
        if acts.shape[0] > total_samples:
            acts = acts[:total_samples]

        if isinstance(acts, torch.Tensor):
            X_layer = acts.float().numpy() 
        else:
            X_layer = acts.astype(np.float32)

        if X_layer.ndim > 2:
            X_layer = X_layer.reshape(X_layer.shape[0], -1)
            
        all_features.append(X_layer)
        print(f"done ({X_layer.shape})")
        
        del acts
        gc.collect()

    if not all_features:
        raise ValueError(f"Nessun layer valido trovato per {model_name}")

    print(" Concatenating layers...")
    X_full = np.concatenate(all_features, axis=1)
    
    X_train = X_full[train_indices]
    X_test  = X_full[test_indices]
    
    print(f" Completato! Train: {X_train.shape}, Test: {X_test.shape}")

    return X_train, X_test, y_train, y_test


def train_mlp_prober(X_train, y_train, X_val, y_val, input_dim, device, prober_config=PROBER_CONFIG):
    """Train MLP prober with early stopping based on validation accuracy."""
    
    prober = MLPProber(
        input_dim=input_dim, 
        hidden_dim=prober_config['hidden_dim'], 
        dropout=prober_config['dropout']
    ).to(device)
    
    # Compute class weights for imbalanced data
    if prober_config['use_class_weights']:
        n_pos = y_train.sum()
        n_neg = len(y_train) - n_pos
        pos_weight = torch.tensor([n_neg / n_pos]).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    else:
        criterion = nn.BCEWithLogitsLoss()
    
    optimizer = optim.AdamW(
        prober.parameters(), 
        lr=prober_config['learning_rate'], 
        weight_decay=prober_config['weight_decay']
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=prober_config['max_epochs'])
    
    # Create dataloaders
    train_dataset = ClassificationDataset(X_train, y_train)
    val_dataset = ClassificationDataset(X_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=prober_config['batch_size'], 
                             shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=prober_config['batch_size'], 
                           shuffle=False, num_workers=0)
    
    best_val_acc = 0.0
    patience_counter = 0
    best_model_state = None
    
    for epoch in range(prober_config['max_epochs']):
        # Training
        prober.train()
        epoch_loss = 0.0
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            logits = prober(X_batch)
            loss = criterion(logits, y_batch.float())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                prober.parameters(), 
                max_norm=prober_config['gradient_clip_max_norm']
            )
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_train_loss = epoch_loss / len(train_loader)
        
        # Validation
        prober.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                preds = prober.predict(X_batch)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(y_batch.cpu().numpy())
        
        val_f1 = f1_score(all_labels, all_preds)
        val_acc = accuracy_score(all_labels, all_preds)
        
        scheduler.step()
        
        if (epoch + 1) % 20 == 0:
            print(f"   Epoch {epoch+1:3d}/{prober_config['max_epochs']} | Train Loss: {avg_train_loss:.4f} | Val F1: {val_f1:.4f} | Val Acc: {val_acc:.4f}")
        
        # Early Stopping based on accuracy
        if val_acc > best_val_acc + prober_config['early_stopping_min_delta']:
            best_val_acc = val_acc
            patience_counter = 0
            best_model_state = prober.state_dict().copy()
        else:
            patience_counter += 1
        
        if patience_counter >= prober_config['early_stopping_patience']:
            print(f"   Early stopping at epoch {epoch+1}. Best Val ACC: {best_val_acc:.4f}")
            break
    
    # Load best model
    if best_model_state is not None:
        prober.load_state_dict(best_model_state)
    
    return prober, best_val_acc, epoch + 1


def run_experiment_pipeline_cached(X_teacher, y_teacher, teacher_name,
                                   X_student, y_student, student_name, 
                                   layer_type, config_name,
                                   alignment_config=ALIGNMENT_CONFIG,
                                   prober_config=PROBER_CONFIG):
    
    print(f"\n{'='*60}")
    print(f"EXPERIMENT: {layer_type.upper()} → {teacher_name} ← {student_name}")
    print(f"{'='*60}")

    X_A_train_full, X_A_test = X_teacher['X_train'], X_teacher['X_test']
    y_A_train_full, y_A_test = y_teacher['y_train'], y_teacher['y_test']
    X_B_train_full, X_B_test = X_student['X_train'], X_student['X_test']
    y_B_train_full, y_B_test = y_student['y_train'], y_student['y_test']

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --------------------------------------------------
    # 1. Teacher MLP Probing
    # --------------------------------------------------
    print("1. Training teacher MLP prober...")
    
    # Split interno per validation del prober
    num_train = len(X_A_train_full)
    indices = np.arange(num_train)
    np.random.seed(42)
    np.random.shuffle(indices)
    prober_val_size = int(num_train * 0.15)
    prober_train_idx = indices[prober_val_size:]
    prober_val_idx = indices[:prober_val_size]
    
    X_A_prober_train = torch.from_numpy(X_A_train_full[prober_train_idx]).float().to(device)
    y_A_prober_train = torch.from_numpy(y_A_train_full[prober_train_idx]).long().to(device)
    X_A_prober_val = torch.from_numpy(X_A_train_full[prober_val_idx]).float().to(device)
    y_A_prober_val = torch.from_numpy(y_A_train_full[prober_val_idx]).long().to(device)
    
    probe_teacher, best_prober_acc, prober_epochs = train_mlp_prober(
        X_A_prober_train, y_A_prober_train,
        X_A_prober_val, y_A_prober_val,
        input_dim=X_A_train_full.shape[1],
        device=device,
        prober_config=prober_config
    )
    print(f"   Best prober validation F1: {best_prober_acc:.4f}")

    # --- METRICHE TEACHER ---
    probe_teacher.eval()
    X_A_test_t = torch.from_numpy(X_A_test).float().to(device)
    y_pred_teacher = probe_teacher.predict(X_A_test_t).cpu().numpy()
    
    cm_teacher = confusion_matrix(y_A_test, y_pred_teacher)
    acc_teacher = accuracy_score(y_A_test, y_pred_teacher)
    prec_teacher = precision_score(y_A_test, y_pred_teacher)
    rec_teacher = recall_score(y_A_test, y_pred_teacher)
    f1_teacher = f1_score(y_A_test, y_pred_teacher)
    print(f"   Teacher Test Acc: {acc_teacher:.4f}, F1: {f1_teacher:.4f}")

    # --------------------------------------------------
    # 2. Alignment Training
    # --------------------------------------------------
    print("2. Training alignment network (with 90/10 validation split)...")
    
    X_A_train_full_t = torch.from_numpy(X_A_train_full).float()
    X_A_test_t = torch.from_numpy(X_A_test).float()
    X_B_train_full_t = torch.from_numpy(X_B_train_full).float()
    X_B_test_t = torch.from_numpy(X_B_test).float()

    num_train = len(X_B_train_full)
    indices = np.arange(num_train)
    np.random.seed(42)
    np.random.shuffle(indices)
    val_size = int(num_train * 0.1)
    train_indices = indices[val_size:]
    val_indices = indices[:val_size]

    X_B_align_train = X_B_train_full_t[train_indices]
    X_A_align_train = X_A_train_full_t[train_indices]
    
    X_B_val = X_B_train_full_t[val_indices]
    X_A_val = X_A_train_full_t[val_indices]

    train_dataset = AlignmentDataset(X_B_align_train.to(device), X_A_align_train.to(device))
    val_dataset = AlignmentDataset(X_B_val.to(device), X_A_val.to(device))
    
    train_loader = DataLoader(train_dataset, batch_size=alignment_config['batch_size'], 
                             shuffle=True, num_workers=0, pin_memory=False)
    val_loader = DataLoader(val_dataset, batch_size=alignment_config['batch_size'], 
                           shuffle=False, num_workers=0, pin_memory=False)
    
    criterion = MixedLoss(
        alpha=alignment_config['loss_alpha'],
        beta=alignment_config['loss_beta']
    ).to(device)

    aligner = AlignmentNetwork(
        input_dim=X_B_align_train.shape[1],
        output_dim=X_A_align_train.shape[1],
        hidden_dim=alignment_config['hidden_dim'],
        dropout=alignment_config['dropout']
    ).to(device)
    
    optimizer = optim.AdamW(
        aligner.parameters(), 
        lr=alignment_config['learning_rate'], 
        weight_decay=alignment_config['weight_decay']
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=alignment_config['max_epochs'])
    
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    
    for epoch in range(alignment_config['max_epochs']):
        # Training
        aligner.train()
        epoch_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            projected = aligner(data)
            loss = criterion(projected, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                aligner.parameters(), 
                max_norm=alignment_config['gradient_clip_max_norm']
            )
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_train_loss = epoch_loss / len(train_loader)
        
        # Validation
        aligner.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                projected = aligner(data)
                loss = criterion(projected, target)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()
        
        if (epoch + 1) % 50 == 0:
            print(f"   Epoch {epoch+1:3d}/{alignment_config['max_epochs']} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")
            
        # Early Stopping Check
        if avg_val_loss < best_val_loss - alignment_config['early_stopping_min_delta']:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = aligner.state_dict()
        else:
            patience_counter += 1
            
        if patience_counter >= alignment_config['early_stopping_patience']:
            print(f"   Early stopping triggered at epoch {epoch+1}. Best Val Loss: {best_val_loss:.6f}")
            break
            
    # Load best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)
    
    # Save alignment network
    model_save_dir = os.path.join("alignment_models", layer_type)
    os.makedirs(model_save_dir, exist_ok=True)
    model_filename = os.path.join(model_save_dir, f"{config_name}_aligner_{student_name}_to_{teacher_name}.pt")
    
    torch.save({
        'model_state_dict': aligner.state_dict(),
        'alignment_config': alignment_config,
        'input_dim': X_B_align_train.shape[1],
        'output_dim': X_A_align_train.shape[1],
        'best_val_loss': best_val_loss,
        'epochs_trained': epoch + 1,
        'layer_type': layer_type,
        'student_model': student_name,
        'teacher_model': teacher_name,
    }, model_filename)
    print(f"   ✓ Alignment network saved: {model_filename}")
    
    # Save MLP prober
    prober_filename = os.path.join(model_save_dir, f"{config_name}_mlp_prober_{teacher_name}.pt")
    torch.save({
        'model_state_dict': probe_teacher.state_dict(),
        'prober_config': prober_config,
        'input_dim': X_A_train_full.shape[1],
        'best_val_acc': best_prober_acc,
        'epochs_trained': prober_epochs,
        'layer_type': layer_type,
        'teacher_model': teacher_name,
    }, prober_filename)
    print(f"   ✓ MLP prober saved: {prober_filename}")

    # --------------------------------------------------
    # 3. Evaluation
    # --------------------------------------------------
    print("3. Projecting student test set & evaluating...")
    aligner.eval()
    with torch.no_grad():
        X_B_projected = aligner(X_B_test_t.to(device))
    
    y_pred_cross = probe_teacher.predict(X_B_projected).cpu().numpy()
    
    # --- METRICHE CROSS-MODEL ---
    cm_cross = confusion_matrix(y_B_test, y_pred_cross)
    acc_cross = accuracy_score(y_B_test, y_pred_cross)
    prec_cross = precision_score(y_B_test, y_pred_cross)
    rec_cross = recall_score(y_B_test, y_pred_cross)
    f1_cross = f1_score(y_B_test, y_pred_cross)
    
    print(f"\nFINAL RESULT:")
    print(f"   Teacher Acc         : {acc_teacher:.4f}, F1: {f1_teacher:.4f}")
    print(f"   Student → Teacher Acc: {acc_cross:.4f}, F1: {f1_cross:.4f}")
    print(f"   Transfer gap (Acc)  : {acc_teacher - acc_cross:.4f}")
    print(f"   Transfer gap (F1)   : {f1_teacher - f1_cross:.4f}")

    return {
        "type": layer_type,
        "teacher_name": teacher_name,
        "student_name": student_name,
        "alignment_model": {
            "input_dim": X_B_align_train.shape[1],
            "output_dim": X_A_align_train.shape[1],
            "config": alignment_config,
            "best_val_loss": float(best_val_loss),
            "epochs_trained": epoch + 1,
            "model_path": model_filename
        },
        "prober_model": {
            "input_dim": X_A_train_full.shape[1],
            "config": prober_config,
            "best_val_acc": float(best_prober_acc),
            "epochs_trained": prober_epochs,
            "model_path": prober_filename
        },
        "teacher": {
            "accuracy": acc_teacher,
            "precision": prec_teacher,
            "recall": rec_teacher,
            "f1": f1_teacher,
            "confusion_matrix": cm_teacher.tolist()
        },
        "student_on_teacher": {
            "accuracy": acc_cross,
            "precision": prec_cross,
            "recall": rec_cross,
            "f1": f1_cross,
            "confusion_matrix": cm_cross.tolist()
        }
    }


def plot_confusion_matrix(cm, layer_type, model_name="", save_dir="confusion_matrices"):
    """
    Plotta e salva la confusion matrix come immagine.
    """
    os.makedirs(save_dir, exist_ok=True)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True, ax=ax,
                xticklabels=['Non-Hallucinated', 'Hallucinated'],
                yticklabels=['Non-Hallucinated', 'Hallucinated'])
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    title = f'Confusion Matrix - {layer_type.upper()} Layers'
    if model_name:
        title += f' ({model_name})'
    ax.set_title(title)
    
    plt.tight_layout()
    filename = os.path.join(save_dir, f'confusion_matrix_{layer_type}_{model_name}.png' if model_name else f'confusion_matrix_{layer_type}.png')
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"   ✓ Salvato: {filename}")

In [19]:

print("="*80)
print("FASE 1: PRE-CARICAMENTO E SPLITTING DEI DATI (stessi indici shuffled per TUTTI i layer type)")
print("="*80 + "\n")

n_samples = qwen_stats['total'] 
rng = np.random.RandomState(42)
shuffled_indices = rng.permutation(n_samples)
split_idx = int(0.7 * n_samples)

train_indices = shuffled_indices[:split_idx]
test_indices = shuffled_indices[split_idx:]

scenarios = [
    {"teacher_model": "Qwen2.5-7B", "student_model": "Falcon3-7B-Base"},
    {"teacher_model": "Falcon3-7B-Base", "student_model": "Qwen2.5-7B"}
]

scenario_results_map = {0: [], 1: []}

for layer_type in ['attn', 'mlp', 'hidden']:
    print(f"\n{'='*40}")
    print(f"PROCESSING LAYER TYPE: {layer_type.upper()}")
    print(f"{'='*40}")
    gc.collect()
    
    try:
        X_qwen_train, X_qwen_test, y_qwen_train, y_qwen_test = load_and_split_layers(
            "Qwen2.5-7B", "belief_bank", 
            LAYER_CONFIG["Qwen2.5-7B"][layer_type], 
            layer_type, qwen_stats,
            train_indices, test_indices
        )

        X_falcon_train, X_falcon_test, y_falcon_train, y_falcon_test = load_and_split_layers(
            "Falcon3-7B-Base", "belief_bank", 
            LAYER_CONFIG["Falcon3-7B-Base"][layer_type], 
            layer_type, falcon_stats,
            train_indices, test_indices
        )
        
        print("   Normalizzazione dati...")
        scaler_qwen = StandardScaler()
        X_qwen_train = scaler_qwen.fit_transform(X_qwen_train)
        X_qwen_test = scaler_qwen.transform(X_qwen_test)
        
        scaler_falcon = StandardScaler()
        X_falcon_train = scaler_falcon.fit_transform(X_falcon_train)
        X_falcon_test = scaler_falcon.transform(X_falcon_test)
        
        current_data = {
            "qwen": {"X_train": X_qwen_train, "X_test": X_qwen_test, "y_train": y_qwen_train, "y_test": y_qwen_test},
            "falcon": {"X_train": X_falcon_train, "X_test": X_falcon_test, "y_train": y_falcon_train, "y_test": y_falcon_test}
        }

        for i, scenario in enumerate(scenarios):
            print(f"\n   --- Scenario: {scenario['teacher_model']} -> {scenario['student_model']} ---")
            
            if scenario['teacher_model'] == "Qwen2.5-7B":
                X_teacher_data = current_data['qwen']
                X_student_data = current_data['falcon']
            else:
                X_teacher_data = current_data['falcon']
                X_student_data = current_data['qwen']
            
            res = run_experiment_pipeline_cached(
                X_teacher_data, X_teacher_data, scenario['teacher_model'],
                X_student_data, X_student_data, scenario['student_model'],
                layer_type, "CONFIG1"
            )
            scenario_results_map[i].append(res)
            
            plot_confusion_matrix(
                np.array(res['teacher']['confusion_matrix']), 
                layer_type, 
                f"Teacher_{scenario['teacher_model'].split('.')[0]}"
            )
            plot_confusion_matrix(
                np.array(res['student_on_teacher']['confusion_matrix']), 
                layer_type, 
                f"{scenario['student_model'].split('.')[0]}_on_{scenario['teacher_model'].split('.')[0]}"
            )

        del current_data, X_qwen_train, X_qwen_test, X_falcon_train, X_falcon_test
        del scaler_qwen, scaler_falcon
        gc.collect()
        torch.cuda.empty_cache()
        print(f"   Memoria liberata per {layer_type}.")

    except Exception as e:
        print(f"Errore critico nel layer {layer_type}: {e}")
        traceback.print_exc()
        exit(1)

# Ricostruisci la struttura all_results
all_results = []
for i, scenario in enumerate(scenarios):
    all_results.append({
        "scenario": f"{scenario['teacher_model']} (teacher) → {scenario['student_model']} (student)",
        "results": scenario_results_map[i]
    })

# Salva JSON completo
os.makedirs("results_metrics", exist_ok=True)
metrics_file = "results_metrics/experiment_results_all_scenarios_mlp_prober.json"

all_results_json = []
for scenario_data in all_results:
    scenario_results = []
    for r in scenario_data['results']:
        align_config = r['alignment_model']['config']
        prober_config = r['prober_model']['config']
        
        scenario_results.append({
            "layer_type": r['type'],
            "teacher_model": r['teacher_name'],
            "student_model": r['student_name'],
            "alignment_model_info": {
                "input_dim": r['alignment_model']['input_dim'],
                "output_dim": r['alignment_model']['output_dim'],
                "hidden_dim": align_config['hidden_dim'],
                "dropout": align_config['dropout']
            },
            "alignment_training_hyperparameters": {
                "optimizer": align_config['optimizer'],
                "learning_rate": align_config['learning_rate'],
                "weight_decay": align_config['weight_decay'],
                "batch_size": align_config['batch_size'],
                "max_epochs": align_config['max_epochs'],
                "scheduler": align_config['scheduler'],
                "gradient_clip_max_norm": align_config['gradient_clip_max_norm'],
                "early_stopping_patience": align_config['early_stopping_patience'],
                "early_stopping_min_delta": align_config['early_stopping_min_delta']
            },
            "alignment_loss_function": {
                "type": "MixedLoss",
                "mse_weight": align_config['loss_alpha'],
                "cosine_weight": align_config['loss_beta']
            },
            "alignment_training_results": {
                "best_val_loss": round(r['alignment_model']['best_val_loss'], 6),
                "epochs_trained": r['alignment_model']['epochs_trained'],
                "model_saved_path": r['alignment_model']['model_path']
            },
            "prober_model_info": {
                "input_dim": r['prober_model']['input_dim'],
                "hidden_dim": prober_config['hidden_dim'],
                "dropout": prober_config['dropout']
            },
            "prober_training_hyperparameters": {
                "optimizer": prober_config['optimizer'],
                "learning_rate": prober_config['learning_rate'],
                "weight_decay": prober_config['weight_decay'],
                "batch_size": prober_config['batch_size'],
                "max_epochs": prober_config['max_epochs'],
                "scheduler": prober_config['scheduler'],
                "gradient_clip_max_norm": prober_config['gradient_clip_max_norm'],
                "early_stopping_patience": prober_config['early_stopping_patience'],
                "early_stopping_min_delta": prober_config['early_stopping_min_delta'],
                "loss_function": prober_config['loss_function'],
                "use_class_weights": prober_config['use_class_weights']
            },
            "prober_training_results": {
                "best_val_acc": round(r['prober_model']['best_val_acc'], 4),
                "epochs_trained": r['prober_model']['epochs_trained'],
                "model_saved_path": r['prober_model']['model_path']
            },
            "teacher": {
                "accuracy": round(r['teacher']['accuracy'], 4),
                "precision": round(r['teacher']['precision'], 4),
                "recall": round(r['teacher']['recall'], 4),
                "f1_score": round(r['teacher']['f1'], 4),
                "confusion_matrix": {
                    "TN": int(r['teacher']['confusion_matrix'][0][0]),
                    "FP": int(r['teacher']['confusion_matrix'][0][1]),
                    "FN": int(r['teacher']['confusion_matrix'][1][0]),
                    "TP": int(r['teacher']['confusion_matrix'][1][1])
                }
            },
            "student_on_teacher": {
                "accuracy": round(r['student_on_teacher']['accuracy'], 4),
                "precision": round(r['student_on_teacher']['precision'], 4),
                "recall": round(r['student_on_teacher']['recall'], 4),
                "f1_score": round(r['student_on_teacher']['f1'], 4),
                "confusion_matrix": {
                    "TN": int(r['student_on_teacher']['confusion_matrix'][0][0]),
                    "FP": int(r['student_on_teacher']['confusion_matrix'][0][1]),
                    "FN": int(r['student_on_teacher']['confusion_matrix'][1][0]),
                    "TP": int(r['student_on_teacher']['confusion_matrix'][1][1])
                }
            }
        })
    
    all_results_json.append({
        "scenario": scenario_data['scenario'],
        "results": scenario_results
    })

with open(metrics_file, 'w') as f:
    json.dump(all_results_json, f, indent=2)

print(f"\n✓ Risultati salvati in: {metrics_file}")

FASE 1: PRE-CARICAMENTO E SPLITTING DEI DATI (stessi indici shuffled per TUTTI i layer type)


PROCESSING LAYER TYPE: ATTN
 Caricamento IN-MEMORY Qwen2.5-7B [attn]: layers [15, 16, 18]...
  Loading layer 15... done ((27416, 3584))
  Loading layer 16... done ((27416, 3584))
  Loading layer 18... done ((27416, 3584))
 Concatenating layers...
 Completato! Train: (19191, 10752), Test: (8225, 10752)
 Caricamento IN-MEMORY Falcon3-7B-Base [attn]: layers [2, 7, 12]...
  Loading layer 2... done ((27416, 3072))
  Loading layer 7... done ((27416, 3072))
  Loading layer 12... done ((27416, 3072))
 Concatenating layers...
 Completato! Train: (19191, 9216), Test: (8225, 9216)
   Normalizzazione dati...

   --- Scenario: Qwen2.5-7B -> Falcon3-7B-Base ---

EXPERIMENT: ATTN → Qwen2.5-7B ← Falcon3-7B-Base
Using device: cuda
1. Training teacher MLP prober...
   Epoch  20/200 | Train Loss: 0.0339 | Val F1: 0.9894 | Val Acc: 0.9878
   Epoch  40/200 | Train Loss: 0.0192 | Val F1: 0.9933 | Val Acc: 0.9924
 

In [20]:
metrics_file

'results_metrics/experiment_results_all_scenarios_mlp_prober.json'