In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, GCNConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import numpy as np
import pandas as pd
from tqdm import tqdm
import shutil
import threading
from collections import defaultdict
import time

data_lock = threading.Lock()

def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class GCNTeacher(nn.Module):
    def __init__(self, node_dim: int, global_dim: int, hidden_dims=[128, 128], dropout=0.2):
        super().__init__()
        self.norm = nn.BatchNorm1d(node_dim)
        self.convs = nn.ModuleList()
        in_dim = node_dim
        for h in hidden_dims:
            self.convs.append(GCNConv(in_dim, h))
            in_dim = h
        self.dropout = nn.Dropout(dropout)
        if global_dim:
            self.global_norm = nn.BatchNorm1d(global_dim)
            self.global_mlp = nn.Sequential(
                nn.Linear(global_dim, 128), nn.ReLU(), nn.Dropout(dropout)
            )
            self.final_dim = hidden_dims[-1] + 128
        else:
            self.final_dim = hidden_dims[-1]
        self.output = nn.Sequential(
            nn.Linear(self.final_dim, self.final_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.final_dim // 2, 1)
        )

    def forward(self, data: Data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.norm(x)
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            x = self.dropout(x)
        x = global_mean_pool(x, batch)
        if hasattr(data, 'u') and data.u is not None:
            u = self.global_norm(data.u)
            u = self.global_mlp(u)
            x = torch.cat([x, u], dim=1)
        return self.output(x).squeeze()

def load_graphs(path: str):
    print(f"[Single-threaded] Loading graph data: {path}")
    data = torch.load(os.path.join(path, 'graph_data.pt'), map_location='cpu')
    print(f"Number of samples: {len(data)} | Node dimension: {data[0]['x'].shape[1]}")
    return data

def make_loader(graphs, batch_size, shuffle=True):
    return DataLoader(
        [Data(**g) for g in graphs],
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
        pin_memory=False,
        prefetch_factor=None
    )

def print_gpu_memory(desc: str):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 3)
        cached = torch.cuda.memory_reserved() / (1024 ** 3)
        print(f"[GPU Memory] {desc} | Allocated: {allocated:.2f}GB | Cached: {cached:.2f}GB")

def train_and_eval(graphs, split_seed, learn_seed, dropout, batch_size, epochs=200, patience=50):
    print(f"\n[Training Config] split_seed={split_seed}, learn_seed={learn_seed}, dropout={dropout}, batch_size={batch_size}")
    
    device = torch.device('cuda')
    print(f"[Device] Using GPU: {torch.cuda.get_device_name(0)} (ID: {torch.cuda.current_device()})")

    labels = [g['label'].item() for g in graphs]
    try:
        idx_train, idx_val = train_test_split(
            np.arange(len(graphs)), test_size=0.1, random_state=split_seed,
            shuffle=True, stratify=labels
        )
        print("[Sampling Strategy] ✅ Stratified Sampling")
    except ValueError as e:
        print(f"[Sampling Strategy] ⚠️ Stratified Sampling Failed: {str(e)}, Switching to Random Sampling")
        idx_train, idx_val = train_test_split(
            np.arange(len(graphs)), test_size=0.1, random_state=split_seed, shuffle=True
        )
    
    train_graphs = [graphs[i] for i in idx_train]
    val_graphs = [graphs[i] for i in idx_val]

    ys = torch.stack([g['y'] for g in train_graphs])
    y_mean, y_std = ys.mean().item(), ys.std().item() + 1e-8
    for g in train_graphs + val_graphs:
        g['y'] = (g['y'] - y_mean) / y_std
    del ys
    torch.cuda.empty_cache()

    train_loader = make_loader(train_graphs, batch_size)
    val_loader = make_loader(val_graphs, batch_size, shuffle=False)

    set_seed(learn_seed)
    sample = train_graphs[0]
    node_dim = sample['x'].size(1)
    global_dim = sample['u'].size(1) if sample.get('u', None) is not None else 0
    model = GCNTeacher(node_dim, global_dim, dropout=dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    criterion = nn.MSELoss()

    best_r2, best_state = -np.inf, None
    no_improve_epochs = 0

    for epoch in range(1, epochs + 1):
        model.train()
        train_losses, train_preds, train_trues = [], [], []
        for batch in train_loader:
            batch = batch.to(device, non_blocking=False)
            pred = model(batch)
            loss = criterion(pred, batch.y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            train_preds.append(pred.detach().cpu().numpy())
            train_trues.append(batch.y.cpu().numpy())
            
            del batch, pred, loss
            torch.cuda.empty_cache()

        train_r2 = r2_score(np.concatenate(train_trues), np.concatenate(train_preds))
        train_loss = np.mean(train_losses)
        del train_losses, train_preds, train_trues
        torch.cuda.empty_cache()

        model.eval()
        val_losses, val_preds, val_trues = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device, non_blocking=False)
                pred = model(batch)
                val_losses.append(criterion(pred, batch.y).item())
                val_preds.append(pred.cpu().numpy())
                val_trues.append(batch.y.cpu().numpy())
                
                del batch, pred
                torch.cuda.empty_cache()

        val_r2 = r2_score(np.concatenate(val_trues), np.concatenate(val_preds))
        val_loss = np.mean(val_losses)
        del val_losses, val_preds, val_trues
        torch.cuda.empty_cache()

        if epoch % 10 == 0 or epoch == 1 or epoch == epochs:
            print(f"[Epoch {epoch:4d}/{epochs}] "
                  f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
                  f"Train R²: {train_r2:.4f} | Val R²: {val_r2:.4f}")

        if val_r2 > best_r2:
            best_r2 = val_r2
            best_state = model.state_dict()
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"[Early Stopping Triggered] Val R² has no improvement for {patience} consecutive epochs, stopping training")
                break

    del model, optimizer, criterion, train_loader, val_loader
    del train_graphs, val_graphs
    torch.cuda.empty_cache()
    print_gpu_memory(f"Parameter combination (ss={split_seed}, ls={learn_seed}) after training completion")
    print(f"[Training Result] Best Val R²: {best_r2:.4f}")

    return best_r2, best_state, y_mean, y_std, node_dim, global_dim

def generate_predictions(model, graphs, y_mean, y_std, batch_size=64):
    print("\n[Prediction] Starting to generate results")
    loader = make_loader(graphs, batch_size, shuffle=False)
    device = next(model.parameters()).device
    model.eval()

    normalized_preds = []
    denormalized_preds = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="[Prediction Progress]"):
            batch = batch.to(device, non_blocking=False)
            p_normalized = model(batch)
            
            p_normalized_cpu = p_normalized.cpu().numpy()
            p_denormalized_cpu = p_normalized_cpu * y_std + y_mean
            
            normalized_preds.extend(p_normalized_cpu)
            denormalized_preds.extend(p_denormalized_cpu)
            
            del batch, p_normalized, p_normalized_cpu
            torch.cuda.empty_cache()

    del model
    torch.cuda.empty_cache()
    print("[Prediction] Completed")
    return np.array(normalized_preds), np.array(denormalized_preds)

