# Hybrid: Non-linear adaptation with AdapterMLP

 In this notebook a first non-linear approach is tested. We take all the activations of both LLMs, we train 3 classifiers (1 per type of layer) for the Teacher model, with an AdapterMLP we try to adapt the Student latent space to the Teacher one. Finally we test the adapted Student activations with the Teacher classifiers.

In [14]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import gc
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, roc_auc_score
import traceback
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import random

# ==================================================================
# DEVICE CONFIGURATION
# ==================================================================
DEVICE = torch.device("cuda:2")

# ==================================================================
# REPRODUCIBILITY SETTINGS
# ==================================================================
SEED = 42

def set_seed(seed=SEED):
    """Set all seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_generator(seed=SEED):
    """Create a reproducible generator for DataLoader"""
    g = torch.Generator()
    g.manual_seed(seed)
    return g


In [15]:
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
CACHE_DIR_NAME = "activation_cache"
HF_DEFAULT_HOME = os.environ.get("HF_HOME", "~\\.cache\\huggingface\\hub")

# Nomi dei modelli (usati come costanti in tutto il notebook)
MODEL_A = "gemma-2-9b-it"
MODEL_B = "Llama-3.1-8B-Instruct"

LAYER_CONFIG = {
    MODEL_A: 
    {
        "attn": [21,24,27],
        "mlp":[22,25,27],
        "hidden": [23,26,34]
    },    
    MODEL_B: 
    {
        "attn": [8,13,14],
        "mlp":[14,15,21],
        "hidden": [14,15,16]
    }  
}
DATASET_NAME = "belief_bank_facts"


ALIGNMENT_CONFIG = {
    "hidden_dim": 128,
    "dropout": 0.5,
    "learning_rate": 1e-3,
    "weight_decay": 0.1,
    "batch_size": 32,
    "max_epochs": 1000,
    "early_stopping_patience": 50,
    "early_stopping_min_delta": 1e-4,
    "gradient_clip_max_norm": 1.0,
    "optimizer": "AdamW",
    "scheduler": "CosineAnnealingLR",
    "loss_alpha": 0.01,  # MSE weight
    "loss_beta": 1.0     # Cosine weight
}


PROBE_CONFIG = {
    "type": "LogisticRegression",
    "max_iter": 1000,
    "class_weight": "balanced",
    "solver": "lbfgs",
    "n_jobs": -1
}

### Dataset preparation

In [16]:
def stats_per_json(model_name, dataset_name):
    """
    Versione originale per la vecchia struttura con hallucination_labels.json
    """
    file_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name, "generations", "hallucination_labels.json")
    with open(file_path, 'r') as file:
        data = json.load(file)
    total = len(data)
    hallucinations = sum(1 for item in data if item['is_hallucination'])
    percent_hallucinations = (hallucinations / total) * 100 if total > 0 else 0
    hallucinated_items = [item['instance_id'] for item in data if item['is_hallucination']]
    return {
        'total': total,
        'hallucinations': hallucinations,
        'percent_hallucinations': percent_hallucinations,
        'hallucinated_items': hallucinated_items,
        'model_name': model_name,
        'dataset_name': dataset_name
    }

def stats_from_new_structure(model_name, dataset_name):
    """
    Nuova funzione per la struttura con cartelle hallucinated/ e not_hallucinated/
    """
    base_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name, "activation_attn")
    hallucinated_path = os.path.join(base_path, "hallucinated")
    not_hallucinated_path = os.path.join(base_path, "not_hallucinated")
    
    hall_ids_path = os.path.join(hallucinated_path, "layer0_instance_ids.json")
    not_hall_ids_path = os.path.join(not_hallucinated_path, "layer0_instance_ids.json")
    
    with open(hall_ids_path, 'r') as f:
        hallucinated_ids = json.load(f)
    with open(not_hall_ids_path, 'r') as f:
        not_hallucinated_ids = json.load(f)
    
    total = len(hallucinated_ids) + len(not_hallucinated_ids)
    hallucinations = len(hallucinated_ids)
    percent_hallucinations = (hallucinations / total) * 100 if total > 0 else 0
    
    return {
        'total': total,
        'hallucinations': hallucinations,
        'not_hallucinations': len(not_hallucinated_ids),
        'percent_hallucinations': percent_hallucinations,
        'hallucinated_ids': hallucinated_ids,
        'not_hallucinated_ids': not_hallucinated_ids,
        'hallucinated_items': hallucinated_ids,  # Alias per compatibilità
        'model_name': model_name,
        'dataset_name': dataset_name
    }

def detect_structure_type(model_name, dataset_name):
    """
    Rileva automaticamente se la struttura è vecchia o nuova.
    Ritorna 'new' se esistono le cartelle hallucinated/not_hallucinated,
    altrimenti 'old'.
    """
    base_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name, "activation_attn")
    hallucinated_path = os.path.join(base_path, "hallucinated")
    if os.path.isdir(hallucinated_path):
        return 'new'
    return 'old'

def get_stats(model_name, dataset_name):
    """
    Funzione wrapper che rileva automaticamente la struttura e chiama la funzione appropriata.
    """
    structure = detect_structure_type(model_name, dataset_name)
    if structure == 'new':
        return stats_from_new_structure(model_name, dataset_name)
    else:
        return stats_per_json(model_name, dataset_name)


def get_concordant_indices_and_undersample(stats_model1, stats_model2, seed=SEED):
    """
    Trova gli indici dove ENTRAMBI i modelli concordano sull'etichetta,
    poi applica undersampling per bilanciare le classi.
    
    Returns:
        concordant_indices: array di indici concordanti e bilanciati
        labels: array di label corrispondenti (0=non-hallucinated, 1=hallucinated)
    """
    total_samples = stats_model1['total']
    assert stats_model1['total'] == stats_model2['total'], "I due modelli devono avere lo stesso numero di campioni"
    
    # Costruisci array di label per entrambi i modelli
    hall_set_1 = set(stats_model1['hallucinated_items'])
    hall_set_2 = set(stats_model2['hallucinated_items'])
    
    y1 = np.array([1 if i in hall_set_1 else 0 for i in range(total_samples)])
    y2 = np.array([1 if i in hall_set_2 else 0 for i in range(total_samples)])
    
    # Trova campioni CONCORDANTI (stessa label in entrambi i modelli)
    concordant_mask = (y1 == y2)
    concordant_indices = np.where(concordant_mask)[0]
    concordant_labels = y1[concordant_indices]  # uguale a y2[concordant_indices]
    
    # Conta per classe
    n_hall = np.sum(concordant_labels == 1)
    n_non_hall = np.sum(concordant_labels == 0)
    
    print(f"  Campioni concordanti: {len(concordant_indices)} / {total_samples}")
    print(f"    - Hallucinated (concordanti): {n_hall}")
    print(f"    - Non-hallucinated (concordanti): {n_non_hall}")
    
    # Undersampling sulla classe maggioritaria
    min_count = min(n_hall, n_non_hall)
    
    rng = np.random.RandomState(seed)
    
    hall_concordant = concordant_indices[concordant_labels == 1]
    non_hall_concordant = concordant_indices[concordant_labels == 0]
    
    hall_sampled = rng.choice(hall_concordant, size=min_count, replace=False)
    non_hall_sampled = rng.choice(non_hall_concordant, size=min_count, replace=False)
    
    # Combina e shuffle
    balanced_indices = np.concatenate([hall_sampled, non_hall_sampled])
    balanced_labels = np.concatenate([np.ones(min_count, dtype=np.int8), np.zeros(min_count, dtype=np.int8)])
    
    # Shuffle finale
    shuffle_idx = rng.permutation(len(balanced_indices))
    balanced_indices = balanced_indices[shuffle_idx]
    balanced_labels = balanced_labels[shuffle_idx]
    
    print(f"  Dopo undersampling: {len(balanced_indices)} campioni bilanciati ({min_count} per classe)")
    
    return balanced_indices, balanced_labels


# Carica stats e calcola indici concordanti bilanciati
model_a_stats = get_stats(MODEL_A, DATASET_NAME)
model_b_stats = get_stats(MODEL_B, DATASET_NAME)

print(f"\n{'='*60}")
print("ANALISI CONCORDANZA E UNDERSAMPLING")
print(f"{'='*60}")
print(f"{MODEL_A} totali: {model_a_stats['total']}, hallucinated: {model_a_stats['hallucinations']}")
print(f"{MODEL_B} totali: {model_b_stats['total']}, hallucinated: {model_b_stats['hallucinations']}")
print()

balanced_indices, balanced_labels = get_concordant_indices_and_undersample(model_a_stats, model_b_stats, seed=SEED)


ANALISI CONCORDANZA E UNDERSAMPLING
gemma-2-9b-it totali: 27416, hallucinated: 802
Llama-3.1-8B-Instruct totali: 27416, hallucinated: 1799

  Campioni concordanti: 25749 / 27416
    - Hallucinated (concordanti): 467
    - Non-hallucinated (concordanti): 25282
  Dopo undersampling: 934 campioni bilanciati (467 per classe)


In [17]:
# ------------------------------------------------------------------
# 1. Dataset class
# ------------------------------------------------------------------
class AlignmentDataset(Dataset):
    def __init__(self, x_source: torch.Tensor, x_target: torch.Tensor):
        # Ora assumiamo che i dati siano già torch.Tensor
        self.x_source = x_source
        self.x_target = x_target
    
    def __len__(self):
        return self.x_source.shape[0]
    
    def __getitem__(self, idx):
        return self.x_source[idx], self.x_target[idx]

# ------------------------------------------------------------------
# 2. AlignmentNetwork
# ------------------------------------------------------------------
class AlignmentNetwork(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 128, dropout: float = 0.5):
        super().__init__()
        
        if input_dim != output_dim:
            self.input_proj = nn.Linear(input_dim, output_dim, bias=False)
        else:
            self.input_proj = nn.Identity()

        # Ramo Non-Lineare (Bottleneck Estremo)
        self.net = nn.Sequential(
            nn.Linear(output_dim, hidden_dim), # Compressione forte (es. 10000 -> 128)
            nn.LayerNorm(hidden_dim),          # Normalizzazione
            nn.GELU(),
            nn.Dropout(dropout),               # Dropout aggressivo (0.5)
            nn.Linear(hidden_dim, output_dim), # Decompressione
            nn.Dropout(dropout)                # Dropout finale
        )
        
        # Zero-Init per partire come una funzione lineare pura
        self._init_zero()

    def _init_zero(self):
        nn.init.zeros_(self.net[-2].weight)
        if self.net[-2].bias is not None:
            nn.init.zeros_(self.net[-2].bias)

    def forward(self, x):
        x_base = self.input_proj(x)
        return x_base + self.net(x_base)


class MixedLoss(nn.Module):
    def __init__(self, alpha=0.01, beta=1.0):
        super().__init__()
        self.alpha = alpha  # Peso per MSE
        self.beta = beta    # Peso per Cosine
        self.mse = nn.MSELoss()

    def forward(self, pred, target):
        loss_mse = self.mse(pred, target)
        cosine_sim = F.cosine_similarity(pred, target, dim=1).mean()
        loss_cosine = 1 - cosine_sim
        
        # Loss combinata
        return self.alpha * loss_mse + self.beta * loss_cosine

In [18]:
def load_and_split_layers(model_name, dataset_name, layer_indices, type_layer, 
                          balanced_indices, balanced_labels, train_indices, test_indices):
    """
    Caricamento standard in RAM con indici bilanciati pre-calcolati.
    Supporta sia la vecchia struttura (file direttamente in activation_X/)
    sia la nuova struttura (file in hallucinated/ e not_hallucinated/).
    
    Args:
        balanced_indices: indici globali dei campioni concordanti e bilanciati
        balanced_labels: label corrispondenti (condivise tra i modelli)
        train_indices: indici LOCALI (0..len(balanced_indices)-1) per training
        test_indices: indici LOCALI (0..len(balanced_indices)-1) per test
    """
    print(f" Caricamento IN-MEMORY {model_name} [{type_layer}]: layers {layer_indices}...")

    # Rileva la struttura
    structure_type = detect_structure_type(model_name, dataset_name)
    print(f"  Struttura rilevata: {structure_type}")
    
    # Load and concatenate
    all_features = []
    
    for layer_idx in layer_indices:
        base_path = os.path.join(PROJECT_ROOT, CACHE_DIR_NAME, model_name, dataset_name,
                                 "activation_" + type_layer)
        
        if structure_type == 'new':
            # Nuova struttura: carica da hallucinated/ e not_hallucinated/ separatamente
            hall_path = os.path.join(base_path, "hallucinated", f"layer{layer_idx}_activations.pt")
            not_hall_path = os.path.join(base_path, "not_hallucinated", f"layer{layer_idx}_activations.pt")
            hall_ids_path = os.path.join(base_path, "hallucinated", f"layer{layer_idx}_instance_ids.json")
            not_hall_ids_path = os.path.join(base_path, "not_hallucinated", f"layer{layer_idx}_instance_ids.json")
            
            if not os.path.exists(hall_path) or not os.path.exists(not_hall_path):
                print(f" Warning: Layer {layer_idx} non trovato. Salto.")
                continue
            
            print(f"  Loading layer {layer_idx} (new structure)...", end=" ")
            
            # Carica le attivazioni
            acts_hall = torch.load(hall_path, map_location='cpu')
            acts_not_hall = torch.load(not_hall_path, map_location='cpu')
            
            # Carica gli instance_ids per sapere l'ordine
            with open(hall_ids_path, 'r') as f:
                hall_ids = json.load(f)
            with open(not_hall_ids_path, 'r') as f:
                not_hall_ids = json.load(f)
            
            # Convert to numpy
            if isinstance(acts_hall, torch.Tensor):
                X_hall = acts_hall.float().numpy()
            else:
                X_hall = acts_hall.astype(np.float32)
                
            if isinstance(acts_not_hall, torch.Tensor):
                X_not_hall = acts_not_hall.float().numpy()
            else:
                X_not_hall = acts_not_hall.astype(np.float32)
            
            # Flatten if needed
            if X_hall.ndim > 2:
                X_hall = X_hall.reshape(X_hall.shape[0], -1)
            if X_not_hall.ndim > 2:
                X_not_hall = X_not_hall.reshape(X_not_hall.shape[0], -1)
            
            # Ricostruisci l'array completo nell'ordine originale degli indici
            total_samples = len(hall_ids) + len(not_hall_ids)
            feature_dim = X_hall.shape[1]
            X_layer = np.zeros((total_samples, feature_dim), dtype=np.float32)
            
            # Mappa: instance_id -> posizione nel file
            for i, inst_id in enumerate(hall_ids):
                X_layer[inst_id] = X_hall[i]
            for i, inst_id in enumerate(not_hall_ids):
                X_layer[inst_id] = X_not_hall[i]
            
            del acts_hall, acts_not_hall, X_hall, X_not_hall
            
        else:
            # Vecchia struttura: file direttamente nella cartella
            file_path = os.path.join(base_path, f"layer{layer_idx}_activations.pt")
            if not os.path.exists(file_path):
                print(f" Warning: Layer {layer_idx} non trovato. Salto.")
                continue

            print(f"  Loading layer {layer_idx} (old structure)...", end=" ")
            acts = torch.load(file_path, map_location='cpu')

            # Convert to numpy
            if isinstance(acts, torch.Tensor):
                X_layer = acts.float().numpy() 
            else:
                X_layer = acts.astype(np.float32)

            # Flatten
            if X_layer.ndim > 2:
                X_layer = X_layer.reshape(X_layer.shape[0], -1)
            
            del acts
        
        # Seleziona SOLO i campioni bilanciati (usando gli indici globali)
        X_layer = X_layer[balanced_indices]
            
        all_features.append(X_layer)
        print(f"done ({X_layer.shape})")
        
        gc.collect()

    if not all_features:
        raise ValueError(f"Nessun layer valido trovato per {model_name}")

    print(" Concatenating layers...")
    X_balanced = np.concatenate(all_features, axis=1)
    
    # Ora train_indices e test_indices sono indici LOCALI (0..len(balanced_indices)-1)
    X_train = X_balanced[train_indices]
    X_test = X_balanced[test_indices]
    y_train = balanced_labels[train_indices]
    y_test = balanced_labels[test_indices]
    
    print(f" Completato! Train: {X_train.shape}, Test: {X_test.shape}")

    return X_train, X_test, y_train, y_test


def run_experiment_pipeline_cached(X_teacher, y_teacher, teacher_name,
                                   X_student, y_student, student_name, layer_type, config_name,
                                   alignment_config=ALIGNMENT_CONFIG,
                                   probe_config=PROBE_CONFIG):
    
    print(f"\n{'='*60}")
    print(f"EXPERIMENT: {layer_type.upper()} → {teacher_name} ← {student_name}")
    print(f"{'='*60}")

    # Dati già splittati (numpy per sklearn)
    X_A_train_full, X_A_test = X_teacher['X_train'], X_teacher['X_test']
    y_A_train_full, y_A_test = y_teacher['y_train'], y_teacher['y_test']
    X_B_train_full, X_B_test = X_student['X_train'], X_student['X_test']
    y_B_train_full, y_B_test = y_student['y_train'], y_student['y_test']

    device = DEVICE
    print(f"Using device: {device}")

    # --------------------------------------------------
    # 1. Teacher Probing (su FULL training set)
    # --------------------------------------------------
    print("1. Training teacher probe on FULL training set...")
    probe_teacher = LogisticRegression(
        max_iter=probe_config['max_iter'],
        class_weight=probe_config['class_weight'],
        solver=probe_config['solver'],
        n_jobs=probe_config['n_jobs']
    )
    probe_teacher.fit(X_A_train_full, y_A_train_full)
    
    # --- METRICHE TEACHER ---
    y_pred_teacher = probe_teacher.predict(X_A_test)
    y_proba_teacher = probe_teacher.predict_proba(X_A_test)[:, 1]
    cm_teacher = confusion_matrix(y_A_test, y_pred_teacher)
    acc_teacher = accuracy_score(y_A_test, y_pred_teacher)
    prec_teacher = precision_score(y_A_test, y_pred_teacher)
    rec_teacher = recall_score(y_A_test, y_pred_teacher)
    f1_teacher = f1_score(y_A_test, y_pred_teacher)
    auroc_teacher = roc_auc_score(y_A_test, y_proba_teacher)
    print(f"   Acc teacher: {acc_teacher:.4f}, AUROC: {auroc_teacher:.4f}")

    # --------------------------------------------------
    # 2. Alignment Training (Student → Teacher space)
    # --------------------------------------------------
    print("2. Training alignment network (with 90/10 validation split)...")
    
    # Preconversion a torch.Tensor UNA VOLTA SOLA
    X_A_train_full_t = torch.from_numpy(X_A_train_full).float()
    X_A_test_t       = torch.from_numpy(X_A_test).float()
    X_B_train_full_t = torch.from_numpy(X_B_train_full).float()
    X_B_test_t       = torch.from_numpy(X_B_test).float()

    # Create Validation Split (10%) per l'alignment network SOLTANTO
    num_train = len(X_B_train_full)
    indices = np.arange(num_train)
    np.random.seed(SEED)
    np.random.shuffle(indices)
    val_size = int(num_train * 0.1)
    train_idx_local = indices[val_size:]
    val_idx_local = indices[:val_size]

    # Slice diretta sui tensori (no conversione per-item)
    X_B_align_train = X_B_train_full_t[train_idx_local]
    X_A_align_train = X_A_train_full_t[train_idx_local]
    
    X_B_val = X_B_train_full_t[val_idx_local]
    X_A_val = X_A_train_full_t[val_idx_local]

    train_dataset = AlignmentDataset(X_B_align_train.to(device), X_A_align_train.to(device))
    val_dataset   = AlignmentDataset(X_B_val.to(device),  X_A_val.to(device))
    
    train_loader = DataLoader(train_dataset, batch_size=alignment_config['batch_size'], 
                             shuffle=True, num_workers=0, pin_memory=False,
                             generator=get_generator(SEED))
    val_loader   = DataLoader(val_dataset, batch_size=alignment_config['batch_size'], 
                             shuffle=False, num_workers=0, pin_memory=False)
    
    criterion = MixedLoss(
        alpha=alignment_config['loss_alpha'],
        beta=alignment_config['loss_beta']
    ).to(device)

    set_seed(SEED)  # Reset seed prima di inizializzare il modello
    aligner = AlignmentNetwork(
        input_dim=X_B_align_train.shape[1],
        output_dim=X_A_align_train.shape[1],
        hidden_dim=alignment_config['hidden_dim'],
        dropout=alignment_config['dropout']
    ).to(device)
    
    optimizer = optim.AdamW(
        aligner.parameters(), 
        lr=alignment_config['learning_rate'], 
        weight_decay=alignment_config['weight_decay']
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=alignment_config['max_epochs']
    )
    
    # Early Stopping variables
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    epochs_trained = 0
    
    for epoch in range(alignment_config['max_epochs']):
        # Training
        aligner.train()
        epoch_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            projected = aligner(data)
            loss = criterion(projected, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                aligner.parameters(), 
                max_norm=alignment_config['gradient_clip_max_norm']
            )
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_train_loss = epoch_loss / len(train_loader)
        
        # Validation
        aligner.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                projected = aligner(data)
                loss = criterion(projected, target)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        
        scheduler.step()
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"   Epoch {epoch+1:3d}/{alignment_config['max_epochs']} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")
            
        # Early Stopping Check
        if avg_val_loss < best_val_loss - alignment_config['early_stopping_min_delta']:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = aligner.state_dict()
            epochs_trained = epoch + 1
        else:
            patience_counter += 1
            
        if patience_counter >= alignment_config['early_stopping_patience']:
            print(f"   Early stopping at epoch {epoch+1}. Best Val Loss: {best_val_loss:.6f}")
            break
    
    # Se non c'è stato early stopping, epochs_trained = max_epochs
    if epochs_trained == 0:
        epochs_trained = alignment_config['max_epochs']
            
    # Load best model
    if best_model_state is not None:
        aligner.load_state_dict(best_model_state)
    
    # Save the best alignment network to disk
    model_save_dir = os.path.join("alignment_models", layer_type)
    os.makedirs(model_save_dir, exist_ok=True)
    model_filename = os.path.join(model_save_dir, f"{config_name}_aligner_{student_name}_to_{teacher_name}.pt")
    
    torch.save({
        'model_state_dict': aligner.state_dict(),
        'alignment_config': alignment_config,
        'probe_config': probe_config,
        'input_dim': X_B_align_train.shape[1],
        'output_dim': X_A_align_train.shape[1],
        'best_val_loss': best_val_loss,
        'epochs_trained': epochs_trained,
        'layer_type': layer_type,
        'student_model': student_name,
        'teacher_model': teacher_name,
    }, model_filename)
    print(f"   ✓ Alignment network saved: {model_filename}")

    # --------------------------------------------------
    # 3. Evaluation: Student projected → Teacher probe
    # --------------------------------------------------
    print("3. Projecting student test set & evaluating...")
    aligner.eval()
    with torch.no_grad():
        X_B_projected = aligner(X_B_test_t.to(device)).cpu().numpy()
    y_pred_cross = probe_teacher.predict(X_B_projected)
    y_proba_cross = probe_teacher.predict_proba(X_B_projected)[:, 1]
    
    # --- METRICHE CROSS-MODEL ---
    cm_cross = confusion_matrix(y_B_test, y_pred_cross)
    acc_cross = accuracy_score(y_B_test, y_pred_cross)
    prec_cross = precision_score(y_B_test, y_pred_cross)
    rec_cross = recall_score(y_B_test, y_pred_cross)
    f1_cross = f1_score(y_B_test, y_pred_cross)
    auroc_cross = roc_auc_score(y_B_test, y_proba_cross)
    
    print(f"\nFINAL RESULT:")
    print(f"   Teacher Acc         : {acc_teacher:.4f}, AUROC: {auroc_teacher:.4f}")
    print(f"   Student → Teacher Acc: {acc_cross:.4f}, AUROC: {auroc_cross:.4f}")
    print(f"   Transfer gap        : {acc_teacher - acc_cross:.4f}")

    return {
        "type": layer_type,
        "teacher_name": teacher_name,
        "student_name": student_name,
        "alignment_model": {
            "input_dim": int(X_B_align_train.shape[1]),
            "output_dim": int(X_A_align_train.shape[1]),
            "config": alignment_config,
            "best_val_loss": float(best_val_loss),
            "epochs_trained": epochs_trained,
            "model_path": model_filename
        },
        "probe_config": probe_config,
        "teacher": {
            "accuracy": acc_teacher,
            "precision": prec_teacher,
            "recall": rec_teacher,
            "f1": f1_teacher,
            "auroc": auroc_teacher,
            "confusion_matrix": cm_teacher.tolist()
        },
        "student_on_teacher": {
            "accuracy": acc_cross,
            "precision": prec_cross,
            "recall": rec_cross,
            "f1": f1_cross,
            "auroc": auroc_cross,
            "confusion_matrix": cm_cross.tolist()
        }
    }


def plot_confusion_matrix(cm, layer_type, model_name="", save_dir="confusion_matrices"):
    """
    Plotta e salva la confusion matrix come immagine.
    """
    os.makedirs(save_dir, exist_ok=True)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True, ax=ax,
                xticklabels=['Non-Hallucinated', 'Hallucinated'],
                yticklabels=['Non-Hallucinated', 'Hallucinated'])
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    title = f'Confusion Matrix - {layer_type.upper()} Layers'
    if model_name:
        title += f' ({model_name})'
    ax.set_title(title)
    
    plt.tight_layout()
    filename = os.path.join(save_dir, f'confusion_matrix_{layer_type}_{model_name}.png' if model_name else f'confusion_matrix_{layer_type}.png')
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"   ✓ Salvato: {filename}")

In [19]:
print("="*80)
print("FASE 1: SPLIT DEI DATI BILANCIATI (70% train / 30% test)")
print("="*80 + "\n")

# Ora lavoriamo sugli indici LOCALI (0..len(balanced_indices)-1)
n_balanced = len(balanced_indices)
rng = np.random.RandomState(SEED)
shuffled_local_indices = rng.permutation(n_balanced)
split_idx = int(0.7 * n_balanced)

train_indices = shuffled_local_indices[:split_idx]
test_indices = shuffled_local_indices[split_idx:]

print(f"Campioni bilanciati totali: {n_balanced}")
print(f"Train: {len(train_indices)}, Test: {len(test_indices)}")
print(f"Label train - Hall: {np.sum(balanced_labels[train_indices]==1)}, Non-Hall: {np.sum(balanced_labels[train_indices]==0)}")
print(f"Label test  - Hall: {np.sum(balanced_labels[test_indices]==1)}, Non-Hall: {np.sum(balanced_labels[test_indices]==0)}")

# Definisci gli scenari di esperimento usando le costanti
scenarios = [
    {"teacher_model": MODEL_A, "student_model": MODEL_B},
    {"teacher_model": MODEL_B, "student_model": MODEL_A}
]

# Struttura per raccogliere i risultati mantenendo l'ordine degli scenari
scenario_results_map = {0: [], 1: []}

# Loop sui layer types (Carica -> Esegui -> Libera Memoria)
for layer_type in ['attn', 'mlp', 'hidden']:
    print(f"\n{'='*40}")
    print(f"PROCESSING LAYER TYPE: {layer_type.upper()}")
    print(f"{'='*40}")
    gc.collect()
    
    try:
        # 1. CARICAMENTO E SPLITTING CON INDICI BILANCIATI
        X_model_a_train, X_model_a_test, y_model_a_train, y_model_a_test = load_and_split_layers(
            MODEL_A, DATASET_NAME, 
            LAYER_CONFIG[MODEL_A][layer_type], 
            layer_type,
            balanced_indices, balanced_labels,
            train_indices, test_indices
        )

        X_model_b_train, X_model_b_test, y_model_b_train, y_model_b_test = load_and_split_layers(
            MODEL_B, DATASET_NAME, 
            LAYER_CONFIG[MODEL_B][layer_type], 
            layer_type,
            balanced_indices, balanced_labels,
            train_indices, test_indices
        )
        
        # 2. SCALING (con cast esplicito a float32 per risparmiare memoria)
        print("   Normalizzazione dati...")
        scaler_model_a = StandardScaler()
        X_model_a_train = scaler_model_a.fit_transform(X_model_a_train).astype(np.float32)
        X_model_a_test = scaler_model_a.transform(X_model_a_test).astype(np.float32)
        
        scaler_model_b = StandardScaler()
        X_model_b_train = scaler_model_b.fit_transform(X_model_b_train).astype(np.float32)
        X_model_b_test = scaler_model_b.transform(X_model_b_test).astype(np.float32)
        
        # Organizza i dati per l'uso (le label sono le stesse per entrambi i modelli!)
        current_data = {
            "model_a": {"X_train": X_model_a_train, "X_test": X_model_a_test, "y_train": y_model_a_train, "y_test": y_model_a_test},
            "model_b": {"X_train": X_model_b_train, "X_test": X_model_b_test, "y_train": y_model_b_train, "y_test": y_model_b_test}
        }

        # 3. ESECUZIONE ESPERIMENTI PER ENTRAMBI GLI SCENARI
        for i, scenario in enumerate(scenarios):
            print(f"\n   --- Scenario: {scenario['teacher_model']} -> {scenario['student_model']} ---")
            
            if scenario['teacher_model'] == MODEL_A:
                X_teacher_data = current_data['model_a']
                X_student_data = current_data['model_b']
            else:
                X_teacher_data = current_data['model_b']
                X_student_data = current_data['model_a']
            
            res = run_experiment_pipeline_cached(
                X_teacher_data, X_teacher_data, scenario['teacher_model'],
                X_student_data, X_student_data, scenario['student_model'],
                layer_type, "CONFIG1"
            )
            scenario_results_map[i].append(res)
            
            # Plot confusion matrices
            plot_confusion_matrix(
                np.array(res['teacher']['confusion_matrix']), 
                layer_type, 
                f"Teacher_{scenario['teacher_model'].replace('.', '_')}"
            )
            plot_confusion_matrix(
                np.array(res['student_on_teacher']['confusion_matrix']), 
                layer_type, 
                f"{scenario['student_model'].replace('.', '_')}_on_{scenario['teacher_model'].replace('.', '_')}"
            )

        # 4. PULIZIA MEMORIA
        del current_data, X_model_a_train, X_model_a_test, X_model_b_train, X_model_b_test
        del scaler_model_a, scaler_model_b
        gc.collect()
        torch.cuda.empty_cache()
        print(f"   Memoria liberata per {layer_type}.")

    except Exception as e:
        print(f"Errore critico nel layer {layer_type}: {e}")
        traceback.print_exc()
        raise

# Ricostruisci la struttura all_results per il salvataggio JSON
all_results = []
for i, scenario in enumerate(scenarios):
    all_results.append({
        "scenario": f"{scenario['teacher_model']} (teacher) → {scenario['student_model']} (student)",
        "results": scenario_results_map[i]
    })

# Salva tutti i risultati in JSON
os.makedirs("results_metrics", exist_ok=True)
metrics_file = "results_metrics/hybrid_adapter_logreg_results.json"

all_results_json = []
for scenario_data in all_results:
    scenario_results = []
    for r in scenario_data['results']:
        config = r['alignment_model']['config']
        probe_cfg = r['probe_config']
        
        scenario_results.append({
            "layer_type": r['type'],
            "teacher_model": r['teacher_name'],
            "student_model": r['student_name'],
            "data_info": {
                "total_balanced_samples": int(n_balanced),
                "train_samples": int(len(train_indices)),
                "test_samples": int(len(test_indices)),
                "concordant_undersampling": True
            },
            "alignment_model_info": {
                "architecture": "AlignmentNetwork",
                "input_dim": r['alignment_model']['input_dim'],
                "output_dim": r['alignment_model']['output_dim'],
                "hidden_dim": config['hidden_dim'],
                "dropout": config['dropout'],
                "activation": "GELU",
                "normalization": "LayerNorm",
                "residual_connection": True,
                "initialization": "zero_init"
            },
            "training_hyperparameters": {
                "optimizer": config['optimizer'],
                "learning_rate": config['learning_rate'],
                "weight_decay": config['weight_decay'],
                "batch_size": config['batch_size'],
                "max_epochs": config['max_epochs'],
                "scheduler": config['scheduler'],
                "gradient_clip_max_norm": config['gradient_clip_max_norm'],
                "early_stopping_patience": config['early_stopping_patience'],
                "early_stopping_min_delta": config['early_stopping_min_delta']
            },
            "loss_function": {
                "type": "MixedLoss",
                "mse_weight": config['loss_alpha'],
                "cosine_weight": config['loss_beta']
            },
            "training_results": {
                "alignment_network": {
                    "best_val_loss": round(r['alignment_model']['best_val_loss'], 6),
                    "epochs_trained": r['alignment_model']['epochs_trained'],
                    "model_saved_path": r['alignment_model']['model_path']
                }
            },
            "teacher_probe": {
                "type": probe_cfg['type'],
                "max_iter": probe_cfg['max_iter'],
                "class_weight": probe_cfg['class_weight'],
                "solver": probe_cfg['solver']
            },
            "metrics": {
                "teacher": {
                    "accuracy": round(r['teacher']['accuracy'], 4),
                    "precision": round(r['teacher']['precision'], 4),
                    "recall": round(r['teacher']['recall'], 4),
                    "f1_score": round(r['teacher']['f1'], 4),
                    "auroc": round(r['teacher']['auroc'], 4),
                    "confusion_matrix": {
                        "TN": int(r['teacher']['confusion_matrix'][0][0]),
                        "FP": int(r['teacher']['confusion_matrix'][0][1]),
                        "FN": int(r['teacher']['confusion_matrix'][1][0]),
                        "TP": int(r['teacher']['confusion_matrix'][1][1])
                    }
                },
                "student_on_teacher": {
                    "accuracy": round(r['student_on_teacher']['accuracy'], 4),
                    "precision": round(r['student_on_teacher']['precision'], 4),
                    "recall": round(r['student_on_teacher']['recall'], 4),
                    "f1_score": round(r['student_on_teacher']['f1'], 4),
                    "auroc": round(r['student_on_teacher']['auroc'], 4),
                    "confusion_matrix": {
                        "TN": int(r['student_on_teacher']['confusion_matrix'][0][0]),
                        "FP": int(r['student_on_teacher']['confusion_matrix'][0][1]),
                        "FN": int(r['student_on_teacher']['confusion_matrix'][1][0]),
                        "TP": int(r['student_on_teacher']['confusion_matrix'][1][1])
                    }
                }
            }
        })

    all_results_json.append({
        "scenario": scenario_data['scenario'],
        "results": scenario_results
    })

with open(metrics_file, 'w') as f:
    json.dump(all_results_json, f, indent=2)

print(f"\n{'='*60}")

print(f"✓ Risultati salvati in: {metrics_file}")
print(f"{'='*60}")

FASE 1: SPLIT DEI DATI BILANCIATI (70% train / 30% test)

Campioni bilanciati totali: 934
Train: 653, Test: 281
Label train - Hall: 314, Non-Hall: 339
Label test  - Hall: 153, Non-Hall: 128

PROCESSING LAYER TYPE: ATTN
 Caricamento IN-MEMORY gemma-2-9b-it [attn]: layers [21, 24, 27]...
  Struttura rilevata: new
  Loading layer 21 (new structure)... done ((934, 3584))
  Loading layer 24 (new structure)... done ((934, 3584))
  Loading layer 24 (new structure)... done ((934, 3584))
  Loading layer 27 (new structure)... done ((934, 3584))
  Loading layer 27 (new structure)... done ((934, 3584))
 Concatenating layers...
 Completato! Train: (653, 10752), Test: (281, 10752)
 Caricamento IN-MEMORY Llama-3.1-8B-Instruct [attn]: layers [8, 13, 14]...
  Struttura rilevata: new
  Loading layer 8 (new structure)... done ((934, 3584))
 Concatenating layers...
 Completato! Train: (653, 10752), Test: (281, 10752)
 Caricamento IN-MEMORY Llama-3.1-8B-Instruct [attn]: layers [8, 13, 14]...
  Struttura ri

In [20]:
with open(metrics_file, 'w') as f:
    json.dump(all_results_json, f, indent=4)