# 1. Active_learning_part

## Main_code

In [None]:
#!/usr/bin/env python
import os
import sys
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import subprocess
import shutil
import random
import copy
from datetime import datetime
from tqdm import tqdm
from torch.nn.utils.parametrizations import weight_norm

# ============ Configuration Parameters ============
# Base paths
BASE_DIR = 'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0'
POOL_DATA_DIR = os.path.join(BASE_DIR, 'Pool_data')
TOTAL_DATA_DIR = os.path.join(BASE_DIR, 'Total_data') #Test_data
WORK_PATH = BASE_DIR

# Active Learning Parameters
INITIAL_TRAIN_SIZE = 1  # q - Initial training set size
ENSEMBLE_SIZE = 4       # m - Number of ensemble models
TOP_K = 1               # Number of new points selected per round
N_ITERATIONS = 100      # Active learning iterations
TRAINING_FRAMES = 500   # N2 - Number of frames for training
RANDOM_SEED = 42        # --- NEW ---: Add a global random seed to ensure experimental reproducibility

# TCN Model Parameters
INPUT_LEN = 50
PRED_LEN = 150
STRIDE = 10
BATCH_SIZE = 32
EPOCHS = 200
LEARNING_RATE = 1e-3
# --- MODIFICATION ---: Define the range for Dropout rate
DROPOUT_BASE = 0.2
DROPOUT_RANGE = 0.15 # Each model's dropout will be randomly selected from [0.2-0.15, 0.2+0.15], i.e., [0.05, 0.35]

# Test set paths
TEST_PATHS = [
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_2atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_3atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_4atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_5atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_6atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_2atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_3atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_4atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_5atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_6atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_2atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_3atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_4atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_5atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_6atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_2atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_3atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_4atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_5atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_6atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_2atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_3atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_4atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_5atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_6atm_1per22.5',
]