def process_feature_directory(feat_dir, predict_dir, output_dir, epochs, split_seeds, learn_seeds, dropouts, batch_sizes, 
                              all_normalized_labels, all_denormalized_labels):
    print(f"\n" + "="*50)
    print(f"[Feature Processing] Starting to process: {feat_dir}")
    start_time = time.time()

    try:
        graphs = load_graphs(feat_dir)
        feat_name = os.path.basename(feat_dir.rstrip('/\\'))
        feat_output_dir = os.path.join(output_dir, feat_name)
        os.makedirs(feat_output_dir, exist_ok=True)

        best_cfg = None
        best_r2 = -np.inf
        best_state, best_mean, best_std = None, None, None
        best_node_dim, best_global_dim, best_dropout = None, None, None

        for ss in split_seeds:
            for ls in learn_seeds:
                for do in dropouts:
                    for bs in batch_sizes:
                        r2, state, y_m, y_s, node_dim, global_dim = train_and_eval(
                            graphs, ss, ls, do, bs, epochs
                        )

                        model_name = f"model_ss{ss}_ls{ls}_do{do}_bs{bs}.pt"
                        save_path = os.path.join(feat_output_dir, model_name)
                        torch.save({
                            'model_state_dict': state, 'node_dim': node_dim, 'global_dim': global_dim,
                            'hidden_dims': [128, 128], 'dropout': do, 'model_type': 'gcn',
                            'y_mean': y_m, 'y_std': y_s
                        }, save_path)

                        if r2 > best_r2:
                            best_r2 = r2
                            best_cfg = (ss, ls, do, bs)
                            best_state = state
                            best_mean = y_m
                            best_std = y_s
                            best_node_dim = node_dim
                            best_global_dim = global_dim
                            best_dropout = do

                        del state
                        torch.cuda.empty_cache()
                        print(f"[Hyperparameter] Completed configuration (ss={ss}, ls={ls}, do={do}, bs={bs}), current best R²: {best_r2:.4f}")

        ss, ls, do, bs = best_cfg
        best_name = f"{feat_name}_best_ss{ss}_ls{ls}_do{do}_bs{bs}.pt"
        best_path = os.path.join(feat_output_dir, best_name)
        torch.save({
            'model_state_dict': best_state, 'node_dim': best_node_dim, 'global_dim': best_global_dim,
            'hidden_dims': [128, 128], 'dropout': best_dropout, 'model_type': 'gcn',
            'y_mean': best_mean, 'y_std': best_std
        }, best_path)

        best_models_dir = os.path.join(output_dir, "best_models")
        os.makedirs(best_models_dir, exist_ok=True)
        shutil.copyfile(best_path, os.path.join(best_models_dir, best_name))
        print(f"[Best Model] Save path: {best_path}")

        graphs_pred = load_graphs(predict_dir)
        best_model = GCNTeacher(
            node_dim=best_node_dim, global_dim=best_global_dim,
            hidden_dims=[128, 128], dropout=best_dropout
        ).to(torch.device('cuda'))
        best_model.load_state_dict(best_state)

        normalized_preds, denormalized_preds = generate_predictions(
            best_model, graphs_pred, best_mean, best_std
        )

        normalized_csv = os.path.join(feat_output_dir, 'predict_normalized_labels.csv')
        pd.DataFrame({f"{feat_name}_teacher_normalized": normalized_preds}).to_csv(normalized_csv, index=False)
        
        denormalized_csv = os.path.join(feat_output_dir, 'predict_denormalized_labels.csv')
        pd.DataFrame({f"{feat_name}_teacher_denormalized": denormalized_preds}).to_csv(denormalized_csv, index=False)

        with data_lock:
            all_normalized_labels[feat_name] = normalized_preds
            all_denormalized_labels[feat_name] = denormalized_preds

        del graphs, graphs_pred, best_state, normalized_preds, denormalized_preds
        torch.cuda.empty_cache()

        cost_time = time.time() - start_time
        print(f"\n[Feature Processing] ✅ Completed: {feat_name} | Time consumed: {cost_time:.2f} seconds")
        print(f"[Output Path] Normalized labels: {normalized_csv}")
        print(f"[Output Path] Denormalized labels: {denormalized_csv}")
        print("="*50 + "\n")
        return True, feat_name

    except Exception as e:
        print(f"\n[Feature Processing] ❌ Failed: {str(e)}")
        import traceback
        traceback.print_exc()
        print("="*50 + "\n")
        torch.cuda.empty_cache()
        return False, os.path.basename(feat_dir)

