In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, GCNConv
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import numpy as np
import pandas as pd
import pickle
import random
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt
from scipy import stats
import hashlib
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


# --------------------
# 1. Strict Randomness Control
# --------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)


# --------------------
# 2. Model Parameter Hash Calculation
# --------------------
def get_model_hash(model):
    """Calculate MD5 hash of model parameters (device-agnostic, exclude dynamic counters)"""
    params = []
    for k, v in model.state_dict().items():
        if "num_batches_tracked" in k:
            continue
        params.append(v.detach().cpu().float().view(-1))
    params = torch.cat(params).numpy()
    return hashlib.md5(params.tobytes()).hexdigest()


# --------------------
# 3. Model Definitions (GCN instead of GNN)
# --------------------
class SimpleGCN(nn.Module):
    def __init__(self, node_dim, edge_dim, global_dim, hidden_dims, dropout=0.2):
        super().__init__()
        self.norm = nn.BatchNorm1d(node_dim)
        if edge_dim:
            self.edge_norm = nn.BatchNorm1d(edge_dim)
        if global_dim:
            self.global_norm = nn.BatchNorm1d(global_dim)
            self.global_mlp = nn.Sequential(
                nn.Linear(global_dim, hidden_dims[-1]),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
        self.convs = nn.ModuleList()
        in_dim = node_dim
        for h in hidden_dims:
            self.convs.append(GCNConv(in_dim, h))
            in_dim = h
        self.dropout = nn.Dropout(dropout)
        self.final_dim = hidden_dims[-1] * (2 if global_dim else 1)
        self.output = nn.Sequential(
            nn.Linear(self.final_dim, self.final_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.final_dim // 2, 1)
        )

    def forward(self, data, return_feat=False):
        x = self.norm(data.x)
        if hasattr(self, 'edge_norm') and hasattr(data, 'edge_attr') and data.edge_attr is not None:
            _ = self.edge_norm(data.edge_attr)
        u = data.u if hasattr(data, 'u') else None
        if u is not None and hasattr(self, 'global_norm'):
            u = self.global_norm(u)
        
        for conv in self.convs:
            x = F.relu(conv(x, data.edge_index))
            x = self.dropout(x)
        
        node_pool = global_mean_pool(x, data.batch)
        h = torch.cat([node_pool, self.global_mlp(u)], dim=1) if (u is not None and hasattr(self, 'global_mlp')) else node_pool
        out = self.output(h).squeeze()
        return (out, h) if return_feat else out

class EnhancedGCN(nn.Module):
    def __init__(self, node_dim, edge_dim, global_dim, hidden_dims, dropout=0.1):
        super().__init__()
        self.node_norm = nn.BatchNorm1d(node_dim)
        self.edge_norm = nn.BatchNorm1d(edge_dim) if edge_dim else None
        self.global_norm = nn.BatchNorm1d(global_dim) if global_dim else None
        if global_dim:
            self.global_mlp = nn.Sequential(
                nn.Linear(global_dim, hidden_dims[-1]),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
        self.convs = nn.ModuleList()
        in_dim = node_dim
        for h in hidden_dims:
            self.convs.append(GCNConv(in_dim, h))
            in_dim = h
        self.dropout = nn.Dropout(dropout)
        self.final_dim = hidden_dims[-1] * (2 if global_dim else 1)
        self.output_mlp = nn.Sequential(
            nn.Linear(self.final_dim, self.final_dim//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.final_dim//2, 1)
        )

    def forward(self, data, return_feat=False):
        x = self.node_norm(data.x)
        if self.edge_norm and hasattr(data, 'edge_attr') and data.edge_attr is not None:
            _ = self.edge_norm(data.edge_attr)
        u = getattr(data, 'u', None)
        gf = None
        if u is not None and self.global_norm is not None:
            u = self.global_norm(u)
            gf = self.global_mlp(u)
        for conv in self.convs:
            x = F.relu(conv(x, data.edge_index))
            x = self.dropout(x)
        pooled = global_mean_pool(x, data.batch)
        h = torch.cat([pooled, gf], dim=1) if gf is not None else pooled
        out = self.output_mlp(h).squeeze()
        return (out, h) if return_feat else out


class GateNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_teachers):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_teachers)
    def forward(self, h):
        a = F.relu(self.fc1(h))
        return F.softmax(self.fc2(a), dim=-1)


class Adapter(nn.Module):
    def __init__(self, dim_s, dim_t):
        super().__init__()
        self.linear = nn.Linear(dim_s, dim_t)
    def forward(self, h):
        return self.linear(h)


# --------------------
# 4. Data Loader (Reproducible Shuffle)
# --------------------
def create_data_loader(graph_data, batch_size=32, shuffle=True, seed=42):
    """Create reproducible data loader with fixed shuffle generator"""
    data_list = []
    for graph in graph_data:
        x = graph['x'].float() if isinstance(graph['x'], torch.Tensor) else torch.tensor(graph['x'], dtype=torch.float32)
        edge_index = graph['edge_index'].long() if isinstance(graph['edge_index'], torch.Tensor) else torch.tensor(graph['edge_index'], dtype=torch.long)
        edge_attr = None
        if 'edge_attr' in graph and graph['edge_attr'] is not None:
            edge_attr = graph['edge_attr'].float() if isinstance(graph['edge_attr'], torch.Tensor) else torch.tensor(graph['edge_attr'], dtype=torch.float32)
        u = None
        if 'u' in graph and graph['u'] is not None:
            u = graph['u'].float() if isinstance(graph['u'], torch.Tensor) else torch.tensor(graph['u'], dtype=torch.float32)
        y = graph['y'].float() if isinstance(graph['y'], torch.Tensor) else torch.tensor(graph['y'], dtype=torch.float32)
        y_soft = graph.get('y_soft', y).float() if isinstance(graph.get('y_soft', y), torch.Tensor) else torch.tensor(graph.get('y_soft', y), dtype=torch.float32)
        
        data = Data(
            x=x, edge_index=edge_index,
            edge_attr=edge_attr, u=u,
            y=y, y_soft=y_soft,
            idx=graph.get('idx', len(data_list))
        )
        data_list.append(data)
    
    if shuffle:
        generator = torch.Generator().manual_seed(seed)
        return DataLoader(
            data_list,
            batch_size=batch_size,
            shuffle=True,
            generator=generator,
            num_workers=0
        )
    else:
        return DataLoader(data_list, batch_size=batch_size, shuffle=False, num_workers=0)


# --------------------
# 5. Evaluation and Prediction
# --------------------
def eval_loader(loader, model, device):
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for b in loader:
            b = b.to(device)
            out, _ = model(b, return_feat=True)
            ys.append(b.y.view(-1).cpu().numpy())
            ps.append(out.cpu().numpy())
    ys = np.concatenate(ys)
    ps = np.concatenate(ps)
    return r2_score(ys, ps), ys, ps


def predict_and_evaluate(model, data_loader, y_mean, y_std, device, dataset_name, save_dir, seed=42):
    set_seed(seed)
    r2_norm, true_norm, pred_norm = eval_loader(data_loader, model, device)
    
    true_raw = true_norm * y_std + y_mean
    pred_raw = pred_norm * y_std + y_mean
    r2_raw = r2_score(true_raw, pred_raw)
    
    print(f"[{dataset_name}] Normalized RÂ²: {r2_norm:.6f} | Raw Scale RÂ²: {r2_raw:.6f}")
    
    # Scatter plot
    plt.figure(figsize=(8, 6))
    plt.scatter(true_raw, pred_raw, alpha=0.6, label=f'RÂ² = {r2_raw:.4f}')
    slope, intercept, _, _, _ = stats.linregress(true_raw, pred_raw)
    plt.plot(true_raw, intercept + slope * true_raw, 'r--')
    plt.xlabel('True Values', fontname="Times New Roman", fontsize=12)
    plt.ylabel('Predicted Values', fontname="Times New Roman", fontsize=12)
    plt.title(f'{dataset_name} Prediction Results', fontname="Times New Roman", fontsize=14)
    plt.legend(prop={"family": "Times New Roman"})
    plt.grid(alpha=0.3)
    plt.tight_layout()
    scatter_path = os.path.join(save_dir, f'{dataset_name}_scatter.png')
    plt.savefig(scatter_path, dpi=300)
    plt.close()
    print(f"Scatter plot saved to: {scatter_path}")
    
    # Save prediction results
    indices = np.array([data.idx for data in data_loader.dataset])
    df = pd.DataFrame({
        'Sample_Index': indices,
        'True_Value_Raw': true_raw,
        'Pred_Value_Raw': pred_raw,
        'True_Value_Normalized': true_norm,
        'Pred_Value_Normalized': pred_norm
    })
    csv_path = os.path.join(save_dir, f'{dataset_name}_predictions.csv')
    df.to_csv(csv_path, index=False, encoding='utf-8-sig')
    print(f"Prediction results saved to: {csv_path}\n")
    
    return r2_raw, true_norm, pred_norm


# --------------------
# 6. Main Training Function
# --------------------
def train_model(
    train_dir, val_dir, teacher_paths, save_path,
    gate_hidden=128, hint_lambda=5.0, weight_ratio=(0.6, 0.4),
    hidden_dims=[128, 128], dropout=0.1,
    epochs=1000, batch_size=64, lr=1e-3, min_lr=1e-4,
    lr_patience=20, 
    seed=42,
    no_save_epochs=200,
    es_patience=100,
    r2_effective_thresh=0.005,
    best_r2_rel_thresh=0.005
):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU model: {torch.cuda.get_device_name(0)}")
    set_seed(seed)

    # Load data
    train_g = torch.load(os.path.join(train_dir, 'graph_data.pt'), map_location='cpu')
    val_g = torch.load(os.path.join(val_dir, 'graph_data.pt'), map_location='cpu')
    for i, g in enumerate(train_g):
        g['idx'] = i
    for i, g in enumerate(val_g):
        g['idx'] = i

    # Label normalization
    ys = torch.stack([g['y'].float() for g in train_g]).view(-1)
    y_mean, y_std = ys.mean().item(), ys.std().item() + 1e-8
    for g in train_g + val_g:
        g['y'] = (g['y'].float() - y_mean) / y_std
        g.setdefault('y_soft', g['y'].clone())
        g['y_soft'] = g['y_soft'].float()

    # Create data loaders
    train_loader = create_data_loader(train_g, batch_size, shuffle=True, seed=seed)
    val_loader = create_data_loader(val_g, batch_size, shuffle=False, seed=seed)
    val_batch = next(iter(val_loader)) if len(val_loader) > 0 else None
    tr_loader_pred = create_data_loader(train_g, batch_size, shuffle=False, seed=seed)
    va_loader_pred = create_data_loader(val_g, batch_size, shuffle=False, seed=seed)

    # Get input dimensions
    sample = train_g[0]
    n_dim = sample['x'].size(1) if isinstance(sample['x'], torch.Tensor) else len(sample['x'][0])
    e_dim = sample.get('edge_attr', None).size(1) if (sample.get('edge_attr') is not None) else 0
    g_dim = sample.get('u', None).size(1) if (sample.get('u') is not None) else 0

    # Build student model
    student = EnhancedGCN(n_dim, e_dim, g_dim, hidden_dims, dropout).to(device)
    
    # Load teacher models
    teachers = []
    for p in teacher_paths:
        ck = torch.load(p, map_location=device)
        t = SimpleGCN(
            node_dim=ck['node_dim'], edge_dim=ck.get('edge_dim', 0),
            global_dim=ck.get('global_dim', 0), hidden_dims=ck['hidden_dims']
        ).to(device)
        t.load_state_dict(ck['model_state_dict'], strict=True)
        t.eval()
        teachers.append(t)
    print(f"Loaded {len(teachers)} teacher models")

    # Optimizer settings
    K = len(teachers)
    gate = GateNet(student.final_dim, gate_hidden, K).to(device)
    adapter = Adapter(student.final_dim, student.final_dim).to(device)
    optimizer = optim.Adam(
        params=list(student.parameters()) + list(gate.parameters()) + list(adapter.parameters()),
        lr=lr, weight_decay=1e-5
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=lr_patience, min_lr=min_lr, verbose=True
    )
    scaler = GradScaler(enabled=(device.type == 'cuda'))

    # Training history and early stopping
    history = {'loss': [], 'train_r2': [], 'val_r2': []}
    best_val_r2 = -float('inf')
    best_train_r2 = -float('inf')
    patience = 0
    print(f"\nðŸš€ Start training (Total {epochs} epochs, seed={seed}) ===")
    print(f"R2 effective improvement: absolute â‰¥{r2_effective_thresh} or relative â‰¥{best_r2_rel_thresh*100}% of best R2")

    # Training loop
    for epoch in range(1, epochs + 1):
        student.train()
        gate.train()
        adapter.train()
        total_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}/{epochs}", leave=False)
        for batch in pbar:
            batch = batch.to(device)
            optimizer.zero_grad()
            with autocast(enabled=(device.type == 'cuda')):
                pred_s, h_s = student(batch, return_feat=True)
                with torch.no_grad():
                    Ht = torch.stack([t(batch, return_feat=True)[1] for t in teachers], dim=1)
                w = gate(h_s)
                Ht_g = (w.unsqueeze(-1) * Ht).sum(dim=1)
                loss_hint = F.mse_loss(adapter(h_s), Ht_g)
                fused = (w * batch.y_soft).sum(dim=1)
                loss_pred = F.mse_loss(pred_s, batch.y.view(-1))
                loss_fused = F.mse_loss(pred_s, fused)
                total_loss_batch = (
                    weight_ratio[0] * loss_pred +
                    weight_ratio[1] * loss_fused +
                    hint_lambda * loss_hint
                )
            scaler.scale(total_loss_batch).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()
            total_loss += total_loss_batch.item() * batch.num_graphs
            pbar.set_postfix({"batch_loss": f"{total_loss_batch.item():.4f}"})
        pbar.close()

        # Evaluation
        avg_loss = total_loss / len(train_g)
        train_r2, _, _ = eval_loader(train_loader, student, device)
        val_r2, _, _ = eval_loader(val_loader, student, device)
        history['loss'].append(avg_loss)
        history['train_r2'].append(train_r2)
        history['val_r2'].append(val_r2)
        current_lr = optimizer.param_groups[0]['lr']

        # Effective improvement judgment
        r2_effective = False
        log_msg = ""
        current_val_r2 = val_r2

        # No save for first N epochs
        if epoch <= no_save_epochs:
            log_msg = f"[First {no_save_epochs} epochs] Epoch {epoch:3d} | Train R2: {train_r2:.4f} | Val R2: {current_val_r2:.4f} (Not for optimal calculation)"
        
        # Optimal model judgment after N epochs
        else:
            rel_r2_thresh = max(r2_effective_thresh, best_val_r2 * best_r2_rel_thresh)
            
            if current_val_r2 > best_val_r2:
                r2_delta = current_val_r2 - best_val_r2
                if r2_delta >= r2_effective_thresh or r2_delta >= rel_r2_thresh:
                    r2_effective = True
                    best_val_r2 = current_val_r2
                    best_train_r2 = train_r2
                    
                    # Save best model
                    student_hash = get_model_hash(student)
                    gate_hash = get_model_hash(gate)
                    adapter_hash = get_model_hash(adapter)
                    with torch.no_grad():
                        fixed_output = student(val_batch.to(device)).cpu().numpy() if val_batch is not None else None
                    
                    torch.save({
                        'student_state_dict': student.state_dict(),
                        'gate_state_dict': gate.state_dict(),
                        'adapter_state_dict': adapter.state_dict(),
                        'y_mean': y_mean, 'y_std': y_std,
                        'node_dim': n_dim, 'edge_dim': e_dim, 'global_dim': g_dim,
                        'hidden_dims': hidden_dims, 'dropout': dropout,
                        'gate_hidden': gate_hidden, 'hint_lambda': hint_lambda,
                        'weight_ratio': weight_ratio,
                        'best_val_r2': best_val_r2,
                        'best_train_r2': best_train_r2,
                        'student_param_hash': student_hash,
                        'gate_param_hash': gate_hash,
                        'adapter_param_hash': adapter_hash,
                        'fixed_output': fixed_output,
                        'seed': seed,
                        'saved_epoch': epoch,
                        'batch_size': batch_size,
                        'lr': lr
                    }, save_path)
                    log_msg = (f"[Effective improvement] Epoch {epoch:3d} | "
                               f"Train R2: {train_r2:.4f} | Val R2 updated to {best_val_r2:.4f} (Improvement {r2_delta:.4f}) | "
                               f"Student model hash={student_hash[:8]}... | Model saved")
                else:
                    log_msg = (f"[Minor improvement] Epoch {epoch:3d} | "
                               f"Train R2: {train_r2:.4f} | Val R2={current_val_r2:.4f} (Improvement {r2_delta:.4f} < Threshold) | "
                               f"Current best Val R2={best_val_r2:.4f}")
            else:
                log_msg = (f"[No improvement] Epoch {epoch:3d} | "
                           f"Train R2: {train_r2:.4f} | Val R2={current_val_r2:.4f} | "
                           f"Current best Val R2={best_val_r2:.4f}")

        # Update early stopping counter
        if epoch > no_save_epochs:
            if r2_effective:
                patience = 0
                print(log_msg)
            else:
                patience += 1
                patience_checks = [es_patience//4, es_patience//2, int(es_patience*0.75), es_patience]
                if patience % 10 == 0 or patience in patience_checks:
                    print(f"[Early stopping counter] Epoch {epoch:3d} | No effective improvement for {patience}/{es_patience} epochs | "
                          f"Train R2: {train_r2:.4f} | Current Val R2={current_val_r2:.4f} | Best Val R2={best_val_r2:.4f}")
        else:
            if epoch % 10 == 0 or epoch == 1:
                print(log_msg)

        # Stage switch prompt
        if epoch == no_save_epochs:
            print(f"\n[Stage switch] Epoch {epoch:3d} | First {no_save_epochs} epochs ended, start optimal R2 calculation and model saving")

        # Trigger early stopping
        if epoch > no_save_epochs and patience >= es_patience:
            print(f"\nðŸ›‘ [Early stopping triggered] Epoch {epoch:3d} | No effective R2 improvement for {es_patience} consecutive epochs | "
                  f"Best Val R2={best_val_r2:.4f} | Corresponding Train R2={best_train_r2:.4f}")
            break

        # Learning rate scheduling
        scheduler.step(avg_loss)

    # Plot training curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_r2'], label='Train R2', color='blue')
    plt.plot(history['val_r2'], label='Val R2', color='orange')
    plt.axhline(y=best_val_r2, color='red', linestyle='--', label=f'Best Val R2 ({best_val_r2:.4f})')
    plt.axvline(x=no_save_epochs, color='gray', linestyle=':', label=f'Optimal calc start ({no_save_epochs})')
    if os.path.exists(save_path):
        best_epoch = torch.load(save_path)['saved_epoch']
        plt.axvline(x=best_epoch, color='green', linestyle='-.', label=f'Best R2 epoch ({best_epoch})')
        plt.scatter(best_epoch, best_val_r2, color='red', s=100, zorder=5)
    plt.title('Training & Validation R2', fontname="Times New Roman", fontsize=14)
    plt.xlabel('Epoch', fontname="Times New Roman", fontsize=12)
    plt.ylabel('R2 Score', fontname="Times New Roman", fontsize=12)
    plt.legend(prop={"family": "Times New Roman"})
    plt.grid(True, linestyle='--', alpha=0.7)

    plt.subplot(1, 2, 2)
    plt.plot(history['loss'], label='Training Loss', color='blue')
    plt.axvline(x=no_save_epochs, color='gray', linestyle=':', label=f'Optimal calc start ({no_save_epochs})')
    plt.title('Training Loss', fontname="Times New Roman", fontsize=14)
    plt.xlabel('Epoch', fontname="Times New Roman", fontsize=12)
    plt.ylabel('MSE Loss', fontname="Times New Roman", fontsize=12)
    plt.legend(prop={"family": "Times New Roman"})
    plt.grid(True, linestyle='--', alpha=0.7)

    plt.tight_layout()
    curve_path = save_path.replace('.pt', '_curves.png')
    plt.savefig(curve_path, dpi=300)
    plt.close()
    print(f"Training curves saved to: {curve_path}")

    # Training completion prompt
    if os.path.exists(save_path):
        saved_info = torch.load(save_path)
        print(f"\nðŸŽ¯ Training completed! Best Val R2: {best_val_r2:.4f} (Corresponding Train R2: {saved_info['best_train_r2']:.4f}), "
              f"Model saved at {save_path} (Saved epoch: {saved_info['saved_epoch']})")
    else:
        print(f"\nðŸŽ¯ Training completed! No model saved (Early stopped at epoch {epoch}, not exceeding {no_save_epochs} epochs or no effective improvement)")

    # Prediction and evaluation with best model
    if os.path.exists(save_path):
        save_dir = os.path.dirname(save_path)
        print("\n" + "="*50)
        print(f"Predict with best model (Epoch {saved_info['saved_epoch']})...")
        
        train_r2, _, _ = predict_and_evaluate(
            model=student, data_loader=tr_loader_pred,
            y_mean=y_mean, y_std=y_std, device=device,
            dataset_name="Train_Set", save_dir=save_dir, seed=seed
        )
        
        test_r2, _, _ = predict_and_evaluate(
            model=student, data_loader=va_loader_pred,
            y_mean=y_mean, y_std=y_std, device=device,
            dataset_name="Test_Set", save_dir=save_dir, seed=seed
        )

    return best_val_r2, save_path


# --------------------
# 7. Main Entry (Multi-seed Training)
# --------------------
if __name__ == '__main__':
    # Hyperparameters
    seeds = [0, 42, 100, 2025]
    hint_lambdas = [1, 5, 10, 20]
    weight_ratios = [(0.4, 0.6), (0.5,0.5), (0.6, 0.4), (0.7,0.3)]
    gate_hiddens = [64, 128, 256]

    # Path placeholders
    train_dir = "TRAIN_DATA_DIR_PATH"
    val_dir = "VAL_DATA_DIR_PATH"
    teacher_paths = ["TEACHER_MODEL_PATH_1", "TEACHER_MODEL_PATH_2", "TEACHER_MODEL_PATH_3", "TEACHER_MODEL_PATH_4", "TEACHER_MODEL_PATH_5"]
    save_root = "MODEL_SAVE_ROOT_PATH"
    os.makedirs(save_root, exist_ok=True)

    # Training parameters
    epochs = 2000
    batch_size = 32
    lr = 1e-3
    min_lr = 5e-5
    lr_patience = 30
    hidden_dims = [128, 128]
    dropout = 0.1

    # Run experiments
    results = []
    for hint_lambda in hint_lambdas:
        for weight_ratio in weight_ratios:
            for gate_hidden in gate_hiddens:
                combo_key = f"hl{hint_lambda}_wr{weight_ratio[0]}_{weight_ratio[1]}_gh{gate_hidden}"
                r2s = []
                print(f"\n" + "="*50)
                print(f"Experiment combo: {combo_key}")
                print("="*50)
                
                for seed in seeds:
                    print(f"\n-- Seed: {seed}")
                    set_seed(seed)
                    save_path = os.path.join(save_root, f"student_{combo_key}_seed{seed}.pt")
                    
                    best_val_r2, _ = train_model(
                        train_dir=train_dir,
                        val_dir=val_dir,
                        teacher_paths=teacher_paths,
                        save_path=save_path,
                        gate_hidden=gate_hidden,
                        hint_lambda=hint_lambda,
                        weight_ratio=weight_ratio,
                        hidden_dims=hidden_dims,
                        dropout=dropout,
                        epochs=epochs,
                        batch_size=batch_size,
                        lr=lr,
                        min_lr=min_lr,
                        lr_patience=lr_patience,
                        seed=seed,
                        no_save_epochs=200,
                        es_patience=100,
                        r2_effective_thresh=0.005,
                        best_val_r2_rel_thresh=0.005
                    )
                    r2s.append(best_val_r2)
                    
                    # Clear GPU cache
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                
                # Summarize multi-seed results
                mean_r2 = np.mean(r2s)
                std_r2 = np.std(r2s) if len(r2s) > 1 else 0.0
                print(f"\n" + "-"*30)
                print(f"Combo {combo_key} results:")
                print(f"  R2 list: {[f'{r:.4f}' for r in r2s]}")
                print(f"  Mean R2: {mean_r2:.4f} | Std R2: {std_r2:.4f}")
                print("-"*30)
                
                results.append({
                    'combo': combo_key,
                    'r2_mean': mean_r2,
                    'r2_std': std_r2
                })

    # Save summary results
    summary_path = os.path.join(save_root, 'results_summary.pkl')
    with open(summary_path, 'wb') as f:
        pickle.dump(results, f)
    print(f"\nAll experiments completed! Summary saved to: {summary_path}")

    # Plot hyperparameter performance
    plt.figure(figsize=(10, 6))
    combos = [r['combo'] for r in results]
    means = [r['r2_mean'] for r in results]
    stds = [r['r2_std'] for r in results]
    
    sorted_idx = np.argsort(means)[::-1]
    combos = [combos[i] for i in sorted_idx]
    means = [means[i] for i in sorted_idx]
    stds = [stds[i] for i in sorted_idx]
    
    bars = plt.bar(combos, means, yerr=stds, capsize=5, alpha=0.7, color='skyblue')
    plt.xticks(rotation=45, ha='right', fontsize=8)
    plt.ylabel('Mean Validation R2', fontsize=12)
    plt.title('Hyperparameter Combo Performance Ranking', fontsize=14)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    
    for bar, mean in zip(bars, means):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                 f'{mean:.4f}', ha='center', va='bottom', fontsize=8)
    
    hyper_param_path = os.path.join(save_root, 'hyper_param_performance.png')
    plt.savefig(hyper_param_path, dpi=300)
    plt.close()
    print(f"Hyperparameter performance plot saved to: {hyper_param_path}")