# --- NEW ---: Set global random seed
def set_seed(seed):
    """Sets the random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ============ TCN Model Definition (accepts dropout parameter) ============
class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(
            nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride,
                      padding=padding, dilation=dilation)
        )
        self.relu1 = nn.ReLU()
        # --- MODIFICATION ---: dropout is now a configurable parameter
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = weight_norm(
            nn.Conv1d(out_ch, out_ch, kernel_size, stride=stride,
                      padding=padding, dilation=dilation)
        )
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.net = nn.Sequential(self.conv1, self.relu1, self.dropout1,
                                 self.conv2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        if out.size(2) != res.size(2):
            out = out[:, :, :res.size(2)]
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    # --- MODIFICATION ---: Accept dropout parameter
    def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation = 2 ** i
            in_ch = num_inputs if i == 0 else num_channels[i-1]
            out_ch = num_channels[i]
            padding = (kernel_size-1) * dilation
            layers += [TemporalBlock(in_ch, out_ch, kernel_size, stride=1,
                                     dilation=dilation, padding=padding,
                                     # --- MODIFICATION ---: Pass dropout to TemporalBlock
                                     dropout=dropout)]
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

class TCNForecast(nn.Module):
    # --- MODIFICATION ---: Accept dropout parameter
    def __init__(self, input_dim, num_channels, kernel_size, pred_len, dropout=0.2):
        super(TCNForecast, self).__init__()
        # --- MODIFICATION ---: Pass dropout to TemporalConvNet
        self.tcn = TemporalConvNet(input_dim, num_channels, kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], input_dim * pred_len)
        self.input_dim = input_dim
        self.pred_len = pred_len
    
    def forward(self, x):
        x = x.transpose(1, 2)  # to (batch, input_dim, seq_len)
        y = self.tcn(x)        # (batch, hidden, seq_len)
        out = y[:, :, -1]      # last time step
        pred = self.linear(out)  # (batch, input_dim*pred_len)
        return pred.view(-1, self.pred_len, self.input_dim)

# ============ Data Processing Tools ============
class TimeSeriesDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X).float()
        self.Y = torch.from_numpy(Y).float()
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

# ============ Active Learning Core Functions ============
class ActiveLearningFramework:
    def __init__(self):
        self.pool_info = None
        self.train_indices = []
        self.test_dirs = TEST_PATHS
        self.ensemble_models = []
        self.num_species = None
        self.iteration = 0
        
        # Create results directory
        self.results_dir = os.path.join(BASE_DIR, f'AL_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
        os.makedirs(self.results_dir, exist_ok=True)
        
        # Load pool information
        self.load_pool_info()
    
    def load_pool_info(self):
        """Load parameter pool information"""
        pool_info_file = os.path.join(BASE_DIR, 'pool_info.json')
        if not os.path.exists(pool_info_file):
            raise FileNotFoundError(f"Pool info file not found: {pool_info_file}")
        
        with open(pool_info_file, 'r') as f:
            self.pool_info = json.load(f)
        
        print(f"Loaded {len(self.pool_info)} parameter point infos")
    
    def select_initial_training_set(self):
        """Step 3: Randomly select initial training set"""
        print("\nStep 3: Selecting initial training set...")
        
        # Select only completed points
        completed_indices = [
            i for i, info in enumerate(self.pool_info)
            if info['status'] == 'completed'
        ]
        
        if len(completed_indices) < INITIAL_TRAIN_SIZE:
            raise ValueError(f"Number of completed simulation points ({len(completed_indices)}) is less than initial train size ({INITIAL_TRAIN_SIZE})")
        
        # --- MODIFICATION ---: Use the global random seed to select initial points, ensuring reproducibility of this part
        # random.seed(88) # Replaced with global seed management
        self.train_indices = random.sample(completed_indices, INITIAL_TRAIN_SIZE)
        
        print(f"Selected initial training set indices: {self.train_indices}")
        
        # Record selected points
        train_info = [self.pool_info[i] for i in self.train_indices]
        with open(os.path.join(self.results_dir, 'initial_train_set.json'), 'w') as f:
            json.dump(train_info, f, indent=2)
    
    def check_simulation_complete(self, sim_dir):
        """Check if simulation is complete"""
        log_file = os.path.join(sim_dir, "log.lammps")
        if not os.path.exists(log_file):
            return False
        
        with open(log_file, 'r') as f:
            lines = f.readlines()
            if lines:
                # Check the last few lines
                for line in lines[-5:]:
                    if "Total wall time:" in line:
                        return True
        return False
    
    def run_lammps_simulation(self, sim_dir):
        """Run LAMMPS simulation (for extended simulation)"""
        in_path = os.path.join(sim_dir, "in.MoO3S")
        
        # Build command
        cmd = (
            "module load lammps/20230328-intel-2021.4.0-omp && "
            f"mpirun -np 48 lmp -in {in_path}"
        )
        
        print(f"  Running simulation: {os.path.basename(sim_dir)}")
        
        try:
            # Run simulation
            proc = subprocess.Popen(
                cmd,
                cwd=sim_dir,
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True
            )
            
            # Monitor output
            count = 0
            done = False
            for line in proc.stdout:
                count += 1
                if count % 100 == 0:
                    print(".", end="", flush=True)
                if "Total wall time:" in line:
                    print(f"\n  Completed: {line.strip()}")
                    done = True
                    break
            
            proc.wait()
            
            if not done:
                print(f"\n  Warning: {sim_dir} simulation may not have completed normally")
                return False
                
            # Run post-processing script
            proc_script = os.path.join(WORK_PATH, "lammps_output_process.py")
            if os.path.exists(proc_script):
                ret = subprocess.run(
                    ["python", proc_script, sim_dir],
                    capture_output=True,
                    text=True
                )
                if ret.returncode != 0:
                    print(f"  Post-processing failed: {ret.stderr}")
                    return False
                else:
                    print(f"  Post-processing complete")
            else:
                print(f"  Warning: Post-processing script {proc_script} not found")
            
            return True
            
        except Exception as e:
            print(f"  Simulation error: {e}")
            return False
    
    def run_extended_simulation(self, indices):
        """Step 4: Run extended simulations for selected points (if needed)"""
        print("\nStep 4: Checking and running extended simulations...")
        
        # Initial simulation frames (value set from initialize_pool.py)
        INITIAL_FRAMES = 500
        
        for idx in indices:
            info = self.pool_info[idx]
            sim_dir = info['sim_dir']
            
            # Check if extended simulation is needed
            if TRAINING_FRAMES > INITIAL_FRAMES:
                print(f"Extended simulation needed: {os.path.basename(sim_dir)} (extending from {INITIAL_FRAMES} to {TRAINING_FRAMES} frames)")
                
                # Get parameters
                T = info['T']
                P = info['P']
                ratio = info['ratio']
                
                # Modify the run steps in the in.MoO3S file
                in_path = os.path.join(sim_dir, "in.MoO3S")
                
                # Calculate new total steps
                steps_per_frame = 5000
                new_run_steps = steps_per_frame * TRAINING_FRAMES
                
                # Read and modify the file
                new_lines = []
                with open(in_path, 'r') as f:
                    for line in f:
                        if line.strip().startswith("run"):
                            # Update run steps
                            line = f"run {new_run_steps}\n"
                        new_lines.append(line)
                
                # Write back to file
                with open(in_path, 'w') as f:
                    f.writelines(new_lines)
                
                print(f"  Run steps updated to: {new_run_steps}")
                
                # Backup old log file (if it exists)
                log_file = os.path.join(sim_dir, "log.lammps")
                if os.path.exists(log_file):
                    backup_log = os.path.join(sim_dir, f"log.lammps.backup_{INITIAL_FRAMES}frames")
                    shutil.move(log_file, backup_log)
                    print(f"  Backed up original log file")
                
                # Run extended simulation
                success = self.run_lammps_simulation(sim_dir)
                
                if success:
                    print(f"  Extended simulation completed successfully")
                    info['extended_frames'] = TRAINING_FRAMES
                else:
                    print(f"  Extended simulation failed")
                    info['status'] = 'extension_failed'
                    
            else:
                print(f"No extension needed: {os.path.basename(sim_dir)} (current frames meet requirement)")
    
    def find_common_species(self):
        """Find common species (true subset) across all data"""
        print("\nFinding common species...")
        
        # Prepare all paths (training set + test set)
        all_paths = []
        
        # Add training set paths
        for idx in self.train_indices:
            sim_dir = self.pool_info[idx]['sim_dir']
            list_file = os.path.join(sim_dir, 'species_list_initial.txt')
            all_paths.append(list_file)
        
        # Add test set paths
        for test_dir in self.test_dirs:
            list_file = os.path.join(test_dir, 'species_list_initial.txt')
            all_paths.append(list_file)
        
        # Call found_species_subsets.py script
        subset_script = os.path.join(WORK_PATH, 'found_species_subsets.py')
        if os.path.exists(subset_script):
            cmd = ['python', subset_script] + all_paths
            ret = subprocess.run(cmd, capture_output=True, text=True)
            if ret.returncode != 0:
                print(f"Warning: True subset finding failed: {ret.stderr}")
            else:
                print("True subset finding complete")
        else:
            print(f"Warning: found_species_subsets.py script not found")
        
        # Get number of species
        first_train_dir = self.pool_info[self.train_indices[0]]['sim_dir']
        species_list_file = os.path.join(first_train_dir, 'species_list.txt')
        
        if os.path.exists(species_list_file):
            with open(species_list_file, 'r') as f:
                species_list = [line.strip() for line in f if line.strip()]
                self.num_species = len(species_list)
                print(f"Number of common species: {self.num_species}")
        else:
            raise FileNotFoundError(f"Species list file not found: {species_list_file}")
    
    def load_data(self, dir_list, max_frames=TRAINING_FRAMES):
        """Load data"""
        mats = []
        species_lists = []
        
        for d in dir_list:
            mat_path = os.path.join(d, 'species_time_matrix.npy')
            list_path = os.path.join(d, 'species_list.txt')
            
            if not os.path.exists(mat_path) or not os.path.exists(list_path):
                # Try using initial files
                mat_path = os.path.join(d, 'species_time_matrix_initial.npy')
                list_path = os.path.join(d, 'species_list_initial.txt')
                
                if not os.path.exists(mat_path) or not os.path.exists(list_path):
                    print(f"Warning: Skipping directory {d}, files not found")
                    continue
            
            mat = np.load(mat_path)
            mats.append(mat[:, :max_frames])
            
            with open(list_path, 'r') as f:
                species_lists.append([line.strip() for line in f if line.strip()])
        
        return mats, species_lists
    
    def create_windows(self, matrices):
        """Create sliding window data"""
        X_list, Y_list = [], []
        
        for mat in matrices:
            num_species, T = mat.shape
            for start in range(0, T - INPUT_LEN - PRED_LEN + 1, STRIDE):
                x = mat[:, start: start + INPUT_LEN]
                y = mat[:, start + INPUT_LEN: start + INPUT_LEN + PRED_LEN]
                X_list.append(x.T)  # (seq_len, num_species)
                Y_list.append(y.T)  # (pred_len, num_species)
        
        if X_list:
            return np.stack(X_list), np.stack(Y_list)
        else:
            return np.array([]), np.array([])
    
    def train_ensemble_models(self):
        """Step 5: Train Deep Ensemble TCN Models (Modified Version)"""
        print("\nStep 5: Training Deep Ensemble TCN Models (Enhanced Diversity)...")
        
        # Prepare training and validation data
        train_dirs = [self.pool_info[i]['sim_dir'] for i in self.train_indices]
        
        # Load data
        train_mats, _ = self.load_data(train_dirs)
        val_mats, _ = self.load_data(self.test_dirs) # Use the fixed test set as the validation set
        
        # Create window data
        X_train_full, Y_train_full = self.create_windows(train_mats)
        X_val, Y_val = self.create_windows(val_mats)
        
        if len(X_train_full) == 0:
            raise ValueError("Training data is empty")
        
        print(f"Full training data shape: X={X_train_full.shape}, Y={Y_train_full.shape}")
        print(f"Validation data shape: X={X_val.shape}, Y={Y_val.shape}")
        
        # Create validation data loader (shared by all models)
        val_loader = DataLoader(
            TimeSeriesDataset(X_val, Y_val),
            batch_size=BATCH_SIZE
        )
        
        # Train ensemble models
        self.ensemble_models = []
        input_dim = X_train_full.shape[2]
        hidden_channels = [64, 64, 64]
        kernel_size = 3
        
        n_train_samples = X_train_full.shape[0]

        for model_idx in range(ENSEMBLE_SIZE):
            print(f"\nTraining model {model_idx + 1}/{ENSEMBLE_SIZE}")

            # --- MODIFICATION START: Implement Bagging and Hyperparameter Randomization ---
            
            # 1. Bagging: Create a bootstrap sampled dataset for the current model
            #    By sampling with replacement, create a unique training perspective for each model
            bootstrap_indices = np.random.choice(n_train_samples, size=n_train_samples, replace=True)
            X_train_boot = X_train_full[bootstrap_indices]
            Y_train_boot = Y_train_full[bootstrap_indices]

            print(f"  Using Bagging: Created bootstrap training set of size {len(X_train_boot)}")

            # Create a data loader specific to this model
            train_loader = DataLoader(
                TimeSeriesDataset(X_train_boot, Y_train_boot),
                batch_size=BATCH_SIZE,
                shuffle=True # Shuffle the bootstrap set every epoch
            )

            # 2. Hyperparameter Randomization: Select a random Dropout rate for the model
            #    This further increases the diversity among models
            dropout_rate = random.uniform(
                max(0.0, DROPOUT_BASE - DROPOUT_RANGE),
                min(0.5, DROPOUT_BASE + DROPOUT_RANGE)
            )
            print(f"  Using random hyperparameter: Dropout rate = {dropout_rate:.4f}")

            # --- MODIFICATION END ---

            # Create model (using randomized dropout rate)
            model = TCNForecast(input_dim, hidden_channels, kernel_size, PRED_LEN, dropout=dropout_rate)
            
            # If incremental training and model already exists, load the previous model
            # Note: In the AL loop, this usually means continuing training on new data, Bagging is still effective
            if self.iteration > 0 and model_idx < len(self.ensemble_models):
                old_model = self.ensemble_models[model_idx]
                model.load_state_dict(old_model.state_dict())
                print(f"  Loading pre-trained model weights for incremental training")
            
            optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
            criterion = nn.MSELoss()
            early_stopping = EarlyStopping(patience=10)
            
            best_val_loss = float('inf')
            best_state_dict = None
            
            # Training loop
            for epoch in range(1, EPOCHS + 1):
                # Train
                model.train()
                total_loss = 0.0
                for x_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    loss = criterion(model(x_batch), y_batch)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                
                avg_train = total_loss / len(train_loader)
                
                # Validate
                model.eval()
                val_loss = 0.0
                with torch.no_grad():
                    for x_batch, y_batch in val_loader:
                        val_loss += criterion(model(x_batch), y_batch).item()
                
                avg_val = val_loss / len(val_loader)
                
                if epoch % 20 == 0:
                    print(f"  Epoch {epoch}: Train Loss={avg_train:.4f}, Val Loss={avg_val:.4f}")
                
                # Save the best model
                if avg_val < best_val_loss:
                    best_val_loss = avg_val
                    best_state_dict = copy.deepcopy(model.state_dict())
                
                # Early stopping check
                if early_stopping(avg_val):
                    print(f"  Early stopping at epoch {epoch}")
                    break
            
            # Restore the best model
            if best_state_dict is not None:
                model.load_state_dict(best_state_dict)
            
            # Save the model
            model_path = os.path.join(
                self.results_dir,
                f'model_{model_idx}_iter_{self.iteration}.pth'
            )
            torch.save(model.state_dict(), model_path)
            
            self.ensemble_models.append(model)
        
        print(f"\nEnsemble model training complete, total {len(self.ensemble_models)} models")
    
    def calculate_uncertainty(self, data_dir):
        """Calculate uncertainty for a single data point (using rolling prediction)"""
        # Load data
        mat_path = os.path.join(data_dir, 'species_time_matrix_initial.npy')
        list_path = os.path.join(data_dir, 'species_list_initial.txt')
        
        if not os.path.exists(mat_path):
            return float('inf')
        
        # Load matrix
        mat_initial = np.load(mat_path)
        
        # Read species list
        with open(list_path, 'r') as f:
            species_initial = [line.strip() for line in f if line.strip()]
        
        # Get common species list (from the first training set directory)
        first_train_dir = self.pool_info[self.train_indices[0]]['sim_dir']
        common_species_path = os.path.join(first_train_dir, 'species_list.txt')
        
        if not os.path.exists(common_species_path):
            print(f"Warning: Common species list not found: {common_species_path}")
            return float('inf')
        
        with open(common_species_path, 'r') as f:
            common_species = [line.strip() for line in f if line.strip()]
        
        # Create alignment matrix, initialize to 0
        aligned_mat = np.zeros((len(common_species), mat_initial.shape[1]))
        
        # Align existing species
        for i, sp in enumerate(common_species):
            if sp in species_initial:
                # If species exists, copy the corresponding data
                idx = species_initial.index(sp)
                aligned_mat[i, :] = mat_initial[idx, :]
            # If species does not exist, the corresponding row remains 0 (already initialized to 0)
        
        # Check data length
        T = aligned_mat.shape[1]
        if T < INPUT_LEN:
            print(f"Warning: Data length insufficient ({T} < {INPUT_LEN})")
            return float('inf')
        
        # Calculate total prediction steps needed
        total_pred_steps = TRAINING_FRAMES - INPUT_LEN  # 500 - 50 = 450
        
        try:
            # Use the starting 50 frames as initial input
            initial_input = aligned_mat[:, :INPUT_LEN].T  # (INPUT_LEN, num_species)
            
            # Perform rolling prediction for each model
            all_predictions = []
            
            for model in self.ensemble_models:
                model.eval()
                
                # Initialize prediction sequence
                predicted_sequence = []
                current_input = initial_input.copy()
                
                # Rolling prediction
                steps_predicted = 0
                while steps_predicted < total_pred_steps:
                    # Prepare input tensor
                    X_tensor = torch.from_numpy(current_input[np.newaxis, :, :]).float()
                    
                    # Predict next PRED_LEN steps
                    with torch.no_grad():
                        pred = model(X_tensor).numpy()[0]  # (PRED_LEN, num_species)
                    
                    # Determine the number of steps to actually use this time
                    steps_to_use = min(PRED_LEN, total_pred_steps - steps_predicted)
                    predicted_sequence.append(pred[:steps_to_use])
                    steps_predicted += steps_to_use
                    
                    # If further prediction is needed, update the input
                    if steps_predicted < total_pred_steps:
                        # Simplified version: Build the next input entirely based on past predictions
                        # This is a purer auto-regressive prediction, which can better expose model uncertainty
                        current_input = np.vstack(predicted_sequence)[-INPUT_LEN:]
            
                # Concatenate all predictions
                full_prediction = np.vstack(predicted_sequence)  # (total_pred_steps, num_species)
                all_predictions.append(full_prediction)
            
            # Calculate variance of all model predictions
            all_predictions = np.stack(all_predictions)  # (num_models, total_pred_steps, num_species)
            
            # Calculate variance for each time step and species, then take the mean
            variance = np.var(all_predictions, axis=0)  # (total_pred_steps, num_species)
            uncertainty = np.mean(variance)
            
            return uncertainty
            
        except Exception as e:
            print(f"Error during uncertainty calculation: {e}")
            return float('inf')
    
    def select_uncertain_points(self):
        """Step 6: Select new training points based on uncertainty"""
        print("\nStep 6: Calculating uncertainty and selecting new points...")
        print(f"Calculating uncertainty using rolling prediction (predicting {TRAINING_FRAMES - INPUT_LEN} steps)")
        
        # Get points from the pool that have not been selected
        available_indices = [
            i for i in range(len(self.pool_info))
            if i not in self.train_indices and self.pool_info[i]['status'] == 'completed'
        ]
        
        print(f"Number of available points: {len(available_indices)}")
        
        if len(available_indices) == 0:
            print("Warning: No available points for selection")
            return []
        
        # Calculate uncertainty for each point
        uncertainties = []
        for idx in tqdm(available_indices, desc="Calculating uncertainty"):
            sim_dir = self.pool_info[idx]['sim_dir']
            unc = self.calculate_uncertainty(sim_dir)
            if unc != float('inf'):  # Only record valid uncertainties
                uncertainties.append((idx, unc))
        
        if not uncertainties:
            print("Warning: No valid uncertainties were calculated")
            return []
        
        # Sort and select TOP_K
        uncertainties.sort(key=lambda x: x[1], reverse=True)
        
        # Ensure not to exceed the number of available points
        n_select = min(TOP_K, len(uncertainties))
        selected_indices = [idx for idx, _ in uncertainties[:n_select]]
        
        print(f"Selected new point indices: {selected_indices}")
        print(f"Corresponding uncertainties: {[unc for _, unc in uncertainties[:n_select]]}")
        
        return selected_indices
    
    def evaluate_ensemble(self):
        """Evaluate ensemble model performance"""
        print("\nEvaluating ensemble model...")
        
        # Load test data
        test_mats, _ = self.load_data(self.test_dirs)
        X_test, Y_test = self.create_windows(test_mats)
        
        if len(X_test) == 0:
            print("Test data is empty, skipping evaluation")
            return None
        
        test_loader = DataLoader(
            TimeSeriesDataset(X_test, Y_test),
            batch_size=BATCH_SIZE
        )
        
        # Ensemble prediction
        criterion = nn.MSELoss()
        total_loss = 0.0
        
        with torch.no_grad():
            for x_batch, y_batch in test_loader:
                # Get predictions from all models
                predictions = []
                for model in self.ensemble_models:
                    model.eval()
                    pred = model(x_batch)
                    predictions.append(pred)
                
                # Average predictions
                ensemble_pred = torch.stack(predictions).mean(dim=0)
                
                # Calculate loss
                loss = criterion(ensemble_pred, y_batch)
                total_loss += loss.item()
        
        avg_loss = total_loss / len(test_loader)
        print(f"Ensemble model test set MSE: {avg_loss:.6f}")
        
        # Save evaluation results
        eval_result = {
            'iteration': self.iteration,
            'test_mse': avg_loss,
            'train_size': len(self.train_indices),
            'ensemble_size': len(self.ensemble_models)
        }
        
        eval_file = os.path.join(self.results_dir, f'evaluation_iter_{self.iteration}.json')
        with open(eval_file, 'w') as f:
            json.dump(eval_result, f, indent=2)
        
        return avg_loss
    
    def run(self):
        """Run the main active learning loop"""
        print("="*60)
        print("Starting Active Learning Process")
        print("="*60)
        
        # Step 3: Select initial training set
        self.select_initial_training_set()
        
        # Active learning iterations
        for iter_idx in range(N_ITERATIONS):
            self.iteration = iter_idx
            print(f"\n{'='*60}")
            print(f"Iteration {iter_idx + 1}/{N_ITERATIONS}")
            print(f"{'='*60}")
            
            # Step 4: Run extended simulations (if needed)
            self.run_extended_simulation(self.train_indices)
            
            # Find common species
            self.find_common_species()
            
            # Step 5: Train ensemble models
            self.train_ensemble_models()
            
            # Evaluate models
            test_loss = self.evaluate_ensemble()
            
            # If not the last iteration, select new points
            if iter_idx < N_ITERATIONS - 1:
                # Step 6: Select new uncertain points
                new_indices = self.select_uncertain_points()
                
                if new_indices:  # Only add if new points were selected
                    # Add new points to the training set
                    self.train_indices.extend(new_indices)
                    print(f"Training set size: {len(self.train_indices)}")
                else:
                    print("Warning: No new points were selected, continuing to next iteration")
                
                # Save current state
                state = {
                    'iteration': iter_idx,
                    'train_indices': self.train_indices,
                    'test_loss': test_loss,
                    'timestamp': datetime.now().isoformat()
                }
                
                state_file = os.path.join(self.results_dir, f'state_iter_{iter_idx}.json')
                with open(state_file, 'w') as f:
                    json.dump(state, f, indent=2)
        
        print("\n" + "="*60)
        print("Active learning complete!")
        print(f"Final training set size: {len(self.train_indices)}")
        print(f"Results saved in: {self.results_dir}")
        print("="*60)

def main():
    """Main function"""
    try:
        # --- NEW ---: Set the global random seed at the start of the program
        set_seed(RANDOM_SEED)

        # Create an instance of the ActiveLearningFramework
        framework = ActiveLearningFramework()
        
        # Run active learning
        framework.run()
        
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

### Continue running AL

In [None]:
# ===== Resume Function (Jupyter friendly, no main required) =====
import os, re, glob, json
from datetime import datetime

def _safe_int(x, default=-1):
    try:
        return int(x)
    except Exception:
        return default

def _load_train_indices_from_results(self, prev_results_dir: str):
    """
    Read the selected training set indices and the last iteration number from the previous results directory.
    Prioritizes reading state_iter_*.json; falls back to initial_train_set.json if not found.
    """
    if not os.path.isdir(prev_results_dir):
        raise FileNotFoundError(f"Resume directory not found: {prev_results_dir}")

    # 1) Priority: Read the latest state_iter_*.json
    state_files = glob.glob(os.path.join(prev_results_dir, 'state_iter_*.json'))
    if state_files:
        def iternum(p):
            m = re.search(r'state_iter_(\d+)\.json', os.path.basename(p))
            return _safe_int(m.group(1) if m else -1, -1)
        latest_state = max(state_files, key=iternum)
        with open(latest_state, 'r') as f:
            state = json.load(f)
        train_indices = state.get('train_indices', [])
        last_iter = state.get('iteration', -1)
        if not train_indices:
            raise RuntimeError(f"train_indices not found in {latest_state}")
        print(f"Loaded training set from {os.path.basename(latest_state)}, size {len(train_indices)}, last iteration number {last_iter}")
        return train_indices, last_iter

    # 2) Fallback: Only initial set (requires index reverse lookup)
    init_file = os.path.join(prev_results_dir, 'initial_train_set.json')
    if os.path.exists(init_file):
        with open(init_file, 'r') as f:
            init_list = json.load(f)
        # Extract sim_dir from items
        init_sim_dirs = set()
        for it in init_list:
            if isinstance(it, dict):
                # Common field sim_dir
                if 'sim_dir' in it:
                    init_sim_dirs.add(os.path.abspath(it['sim_dir']))
                # If your structure is different, you can add other fields for reverse lookup here
        if not init_sim_dirs:
            raise RuntimeError(f"No usable sim_dir field in {init_file}, cannot reverse lookup indices")

        # Reverse lookup indices based on current self.pool_info
        idxs = []
        for i, info in enumerate(self.pool_info):
            sd = os.path.abspath(info.get('sim_dir', ''))
            if sd in init_sim_dirs:
                idxs.append(i)
        if not idxs:
            raise RuntimeError("Cannot reverse lookup any indices from initial_train_set.json, please check if pool_info and results directory correspond")
        print(f"state_iter_*.json not found, falling back to initial set, size {len(idxs)}")
        return idxs, -1

    raise FileNotFoundError(
        f"Neither state_iter_*.json nor initial_train_set.json found in {prev_results_dir}, cannot resume"
    )

def resume(self, prev_results_dir: str, extra_iterations: int = 50):
    """
    Read the selected training set from an existing results directory (prev_results_dir)
    and continue active learning for extra_iterations rounds.
    Resume output is saved in the current instance's self.results_dir (timestamped).
    """
    print("="*60)
    print(f"Starting resume: Continuing from {prev_results_dir} for {extra_iterations} rounds")
    print("="*60)

    # Read last training set indices
    loaded_indices, last_iter_prev = self._load_train_indices_from_results(prev_results_dir)
    self.train_indices = list(loaded_indices)  # Copy
    print(f"Training set size at resume start: {len(self.train_indices)}")

    # Record source
    resume_meta = {
        'resume_from': os.path.abspath(prev_results_dir),
        'loaded_last_iter': last_iter_prev,
        'start_train_size': len(self.train_indices),
        'timestamp': datetime.now().isoformat()
    }
    with open(os.path.join(self.results_dir, 'resume_from.json'), 'w') as f:
        json.dump(resume_meta, f, indent=2)

    # Enter resume loop
    start_offset = (last_iter_prev + 1) if (last_iter_prev is not None and last_iter_prev >= 0) else 0
    for k in range(extra_iterations):
        self.iteration = start_offset + k
        print(f"\n{'='*60}")
        print(f"Resume iteration {k + 1}/{extra_iterations} (Global number {self.iteration})")
        print(f"{'='*60}")

        # Follow original flow:
        self.run_extended_simulation(self.train_indices)
        self.find_common_species()
        self.train_ensemble_models()
        test_loss = self.evaluate_ensemble()

        # Select new points and expand training set
        new_indices = self.select_uncertain_points()
        if new_indices:
            self.train_indices.extend(new_indices)
            print(f"Training set size after resume: {len(self.train_indices)}")
        else:
            print("Note: No new points selected this round (perhaps no available completed points in the pool)")

        # Save state
        state = {
            'iteration': self.iteration,
            'train_indices': self.train_indices,
            'test_loss': test_loss,
            'timestamp': datetime.now().isoformat()
        }
        state_file = os.path.join(self.results_dir, f'state_iter_{self.iteration}.json')
        with open(state_file, 'w') as f:
            json.dump(state, f, indent=2)

    print("\n" + "="*60)
    print("Resume complete!")
    print(f"Final training set size: {len(self.train_indices)}")
    print(f"Resume results saved in: {self.results_dir}")
    print("="*60)

# ---- "Mount" the above two methods onto your existing ActiveLearningFramework class ----
ActiveLearningFramework._load_train_indices_from_results = _load_train_indices_from_results
ActiveLearningFramework.resume = resume

# Note: The following lines assume 'ActiveLearningFramework' is defined in a previous cell (e.g., in Jupyter).
# framework = ActiveLearningFramework()
# framework.resume(prev_results_dir="C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/AL_results_20250819_162332",
#                  extra_iterations=50)

# 2. Random_search_part 

## Main_code

In [None]:
#!/usr/bin/env python
"""
Random Sampling Comparison Test Script
Used for performance comparison against the active learning framework
"""

import os
import sys
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import subprocess
import shutil
import random
import copy
from datetime import datetime
from tqdm import tqdm
import argparse

# ============ Configuration Parameters ============
BASE_DIR = 'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0'
POOL_DATA_DIR = os.path.join(BASE_DIR, 'Pool_data')
WORK_PATH = BASE_DIR

# Model Parameters (consistent with active learning)
TRAINING_FRAMES = 500  # N2 - Number of frames for training
INPUT_LEN = 50
PRED_LEN = 150
STRIDE = 10
BATCH_SIZE = 32
EPOCHS = 200
LEARNING_RATE = 1e-3
ENSEMBLE_SIZE = 10  # Number of ensemble models

AL_test_MSE = 36.638477

# Test set paths
TEST_PATHS = [
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_2atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_3atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_4atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_5atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_6atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_2atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_3atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_4atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_5atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_6atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_2atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_3atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_4atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_5atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_6atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_2atm_1per22.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_3atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_4atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_5atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_6atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_2atm_1per25',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_3atm_1per15',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_4atm_1per17.5',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_5atm_1per20',
    'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_6atm_1per22.5',
]

# ============ TCN Model Definition (consistent with active learning) ============
class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.utils.weight_norm(
            nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride,
                      padding=padding, dilation=dilation)
        )
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = nn.utils.weight_norm(
            nn.Conv1d(out_ch, out_ch, kernel_size, stride=stride,
                      padding=padding, dilation=dilation)
        )
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.net = nn.Sequential(self.conv1, self.relu1, self.dropout1,
                                 self.conv2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        if out.size(2) != res.size(2):
            out = out[:, :, :res.size(2)]
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation = 2 ** i
            in_ch = num_inputs if i == 0 else num_channels[i-1]
            out_ch = num_channels[i]
            padding = (kernel_size-1) * dilation
            layers += [TemporalBlock(in_ch, out_ch, kernel_size, stride=1,
                                     dilation=dilation, padding=padding,
                                     dropout=dropout)]
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

class TCNForecast(nn.Module):
    def __init__(self, input_dim, num_channels, kernel_size, pred_len):
        super(TCNForecast, self).__init__()
        self.tcn = TemporalConvNet(input_dim, num_channels, kernel_size)
        self.linear = nn.Linear(num_channels[-1], input_dim * pred_len)
        self.input_dim = input_dim
        self.pred_len = pred_len
    
    def forward(self, x):
        x = x.transpose(1, 2)
        y = self.tcn(x)
        out = y[:, :, -1]
        pred = self.linear(out)
        return pred.view(-1, self.pred_len, self.input_dim)

class TimeSeriesDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X).float()
        self.Y = torch.from_numpy(Y).float()
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

# ============ Random Sampling Test Class ============
class RandomSamplingTest:
    def __init__(self, train_size=50, num_runs=5, seed=None):
        """
        Args:
            train_size: Training set size
            num_runs: Number of runs (average over multiple random samples)
            seed: Random seed
        """
        self.train_size = train_size
        self.num_runs = num_runs
        self.seed = seed
        self.pool_info = None
        self.test_dirs = TEST_PATHS
        self.ensemble_models = []
        
        # Create results directory
        self.results_dir = os.path.join(
            BASE_DIR, 
            f'Random_Sampling_Results_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
        )
        os.makedirs(self.results_dir, exist_ok=True)
        
        # Load pool information
        self.load_pool_info()
    
    def load_pool_info(self):
        """Load parameter pool information"""
        pool_info_file = os.path.join(BASE_DIR, 'pool_info.json')
        if not os.path.exists(pool_info_file):
            raise FileNotFoundError(f"Pool info file not found: {pool_info_file}")
        
        with open(pool_info_file, 'r') as f:
            self.pool_info = json.load(f)
        
        print(f"Loaded {len(self.pool_info)} parameter point infos")
    
    def check_simulation_frames(self, sim_dir):
        """Check if simulation has reached TRAINING_FRAMES"""
        # Check the frame count of species_time_matrix_initial.npy
        mat_path = os.path.join(sim_dir, 'species_time_matrix_initial.npy')
        if os.path.exists(mat_path):
            mat = np.load(mat_path)
            return mat.shape[1] >= TRAINING_FRAMES
        return False
    
    def run_extended_simulation(self, sim_dir, info):
        """Run extended simulation to reach TRAINING_FRAMES"""
        print(f"  Extending simulation: {os.path.basename(sim_dir)}")
        
        # Modify in.MoO3S file
        in_path = os.path.join(sim_dir, "in.MoO3S")
        
        # Calculate new total steps
        steps_per_frame = 5000
        new_run_steps = steps_per_frame * TRAINING_FRAMES
        
        # Read and modify file
        new_lines = []
        with open(in_path, 'r') as f:
            for line in f:
                if line.strip().startswith("run"):
                    line = f"run {new_run_steps}\n"
                new_lines.append(line)
        
        # Write back to file
        with open(in_path, 'w') as f:
            f.writelines(new_lines)
        
        # Backup old log file
        log_file = os.path.join(sim_dir, "log.lammps")
        if os.path.exists(log_file):
            backup_log = os.path.join(sim_dir, f"log.lammps.backup_random")
            shutil.move(log_file, backup_log)
        
        # Run LAMMPS simulation
        cmd = (
            "module load lammps/20230328-intel-2021.4.0-omp && "
            f"mpirun -np 48 lmp -in {in_path}"
        )
        
        try:
            result = subprocess.run(
                cmd,
                cwd=sim_dir,
                shell=True,
                capture_output=True,
                text=True,
                timeout=7200
            )
            
            # Run post-processing
            proc_script = os.path.join(WORK_PATH, "lammps_output_process.py")
            if os.path.exists(proc_script):
                subprocess.run(["python", proc_script, sim_dir], capture_output=True, text=True)
            
            return True
        except Exception as e:
            print(f"  Extended simulation failed: {e}")
            return False
    
    def random_sample_train_set(self, run_idx):
        """Randomly select training set"""
        print(f"\nRun {run_idx+1}/{self.num_runs}: Randomly selecting training set...")
        
        # Select only completed points
        completed_indices = [
            i for i, info in enumerate(self.pool_info)
            if info['status'] == 'completed'
        ]
        
        if len(completed_indices) < self.train_size:
            raise ValueError(f"Number of completed simulation points ({len(completed_indices)}) is less than training set size ({self.train_size})")
        
        # Set random seed
        if self.seed is not None:
            random.seed(self.seed + run_idx)
        
        # Randomly select
        train_indices = random.sample(completed_indices, self.train_size)
        print(f"Selected training point indices: {train_indices}")
        
        return train_indices
    
    def ensure_training_frames(self, train_indices):
        """Ensure all training points reach TRAINING_FRAMES"""
        print("\nChecking and extending simulations to required frame count...")
        
        for idx in train_indices:
            info = self.pool_info[idx]
            sim_dir = info['sim_dir']
            
            if not self.check_simulation_frames(sim_dir):
                print(f"Extension needed: {os.path.basename(sim_dir)}")
                success = self.run_extended_simulation(sim_dir, info)
                if not success:
                    print(f"Warning: Extension failed - {sim_dir}")
    
    def find_common_species(self, train_indices):
        """Find common species"""
        print("\nFinding common species...")
        
        # Prepare all paths
        all_paths = []
        
        # Training set paths
        for idx in train_indices:
            sim_dir = self.pool_info[idx]['sim_dir']
            list_file = os.path.join(sim_dir, 'species_list_initial.txt')
            all_paths.append(list_file)
        
        # Test set paths
        for test_dir in self.test_dirs:
            list_file = os.path.join(test_dir, 'species_list_initial.txt')
            all_paths.append(list_file)
        
        # Call found_species_subsets.py script
        subset_script = os.path.join(WORK_PATH, 'found_species_subsets.py')
        if os.path.exists(subset_script):
            cmd = ['python', subset_script] + all_paths
            subprocess.run(cmd, capture_output=True, text=True)
            print("True subset finding complete")
    
    def load_data(self, dir_list, max_frames=TRAINING_FRAMES):
        """Load data"""
        mats = []
        species_lists = []
        
        for d in dir_list:
            mat_path = os.path.join(d, 'species_time_matrix.npy')
            list_path = os.path.join(d, 'species_list.txt')
            
            if not os.path.exists(mat_path):
                mat_path = os.path.join(d, 'species_time_matrix_initial.npy')
                list_path = os.path.join(d, 'species_list_initial.txt')
            
            if os.path.exists(mat_path) and os.path.exists(list_path):
                mat = np.load(mat_path)
                mats.append(mat[:, :max_frames])
                with open(list_path, 'r') as f:
                    species_lists.append([line.strip() for line in f if line.strip()])
        
        return mats, species_lists
    
    def create_windows(self, matrices):
        """Create sliding window data"""
        X_list, Y_list = [], []
        
        for mat in matrices:
            num_species, T = mat.shape
            for start in range(0, T - INPUT_LEN - PRED_LEN + 1, STRIDE):
                x = mat[:, start: start + INPUT_LEN]
                y = mat[:, start + INPUT_LEN: start + INPUT_LEN + PRED_LEN]
                X_list.append(x.T)
                Y_list.append(y.T)
        
        if X_list:
            return np.stack(X_list), np.stack(Y_list)
        else:
            return np.array([]), np.array([])
    
    def train_ensemble(self, X_train, Y_train, X_test, Y_test, run_idx):
        """Train ensemble models"""
        print("\nTraining ensemble TCN models...")
        
        train_loader = DataLoader(
            TimeSeriesDataset(X_train, Y_train),
            batch_size=BATCH_SIZE,
            shuffle=True
        )
        test_loader = DataLoader(
            TimeSeriesDataset(X_test, Y_test),
            batch_size=BATCH_SIZE
        )
        
        self.ensemble_models = []
        input_dim = X_train.shape[2]
        hidden_channels = [64, 64, 64]
        kernel_size = 3
        
        for model_idx in range(ENSEMBLE_SIZE):
            print(f"Training model {model_idx + 1}/{ENSEMBLE_SIZE}")
            
            model = TCNForecast(input_dim, hidden_channels, kernel_size, PRED_LEN)
            optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
            criterion = nn.MSELoss()
            early_stopping = EarlyStopping(patience=10)
            
            best_val_loss = float('inf')
            best_state_dict = None
            
            for epoch in range(1, EPOCHS + 1):
                # Train
                model.train()
                total_loss = 0.0
                for x_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    loss = criterion(model(x_batch), y_batch)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                
                avg_train = total_loss / len(train_loader)
                
                # Validate
                model.eval()
                val_loss = 0.0
                with torch.no_grad():
                    for x_batch, y_batch in test_loader:
                        val_loss += criterion(model(x_batch), y_batch).item()
                
                avg_val = val_loss / len(test_loader)
                
                if epoch % 20 == 0:
                    print(f"  Epoch {epoch}: Train Loss={avg_train:.4f}, Val Loss={avg_val:.4f}")
                
                if avg_val < best_val_loss:
                    best_val_loss = avg_val
                    best_state_dict = copy.deepcopy(model.state_dict())
                
                if early_stopping(avg_val):
                    print(f"  Early stopping at epoch {epoch}")
                    break
            
            if best_state_dict is not None:
                model.load_state_dict(best_state_dict)
            
            # Save model
            model_path = os.path.join(
                self.results_dir,
                f'model_run{run_idx}_model{model_idx}.pth'
            )
            torch.save(model.state_dict(), model_path)
            
            self.ensemble_models.append(model)
    
    def evaluate_ensemble(self, X_test, Y_test):
        """Evaluate ensemble model"""
        print("\nEvaluating ensemble model...")
        
        test_loader = DataLoader(
            TimeSeriesDataset(X_test, Y_test),
            batch_size=BATCH_SIZE
        )
        
        criterion = nn.MSELoss()
        total_loss = 0.0
        
        with torch.no_grad():
            for x_batch, y_batch in test_loader:
                predictions = []
                for model in self.ensemble_models:
                    model.eval()
                    pred = model(x_batch)
                    predictions.append(pred)
                
                ensemble_pred = torch.stack(predictions).mean(dim=0)
                loss = criterion(ensemble_pred, y_batch)
                total_loss += loss.item()
        
        avg_loss = total_loss / len(test_loader)
        print(f"Ensemble model test set MSE: {avg_loss:.6f}")
        
        return avg_loss
    
    def run_single_test(self, run_idx):
        """Run a single test"""
        # Randomly select training set
        train_indices = self.random_sample_train_set(run_idx)
        
        # Ensure all points reach TRAINING_FRAMES
        self.ensure_training_frames(train_indices)
        
        # Find common species
        self.find_common_species(train_indices)
        
        # Load data
        train_dirs = [self.pool_info[i]['sim_dir'] for i in train_indices]
        train_mats, _ = self.load_data(train_dirs)
        test_mats, _ = self.load_data(self.test_dirs)
        
        # Create window data
        X_train, Y_train = self.create_windows(train_mats)
        X_test, Y_test = self.create_windows(test_mats)
        
        if len(X_train) == 0 or len(X_test) == 0:
            print("Warning: Data is empty")
            return None, train_indices
        
        print(f"Training data shape: X={X_train.shape}, Y={Y_train.shape}")
        print(f"Test data shape: X={X_test.shape}, Y={Y_test.shape}")
        
        # Train ensemble models
        self.train_ensemble(X_train, Y_train, X_test, Y_test, run_idx)
        
        # Evaluate model
        test_mse = self.evaluate_ensemble(X_test, Y_test)
        
        return test_mse, train_indices
    
    def run(self):
        """Run the full test"""
        print("="*60)
        print(f"Random Sampling Comparison Test")
        print(f"Training set size: {self.train_size}")
        print(f"Number of runs: {self.num_runs}")
        print("="*60)
        
        all_results = []
        
        for run_idx in range(self.num_runs):
            print(f"\n{'='*40}")
            print(f"Run {run_idx+1}/{self.num_runs}")
            print(f"{'='*40}")
            
            test_mse, train_indices = self.run_single_test(run_idx)
            
            if test_mse is not None:
                result = {
                    'run': run_idx + 1,
                    'test_mse': test_mse,
                    'train_indices': train_indices,
                    'train_size': self.train_size
                }
                all_results.append(result)
                
                # Save single run result
                result_file = os.path.join(
                    self.results_dir,
                    f'result_run_{run_idx+1}.json'
                )
                with open(result_file, 'w') as f:
                    json.dump(result, f, indent=2)
        
        # Statistical analysis
        if all_results:
            mse_values = [r['test_mse'] for r in all_results]
            mean_mse = np.mean(mse_values)
            std_mse = np.std(mse_values)
            min_mse = np.min(mse_values)
            max_mse = np.max(mse_values)
            
            summary = {
                'train_size': self.train_size,
                'num_runs': self.num_runs,
                'mean_mse': mean_mse,
                'std_mse': std_mse,
                'min_mse': min_mse,
                'max_mse': max_mse,
                'all_mse_values': mse_values,
                'all_results': all_results
            }
            
            # Save summary
            summary_file = os.path.join(self.results_dir, 'summary.json')
            with open(summary_file, 'w') as f:
                json.dump(summary, f, indent=2)
            
            print("\n" + "="*60)
            print("Test Complete - Statistical Results")
            print("="*60)
            print(f"Mean MSE: {mean_mse:.6f} ± {std_mse:.6f}")
            print(f"Min MSE: {min_mse:.6f}")
            print(f"Max MSE: {max_mse:.6f}")
            print(f"Results saved in: {self.results_dir}")
            
            # Comparison with Active Learning result
            print("\n" + "="*60)
            print("Comparison with Active Learning:")
            print(f"Random Sampling MSE: {mean_mse:.6f} ± {std_mse:.6f} ({self.train_size} points)")
            
            if mean_mse < AL_test_MSE:
                improvement = (AL_test_MSE - mean_mse) / AL_test_MSE * 100
                print(f"Random Sampling is better than Active Learning: {improvement:.2f}% improvement")
            else:
                degradation = (mean_mse - AL_test_MSE) / AL_test_MSE * 100
                print(f"Active Learning is better than Random Sampling: {degradation:.2f}% improvement")

train_size = 100
num_runs = 1
seed = 42
# Create test instance
tester = RandomSamplingTest(
    train_size=train_size,
    num_runs=num_runs,
    seed=seed
)

# Run test
tester.run()

## Visualization_part

In [None]:
#!/usr/bin/env python
"""
Visualization script for parameter space distribution
Visualizes pool points distribution and active learning/random sampling selection
"""

import os
import json
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import argparse
from datetime import datetime

# Base path
BASE_DIR = 'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0'

# Set matplotlib to use a better font for scientific publications
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['font.size'] = 11

def load_pool_info():
    """Load pool information"""
    pool_info_file = os.path.join(BASE_DIR, 'pool_info.json')
    with open(pool_info_file, 'r') as f:
        pool_info = json.load(f)
    return pool_info

def load_active_learning_results(results_dir):
    """Load active learning results including MSE history"""
    train_indices_history = []
    mse_history = []
    
    # Load iteration states
    for i in range(100):  # Assume max 100 iterations
        state_file = os.path.join(results_dir, f'state_iter_{i}.json')
        eval_file = os.path.join(results_dir, f'evaluation_iter_{i}.json')
        
        if os.path.exists(state_file):
            with open(state_file, 'r') as f:
                state = json.load(f)
                train_indices_history.append(state['train_indices'])
        
        if os.path.exists(eval_file):
            with open(eval_file, 'r') as f:
                eval_data = json.load(f)
                mse_history.append({
                    'iteration': eval_data['iteration'],
                    'test_mse': eval_data['test_mse'],
                    'train_size': eval_data['train_size']
                })
    
    # Get final training indices
    final_indices = train_indices_history[-1] if train_indices_history else []
    
    # If no iteration states, try loading initial training set
    if not train_indices_history:
        initial_file = os.path.join(results_dir, 'initial_train_set.json')
        if os.path.exists(initial_file):
            with open(initial_file, 'r') as f:
                initial_set = json.load(f)
                final_indices = [info['index'] for info in initial_set]
    
    return final_indices, train_indices_history, mse_history

def load_random_sampling_results(results_dir):
    """Load random sampling results including MSE statistics"""
    summary_file = os.path.join(results_dir, 'summary.json')
    
    if not os.path.exists(summary_file):
        return [], None
    
    with open(summary_file, 'r') as f:
        summary = json.load(f)
    
    # Get indices from first run
    random_indices = []
    if summary.get('all_results'):
        random_indices = summary['all_results'][0]['train_indices']
    
    return random_indices, summary

def visualize_3d_distribution(pool_info, selected_indices=None, title="Parameter Space Distribution"):
    """3D visualization of parameter space distribution"""
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Extract all points parameters
    all_T = [info['T'] for info in pool_info]
    all_P = [info['P'] for info in pool_info]
    all_ratio = [info['ratio'] for info in pool_info]
    
    # Plot all pool points (gray)
    ax.scatter(all_T, all_P, all_ratio, c='gray', alpha=0.3, s=20, label='Pool points')
    
    # If there are selected points, highlight in red
    if selected_indices:
        selected_T = [pool_info[i]['T'] for i in selected_indices if i < len(pool_info)]
        selected_P = [pool_info[i]['P'] for i in selected_indices if i < len(pool_info)]
        selected_ratio = [pool_info[i]['ratio'] for i in selected_indices if i < len(pool_info)]
        
        ax.scatter(selected_T, selected_P, selected_ratio, 
                  c='red', alpha=1.0, s=100, marker='*', label=f'Selected points (n={len(selected_indices)})')
    
    ax.set_xlabel('Temperature (K)', fontsize=12)
    ax.set_ylabel('Pressure (atm)', fontsize=12)
    ax.set_zlabel('Ratio (Mo3O9/S2)', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.legend()
    
    # Set viewing angle
    ax.view_init(elev=20, azim=45)
    
    plt.tight_layout()
    return fig

def visualize_2d_projections(pool_info, selected_indices=None, title_prefix=""):
    """2D projection visualization"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Extract all points parameters
    all_T = np.array([info['T'] for info in pool_info])
    all_P = np.array([info['P'] for info in pool_info])
    all_ratio = np.array([info['ratio'] for info in pool_info])
    
    # T-P projection
    axes[0].scatter(all_T, all_P, c='gray', alpha=0.3, s=20, label='Pool points')
    if selected_indices:
        selected_T = [pool_info[i]['T'] for i in selected_indices if i < len(pool_info)]
        selected_P = [pool_info[i]['P'] for i in selected_indices if i < len(pool_info)]
        axes[0].scatter(selected_T, selected_P, c='red', alpha=1.0, s=100, 
                       marker='*', label=f'Selected (n={len(selected_indices)})')
    axes[0].set_xlabel('Temperature (K)')
    axes[0].set_ylabel('Pressure (atm)')
    axes[0].set_title(f'{title_prefix} T-P Projection')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # T-Ratio projection
    axes[1].scatter(all_T, all_ratio, c='gray', alpha=0.3, s=20, label='Pool points')
    if selected_indices:
        selected_ratio = [pool_info[i]['ratio'] for i in selected_indices if i < len(pool_info)]
        axes[1].scatter(selected_T, selected_ratio, c='red', alpha=1.0, s=100,
                       marker='*', label=f'Selected (n={len(selected_indices)})')
    axes[1].set_xlabel('Temperature (K)')
    axes[1].set_ylabel('Ratio (Mo3O9/S2)')
    axes[1].set_title(f'{title_prefix} T-Ratio Projection')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # P-Ratio projection
    axes[2].scatter(all_P, all_ratio, c='gray', alpha=0.3, s=20, label='Pool points')
    if selected_indices:
        selected_P = [pool_info[i]['P'] for i in selected_indices if i < len(pool_info)]
        axes[2].scatter(selected_P, selected_ratio, c='red', alpha=1.0, s=100,
                       marker='*', label=f'Selected (n={len(selected_indices)})')
    axes[2].set_xlabel('Pressure (atm)')
    axes[2].set_ylabel('Ratio (Mo3O9/S2)')
    axes[2].set_title(f'{title_prefix} P-Ratio Projection')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def visualize_distribution_statistics(pool_info, selected_indices=None):
    """Statistical distribution visualization"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Extract parameters
    all_T = np.array([info['T'] for info in pool_info])
    all_P = np.array([info['P'] for info in pool_info])
    all_ratio = np.array([info['ratio'] for info in pool_info])
    
    # Temperature distribution histogram
    axes[0, 0].hist(all_T, bins=20, alpha=0.5, color='gray', label='Pool')
    if selected_indices:
        selected_T = [pool_info[i]['T'] for i in selected_indices if i < len(pool_info)]
        axes[0, 0].hist(selected_T, bins=20, alpha=0.7, color='red', label='Selected')
    axes[0, 0].set_xlabel('Temperature (K)')
    axes[0, 0].set_ylabel('Count')
    axes[0, 0].set_title('Temperature Distribution')
    axes[0, 0].legend()
    
    # Pressure distribution histogram
    axes[0, 1].hist(all_P, bins=20, alpha=0.5, color='gray', label='Pool')
    if selected_indices:
        selected_P = [pool_info[i]['P'] for i in selected_indices if i < len(pool_info)]
        axes[0, 1].hist(selected_P, bins=20, alpha=0.7, color='red', label='Selected')
    axes[0, 1].set_xlabel('Pressure (atm)')
    axes[0, 1].set_ylabel('Count')
    axes[0, 1].set_title('Pressure Distribution')
    axes[0, 1].legend()
    
    # Ratio distribution histogram
    axes[0, 2].hist(all_ratio, bins=20, alpha=0.5, color='gray', label='Pool')
    if selected_indices:
        selected_ratio = [pool_info[i]['ratio'] for i in selected_indices if i < len(pool_info)]
        axes[0, 2].hist(selected_ratio, bins=20, alpha=0.7, color='red', label='Selected')
    axes[0, 2].set_xlabel('Ratio (Mo3O9/S2)')
    axes[0, 2].set_ylabel('Count')
    axes[0, 2].set_title('Ratio Distribution')
    axes[0, 2].legend()
    
    # Cumulative distribution functions
    params = [(all_T, 'Temperature'), (all_P, 'Pressure'), (all_ratio, 'Ratio')]
    for idx, (data, name) in enumerate(params):
        sorted_data = np.sort(data)
        cumulative = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
        axes[1, idx].plot(sorted_data, cumulative, 'gray', alpha=0.7, linewidth=2, label='Pool')
        
        if selected_indices:
            if idx == 0:
                selected_data = [pool_info[i]['T'] for i in selected_indices if i < len(pool_info)]
            elif idx == 1:
                selected_data = [pool_info[i]['P'] for i in selected_indices if i < len(pool_info)]
            else:
                selected_data = [pool_info[i]['ratio'] for i in selected_indices if i < len(pool_info)]
            
            if selected_data:
                sorted_selected = np.sort(selected_data)
                cumulative_selected = np.arange(1, len(sorted_selected) + 1) / len(sorted_selected)
                axes[1, idx].plot(sorted_selected, cumulative_selected, 'r-', linewidth=2, label='Selected')
        
        axes[1, idx].set_xlabel(f'{name} {"(K)" if name == "Temperature" else "(atm)" if name == "Pressure" else ""}')
        axes[1, idx].set_ylabel('Cumulative Probability')
        axes[1, idx].set_title(f'{name} Cumulative Distribution')
        axes[1, idx].legend()
        axes[1, idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def visualize_mse_evolution(al_mse_history=None, random_summary=None):
    """Visualize MSE evolution during training"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Active Learning MSE evolution
    if al_mse_history:
        iterations = [d['iteration'] for d in al_mse_history]
        mse_values = [d['test_mse'] for d in al_mse_history]
        train_sizes = [d['train_size'] for d in al_mse_history]
        
        # MSE vs Iteration
        ax1 = axes[0]
        color = 'tab:blue'
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Test MSE', color=color)
        ax1.plot(iterations, mse_values, 'o-', color=color, linewidth=2, markersize=8)
        ax1.tick_params(axis='y', labelcolor=color)
        ax1.grid(True, alpha=0.3)
        
        # Add training size on secondary axis
        ax2 = ax1.twinx()
        color = 'tab:orange'
        ax2.set_ylabel('Training Set Size', color=color)
        ax2.plot(iterations, train_sizes, 's--', color=color, linewidth=1.5, markersize=6)
        ax2.tick_params(axis='y', labelcolor=color)
        
        ax1.set_title('Active Learning: MSE Evolution')
        
        # MSE vs Training Size
        axes[1].plot(train_sizes, mse_values, 'o-', color='darkblue', linewidth=2, markersize=8)
        axes[1].set_xlabel('Training Set Size')
        axes[1].set_ylabel('Test MSE')
        axes[1].set_title('Active Learning: MSE vs Training Size')
        axes[1].grid(True, alpha=0.3)
        
        # Add final MSE annotation
        final_mse = mse_values[-1] if mse_values else 0
        final_size = train_sizes[-1] if train_sizes else 0
        axes[1].annotate(f'Final: MSE={final_mse:.4f}\nn={final_size}',
                        xy=(final_size, final_mse),
                        xytext=(10, 10), textcoords='offset points',
                        bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.7),
                        arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
    
    # Random Sampling MSE statistics
    if random_summary:
        ax = axes[1]
        train_size = random_summary['train_size']
        mean_mse = random_summary['mean_mse']
        std_mse = random_summary['std_mse']
        
        # Add horizontal line for random sampling
        ax.axhline(y=mean_mse, color='green', linestyle='--', linewidth=2, label=f'Random Sampling (n={train_size})')
        ax.fill_between([0, train_size], 
                        [mean_mse - std_mse] * 2, 
                        [mean_mse + std_mse] * 2,
                        color='green', alpha=0.2)
        ax.legend()
    
    plt.tight_layout()
    return fig