def main(base_dir: str, predict_dir: str, output_dir: str, epochs=200):
    os.makedirs(output_dir, exist_ok=True)
    print("="*60)
    print(f"[Main Program] Started (Pure Single-threaded Mode)")
    print(f"[Path Configuration] Feature base directory: {base_dir}")
    print(f"[Path Configuration] Prediction data directory: {predict_dir}")
    print(f"[Path Configuration] Output directory: {output_dir}")
    print(f"[Training Configuration] Total epochs: {epochs} | Device: GPU (Forced)")
    print("="*60 + "\n")

    feat_dirs = []
    for d in os.listdir(base_dir):
        dir_path = os.path.join(base_dir, d)
        if os.path.isdir(dir_path):
            feat_dirs.append(dir_path)
    num_features = len(feat_dirs)
    print(f"[Data Statistics] Detected {num_features} feature directories, will process them sequentially in single thread\n")

    all_normalized_labels = defaultdict(list)
    all_denormalized_labels = defaultdict(list)

    split_seeds = [0, 1, 2]
    learn_seeds = [0, 1, 2]
    dropouts = [0.1, 0.2, 0.3]
    batch_sizes = [32, 64]

    successful_features = []
    for idx, feat_dir in enumerate(feat_dirs, 1):
        print(f"[Main Program] Starting to process the {idx}/{num_features} feature directory")
        success, name = process_feature_directory(
            feat_dir, predict_dir, output_dir, epochs,
            split_seeds, learn_seeds, dropouts, batch_sizes,
            all_normalized_labels, all_denormalized_labels
        )
        if success:
            successful_features.append(name)

    normalized_df = pd.DataFrame(all_normalized_labels)
    all_normalized_path = os.path.join(output_dir, "all_normalized_labels.csv")
    normalized_df.to_csv(all_normalized_path, index=False)

    denormalized_df = pd.DataFrame(all_denormalized_labels)
    all_denormalized_path = os.path.join(output_dir, "all_denormalized_labels.csv")
    denormalized_df.to_csv(all_denormalized_path, index=False)

    print("\n" + "="*60)
    print(f"[Main Program] All processing completed!")
    print(f"[Result Statistics] Successfully processed: {len(successful_features)}/{num_features} feature directories")
    print(f"[Summary Output] All normalized labels: {all_normalized_path}")
    print(f"[Summary Output] All denormalized labels: {all_denormalized_path}")
    print("="*60)

    return all_normalized_path, all_denormalized_path

if __name__ == "__main__":
    if not torch.cuda.is_available():
        raise RuntimeError("❌ No available GPU detected! Please install NVIDIA driver and matching PyTorch (e.g., cu117 version).")
    
    # Input/Output Path Configuration (modify according to actual needs)
    base_dir = "PATH/TO/FEATURE_BASE_DIRECTORY"
    predict_dir = "PATH/TO/PREDICTION_DATA_DIRECTORY"
    output_dir = "PATH/TO/OUTPUT_DIRECTORY"
    epochs = 500

    try:
        normalized_path, denormalized_path = main(base_dir, predict_dir, output_dir, epochs)

        print("\n[Result Analysis] Loading summary data...")
        if os.path.exists(normalized_path) and os.path.getsize(normalized_path) > 0:
            normalized_df = pd.read_csv(normalized_path)
            print("\n[Result Analysis] Normalized labels statistics:")
            print(normalized_df.describe().round(4))
        else:
            print("\n[Result Analysis] Normalized labels file does not exist or is empty")

        if os.path.exists(denormalized_path) and os.path.getsize(denormalized_path) > 0:
            denormalized_df = pd.read_csv(denormalized_path)
            print("\n[Result Analysis] Denormalized labels statistics:")
            print(denormalized_df.describe().round(4))
        else:
            print("\n[Result Analysis] Denormalized labels file does not exist or is empty")
            
    except Exception as e:
        print(f"\n[Runtime Error] Program terminated abnormally: {str(e)}")
        import traceback
        traceback.print_exc()