def visualize_selection_progression(pool_info, train_indices_history):
    """Visualize how points are selected over iterations"""
    if not train_indices_history:
        return None
    
    n_iterations = len(train_indices_history)
    fig, axes = plt.subplots(2, min(3, n_iterations), figsize=(15, 10))
    
    if n_iterations == 1:
        axes = axes.reshape(-1, 1)
    elif n_iterations == 2:
        axes = np.column_stack([axes[:, 0], axes[:, 1], [None, None]])
    
    # Show first, middle, and last iterations
    iterations_to_show = []
    if n_iterations <= 3:
        iterations_to_show = list(range(n_iterations))
    else:
        iterations_to_show = [0, n_iterations // 2, n_iterations - 1]
    
    for plot_idx, iter_idx in enumerate(iterations_to_show):
        if iter_idx >= len(train_indices_history):
            continue
            
        selected_indices = train_indices_history[iter_idx]
        
        # Extract parameters
        all_T = [info['T'] for info in pool_info]
        all_P = [info['P'] for info in pool_info]
        all_ratio = [info['ratio'] for info in pool_info]
        
        # 3D subplot
        ax = axes[0, plot_idx]
        ax = fig.add_subplot(2, 3, plot_idx + 1, projection='3d')
        ax.scatter(all_T, all_P, all_ratio, c='gray', alpha=0.2, s=10)
        
        if selected_indices:
            sel_T = [pool_info[i]['T'] for i in selected_indices if i < len(pool_info)]
            sel_P = [pool_info[i]['P'] for i in selected_indices if i < len(pool_info)]
            sel_ratio = [pool_info[i]['ratio'] for i in selected_indices if i < len(pool_info)]
            ax.scatter(sel_T, sel_P, sel_ratio, c='red', alpha=1.0, s=50, marker='*')
        
        ax.set_title(f'Iteration {iter_idx + 1} (n={len(selected_indices)})')
        ax.set_xlabel('T (K)', fontsize=9)
        ax.set_ylabel('P (atm)', fontsize=9)
        ax.set_zlabel('Ratio', fontsize=9)
        ax.view_init(elev=20, azim=45)
        
        # 2D T-P projection
        ax2 = axes[1, plot_idx]
        if ax2 is not None:
            ax2.scatter(all_T, all_P, c='gray', alpha=0.2, s=10)
            if selected_indices:
                ax2.scatter(sel_T, sel_P, c='red', alpha=1.0, s=50, marker='*')
            ax2.set_xlabel('Temperature (K)')
            ax2.set_ylabel('Pressure (atm)')
            ax2.set_title(f'T-P Projection (Iter {iter_idx + 1})')
            ax2.grid(True, alpha=0.3)
    
    plt.suptitle('Active Learning Selection Progression', fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig

def compare_sampling_methods(pool_info, al_indices, random_indices):
    """Compare active learning and random sampling distributions"""
    fig = plt.figure(figsize=(15, 5))
    
    # 3D subplot - Active Learning
    ax1 = fig.add_subplot(131, projection='3d')
    all_T = [info['T'] for info in pool_info]
    all_P = [info['P'] for info in pool_info]
    all_ratio = [info['ratio'] for info in pool_info]
    
    ax1.scatter(all_T, all_P, all_ratio, c='gray', alpha=0.2, s=10)
    if al_indices:
        al_T = [pool_info[i]['T'] for i in al_indices if i < len(pool_info)]
        al_P = [pool_info[i]['P'] for i in al_indices if i < len(pool_info)]
        al_ratio = [pool_info[i]['ratio'] for i in al_indices if i < len(pool_info)]
        ax1.scatter(al_T, al_P, al_ratio, c='blue', alpha=1.0, s=100, marker='^')
    ax1.set_title(f'Active Learning (n={len(al_indices)})')
    ax1.set_xlabel('T (K)')
    ax1.set_ylabel('P (atm)')
    ax1.set_zlabel('Ratio')
    
    # 3D subplot - Random Sampling
    ax2 = fig.add_subplot(132, projection='3d')
    ax2.scatter(all_T, all_P, all_ratio, c='gray', alpha=0.2, s=10)
    if random_indices:
        rand_T = [pool_info[i]['T'] for i in random_indices if i < len(pool_info)]
        rand_P = [pool_info[i]['P'] for i in random_indices if i < len(pool_info)]
        rand_ratio = [pool_info[i]['ratio'] for i in random_indices if i < len(pool_info)]
        ax2.scatter(rand_T, rand_P, rand_ratio, c='green', alpha=1.0, s=100, marker='o')
    ax2.set_title(f'Random Sampling (n={len(random_indices)})')
    ax2.set_xlabel('T (K)')
    ax2.set_ylabel('P (atm)')
    ax2.set_zlabel('Ratio')
    
    # Coverage comparison
    ax3 = fig.add_subplot(133)
    
    def calculate_coverage(indices, pool_info, param='T', bins=10):
        if not indices:
            return 0
        all_values = [info[param] for info in pool_info]
        selected_values = [pool_info[i][param] for i in indices if i < len(pool_info)]
        
        hist_all, edges = np.histogram(all_values, bins=bins)
        hist_selected, _ = np.histogram(selected_values, bins=edges)
        
        coverage = np.sum(hist_selected > 0) / np.sum(hist_all > 0)
        return coverage
    
    params = ['T', 'P', 'ratio']
    param_names = ['Temperature', 'Pressure', 'Ratio']
    
    al_coverage = [calculate_coverage(al_indices, pool_info, p) for p in params]
    rand_coverage = [calculate_coverage(random_indices, pool_info, p) for p in params]
    
    x = np.arange(len(params))
    width = 0.35
    
    ax3.bar(x - width/2, al_coverage, width, label='Active Learning', color='blue', alpha=0.7)
    ax3.bar(x + width/2, rand_coverage, width, label='Random Sampling', color='green', alpha=0.7)
    
    ax3.set_xlabel('Parameter')
    ax3.set_ylabel('Coverage')
    ax3.set_title('Parameter Space Coverage Comparison')
    ax3.set_xticks(x)
    ax3.set_xticklabels(param_names)
    ax3.legend()
    ax3.grid(True, alpha=0.3, axis='y')
    
    # Add percentage labels on bars
    for i, (al_cov, rand_cov) in enumerate(zip(al_coverage, rand_coverage)):
        ax3.text(i - width/2, al_cov + 0.01, f'{al_cov:.1%}', ha='center', va='bottom')
        ax3.text(i + width/2, rand_cov + 0.01, f'{rand_cov:.1%}', ha='center', va='bottom')
    
    plt.tight_layout()
    return fig

def visualize_mse_comparison(al_mse_history, random_summary):
    """Direct comparison of Active Learning vs Random Sampling MSE"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Box plot comparison
    ax = axes[0]
    
    if al_mse_history and random_summary:
        # Get final AL MSE
        final_al_mse = al_mse_history[-1]['test_mse'] if al_mse_history else None
        final_al_size = al_mse_history[-1]['train_size'] if al_mse_history else None
        
        # Random sampling MSE values
        random_mse_values = random_summary.get('all_mse_values', [])
        
        if final_al_mse and random_mse_values:
            data_to_plot = [random_mse_values, [final_al_mse]]
            labels = [f'Random Sampling\n(n={random_summary["train_size"]})', 
                     f'Active Learning\n(n={final_al_size})']
            
            bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True)
            bp['boxes'][0].set_facecolor('lightgreen')
            bp['boxes'][1].set_facecolor('lightblue')
            
            ax.set_ylabel('Test MSE')
            ax.set_title('MSE Comparison: Random Sampling vs Active Learning')
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add mean values
            ax.plot(1, np.mean(random_mse_values), 'r*', markersize=15, label='Mean')
            ax.plot(2, final_al_mse, 'r*', markersize=15)
            
    # Bar chart comparison
    ax = axes[1]
    
    if al_mse_history and random_summary:
        categories = ['Active Learning', 'Random Sampling']
        mse_means = [final_al_mse, random_summary['mean_mse']]
        mse_stds = [0, random_summary['std_mse']]  # AL has no std as it's single run
        
        x_pos = np.arange(len(categories))
        bars = ax.bar(x_pos, mse_means, yerr=mse_stds, capsize=10, 
                      color=['blue', 'green'], alpha=0.7)
        
        ax.set_xlabel('Method')
        ax.set_ylabel('Test MSE')
        ax.set_title('Average Test MSE Comparison')
        ax.set_xticks(x_pos)
        ax.set_xticklabels(categories)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add value labels on bars
        for i, (bar, mean_val, std_val) in enumerate(zip(bars, mse_means, mse_stds)):
            height = bar.get_height()
            if std_val > 0:
                label = f'{mean_val:.4f}\n±{std_val:.4f}'
            else:
                label = f'{mean_val:.4f}'
            ax.text(bar.get_x() + bar.get_width()/2., height + std_val,
                   label, ha='center', va='bottom')
        
        # Add improvement percentage
        if random_summary['mean_mse'] > final_al_mse:
            improvement = (random_summary['mean_mse'] - final_al_mse) / random_summary['mean_mse'] * 100
            ax.text(0.5, max(mse_means) * 0.9, 
                   f'Active Learning improves by {improvement:.1f}%',
                   ha='center', fontsize=12, color='darkred', fontweight='bold')
        else:
            degradation = (final_al_mse - random_summary['mean_mse']) / random_summary['mean_mse'] * 100
            ax.text(0.5, max(mse_means) * 0.9,
                   f'Random Sampling improves by {degradation:.1f}%',
                   ha='center', fontsize=12, color='darkgreen', fontweight='bold')
    
    plt.tight_layout()
    return fig

In [None]:
al_dir = 'AL_results_20250819_162332'
random_dir = 'Random_Sampling_Results_20250819_153446'
save_dir = 'visualization_results_2'
# Create save directory
save_dir = os.path.join(BASE_DIR, save_dir)
os.makedirs(save_dir, exist_ok=True)

# Load pool information
pool_info = load_pool_info()
print(f"Loaded {len(pool_info)} pool points")

# Visualize entire pool distribution
fig = visualize_3d_distribution(pool_info, title="Parameter Pool Distribution")
fig.savefig(os.path.join(save_dir, 'pool_distribution_3d.png'), dpi=150)

fig = visualize_2d_projections(pool_info, title_prefix="Pool")
fig.savefig(os.path.join(save_dir, 'pool_distribution_2d.png'), dpi=150)

# Initialize variables for comparison
al_indices = []
al_mse_history = []
train_indices_history = []
random_indices = []
random_summary = None

# If active learning results directory is provided
if al_dir:
    al_results_dir = os.path.join(BASE_DIR, al_dir)
    if os.path.exists(al_results_dir):
        al_indices, train_indices_history, al_mse_history = load_active_learning_results(al_results_dir)
        print(f"Active Learning selected {len(al_indices)} points")
        print(f"Found {len(al_mse_history)} iterations of MSE history")
        
        fig = visualize_3d_distribution(pool_info, al_indices, "Active Learning Sampling Distribution")
        fig.savefig(os.path.join(save_dir, 'active_learning_3d.png'), dpi=150)
        
        fig = visualize_2d_projections(pool_info, al_indices, "Active Learning")
        fig.savefig(os.path.join(save_dir, 'active_learning_2d.png'), dpi=150)
        
        fig = visualize_distribution_statistics(pool_info, al_indices)
        fig.savefig(os.path.join(save_dir, 'active_learning_stats.png'), dpi=150)
        
        if al_mse_history:
            fig = visualize_mse_evolution(al_mse_history=al_mse_history)
            fig.savefig(os.path.join(save_dir, 'active_learning_mse_evolution.png'), dpi=150)
        
        if train_indices_history:
            fig = visualize_selection_progression(pool_info, train_indices_history)
            if fig:
                fig.savefig(os.path.join(save_dir, 'active_learning_progression.png'), dpi=150)

# If random sampling results directory is provided
if random_dir:
    random_results_dir = os.path.join(BASE_DIR, random_dir)
    if os.path.exists(random_results_dir):
        random_indices, random_summary = load_random_sampling_results(random_results_dir)
        
        if random_indices:
            print(f"Random Sampling selected {len(random_indices)} points")
            
            fig = visualize_3d_distribution(pool_info, random_indices, "Random Sampling Distribution")
            fig.savefig(os.path.join(save_dir, 'random_sampling_3d.png'), dpi=150)
            
            fig = visualize_2d_projections(pool_info, random_indices, "Random Sampling")
            fig.savefig(os.path.join(save_dir, 'random_sampling_2d.png'), dpi=150)
            
            fig = visualize_distribution_statistics(pool_info, random_indices)
            fig.savefig(os.path.join(save_dir, 'random_sampling_stats.png'), dpi=150)
        
        if random_summary:
            print(f"Random Sampling: Mean MSE = {random_summary['mean_mse']:.6f} ± {random_summary['std_mse']:.6f}")

# If both AL and random results exist, create comparison plots
if al_dir and random_dir and al_indices and random_indices:
    fig = compare_sampling_methods(pool_info, al_indices, random_indices)
    fig.savefig(os.path.join(save_dir, 'sampling_comparison.png'), dpi=150)
    
    if al_mse_history and random_summary:
        fig = visualize_mse_comparison(al_mse_history, random_summary)
        fig.savefig(os.path.join(save_dir, 'mse_comparison.png'), dpi=150)
        
        # Combined MSE evolution plot
        fig = visualize_mse_evolution(al_mse_history=al_mse_history, 
                                    random_summary=random_summary)
        fig.savefig(os.path.join(save_dir, 'combined_mse_evolution.png'), dpi=150)
    
    print(f"\nComparison plots saved")

print(f"\nAll visualization results saved to: {save_dir}")
    
# Print summary statistics
if al_mse_history:
    print(f"\nActive Learning Final Results:")
    print(f"  Final MSE: {al_mse_history[-1]['test_mse']:.6f}")
    print(f"  Final training size: {al_mse_history[-1]['train_size']}")

if random_summary:
    print(f"\nRandom Sampling Results:")
    print(f"  Mean MSE: {random_summary['mean_mse']:.6f} ± {random_summary['std_mse']:.6f}")
    print(f"  Training size: {random_summary['train_size']}")

plt.show()

# 3.Baseline Model comparison

## 1. Setup & Single TCN Training

In [None]:
# === Cell 1: Setup + Data + Utilities + (Updated Metrics) + Train/Eval TCN ===
import os, json, glob, math, copy, random, warnings
from datetime import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.parametrizations import weight_norm

import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

warnings.filterwarnings("ignore")

# ----------------- Basic Configuration (Modifiable) -----------------
AL_RESULTS_DIR = "./AL_results_20250917_192217"  
ITER_FOR_TRAIN = 83                       
FIG_DPI = 300
TABLE_FIGSIZE = (16, 9)
SAVE_ROOT = "./baseline_eval"

# Consistent with Active Learning
INPUT_LEN   = 50
PRED_LEN    = 150
STRIDE      = 10
BATCH_SIZE  = 32
EPOCHS      = 120
LEARNING_RATE = 1e-3
HUBER_DELTA = 1.0
HIDDEN_CHANNELS_TCN = [64, 64, 64]
KERNEL_SIZE_TCN = 4
VAL_SPLIT = 0.1

# Test Set
TEST_PATHS = [
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_2atm_1per15',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_3atm_1per17.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_4atm_1per20',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_5atm_1per22.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1100K_6atm_1per25',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_2atm_1per17.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_3atm_1per20',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_4atm_1per22.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_5atm_1per25',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1200K_6atm_1per15',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_2atm_1per20',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_3atm_1per22.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_4atm_1per25',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_5atm_1per15',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1300K_6atm_1per17.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_2atm_1per22.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_3atm_1per25',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_4atm_1per15',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_5atm_1per17.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1400K_6atm_1per20',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_2atm_1per25',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_3atm_1per15',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_4atm_1per17.5',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_5atm_1per20',
    r'C:/Users/Administrator/Desktop/My research/2_ML_predict_product/AL_GasKit_V2.0/Total_data/1500K_6atm_1per22.5',
]

# ----------------- Random Seed -----------------
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(42)

# ----------------- Preprocessor -----------------
from scipy.ndimage import uniform_filter1d

class DataPreprocessor:
    def __init__(self, normalization_method='none', smooth_window=1):
        self.normalization_method = normalization_method
        self.smooth_window = smooth_window
        self.scalers = {}
        self.fitted = False
        
    def fit(self, matrices, species_lists):
        all_species_data = {}
        for mat, species_list in zip(matrices, species_lists):
            for i, sp in enumerate(species_list):
                all_species_data.setdefault(sp, []).append(mat[i, :])
        for sp, data_list in all_species_data.items():
            combined = np.concatenate(data_list)
            if self.normalization_method == 'species_wise':
                mean, std = np.mean(combined), np.std(combined); std = max(std, 1e-8)
                self.scalers[sp] = {'mean': mean, 'std': std}
            elif self.normalization_method == 'robust_scaling':
                median = np.median(combined); q75, q25 = np.percentile(combined, [75, 25])
                iqr = max(q75-q25, 1e-8); self.scalers[sp] = {'median': median, 'iqr': iqr}
            elif self.normalization_method == 'min_max':
                mn, mx = np.min(combined), np.max(combined); rng = max(mx-mn, 1e-8)
                self.scalers[sp] = {'min': mn, 'range': rng}
        self.fitted = True

    def transform(self, mat, species_list):
        if self.smooth_window > 1:
            mat = uniform_filter1d(mat, size=self.smooth_window, axis=1)
        if self.normalization_method == 'log_transform':
            return np.log1p(mat)
        if self.normalization_method == 'none':
            return mat
        out = np.zeros_like(mat)
        for i, sp in enumerate(species_list):
            if sp in self.scalers:
                sc = self.scalers[sp]
                if self.normalization_method == 'species_wise':
                    out[i, :] = (mat[i, :] - sc['mean']) / sc['std']
                elif self.normalization_method == 'robust_scaling':
                    out[i, :] = (mat[i, :] - sc['median']) / sc['iqr']
                elif self.normalization_method == 'min_max':
                    out[i, :] = (mat[i, :] - sc['min']) / sc['range']
            else:
                out[i, :] = mat[i, :]
        return out

    def inverse_transform(self, mat, species_list):
        if self.normalization_method == 'log_transform':
            return np.expm1(mat)
        if self.normalization_method == 'none':
            return mat
        out = np.zeros_like(mat)
        for i, sp in enumerate(species_list):
            if sp in self.scalers:
                sc = self.scalers[sp]
                if self.normalization_method == 'species_wise':
                    out[i, :] = mat[i, :] * sc['std'] + sc['mean']
                elif self.normalization_method == 'robust_scaling':
                    out[i, :] = mat[i, :] * sc['iqr'] + sc['median']
                elif self.normalization_method == 'min_max':
                    out[i, :] = mat[i, :] * sc['range'] + sc['min']
            else:
                out[i, :] = mat[i, :]
        return out

# ----------------- Huber Loss (Shared for Training & Evaluation) -----------------
class RobustLoss(nn.Module):
    def __init__(self, delta=1.0):
        super().__init__()
        self.delta = delta
    def forward(self, pred, target):
        diff = torch.abs(pred - target)
        mask = diff < self.delta
        loss = torch.where(mask, 0.5 * diff ** 2, self.delta * (diff - 0.5 * self.delta))
        return loss.mean()

def numpy_huber(y_true, y_pred, delta=1.0):
    diff = np.abs(y_true - y_pred)
    loss = np.where(diff < delta, 0.5 * diff**2, delta * (diff - 0.5*delta))
    return float(np.mean(loss))

# ----------------- Data Alignment & Loading -----------------
def load_species_list(dir_path):
    for fn in ['species_list.txt', 'species_list_initial.txt']:
        p = os.path.join(dir_path, fn)
        if os.path.exists(p):
            with open(p, 'r') as f:
                s = [line.strip() for line in f if line.strip()]
                if s: return s
    return None

def load_matrix(dir_path):
    for fn in ['species_time_matrix.npy', 'species_time_matrix_initial.npy']:
        p = os.path.join(dir_path, fn)
        if os.path.exists(p):
            return np.load(p)
    return None

def align_to_reference(mat_src, species_src, species_ref):
    aligned = np.zeros((len(species_ref), mat_src.shape[1]), dtype=mat_src.dtype)
    name_to_idx = {sp:i for i, sp in enumerate(species_src)}
    for i, sp in enumerate(species_ref):
        if sp in name_to_idx:
            aligned[i, :] = mat_src[name_to_idx[sp], :]
    return aligned

def create_windows(matrices):
    X_list, Y_list = [], []
    for mat in matrices:
        S, T = mat.shape
        for start in range(0, T - INPUT_LEN - PRED_LEN + 1, STRIDE):
            x = mat[:, start: start + INPUT_LEN]
            y = mat[:, start + INPUT_LEN: start + INPUT_LEN + PRED_LEN]
            X_list.append(x.T)  # (L_in, S)
            Y_list.append(y.T)  # (L_pred, S)
    if X_list:
        return np.stack(X_list), np.stack(Y_list)
    return np.array([]), np.array([])

class TimeSeriesDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X).float()
        self.Y = torch.from_numpy(Y).float()
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx): return self.X[idx], self.Y[idx]

# ----------------- Get "First 84 Rounds" Training Indices from AL Results -----------------
def get_train_indices_from_results(results_dir, iter_k=83):
    target = os.path.join(results_dir, f"state_iter_{iter_k}.json")
    if os.path.exists(target):
        with open(target, 'r') as f: st = json.load(f)
        return sorted(list(set(st['train_indices'])))
    cands = sorted(glob.glob(os.path.join(results_dir, "state_iter_*.json")))
    best = -1; best_file = None
    for p in cands:
        try:
            i = int(os.path.splitext(os.path.basename(p))[0].split('_')[-1])
            if i <= iter_k and i > best:
                best, best_file = i, p
        except: pass
    if best_file:
        with open(best_file, 'r') as f: st = json.load(f)
        return sorted(list(set(st['train_indices'])))
    initf = os.path.join(results_dir, "initial_train_set.json")
    if os.path.exists(initf):
        with open(initf, 'r') as f: arr = json.load(f)
        return sorted([d['index'] for d in arr if 'index' in d])
    raise FileNotFoundError("Training indices not found.")

# ----------------- Get pool_info.json and Locate Training Directory -----------------
def infer_base_dir_from_results(results_dir):
    return os.path.abspath(os.path.join(results_dir, os.pardir))

def get_pool_info(base_dir):
    p = os.path.join(base_dir, "pool_info.json")
    if not os.path.exists(p):
        raise FileNotFoundError(f"Not found {p}")
    with open(p, 'r') as f: return json.load(f)

# ----------------- Read Preprocessing Configuration -----------------
def read_preprocess_config(results_dir):
    cfgp = os.path.join(results_dir, "config.json")
    if os.path.exists(cfgp):
        with open(cfgp, 'r') as f: cfg = json.load(f)
        method = cfg.get("PREPROCESSING_METHOD", "none")
        smooth = int(cfg.get("SMOOTH_WINDOW", 1))
        return method, smooth
    return "none", 1

# ----------------- Metrics Function: WAPE & Table -----------------
def wape_percent(y_true, y_pred, eps=1e-8):
    y = np.asarray(y_true).reshape(-1)
    p = np.asarray(y_pred).reshape(-1)
    denom = np.sum(np.abs(y))
    if denom < eps: return 0.0
    return float(100.0 * np.sum(np.abs(y - p)) / denom)

def segment_indices(total_steps, short_end=200, long_end=500):
    # Returns (s0, s1), (l0, l1) segments clipped within [0, total_steps)
    s0, s1 = 0, min(short_end, total_steps)
    l0, l1 = 200, min(long_end, total_steps)
    if l0 >= l1:  # No valid long segment
        l0, l1 = 0, 0
    return (s0, s1), (l0, l1)

def table_figure(df, save_dir, fname="performance_table"):
    os.makedirs(save_dir, exist_ok=True)
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Nimbus Sans L', 'Helvetica']
    plt.rcParams['font.size'] = 10

    fig, ax = plt.subplots(figsize=TABLE_FIGSIZE, dpi=FIG_DPI)
    ax.axis('tight'); ax.axis('off')

    # —— Widen first column: Allocate by axis width ratio (sum ≈ 1)
    ncol = len(df.columns)
    FIRST_COL_W = 0.15  # <— Increase to 0.30/0.35 if wider is needed
    other_w = (1.0 - FIRST_COL_W) / (ncol - 1)
    col_widths = [FIRST_COL_W] + [other_w] * (ncol - 1)

    table = ax.table(
        cellText=df.values,
        colLabels=df.columns,
        cellLoc='center',
        loc='center',
        colWidths=col_widths,   
    )

    table.auto_set_font_size(False); table.set_fontsize(9); table.scale(1.2, 1.5)

    # Allow all cells to wrap text
    for (r, c), cell in table.get_celld().items():
        cell.get_text().set_wrap(True)

    # Left-align the first column (including header), more like a list
    table[(0, 0)].set_text_props(ha='left', color='white', weight='bold')
    for r in range(1, len(df) + 1):  # Data rows start from 1
        table[(r, 0)].set_text_props(ha='left')

    # Header and Stripes
    for i in range(ncol):
        table[(0, i)].set_facecolor('#4CAF50'); table[(0, i)].set_text_props(weight='bold', color='white')
    last = len(df)
    for i in range(ncol):
        table[(last, i)].set_facecolor('#FFF9C4'); table[(last, i)].set_text_props(weight='bold')
    for i in range(1, last):
        for j in range(ncol):
            if i % 2 == 0:
                table[(i, j)].set_facecolor('#F5F5F5')

    plt.tight_layout()
    fig.savefig(os.path.join(save_dir, f"{fname}.png"), dpi=FIG_DPI, bbox_inches='tight')
    plt.show()
    plt.close(fig)


# ----------------- DTW (Numba optimized, fallback to pure Python on failure) -----------------
try:
    from numba import njit
    @njit
    def dtw_distance_numba(a, b):
        n, m = a.shape[0], b.shape[0]
        DTW = np.full((n+1, m+1), np.inf)
        DTW[0,0] = 0.0
        for i in range(1, n+1):
            for j in range(1, m+1):
                cost = abs(a[i-1] - b[j-1])
                dp_min = DTW[i-1, j]
                if DTW[i, j-1] < dp_min:
                    dp_min = DTW[i, j-1]
                if DTW[i-1, j-1] < dp_min:
                    dp_min = DTW[i-1, j-1]
                DTW[i, j] = cost + dp_min
        return DTW[n, m]
    _ = dtw_distance_numba(np.array([0.], dtype=np.float64), np.array([0.], dtype=np.float64))
    DTW_IMPL = "numba"
except Exception as e:
    print(f"[Info] Numba not available, DTW falling back to pure Python implementation: {e}")
    def dtw_distance_numba(a, b):
        n, m = a.shape[0], b.shape[0]
        DTW = np.full((n+1, m+1), np.inf)
        DTW[0,0] = 0.0
        for i in range(1, n+1):
            for j in range(1, m+1):
                cost = abs(a[i-1] - b[j-1])
                DTW[i, j] = cost + min(DTW[i-1, j], DTW[i, j-1], DTW[i-1, j-1])
        return DTW[n, m]
    DTW_IMPL = "python"

def average_dtw_across_species(gt_steps_S, pred_steps_S, max_len=200):
    """
    Input:
      gt_steps_S: (steps, S)
      pred_steps_S: (steps, S)
    Returns:
      Average DTW (Calculate DTW for each species, then average)
    """
    gt_S = gt_steps_S.T   # -> (S, steps)
    pd_S = pred_steps_S.T
    S = gt_S.shape[0]
    vals = []
    for i in range(S):
        a = gt_S[i]; b = pd_S[i]
        step = max(1, a.shape[0] // max_len)
        a_ds = a[::step].astype(np.float64, copy=False)
        b_ds = b[::step].astype(np.float64, copy=False)
        vals.append(dtw_distance_numba(a_ds, b_ds))
    return float(np.mean(vals)) if vals else float('nan')

# ----------------- Rolling Forecast Evaluation (New Summary Metrics) -----------------
def rolling_forecast(model, initial_input, total_steps, device):
    model.eval()
    preds = []
    steps_done = 0
    cur = initial_input.copy()  # (L_in, S)
    while steps_done < total_steps:
        X = torch.from_numpy(cur[np.newaxis, :, :]).float().to(device)
        with torch.no_grad():
            out = model(X).detach().cpu().numpy()[0]  # (PRED_LEN, S)
        use = min(PRED_LEN, total_steps - steps_done)
        preds.append(out[:use])
        steps_done += use
        if steps_done < total_steps:
            sofar = np.vstack(preds)
            if sofar.shape[0] >= INPUT_LEN:
                cur = sofar[-INPUT_LEN:]
            else:
                need = INPUT_LEN - sofar.shape[0]
                cur = np.vstack([initial_input[-need:], sofar])
    return np.vstack(preds)  # (steps, S)

def evaluate_model_on_tests(model, species_ref, test_dirs, preproc, save_dir, device):
    os.makedirs(save_dir, exist_ok=True)
    rows_summary = []
    for td in test_dirs:
        mat = load_matrix(td)
        sp  = load_species_list(td)
        if mat is None or sp is None:
            print(f"[Skip] Matrix/Species file not found: {td}")
            continue
        mat_al = align_to_reference(mat, sp, species_ref)  # (S_ref, T)
        mat_tr = preproc.transform(mat_al, species_ref)

        T = mat_tr.shape[1]
        if T <= INPUT_LEN:
            print(f"[Skip] {os.path.basename(td)} T={T} <= INPUT_LEN")
            continue

        initial_input = mat_tr[:, :INPUT_LEN].T  # (L_in, S)
        total_steps = min(500 - INPUT_LEN, T - INPUT_LEN)  # 450 or smaller
        pred = rolling_forecast(model, initial_input, total_steps, device)   # (steps, S)
        gt   = mat_tr[:, INPUT_LEN:INPUT_LEN+total_steps].T                 # (steps, S)

        # ----- Global Basic Metrics -----
        mse  = mean_squared_error(gt.reshape(-1), pred.reshape(-1))
        rmse = math.sqrt(mse)
        mae  = mean_absolute_error(gt.reshape(-1), pred.reshape(-1))
        r2   = r2_score(gt.reshape(-1), pred.reshape(-1))
        wape = wape_percent(gt, pred)
        huber = numpy_huber(gt, pred, delta=HUBER_DELTA)
        pear = float(np.corrcoef(gt.reshape(-1), pred.reshape(-1))[0,1])

        # ----- DTW (DTW per species, then average) -----
        dtw_avg = average_dtw_across_species(gt, pred, max_len=200)

        # ----- Segmentation (short:0-200, long:200-500; auto-clip to total_steps) -----
        (s0, s1), (l0, l1) = segment_indices(total_steps, short_end=200, long_end=500)

        # short segment
        if s1 > s0:
            gt_s, pd_s = gt[s0:s1], pred[s0:s1]
            short_mse = mean_squared_error(gt_s.reshape(-1), pd_s.reshape(-1))
            short_mae = mean_absolute_error(gt_s.reshape(-1), pd_s.reshape(-1))
            short_wape = wape_percent(gt_s, pd_s)
        else:
            short_mse = short_mae = short_wape = float('nan')

        # long segment
        if l1 > l0:
            gt_l, pd_l = gt[l0:l1], pred[l0:l1]
            long_mse = mean_squared_error(gt_l.reshape(-1), pd_l.reshape(-1))
            long_mae = mean_absolute_error(gt_l.reshape(-1), pd_l.reshape(-1))
            long_wape = wape_percent(gt_l, pd_l)
        else:
            long_mse = long_mae = long_wape = float('nan')

        # Record: in your specified order
        rows_summary.append({
            "Test Case":   os.path.basename(td),
            "Huber":       f"{huber:.3e}",
            "MSE":         f"{mse:.3e}",
            "RMSE":        f"{rmse:.3e}",
            "MAE":         f"{mae:.3e}",
            "WAPE (%)":    f"{wape:.1f}",
            "DTW":         f"{dtw_avg:.3e}",
            "pearson":     f"{pear:.3f}",
            "R2":          f"{r2:.3f}",
            "short_MSE":   f"{short_mse:.3e}" if not np.isnan(short_mse) else "",
            "long_MSE":    f"{long_mse:.3e}"  if not np.isnan(long_mse)  else "",
            "short_MAE":   f"{short_mae:.3e}" if not np.isnan(short_mae) else "",
            "long_MAE":    f"{long_mae:.3e}"  if not np.isnan(long_mae)  else "",
            "short_WAPE":  f"{short_wape:.1f}" if not np.isnan(short_wape) else "",
            "long_WAPE":   f"{long_wape:.1f}"  if not np.isnan(long_wape)  else "",
        })

        # Optional: save predictions and ground truth
        case_tag = os.path.basename(td)
        np.save(os.path.join(save_dir, f"{case_tag}_pred.npy"), pred)
        np.save(os.path.join(save_dir, f"{case_tag}_gt.npy"), gt)

    if not rows_summary:
        raise RuntimeError("Failed to evaluate any test sets.")

    # Create DataFrame (ensure column order matches requirements)
    cols_order = ["Test Case", "Huber",  "short_MSE", "long_MSE", "MSE", "RMSE", "short_MAE", "long_MAE", "MAE", "short_WAPE", "long_WAPE", "WAPE (%)",
                  "DTW", "pearson", "R2"]
    df = pd.DataFrame(rows_summary)[cols_order]

    # Append average row (format by column type)
    avg = {"Test Case": "Average"}
    for col in cols_order[1:]:
        vals=[]
        for v in df[col]:
            try: vals.append(float(str(v)))
            except: pass
        if not vals:
            avg[col] = ""
        elif col in ["WAPE (%)","short_WAPE","long_WAPE"]:
            avg[col] = f"{np.nanmean(vals):.1f}"
        elif col in ["pearson","R2"]:
            avg[col] = f"{np.nanmean(vals):.3f}"
        elif col in ["MSE","RMSE","MAE","Huber","DTW","short_MSE","long_MSE","short_MAE","long_MAE"]:
            avg[col] = f"{np.nanmean(vals):.3e}"
        else:
            avg[col] = f"{np.nanmean(vals):.3f}"
    df = pd.concat([df, pd.DataFrame([avg])], ignore_index=True)

    # Save & Plot
    os.makedirs(save_dir, exist_ok=True)
    df.to_csv(os.path.join(save_dir, "performance_metrics.csv"), index=False)
    table_figure(df, save_dir, fname="performance_table")
    return df

# ----------------- Generic Trainer -----------------
def train_and_evaluate(model, model_name, X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT):
    os.makedirs(save_root, exist_ok=True)
    # Split Train/Validation
    n = X_train.shape[0]
    idx = np.arange(n)
    np.random.shuffle(idx)
    n_val = max(1, int(n * VAL_SPLIT))
    val_idx, tr_idx = idx[:n_val], idx[n_val:]
    ds_tr = TimeSeriesDataset(X_train[tr_idx], Y_train[tr_idx])
    ds_va = TimeSeriesDataset(X_train[val_idx], Y_train[val_idx])
    dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True)
    dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False)

    model = model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = RobustLoss(delta=HUBER_DELTA).to(device)

    best_loss = float("inf"); best_state = None; patience=15; bad=0
    for ep in range(1, EPOCHS+1):
        model.train(); run_loss=0.0
        for xb, yb in dl_tr:
            xb, yb = xb.to(device), yb.to(device)
            optim.zero_grad()
            pred = model(xb)
            loss = criterion(pred, yb)
            loss.backward(); optim.step()
            run_loss += loss.item()
        tr_loss = run_loss/len(dl_tr)

        # Validation
        model.eval(); va_loss=0.0
        with torch.no_grad():
            for xb, yb in dl_va:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                va_loss += criterion(pred, yb).item()
        va_loss /= len(dl_va)

        if ep % 20 == 0 or ep==1:
            print(f"[{model_name}] Epoch {ep:03d} | Train {tr_loss:.4f} | Val {va_loss:.4f}")

        if va_loss + 1e-6 < best_loss:
            best_loss = va_loss; best_state = copy.deepcopy(model.state_dict()); bad=0
        else:
            bad += 1
            if bad >= patience:
                print(f"[{model_name}] Early stopping at epoch {ep}")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    # Evaluation
    save_dir = os.path.join(save_root, model_name)
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, f"{model_name}.pth"))
    print(f"[{model_name}] ✓ Training complete, evaluating on test set (new metrics)...")
    df = evaluate_model_on_tests(model, species_ref, TEST_PATHS, preproc, save_dir, device)
    print(f"[{model_name}] Evaluation complete, results saved in: {save_dir}")
    return model, df

# ----------------- Build TCN -----------------
class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = weight_norm(nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding, dilation=dilation))
        self.relu1 = nn.ReLU(); self.drop1 = nn.Dropout(dropout)
        self.conv2 = weight_norm(nn.Conv1d(out_ch, out_ch, kernel_size, stride=stride, padding=padding, dilation=dilation))
        self.relu2 = nn.ReLU(); self.drop2 = nn.Dropout(dropout)
        self.net = nn.Sequential(self.conv1, self.relu1, self.drop1, self.conv2, self.relu2, self.drop2)
        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None
        self.relu = nn.ReLU()
    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        if out.size(2) != res.size(2):
            out = out[:, :, :res.size(2)]
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.2):
        super().__init__()
        layers = []
        for i in range(len(num_channels)):
            dilation = 2 ** i
            in_ch  = num_inputs if i==0 else num_channels[i-1]
            out_ch = num_channels[i]
            padding = (kernel_size - 1) * dilation
            layers += [TemporalBlock(in_ch, out_ch, kernel_size, 1, dilation, padding, dropout)]
        self.network = nn.Sequential(*layers)
    def forward(self, x):  # x: (B, C, L)
        return self.network(x)

class TCNForecast(nn.Module):
    def __init__(self, input_dim, num_channels, kernel_size, pred_len, dropout=0.2):
        super().__init__()
        self.tcn = TemporalConvNet(input_dim, num_channels, kernel_size, dropout)
        self.linear = nn.Linear(num_channels[-1], input_dim * pred_len)
        self.input_dim = input_dim; self.pred_len = pred_len
    def forward(self, x):   # x: (B, L, C)
        x = x.transpose(1, 2)       # -> (B, C, L)
        y = self.tcn(x)             # -> (B, H, L)
        out = y[:, :, -1]           # -> (B, H)
        pred = self.linear(out)     # -> (B, C*P)
        return pred.view(-1, self.pred_len, self.input_dim)

# ----------------- Construct Train/Test Data (First 84 Rounds) -----------------
print("Preparing train/test data...")
BASE_DIR = infer_base_dir_from_results(AL_RESULTS_DIR)
pool_info = get_pool_info(BASE_DIR)
train_indices = get_train_indices_from_results(AL_RESULTS_DIR, ITER_FOR_TRAIN)
print(f"Number of training points in first 84 rounds: {len(train_indices)}")

# Reference species
first_train_dir = pool_info[train_indices[0]]['sim_dir']
species_ref = load_species_list(first_train_dir)
if species_ref is None:
    raise FileNotFoundError("Reference species list not found.")
print(f"Number of reference species: {len(species_ref)}")

# Read preprocessing configuration
def read_preprocess_config(results_dir):
    cfgp = os.path.join(results_dir, "config.json")
    if os.path.exists(cfgp):
        with open(cfgp, 'r') as f: cfg = json.load(f)
        method = cfg.get("PREPROCESSING_METHOD", "none")
        smooth = int(cfg.get("SMOOTH_WINDOW", 1))
        return method, smooth
    return "none", 1

pp_method, pp_smooth = read_preprocess_config(AL_RESULTS_DIR)
print(f"Preprocessing: method={pp_method}, smooth_window={pp_smooth}")
preproc = DataPreprocessor(normalization_method=pp_method, smooth_window=pp_smooth)

# Assemble training matrices
train_mats_raw, train_species = [], []
for idx in train_indices:
    d = pool_info[idx]['sim_dir']
    mat = load_matrix(d); sp = load_species_list(d)
    if mat is None or sp is None: 
        continue
    al = align_to_reference(mat, sp, species_ref)
    train_mats_raw.append(al); train_species.append(species_ref)
if not train_mats_raw:
    raise RuntimeError("Training matrices are empty.")

# Fit preprocessor and transform
preproc.fit(train_mats_raw, train_species)
train_mats = [preproc.transform(m, species_ref) for m in train_mats_raw]

# Create windows
X_train, Y_train = create_windows(train_mats)
print(f"Training windows: X={X_train.shape}, Y={Y_train.shape}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------- Train and Evaluate: TCN -----------------
tcn = TCNForecast(input_dim=len(species_ref), num_channels=HIDDEN_CHANNELS_TCN, 
                  kernel_size=KERNEL_SIZE_TCN, pred_len=PRED_LEN, dropout=0.2)
_ , df_tcn = train_and_evaluate(tcn, "TCN", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_tcn

## 2. SIngle LSTM

In [None]:
import torch
import torch.nn as nn

class LSTMForecast(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_layers=2, pred_len=150):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_dim, input_dim * pred_len)
        self.input_dim = input_dim; self.pred_len = pred_len
    def forward(self, x):  # x: (B, L, C)
        out, _ = self.lstm(x)         # (B, L, H)
        h_last = out[:, -1, :]        # (B, H)
        y = self.fc(h_last)           # (B, C*P)
        return y.view(-1, self.pred_len, self.input_dim)

lstm = LSTMForecast(input_dim=len(species_ref), hidden_dim=64, num_layers=3, pred_len=PRED_LEN)
_, df_lstm = train_and_evaluate(lstm, "LSTM", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_lstm


## 3. Single GRU

In [None]:
import torch
import torch.nn as nn

class GRUForecast(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_layers=2, pred_len=150):
        super().__init__()
        self.gru = nn.GRU(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc  = nn.Linear(hidden_dim, input_dim * pred_len)
        self.input_dim = input_dim; self.pred_len = pred_len
    def forward(self, x):
        out, _ = self.gru(x)
        h_last = out[:, -1, :]
        y = self.fc(h_last)
        return y.view(-1, self.pred_len, self.input_dim)

gru = GRUForecast(input_dim=len(species_ref), hidden_dim=64, num_layers=3, pred_len=PRED_LEN)
_, df_gru = train_and_evaluate(gru, "GRU", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_gru


## 4. RNN

In [None]:
import torch
import torch.nn as nn

class RNNForecast(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_layers=2, pred_len=150, nonlinearity='tanh', dropout=0.0):
        super().__init__()
        self.rnn = nn.RNN(input_size=input_dim, hidden_size=hidden_dim,
                          num_layers=num_layers, nonlinearity=nonlinearity,
                          batch_first=True, dropout=dropout if num_layers>1 else 0.0)
        self.fc = nn.Linear(hidden_dim, input_dim * pred_len)
        self.input_dim = input_dim
        self.pred_len = pred_len

    def forward(self, x):                 # x: (B, L, C)
        out, _ = self.rnn(x)              # -> (B, L, H)
        h_last = out[:, -1, :]            # -> (B, H)
        y = self.fc(h_last)               # -> (B, C*P)
        return y.view(-1, self.pred_len, self.input_dim)

rnn_tanh = RNNForecast(input_dim=len(species_ref), hidden_dim=64, num_layers=3, pred_len=PRED_LEN, nonlinearity='tanh', dropout=0.0)
_, df_rnn = train_and_evaluate(rnn_tanh, "RNN_Tanh", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_rnn


## 5. BiLSTM

In [None]:
import torch
import torch.nn as nn

class BiLSTMForecast(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_layers=2, pred_len=150, dropout=0.1):
        super().__init__()
        self.bilstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim,
                              num_layers=num_layers, batch_first=True,
                              bidirectional=True, dropout=dropout if num_layers>1 else 0.0)
        self.fc = nn.Linear(2 * hidden_dim, input_dim * pred_len)
        self.input_dim = input_dim
        self.pred_len = pred_len

    def forward(self, x):                 # x: (B, L, C)
        out, _ = self.bilstm(x)           # -> (B, L, 2H)
        h_last = out[:, -1, :]            # -> (B, 2H)
        y = self.fc(h_last)               # -> (B, C*P)
        return y.view(-1, self.pred_len, self.input_dim)

bilstm = BiLSTMForecast(input_dim=len(species_ref), hidden_dim=64, num_layers=2, pred_len=PRED_LEN, dropout=0.1)
_, df_bilstm = train_and_evaluate(bilstm, "BiLSTM", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_bilstm


## 6. Transformer Encoder

In [None]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, L, d)
    def forward(self, x):  # x: (B, L, d)
        return x + self.pe[:, :x.size(1), :]

class TransformerForecast(nn.Module):
    def __init__(self, input_dim, d_model=128, nhead=8, num_layers=3, dim_feedforward=256, pred_len=150, dropout=0.1):
        super().__init__()
        self.inp_proj = nn.Linear(input_dim, d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.pos = PositionalEncoding(d_model)
        self.fc = nn.Linear(d_model, input_dim * pred_len)
        self.input_dim = input_dim; self.pred_len = pred_len
    def forward(self, x):   # x: (B, L, C)
        z = self.inp_proj(x)
        z = self.pos(z)
        z = self.encoder(z)            # (B, L, d)
        h = z[:, -1, :]                # (B, d)
        y = self.fc(h)                 # (B, C*P)
        return y.view(-1, self.pred_len, self.input_dim)

tfm = TransformerForecast(input_dim=len(species_ref), d_model=64, nhead=8, num_layers=3, dim_feedforward=128, pred_len=PRED_LEN, dropout=0.1)
_, df_tfm = train_and_evaluate(tfm, "Transformer", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_tfm


## 7.CNN-LSTM

In [None]:
import torch
import torch.nn as nn

class CNNLSTMForecast(nn.Module):
    def __init__(self, input_dim, conv_channels=128, kernel_size=5,
                 lstm_hidden=128, lstm_layers=1, pred_len=150, dropout=0.1):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Sequential(
            nn.Conv1d(input_dim, conv_channels, kernel_size, padding=padding),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(conv_channels, conv_channels, kernel_size, padding=padding),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.lstm = nn.LSTM(input_size=conv_channels, hidden_size=lstm_hidden,
                            num_layers=lstm_layers, batch_first=True)
        self.fc = nn.Linear(lstm_hidden, input_dim * pred_len)
        self.input_dim = input_dim
        self.pred_len = pred_len

    def forward(self, x):                 # x: (B, L, C)
        z = x.transpose(1, 2)             # -> (B, C, L)
        z = self.conv(z)                  # -> (B, conv_channels, L)
        z = z.transpose(1, 2)             # -> (B, L, conv_channels)
        enc, _ = self.lstm(z)             # -> (B, L, H)
        h_last = enc[:, -1, :]            # -> (B, H)
        y = self.fc(h_last)               # -> (B, C*P)
        return y.view(-1, self.pred_len, self.input_dim)

cnnlstm = CNNLSTMForecast(
    input_dim=len(species_ref), conv_channels=128, kernel_size=5,
    lstm_hidden=128, lstm_layers=1, pred_len=PRED_LEN, dropout=0.1
)
_, df_cnnlstm = train_and_evaluate(cnnlstm, "CNN_LSTM", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_cnnlstm


## 8. Seq2Seq LSTM

In [None]:
import torch
import torch.nn as nn

class Seq2SeqLSTM(nn.Module):
    def __init__(self, input_dim, hid=128, layers=2, pred_len=150):
        super().__init__()
        self.encoder = nn.LSTM(input_size=input_dim, hidden_size=hid, num_layers=layers, batch_first=True)
        self.decoder_cell = nn.LSTMCell(input_size=input_dim, hidden_size=hid)
        self.proj = nn.Linear(hid, input_dim)
        self.pred_len = pred_len; self.input_dim = input_dim
    def forward(self, x):  # x: (B,L,C)
        B = x.size(0)
        enc_out, (h, c) = self.encoder(x)   # h,c: (layers,B,H)
        h_t, c_t = h[-1], c[-1]             # 取顶层作为解码初态
        y_t = x[:, -1, :]                   # 用最后观测作为第一步输入
        outs = []
        for _ in range(self.pred_len):
            h_t, c_t = self.decoder_cell(y_t, (h_t, c_t))
            step = self.proj(h_t)           # (B,C)
            outs.append(step.unsqueeze(1))
            y_t = step
        return torch.cat(outs, dim=1)       # (B,P,C)

s2s = Seq2SeqLSTM(input_dim=len(species_ref), hid=64, layers=3, pred_len=PRED_LEN)
_, df_s2s = train_and_evaluate(s2s, "Seq2Seq_LSTM", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_s2s


## 9. ResMLP

In [None]:
import torch
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, d, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, d),
            nn.ReLU(),
            nn.Dropout(p),
            nn.Linear(d, d)
        )
        self.act = nn.ReLU()
    def forward(self, x):
        return self.act(x + self.net(x))

class ResMLP(nn.Module):
    def __init__(self, input_dim, pred_len=150, width=512, depth=3, p=0.1):
        super().__init__()
        self.input_dim = input_dim; self.pred_len = pred_len
        d_in  = INPUT_LEN * input_dim
        d_out = pred_len * input_dim
        self.stem = nn.Sequential(nn.Linear(d_in, width), nn.ReLU(), nn.Dropout(p))
        self.blocks = nn.Sequential(*[ResBlock(width, p=p) for _ in range(depth)])
        self.head = nn.Linear(width, d_out)
    def forward(self, x):          # x: (B,L,C)
        z = x.reshape(x.size(0), -1)
        z = self.stem(z)
        z = self.blocks(z)
        y = self.head(z)
        return y.view(-1, self.pred_len, self.input_dim)

resmlp = ResMLP(input_dim=len(species_ref), pred_len=PRED_LEN, width=256, depth=3, p=0.1)
_, df_resmlp = train_and_evaluate(resmlp, "ResMLP", X_train, Y_train, species_ref, preproc, device, save_root=SAVE_ROOT)
df_resmlp


In [None]:
# === Scatter Plot (健壮版：按规范名聚合绘制，RNN不丢失) =======================
import numpy as np, pandas as pd, matplotlib.pyplot as plt, os, re
from matplotlib.cm import get_cmap

VIZ_DIR   = "./baseline_eval/_viz"
TCN_NAME  = "TCN"

# -------- 可调绘图风格（统一） --------
STYLE = {
    "figsize": (10, 8),        # 图尺寸
    "label_fontsize": 20,      # 轴标签字号
    "tick_fontsize": 16,       # 刻度字号
    "legend_fontsize": 16,     # 图例字号
    "spine_linewidth": 2.5,
    "spine_color": "black",
    "tick_length": 4.0,
    "tick_width": 1.3,
    "axis_labelpad": 6,        # 轴标签与轴的间距
    "grid": True,              # 是否显示网格
    "dpi": 300,                # 画布 dpi
    "legend_frameon": False,   # 图例无边框
    "tcn_color": "tab:blue",   # TCN 颜色（固定）
    "baseline_marker": "o",    # 其他模型点形
    "tcn_marker": "D",         # TCN 点形
    "alpha": 0.75,             # 点透明度
    "tcn_size": 90,            # TCN 点大小
    "baseline_size": 52,       # 其他模型点大小
}

# 你的**固定图例顺序**
LEGEND_ORDER_CANONICAL = [
    "TCN", "RNN", "LSTM", "BiLSTM", "CNN_LSTM", "Seq2Seq_LSTM", "GRU", "Transformer", "ResMLP"
]

# 名称归一化：大小写无关、连字符/空格/下划线互转、常见别名合并
ALIASES_LOWER = {
    "tcn": "TCN",
    "rnn": "RNN", "rnn_tanh": "RNN", "rnn-tanh": "RNN", "rnn tanh": "RNN", "vanilla rnn": "RNN",
    "lstm": "LSTM",
    "bilstm": "BiLSTM", "bi-lstm": "BiLSTM", "bi lstm": "BiLSTM", "bi_lstm": "BiLSTM",
    "cnn_lstm": "CNN_LSTM", "cnn-lstm": "CNN_LSTM", "cnn lstm": "CNN_LSTM",
    "seq2seq_lstm": "Seq2Seq_LSTM", "seq2seq-lstm": "Seq2Seq_LSTM", "seq2seq lstm": "Seq2Seq_LSTM",
    "gru": "GRU",
    "transformer": "Transformer",
    "resmlp": "ResMLP",
}

def to_canonical(name: str) -> str:
    """把任意写法映射到规范名；大小写无关，_- 空格互换，去多余字符；RNN_tanh → RNN。"""
    if not isinstance(name, str):
        return str(name)
    s = name.strip().lower()
    s = re.sub(r"[^\w\s\-]+", "", s)           # 去掉奇异字符
    s = s.replace("__","_").replace("-", " ")  # 统一分隔
    s = re.sub(r"\s+", " ", s).strip()
    # 直接映射
    if s in ALIASES_LOWER:
        return ALIASES_LOWER[s]
    # 再尝试把空格换成下划线
    s2 = s.replace(" ", "_")
    if s2 in ALIASES_LOWER:
        return ALIASES_LOWER[s2]
    # 没匹配就保留原名（首字母大写）
    return name.strip()

def build_color_map(canonical_list, tcn_color="tab:blue", cmap_name="tab20"):
    """按规范名分配颜色：TCN固定，其它从调色板不重复取色。"""
    cmap = get_cmap(cmap_name)
    palette = [cmap(i % cmap.N) for i in range(max(20, len(canonical_list)))]
    color_map = {}
    # 先给 TCN 固定色
    if "TCN" in canonical_list:
        color_map["TCN"] = tcn_color
    idx = 0
    for cn in canonical_list:
        if cn in color_map:
            continue
        color_map[cn] = palette[idx]
        idx += 1
    return color_map

df = pd.read_csv(os.path.join(VIZ_DIR, "all_models_by_test.csv"))

def scatter_xy(xm, ym, xlog=True, ylog=False, save_name="scatter.png"):
    fig, ax = plt.subplots(figsize=STYLE["figsize"], dpi=STYLE["dpi"])

    # 1) 把数据按“规范名”聚合（保证同类合并，如 RNN_tanh → RNN）
    df_ = df[["Model", xm, ym]].dropna().copy()
    df_["Canon"] = df_["Model"].apply(to_canonical)

    # 2) 当前数据里实际出现的规范名，并按固定顺序排列
    canon_present = []
    for k in LEGEND_ORDER_CANONICAL:
        if (df_["Canon"] == k).any():
            canon_present.append(k)
    # 如果还有其它名字（不在固定表里），追加到末尾（很少见，这里做兜底）
    others = sorted(set(df_["Canon"].unique()) - set(canon_present))
    canon_present += others

    # 3) 颜色映射（按规范名）
    color_map = build_color_map(canon_present, tcn_color=STYLE["tcn_color"])

    # 4) 逐规范名绘制散点（把同一规范名所有行一起画）
    handles, labels = [], []
    for cn in canon_present:
        sub = df_.loc[df_["Canon"] == cn, [xm, ym]]
        if sub.empty:
            continue
        is_tcn = (cn == "TCN")
        sc = ax.scatter(
            sub[xm], sub[ym],
            s=STYLE["tcn_size"] if is_tcn else STYLE["baseline_size"],
            alpha=STYLE["alpha"],
            label=cn,
            marker=STYLE["tcn_marker"] if is_tcn else STYLE["baseline_marker"],
            color=color_map.get(cn, "tab:gray"),
            edgecolor="none",
        )
        handles.append(sc); labels.append(cn)

    # 5) 轴设置
    ax.set_xlabel(xm, fontsize=STYLE["label_fontsize"], labelpad=STYLE["axis_labelpad"])
    ax.set_ylabel(ym, fontsize=STYLE["label_fontsize"], labelpad=STYLE["axis_labelpad"])
    if xlog: ax.set_xscale("log")
    if ylog: ax.set_yscale("log")

    # 刻度与脊线
    ax.tick_params(axis='both', which='both',
                   labelsize=STYLE["tick_fontsize"],
                   length=STYLE["tick_length"],
                   width=STYLE["tick_width"])
    for spine in ["top","right","left","bottom"]:
        ax.spines[spine].set_linewidth(STYLE["spine_linewidth"])
        ax.spines[spine].set_color(STYLE["spine_color"])

    # 网格
    if STYLE["grid"]:
        ax.grid(ls="--", alpha=0.3)

    # 6) 图例严格按固定顺序（仅保留出现的项）
    lh = {lab: h for lab, h in zip(labels, handles)}
    legend_order = [lab for lab in LEGEND_ORDER_CANONICAL if lab in lh] + \
                   [lab for lab in labels if lab not in LEGEND_ORDER_CANONICAL]
    legend_handles = [lh[lab] for lab in legend_order]

    ax.legend(
        legend_handles, legend_order,
        loc="best",
        frameon=STYLE["legend_frameon"],
        fontsize=STYLE["legend_fontsize"],
        ncol=2
    )

    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, save_name), bbox_inches="tight", dpi=STYLE["dpi"])
    plt.show()

# 示例调用
scatter_xy("MSE", "DTW", xlog=True, ylog=False, save_name="scatter_MSE_vs_DTW.png")
# scatter_xy("Huber", "DTW", xlog=True, ylog=True, save_name="scatter_Huber_vs_DTW.png")


In [None]:
# === Cell A3: Architectural benchmarking table (averages) =====================
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt

VIZ_DIR   = "./baseline_eval/_viz"
AVG_CSV   = os.path.join(VIZ_DIR, "all_models_averages.csv")
OUT_CSV   = os.path.join(VIZ_DIR, "TableX_architectural_benchmark.csv")
OUT_PNG   = os.path.join(VIZ_DIR, "TableX_architectural_benchmark.png")

# 指标顺序（与文中一致）
METRICS_ORDER = ["Huber","MSE","RMSE","MAE","WAPE (%)","DTW","pearson","R2",
                 "short_MSE","long_MSE","short_MAE","long_MAE","short_WAPE","long_WAPE"]

df = pd.read_csv(AVG_CSV)

# 只保留模型行（去掉 Test Case 列）
df = df[["Model"] + [c for c in METRICS_ORDER if c in df.columns]].copy()

# 数值格式统一
def fmt(v, col):
    try:
        x = float(v)
    except:
        return v
    if col in ["WAPE (%)","short_WAPE","long_WAPE","pearson","R2"]:
        return f"{x:.3f}" if col in ["pearson","R2"] else f"{x:.1f}"
    else:
        return f"{x:.3e}"

for c in METRICS_ORDER:
    if c in df.columns:
        df[c] = df[c].apply(lambda v: fmt(v, c))

# 保存 CSV
df.to_csv(OUT_CSV, index=False)

# 画表（无顶部数字，Arial 风格）
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Helvetica']
fig, ax = plt.subplots(figsize=(16, 9), dpi=300)
ax.axis('tight'); ax.axis('off')
table = ax.table(cellText=df.values, colLabels=df.columns, cellLoc='center', loc='center')
table.auto_set_font_size(False); table.set_fontsize(9); table.scale(1.2, 1.5)

# 表头加底色
for i in range(len(df.columns)):
    table[(0, i)].set_facecolor('#4CAF50'); table[(0, i)].set_text_props(weight='bold', color='white')
# 斑马线
for i in range(1, len(df)):
    for j in range(len(df.columns)):
        if i % 2 == 0:
            table[(i, j)].set_facecolor('#F5F5F5')

plt.tight_layout()
plt.savefig(OUT_PNG, bbox_inches="tight", dpi=300)
plt.show()

print(f"✓ 保存完成：\n  - {OUT_CSV}\n  - {OUT_PNG}")


In [None]:
# === Cell 5: Radar chart (normalized multi-metric) ============================
import numpy as np, pandas as pd, matplotlib.pyplot as plt, os, math
import re # 导入 re 模块用于名称规范化

# ====================================================================
# 【可配置参数区域】
# ====================================================================

# 绘图输出配置
PLOT_DPI = 300
FIGURE_SIZE = (8, 8)
FONT_SIZE_METRIC = 15
FONT_SIZE_RADIAL = 12
# 【新增】图例字体大小
FONT_SIZE_LEGEND = 12 # <--- 在这里调整模型图例的字体大小

# 模型/边框线宽配置
MODEL_LINEWIDTH = 2.5 # 模型的曲线线宽
SPINE_LINEWIDTH = 2.5 # 最外圈线框的线宽

# 图例位置配置
LEGEND_LOC = "upper right"
LEGEND_BBOX_TO_ANCHOR = (1.25, 1.10) 
LEGEND_RIGHT_ADJUST_FACTOR = 0.05 

# 径向刻度配置 (0.2, 0.4 等数字标签)
RADIAL_TICKS = [0.2, 0.4, 0.6, 0.8]
R_LABEL_POSITION_DEG = 22.5 
R_LABEL_PAD = 10

# 标题配置 (原代码中已注释掉，这里保留配置项)
PLOT_TITLE = "Radar (normalized)"
TITLE_PAD = 20

# 标签旋转优化配置 (动态调整阈值，防止WAPE重叠)
ROTATION_THRESHOLD_DEG = 5 

# ====================================================================
# --- 1. 准备数据和设置 (保持不变) ---
# ... (此部分代码未变化，保持不变) ...
# ====================================================================
# 您的**固定模型名称别名**映射（用于名称规范化）
ALIASES_LOWER = {
    "tcn": "TCN",
    "rnn": "RNN", "rnn_tanh": "RNN", "rnn-tanh": "RNN", "rnn tanh": "RNN", "vanilla rnn": "RNN",
    "lstm": "LSTM",
    "bilstm": "BiLSTM", "bi-lstm": "BiLSTM", "bi lstm": "BiLSTM", "bi_lstm": "BiLSTM",
    "cnn_lstm": "CNN_LSTM", "cnn-lstm": "CNN_LSTM", "cnn lstm": "CNN_LSTM",
    "seq2seq_lstm": "Seq2Seq_LSTM", "seq2seq-lstm": "Seq2Seq_LSTM", "seq2seq lstm": "Seq2Seq_LSTM",
    "gru": "GRU",
    "transformer": "Transformer",
    "resmlp": "ResMLP",
}

def to_canonical(name: str) -> str:
    """把任意写法映射到规范名；大小写无关，_- 空格互换，去多余字符；RNN_tanh → RNN。"""
    if not isinstance(name, str):
        return str(name)
    s = name.strip().lower()
    s = re.sub(r"[^\w\s\-]+", "", s)         # 去掉奇异字符
    s = s.replace("__","_").replace("-", " ") # 统一分隔
    s = re.sub(r"\s+", " ", s).strip()
    
    # 直接映射
    if s in ALIASES_LOWER:
        return ALIASES_LOWER[s]
    # 再尝试把空格换成下划线
    s2 = s.replace(" ", "_")
    if s2 in ALIASES_LOWER:
        return ALIASES_LOWER[s2]
    # 没匹配就保留原名（首字母大写）
    return name.strip()


# 假设 VIZ_DIR, METRICS_ORDER, MIN_BETTER, TCN_NAME 已在前面代码中定义

avg = pd.read_csv(os.path.join(VIZ_DIR, "all_models_averages.csv"))

# **使用 to_canonical 函数修正模型名称**
avg["Model"] = avg["Model"].apply(to_canonical) 

avg = avg[["Model"] + METRICS_ORDER].copy()

# 1. 确保标签顺序按照要求：
REQUIRED_ORDER = [
    "TCN", "RNN", "LSTM", "BiLSTM", "CNN_LSTM", 
    "Seq2Seq_LSTM", "GRU", "Transformer", "ResMLP"
]

# 选择要画的模型（None = 自动选出 MSE 最好的前 3 + TCN）
SELECTED_MODELS_AUTO = None 

if SELECTED_MODELS_AUTO is None:
    tmp = avg[["Model","MSE"]].dropna().sort_values("MSE")
    auto = list(tmp["Model"].values[:10])
    
    # 过滤掉不存在于 avg 中的模型，并保持 REQUIRED_ORDER 优先
    all_models_in_avg = set(avg["Model"].values)
    
    SELECTED_MODELS = [m for m in REQUIRED_ORDER if m in all_models_in_avg]
    for m in auto:
        if m not in SELECTED_MODELS and m in all_models_in_avg:
            SELECTED_MODELS.append(m)
            
    if not SELECTED_MODELS:
        SELECTED_MODELS = list(tmp["Model"].values[:3])

    print("最终用于绘图的模型列表（按指定顺序）：", SELECTED_MODELS)
else:
    # 如果手动指定了模型，也应用 to_canonical
    SELECTED_MODELS = [to_canonical(m) for m in SELECTED_MODELS_AUTO]

# 只取参与模型的数据，并按照 SELECTED_MODELS 的顺序设置索引
sub = avg[avg["Model"].isin(SELECTED_MODELS)].set_index("Model")
sub = sub.reindex(SELECTED_MODELS) # 确保顺序

# 归一化：误差类 => 值越小越接近 1；相关类 => 值越大越接近 1
def normalize_series(vals, minimize=True):
    v = np.array(vals, dtype=float)
    if np.all(np.isnan(v)): return np.zeros_like(v) + 0.5
    mn, mx = np.nanmin(v), np.nanmax(v)
    if not math.isfinite(mn) or not math.isfinite(mx) or mn==mx:
        return np.zeros_like(v) + 0.5
    if minimize:
        return 1.0 - (v - mn) / (mx - mn)
    else:
        return (v - mn) / (mx - mn)

normed = {}
for met in METRICS_ORDER:
    minimize = met in MIN_BETTER 
    normed[met] = normalize_series(sub[met].values, minimize=minimize)

# ====================================================================
# --- 2. 绘图 (应用配置) ---
# ====================================================================

# 组装雷达数据（按 METRICS_ORDER）
metrics_count = len(METRICS_ORDER)
angles = np.linspace(0, 2*np.pi, metrics_count, endpoint=False).tolist()
angles += angles[:1]  # 闭合
fig, ax = plt.subplots(figsize=FIGURE_SIZE, subplot_kw=dict(polar=True), dpi=160) 

# 绘制每个模型的雷达图
for i, (model, row) in enumerate(sub.iterrows()):
    vals = [normed[met][i] for met in METRICS_ORDER] 
    vals += vals[:1]
    ax.plot(angles, vals, label=model, linewidth=MODEL_LINEWIDTH, marker='o') 
    ax.fill(angles, vals, alpha=0.1)

# 设置最外圈黑色边框线宽
for spine in ax.spines.values():
    spine.set_linewidth(SPINE_LINEWIDTH)
    spine.set_color('black') 

# 设置角度刻度 (指标名称)
ax.set_xticks(angles[:-1])

# **动态调整指标标签的旋转和对齐方式**
# 获取 X 轴标签对象
labels = ax.set_xticklabels(METRICS_ORDER, fontsize=FONT_SIZE_METRIC)

# 遍历每个标签，根据其角度位置进行调整
for label, angle in zip(labels, angles[:-1]):
    # 将弧度转换为度数
    deg = angle * 180 / np.pi
    
    # 调整旋转角度
    if deg > 90 and deg < 270:
        # 左侧半圆的标签
        rotation = deg + 90
        align = 'right'
    else:
        # 右侧半圆的标签
        rotation = deg - 90
        align = 'left'

    # 进一步优化：如果标签靠近水平线（0, 180度），使用配置的阈值进行特殊处理
    if deg < ROTATION_THRESHOLD_DEG or deg > (360 - ROTATION_THRESHOLD_DEG): # 靠近右侧水平线 (0度)
        rotation = 0
        align = 'left'
    elif deg > (180 - ROTATION_THRESHOLD_DEG) and deg < (180 + ROTATION_THRESHOLD_DEG): # 靠近左侧水平线 (180度)
        rotation = 0
        align = 'right'
        
    label.set_rotation(rotation)
    label.set_horizontalalignment(align)
    label.set_verticalalignment('bottom') 
    
# 调整径向刻度标签（0.2, 0.4等）的位置，避免与图表中心重叠
ax.set_yticks(RADIAL_TICKS)
ax.set_yticklabels([str(t) for t in RADIAL_TICKS], fontsize=FONT_SIZE_RADIAL)
ax.set_rlabel_position(R_LABEL_POSITION_DEG) 
ax.tick_params(axis='y', pad=R_LABEL_PAD)

# # 设置标题 (如果需要)
# if PLOT_TITLE:
#     ax.set_title(PLOT_TITLE, pad=TITLE_PAD)
    
# 【调整图例位置和字体大小】
ax.legend(loc=LEGEND_LOC, bbox_to_anchor=LEGEND_BBOX_TO_ANCHOR, fontsize=FONT_SIZE_LEGEND)

plt.tight_layout()
# **使用 LEGEND_RIGHT_ADJUST_FACTOR 调整右侧边界**
plt.subplots_adjust(right=LEGEND_BBOX_TO_ANCHOR[0] + LEGEND_RIGHT_ADJUST_FACTOR) 

plt.savefig(os.path.join(VIZ_DIR, "radar_normalized.png"), bbox_inches="tight", dpi=PLOT_DPI)
plt.show()