In [1]:
#1 Core Configuration Manager / 核心配置管理器
import os
import json
import logging
import logging.handlers
from datetime import datetime
from collections import deque
from IPython.display import clear_output
from pathlib import Path
from typing import Dict, Any, Optional
from cell3_monitor import (
    DailyRotatingFileHandler, 
    CustomFormatter,
    ProgressHandler, 
    LogDisplayManager
)
import numpy as np  # 确保所有使用np的地方都有导入

class CoreManager:
    """核心管理器 - 整合配置和日志管理"""
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not hasattr(self, 'initialized'):
            # 初始化基础目录结构
            self._init_directories()
            # 初始化日志系统
            self._init_logging_system()
            # 初始化配置系统（需要最先完成）
            self._init_config_system()
            
            # 添加配置验证规则
            self.validation_rules = {
                'learning_rate': lambda x: 0 < x < 1,
                'batch_size': lambda x: x > 0 and x & (x-1) == 0
            }
            
            # 初始化时执行配置测试
            self._test_config_system()
            
            # 最后设置初始化完成标记
            self.initialized = True
            self.logger.info("核心管理器初始化完成")

    def _init_directories(self):
        """初始化目录结构"""
        # 基础路径配置
        self.BASE_DIR = 'D:\\JupyterWork'
        self.LOG_DIR = os.path.join(self.BASE_DIR, 'logs')
        self.MODEL_DIR = os.path.join(self.BASE_DIR, 'models')
        self.DATA_DIR = os.path.join(self.BASE_DIR, 'data')
        self.CHECKPOINT_DIR = os.path.join(self.BASE_DIR, 'checkpoints')
        
        # 创建必要的目录
        for dir_path in [self.LOG_DIR, self.MODEL_DIR, self.DATA_DIR, self.CHECKPOINT_DIR]:
            os.makedirs(dir_path, exist_ok=True)

    def _init_logging_system(self):
        """初始化日志系统"""
        # 创建主日志处理器
        self.logger = logging.getLogger('ACE_System')
        self.logger.setLevel(logging.INFO)
        
        # 日常日志处理器
        daily_handler = self._create_daily_rotating_handler()
        self.logger.addHandler(daily_handler)
        
        # 控制台处理器
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(self._create_custom_formatter())
        self.logger.addHandler(console_handler)
        
        # 进度条处理器
        progress_handler = self._create_progress_handler()
        self.logger.addHandler(progress_handler)

    def _init_config_system(self):
        """初始化配置系统"""
        # 系统配置
        self.SYSTEM_CONFIG = {
            'TRAINING_CONFIG': {  # 添加训练配置
                'batch_size': 64,
                'max_epochs': 100,
                'early_stopping_patience': 10,
                'learning_rate': 0.001,
                'model_checkpoint_interval': 5
            },
            'DATA_CONFIG': {
                'cache_size': 10000,
                'min_sequence_length': 14400,
                'normalize_range': (-1, 1)
            },
            'SAMPLE_CONFIG': {
                'input_length': 14400,
                'target_length': 2880
            }
        }
        
        # 单独的训练配置
        self.TRAINING_CONFIG = self.SYSTEM_CONFIG['TRAINING_CONFIG']
        
        # 数据库配置
        self.DB_CONFIG = {
            'host': 'localhost',
            'port': 3306,
            'user': 'root',
            'password': 'tt198803',
            'database': 'admin_data',
            'charset': 'utf8mb4'
        }
        
        # 优化器配置
        self.OPTIMIZER_CONFIG = {
            'learning_rate': 0.001,
            'beta_1': 0.9,
            'beta_2': 0.999,
            'epsilon': 1e-07
        }
        
        # 日志配置
        self.LOG_CONFIG = {
            'log_level': 'INFO',
            'log_format': '%(asctime)s [%(levelname)s] %(name)s - %(message)s',
            'log_file': 'system.log',
            'log_retention_days': 7
        }
        
        # 模型配置
        self.MODEL_CONFIG = {
            'model_type': 'transformer',
            'num_layers': 6,
            'num_heads': 8,
            'd_model': 512,
            'dff': 2048,
            'dropout_rate': 0.1
        }
        
        self.logger.info("配置系统初始化完成")

    def update_config(self, config_name: str, updates: Dict[str, Any]) -> bool:
        """更新指定配置"""
        try:
            config = getattr(self, f'{config_name}_CONFIG')
            config.update(updates)
            return True
        except AttributeError:
            self.logger.error(f"配置 {config_name} 不存在")
            return False

    def save_config(self, config_name: str) -> bool:
        """保存配置到文件"""
        try:
            config = getattr(self, f'{config_name}_CONFIG')
            save_path = os.path.join(self.BASE_DIR, 'configs', f'{config_name.lower()}_config.json')
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            
            with open(save_path, 'w', encoding='utf-8') as f:
                json.dump(config, f, indent=4, ensure_ascii=False)
            return True
        except Exception as e:
            self.logger.error(f"保存配置失败: {str(e)}")
            return False

    def load_config(self, config_name: str) -> bool:
        """从文件加载配置"""
        try:
            load_path = os.path.join(self.BASE_DIR, 'configs', f'{config_name.lower()}_config.json')
            if not os.path.exists(load_path):
                return False
            
            with open(load_path, 'r', encoding='utf-8') as f:
                config = json.load(f)
            
            setattr(self, f'{config_name}_CONFIG', config)
            return True
        except Exception as e:
            self.logger.error(f"加载配置失败: {str(e)}")
            return False

    def validate_config(self, config_name: str) -> bool:
        """验证配置有效性"""
        try:
            if config_name == 'DB':
                return self._validate_db_config(self.DB_CONFIG)
            elif config_name == 'SYSTEM':
                return self._validate_system_config(self.SYSTEM_CONFIG)
            elif config_name == 'TRAINING':
                return self._validate_training_config(self.TRAINING_CONFIG)
            else:
                self.logger.error(f"未知的配置类型: {config_name}")
                return False
        except Exception as e:
            self.logger.error(f"验证配置失败: {str(e)}")
            return False

    def _validate_db_config(self, config: Dict[str, str]) -> bool:
        """验证数据库配置"""
        required_fields = ['host', 'user', 'password', 'database', 'charset']
        return all(field in config for field in required_fields)

    def _validate_training_config(self, config: Dict[str, Any]) -> bool:
        """验证训练配置"""
        try:
            assert config['batch_size'] > 0
            assert config['max_epochs'] > 0
            return True
        except (AssertionError, KeyError):
            return False

    def _validate_system_config(self, config: Dict[str, Any]) -> bool:
        """验证系统配置"""
        try:
            assert config['memory_limit'] > 0
            assert config['gpu_memory_limit'] > 0
            assert config['cleanup_interval'] > 0
            assert config['log_retention_days'] > 0
            return True
        except (AssertionError, KeyError):
            return False

    def _create_daily_rotating_handler(self):
        """创建每日轮转的日志处理器"""
        handler = DailyRotatingFileHandler(
            base_dir=self.LOG_DIR,
            prefix='system'
        )
        handler.setFormatter(self._create_custom_formatter())
        return handler
        
    # 添加从cell4中的LogManager类的功能
    def setup_continuous_logging(self):
        """设置持续训练的日志系统"""
        self.continuous_logger = logging.getLogger('ContinuousTraining')
        self.continuous_logger.setLevel(logging.INFO)
        self.continuous_log_buffer = deque(maxlen=100)
        
        # 添加持续训练的文件处理器
        continuous_handler = DailyRotatingFileHandler(
            base_dir=self.LOG_DIR,
            prefix='continuous_training'
        )
        formatter = logging.Formatter(
            '[%(asctime)s] %(levellevel)s: %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        continuous_handler.setFormatter(formatter)
        self.continuous_logger.addHandler(continuous_handler)
        
    def update_training_progress(self, model_idx: int, progress: float):
        """更新训练进度"""
        if hasattr(self, 'display_manager'):
            self.display_manager.progress_bars[model_idx] = progress
            self.display_manager._display_logs()
            
    def add_training_log(self, message: str):
        """添加训练日志"""
        if hasattr(self, 'continuous_log_buffer'):
            self.continuous_log_buffer.append(message)
            if hasattr(self, 'display_manager'):
                self.display_manager.log_buffer.append(message)
                self.display_manager._display_logs()

    # 添加从cell1的ConfigValidator的功能
    def validate_all_configs(self) -> Dict[str, bool]:
        """验证所有配置的有效性"""
        results = {}
        
        # 数据库配置验证
        results['db_config'] = self._extended_validate_db_config()
        
        # 训练配置验证
        results['training_config'] = self._extended_validate_training_config()
        
        # 系统配置验证
        results['system_config'] = self._extended_validate_system_config()
        
        return results
        
    def _extended_validate_db_config(self) -> bool:
        """扩展的数据库配置验证"""
        try:
            config = self.DB_CONFIG
            
            # 基本字段验证
            required_fields = ['host', 'user', 'password', 'database', 'charset']
            if not all(field in config for field in required_fields):
                self.logger.error("数据库配置缺少必要字段")
                return False
                
            # 端口验证
            if not isinstance(config.get('port', 3306), int):
                self.logger.error("数据库端口必须是整数")
                return False
                
            # 字符集验证
            if config.get('charset') not in ['utf8', 'utf8mb4']:
                self.logger.error("不支持的字符集")
                return False
                
            return True
            
        except Exception as e:
            self.logger.error(f"验证数据库配置时出错: {str(e)}")
            return False
            
    def _extended_validate_training_config(self) -> bool:
        """扩展的训练配置验证"""
        try:
            config = self.TRAINING_CONFIG
            
            # 参数范围验证
            validations = [
                config.get('batch_size', 0) > 0,
                config.get('max_epochs', 0) > 0,
                0 <= config.get('save_frequency', 100) <= 1000,
                0 <= config.get('eval_frequency', 50) <= 500,
                0 < config.get('min_improvement', 0.001) < 1
            ]
            
            if not all(validations):
                self.logger.error("训练配置参数范围无效")
                return False
                
            return True
            
        except Exception as e:
            self.logger.error(f"验证训练配置时出错: {str(e)}")
            return False
            
    def _extended_validate_system_config(self) -> bool:
        """扩展的系统配置验证"""
        try:
            config = self.SYSTEM_CONFIG
            
            # 资源限制验证
            if not all([
                config.get('memory_limit', 0) >= 1000,
                config.get('gpu_memory_limit', 0) >= 1000,
                config.get('max_threads', 0) > 0
            ]):
                self.logger.error("系统资源限制配置无效")
                return False
            
            # 采样配置验证
            sample_config = config.get('SAMPLE_CONFIG', {})
            if not all([
                sample_config.get('input_length', 0) > 0,
                sample_config.get('target_length', 0) > 0
            ]):
                self.logger.error("采样配置无效")
                return False
                
            return True
            
        except Exception as e:
            self.logger.error(f"验证系统配置时出错: {str(e)}")
            return False

    def export_configs(self, export_path: str = None) -> bool:
        """导出所有配置"""
        if export_path is None:
            export_path = os.path.join(self.BASE_DIR, 'configs', 'all_configs.json')
            
        try:
            configs = {
                'DB_CONFIG': self.DB_CONFIG,
                'TRAINING_CONFIG': self.TRAINING_CONFIG,
                'SYSTEM_CONFIG': self.SYSTEM_CONFIG,
                'OPTUNA_CONFIG': self.OPTUNA_CONFIG
            }
            
            os.makedirs(os.path.dirname(export_path), exist_ok=True)
            with open(export_path, 'w', encoding='utf-8') as f:
                json.dump(configs, f, indent=4, ensure_ascii=False)
                
            self.logger.info(f"配置已导出到: {export_path}")
            return True
            
        except Exception as e:
            self.logger.error(f"导出配置失败: {str(e)}")
            return False

    def _create_custom_formatter(self):
        """创建自定义格式化器"""
        return CustomFormatter()

    def _create_progress_handler(self):
        """创建进度处理器"""
        return ProgressHandler()

    # 从cell4的LoggingManager类补充
    def _configure_logging(self):
        """配置日志系统"""
        # 配置根日志记录器
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s [%(levelname)s] %(name)s - %(message)s',
            handlers=[
                logging.FileHandler(os.path.join(self.LOG_DIR, "system.log")),
                logging.StreamHandler()
            ]
        )
        # 捕获警告信息
        logging.captureWarnings(True)

    # 从cell4的LogManager类补充
    def get_logger(self):
        """获取logger实例"""
        return self.logger

    def get_continuous_logger(self):
        """获取持续训练的logger实例"""
        if not hasattr(self, 'continuous_logger'):
            self.setup_continuous_logging()
        return self.continuous_logger

    # 从cell1的ConfigManager补充
    def get_config(self, config_name: str) -> Dict[str, Any]:
        """获取指定配置
        Args:
            config_name: 配置名称 ('DB', 'TRAINING', 'SYSTEM', 'OPTUNA')
        """
        try:
            return getattr(self, f'{config_name}_CONFIG').copy()
        except AttributeError:
            self.logger.error(f"配置 {config_name} 不存在")
            return {}

    def clear_log_buffers(self):
        """清理日志缓冲区"""
        if hasattr(self, 'continuous_log_buffer'):
            self.continuous_log_buffer.clear()
        if hasattr(self, 'display_manager'):
            self.display_manager.log_buffer.clear()
            self.display_manager._display_logs()

    def cleanup_logs(self, days: int = None):
        """清理旧日志文件
        Args:
            days: 保留天数，默认使用配置中的值
        """
        if days is None:
            days = self.SYSTEM_CONFIG.get('log_retention_days', 7)
            
        try:
            current_time = datetime.now()
            for file in os.listdir(self.LOG_DIR):
                if file.endswith('.log'):
                    file_path = os.path.join(self.LOG_DIR, file)
                    file_time = datetime.fromtimestamp(os.path.getctime(file_path))
                    if (current_time - file_time).days > days:
                        os.remove(file_path)
                        self.logger.info(f"已删除旧日志文件: {file}")
                        
        except Exception as e:
            self.logger.error(f"清理日志文件时出错: {str(e)}")

    def _configure_metrics(self):
        """配置性能指标记录"""
        self.metrics = {
            'training_loss': deque(maxlen=1000),
            'validation_loss': deque(maxlen=1000),
            'learning_rates': deque(maxlen=1000),
            'batch_times': deque(maxlen=100)
        }

    def log_metric(self, metric_name: str, value: float):
        """记录性能指标"""
        if hasattr(self, 'metrics') and metric_name in self.metrics:
            self.metrics[metric_name].append(value)

    def get_metrics_summary(self):
        """获取性能指标摘要"""
        if not hasattr(self, 'metrics'):
            return {}
            
        return {
            name: {
                'mean': np.mean(values),
                'min': np.min(values),
                'max': np.max(values)
            }
            for name, values in self.metrics.items()
            if values
        }

    def _test_config_system(self):
        """测试配置系统完整性"""
        try:
            required_keys = ['host', 'port', 'user', 'password'] 
            assert all(key in self.DB_CONFIG for key in required_keys), "数据库配置缺失必要参数"
            self.logger.info("配置系统测试通过")
        except AssertionError as e:
            self.logger.error(f"配置系统测试失败: {str(e)}")
            raise

    def validate_config_values(self, config: dict) -> bool:
        """验证配置参数值 (from ConfigValidator)"""
        valid = True
        for key, rule in self.validation_rules.items():
            if key in config:
                if not rule(config[key]):
                    self.logger.warning(f"参数 {key} 的值 {config[key]} 无效")
                    valid = False
        return valid

class DailyRotatingFileHandler(logging.FileHandler):
    """每日自动分文件的日志处理器"""
    def __init__(self, base_dir, prefix='log', max_bytes=50*1024*1024):
        self.base_dir = base_dir
        self.prefix = prefix
        self.max_bytes = max_bytes
        self.current_date = None
        self.current_file = None
        self.current_size = 0
        self.file_count = 1
        
        os.makedirs(base_dir, exist_ok=True)
        self._init_file()
        super().__init__(self.current_file, mode='a', encoding='utf-8')
    
    def _init_file(self):
        """初始化日志文件"""
        self.current_file = self._get_file_path()
        self.current_date = datetime.now().strftime('%Y%m%d')
        if os.path.exists(self.current_file):
            self.current_size = os.path.getsize(self.current_file)
        else:
            self.current_size = 0
    
    def _get_file_path(self):
        """获取当前日志文件路径"""
        today = datetime.now().strftime('%Y%m%d')
        if self.current_size >= self.max_bytes:
            self.file_count += 1
            return os.path.join(self.base_dir, f'{self.prefix}_{today}_{self.file_count}.log')
        elif today != self.current_date:
            self.file_count = 1
            self.current_date = today
            return os.path.join(self.base_dir, f'{self.prefix}_{today}.log')
        return self.current_file
    
    def emit(self, record):
        """重写emit方法，在写入日志前检查文件状态"""
        try:
            new_file = self._get_file_path()
            if new_file != self.current_file:
                if self.stream:
                    self.stream.close()
                self.current_file = new_file
                self.baseFilename = new_file
                self.current_size = 0
                self.stream = self._open()
            
            msg = self.format(record) + '\n'
            self.stream.write(msg)
            self.stream.flush()
            self.current_size += len(msg.encode('utf-8'))
            
        except Exception as e:
            self.handleError(record)

class ProgressHandler(logging.Handler):
    """进度条处理器"""
    def __init__(self):
        super().__init__()
        self.progress = 0
        
    def emit(self, record):
        if hasattr(record, 'progress'):
            print('\r' + ' ' * 80, end='\r')  # 清除当前行
            progress = int(record.progress * 50)
            print(f'\rTraining Progress: [{"="*progress}{" "*(50-progress)}] {record.progress*100:.1f}%', end='')

class CustomFormatter(logging.Formatter):
    """自定义日志格式化器"""
    def __init__(self):
        super().__init__()
        self.formatters = {
            logging.DEBUG: logging.Formatter(
                '[%(asctime)s] %(name)s - %(levelname)s - %(message)s'
            ),
            logging.INFO: logging.Formatter(
                '[%(asctime)s] %(levelname)s - %(message)s'
            ),
            logging.WARNING: logging.Formatter(
                '[%(asctime)s] WARNING - %(message)s'
            ),
            logging.ERROR: logging.Formatter(
                '[%(asctime)s] ERROR - %(message)s\n\tat %(pathname)s:%(lineno)d'
            ),
            logging.CRITICAL: logging.Formatter(
                '[%(asctime)s] CRITICAL - %(message)s\n\tat %(pathname)s:%(lineno)d\n%(exc_info)s'
            )
        }
    
    def format(self, record):
        formatter = self.formatters.get(record.levelno)
        return formatter.format(record)

class LogDisplayManager:
    """日志显示管理器"""
    def __init__(self, max_lines=10):
        self.max_lines = max_lines
        self.log_buffer = []
        self.progress_bars = {i: 0.0 for i in range(6)}
        self._clear_output()
    
    def _clear_output(self):
        """清空输出"""
        clear_output(wait=True)
    
    def _display_logs(self):
        """显示日志"""
        self._clear_output()
        
        start_idx = max(0, len(self.log_buffer) - self.max_lines)
        for log in self.log_buffer[start_idx:]:
            if log.strip():
                print(log)
        
        print('-' * 80)
        
        for model_idx, progress in self.progress_bars.items():
            bar_length = 50
            filled = int(progress * bar_length)
            bar = f"Model {model_idx + 1}: [{'='*filled}{' '*(bar_length-filled)}] {progress*100:.1f}%"
            print(bar)

# 创建全局实例
core_manager = CoreManager()

# 导出常用变量
logger = core_manager.logger
config = core_manager
BASE_DIR = core_manager.BASE_DIR
LOG_DIR = core_manager.LOG_DIR
MODEL_DIR = core_manager.MODEL_DIR
DATA_DIR = core_manager.DATA_DIR
CHECKPOINT_DIR = core_manager.CHECKPOINT_DIR





[2025-02-18 15:17:55,822] INFO - 配置系统初始化完成
[2025-02-18 15:17:55,823] INFO - 配置系统测试通过
[2025-02-18 15:17:55,824] INFO - 核心管理器初始化完成


In [2]:
#2 Utility Functions / 工具函数模块
import os
import gc
import re
import time
import psutil
import logging
import tensorflow as tf
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, Tuple
import shutil

# 获取logger
logger = logging.getLogger(__name__)

class DateUtils:
    """日期处理工具类"""
    
    @staticmethod
    def parse_issue(issue_str: str) -> Tuple[str, int]:
        """解析期号字符串"""
        match = re.match(r"(\d{8})-(\d{4})", issue_str)
        if not match:
            raise ValueError("无效的期号格式")
        return match.group(1), int(match.group(2))
    
    @staticmethod
    def get_next_issue(current_issue: str) -> str:
        """获取下一期号"""
        date_str, period = DateUtils.parse_issue(current_issue)
        date = datetime.strptime(date_str, "%Y%m%d")
        
        if period == 1440:
            new_date = date + timedelta(days=1)
            new_period = 1
        else:
            new_date = date
            new_period = period + 1
        
        return f"{new_date.strftime('%Y%m%d')}-{new_period:04d}"

class MemoryManager:
    """内存管理工具类"""
    
    def __init__(self, 
                warning_threshold_mb: int = 8000,
                critical_threshold_mb: int = 10000,
                cleanup_interval: int = 300,
                full_cleanup_interval: int = 14400):
        """
        初始化内存管理器
        Args:
            warning_threshold_mb: 警告阈值(MB)
            critical_threshold_mb: 临界阈值(MB)
            cleanup_interval: 常规清理间隔(秒)
            full_cleanup_interval: 全面清理间隔(秒)
        """
        self.warning_threshold = warning_threshold_mb * 1024 * 1024  # 转换为字节
        self.critical_threshold = critical_threshold_mb * 1024 * 1024
        self.cleanup_interval = cleanup_interval
        self.full_cleanup_interval = full_cleanup_interval
        
        self.last_cleanup_time = time.time()
        self.last_full_cleanup_time = time.time()
        
        logger.info(f"内存管理器初始化完成 - 警告阈值:{warning_threshold_mb}MB, 临界阈值:{critical_threshold_mb}MB")

    def check_memory_status(self) -> bool:
        """
        检查内存状态并在必要时执行清理
        Returns:
            bool: 内存状态是否正常
        """
        try:
            current_usage = self.get_memory_usage()
            current_time = time.time()

            # 检查是否需要执行清理
            if current_time - self.last_cleanup_time > self.cleanup_interval:
                self._regular_cleanup()
                self.last_cleanup_time = current_time

            if current_time - self.last_full_cleanup_time > self.full_cleanup_interval:
                self._full_cleanup()
                self.last_full_cleanup_time = current_time

            # 检查内存使用是否超过阈值
            if current_usage > self.critical_threshold:
                logger.warning(f"内存使用超过临界值: {current_usage/1024/1024:.1f}MB")
                self._emergency_cleanup()
                return False
            elif current_usage > self.warning_threshold:
                logger.warning(f"内存使用超过警告值: {current_usage/1024/1024:.1f}MB")
                self._optimize_memory()

            return True

        except Exception as e:
            logger.error(f"检查内存状态时出错: {str(e)}")
            return False

    def get_memory_usage(self) -> int:
        """
        获取当前进程的内存使用量(字节)
        Returns:
            int: 内存使用量(字节)
        """
        try:
            process = psutil.Process(os.getpid())
            return process.memory_info().rss
        except Exception as e:
            logger.error(f"获取内存使用量时出错: {str(e)}")
            return 0

    def get_memory_info(self) -> Dict[str, Any]:
        """
        获取详细的内存使用信息
        Returns:
            Dict: 内存使用信息
        """
        try:
            memory = psutil.virtual_memory()
            process = psutil.Process(os.getpid())
            
            return {
                'total': memory.total,
                'available': memory.available,
                'used': memory.used,
                'free': memory.free,
                'percent': memory.percent,
                'process_usage': process.memory_info().rss,
                'process_percent': process.memory_percent()
            }
        except Exception as e:
            logger.error(f"获取内存信息时出错: {str(e)}")
            return {}

    def _regular_cleanup(self):
        """执行常规清理"""
        try:
            # 1. 清理Python垃圾
            gc.collect()
            
            # 2. 清理TF会话
            tf.keras.backend.clear_session()
            
            # 3. 清理不用的变量
            for name in list(globals().keys()):
                if name.startswith('_temp_'):
                    del globals()[name]
                    
            logger.info("完成常规内存清理")
            
        except Exception as e:
            logger.error(f"常规清理时出错: {str(e)}")

    def _full_cleanup(self):
        """执行全面清理"""
        try:
            # 1. 执行常规清理
            self._regular_cleanup()
            
            # 2. 重置TensorFlow状态
            gpus = tf.config.list_physical_devices('GPU')
            for gpu in gpus:
                tf.config.experimental.reset_memory_stats(gpu.name)
            
            # 3. 清理模型缓存
            self._cleanup_model_cache()
            
            logger.info("完成全面内存清理")
            
        except Exception as e:
            logger.error(f"全面清理时出错: {str(e)}")

    def _emergency_cleanup(self):
        """执行紧急清理"""
        try:
            # 三级清理策略
            gc.collect(2)  # 强制回收老年代内存
            tf.keras.backend.clear_session()
            
            # 释放多GPU内存
            for gpu in tf.config.list_physical_devices('GPU'):
                try:
                    tf.config.experimental.reset_memory_stats(gpu.name)
                except RuntimeError as e:
                    logger.warning(f"GPU内存重置失败: {str(e)}")

            # 清理临时文件缓存
            self._clean_temp_files()
            
            logger.warning("执行紧急内存清理")
            
        except RuntimeError as e:
            logger.error(f"运行时错误: {str(e)}")
        except MemoryError as e:
            logger.critical("内存严重不足，无法完成清理！")
        except IOError as e:
            logger.error(f"文件清理失败: {str(e)}")
        except Exception as e:
            logger.error(f"未预期的清理错误: {str(e)}", exc_info=True)

    def _clean_temp_files(self) -> None:
        """清理临时文件"""
        try:
            temp_dir = os.path.join(core_manager.BASE_DIR, 'temp')
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
                os.makedirs(temp_dir, exist_ok=True)
        except Exception as e:
            logger.warning(f"临时文件清理失败: {str(e)}")

    def _optimize_memory(self):
        """优化内存使用"""
        try:
            # 1. 检查并清理大对象
            for obj in gc.get_objects():
                if hasattr(obj, 'nbytes') and getattr(obj, 'nbytes', 0) > 1e8:  # >100MB
                    del obj
            
            # 2. 执行垃圾回收
            gc.collect()
            
            logger.info("完成内存优化")
            
        except Exception as e:
            logger.error(f"内存优化时出错: {str(e)}")

    def _cleanup_model_cache(self):
        """清理模型缓存"""
        try:
            # 清理Keras后端缓存
            tf.keras.backend.clear_session()
            
            # 清理模型检查点文件
            checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints')
            if os.path.exists(checkpoint_dir):
                for item in os.listdir(checkpoint_dir):
                    if item.endswith('.temp'):
                        os.remove(os.path.join(checkpoint_dir, item))
                        
            # 释放TensorFlow占用的缓存
            gpus = tf.config.list_physical_devices('GPU')
            for gpu in gpus:
                tf.config.experimental.reset_memory_stats(gpu.name)
            
            logger.info("完成模型缓存清理")
            
        except Exception as e:
            logger.error(f"清理模型缓存时出错: {str(e)}")

    def optimize_for_large_data(self):
        """针对大样本的优化策略"""
        # 新增大样本优化策略
        self.enable_memmap = True  # 启用内存映射
        self.chunk_size = 10000    # 分块加载
        tf.keras.backend.set_floatx('float16')  # 压缩精度
        logger.info("已启用大数据优化策略")

    def optimize_for_hardware(self) -> bool:
        """硬件定制优化"""
        try:
            # 1. 限制TensorFlow内存使用
            gpus = tf.config.list_physical_devices('GPU')
            if gpus:
                tf.config.set_logical_device_configuration(
                    gpus[0],
                    [tf.config.LogicalDeviceConfiguration(memory_limit=1536)]
                )
            
            # 2. 配置CPU并行线程
            tf.config.threading.set_intra_op_parallelism_threads(6)
            tf.config.threading.set_inter_op_parallelism_threads(4)
            
            # 3. 启用内存映射
            self.enable_memmap = True
            logger.info("已完成硬件优化配置")
            return True
        except RuntimeError as e:
            logger.error(f"运行时配置错误: {str(e)}")
            return False
        except ValueError as e:
            logger.error(f"无效的配置参数: {str(e)}")
            return False

# 创建全局实例
date_utils = DateUtils()
memory_manager = MemoryManager()


In [3]:
#3 System Monitor / 系统监控模块
import psutil
import logging
import threading
import time
import gc
import shutil
import tensorflow as tf
import subprocess
import numpy as np
import json
import os
from collections import deque
from datetime import datetime

# 获取logger实例
logger = logging.getLogger(__name__)

# 从cell1移入的日志处理器类
class DailyRotatingFileHandler(logging.FileHandler):
    """每日自动分文件的日志处理器"""
    def __init__(self, base_dir, prefix='log', max_bytes=50*1024*1024):
        self.base_dir = base_dir 
        self.prefix = prefix
        self.max_bytes = max_bytes
        self.current_date = None
        self.current_file = None
        self.current_size = 0
        self.file_count = 1
        
        os.makedirs(base_dir, exist_ok=True)
        self._init_file()
        super().__init__(self.current_file, mode='a', encoding='utf-8')
        
    def _init_file(self):
        """初始化日志文件"""
        self.current_file = self._get_file_path()
        self.current_date = datetime.now().strftime('%Y%m%d')
        if os.path.exists(self.current_file):
            self.current_size = os.path.getsize(self.current_file)
        else:
            self.current_size = 0
            
    def _get_file_path(self):
        """获取当前日志文件路径"""
        today = datetime.now().strftime('%Y%m%d')
        if self.current_size >= self.max_bytes:
            self.file_count += 1
            return os.path.join(self.base_dir, f'{self.prefix}_{today}_{self.file_count}.log')
        elif today != self.current_date:
            self.file_count = 1
            self.current_date = today
            return os.path.join(self.base_dir, f'{self.prefix}_{today}.log')
        return self.current_file
        
    def emit(self, record):
        """重写emit方法,在写入日志前检查文件状态"""
        try:
            new_file = self._get_file_path()
            if new_file != self.current_file:
                if self.stream:
                    self.stream.close()
                self.current_file = new_file
                self.baseFilename = new_file
                self.current_size = 0
                self.stream = self._open()
            
            msg = self.format(record) + '\n'
            self.stream.write(msg)
            self.stream.flush()
            self.current_size += len(msg.encode('utf-8'))
            
        except Exception as e:
            self.handleError(record)

class ProgressHandler(logging.Handler):
    """进度条处理器"""
    def __init__(self):
        super().__init__()
        self.progress = 0
        
    def emit(self, record):
        if hasattr(record, 'progress'):
            print('\r' + ' ' * 80, end='\r')  # 清除当前行
            progress = int(record.progress * 50)
            print(f'\rTraining Progress: [{"="*progress}{" "*(50-progress)}] {record.progress*100:.1f}%', end='')

class CustomFormatter(logging.Formatter):
    """自定义日志格式化器"""
    def __init__(self):
        super().__init__()
        self.formatters = {
            logging.DEBUG: logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            ),
            logging.INFO: logging.Formatter(
                '%(asctime)s - %(levelname)s - %(message)s'
            ),
            logging.WARNING: logging.Formatter(
                '%(asctime)s - %(levelname)s - WARNING: %(message)s'
            ),
            logging.ERROR: logging.Formatter(
                '%(asctime)s - %(levelname)s - ERROR: %(message)s\n%(pathname)s:%(lineno)d'
            ),
            logging.CRITICAL: logging.Formatter(
                '%(asctime)s - %(levelname)s - CRITICAL: %(message)s\n%(pathname)s:%(lineno)d\n%(exc_info)s'
            )
        }
    
    def format(self, record):
        formatter = self.formatters.get(record.levelno)
        return formatter.format(record)

class ResourceMonitor:
    """资源监控器 - 负责监控CPU、内存、GPU等资源使用情况"""
    def __init__(self, window_size=100, check_interval=5):
        self.window_size = window_size
        self.check_interval = check_interval
        self.lock = threading.Lock()
        
        # 监控指标存储
        self.metrics = {
            'cpu_usage': deque(maxlen=window_size),
            'memory_usage': deque(maxlen=window_size),
            'disk_usage': deque(maxlen=window_size),
            'gpu_usage': None,
            'gpu_memory': None
        }
        
        # 添加缺失的警报阈值初始化
        self.thresholds = {
            'cpu_usage': 90,    # CPU使用率超过90%
            'memory_usage': 90,  # 内存使用率超过90%
            'disk_usage': 90,    # 磁盘使用率超过90%
            'gpu_usage': 90,     # GPU使用率超过90%
            'gpu_memory': 90     # GPU内存使用率超过90%
        }
        
        # 添加警报历史
        self.alerts = []
        
        # 确保基础属性初始化
        self._memory_usage = 0.0
        self.cpu_usage = 0.0
        self.gpu_usage = 0.0

        # 添加线程控制
        self._running = False
        self._thread = threading.Thread(target=self._monitor_loop)

        # 确保初始化时收集初始指标
        self._monitor_loop()  # 添加初始数据收集
        self.start()  # 启动监控线程

    def _monitor_loop(self):
        """监控循环的完整实现"""
        while self._running:
            try:
                self._collect_metrics()
                self._check_alerts()
                time.sleep(self.check_interval)
            except Exception as e:
                logger.error(f"资源监控循环出错: {str(e)}")
                
    def _check_alerts(self):
        """完整的警报检查逻辑"""
        with self.lock:
            for metric, values in self.metrics.items():
                if not values:
                    continue
                current = values[-1]
                if current > self.thresholds[metric]:
                    alert = {
                        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                        'metric': metric,
                        'value': current,
                        'threshold': self.thresholds[metric]
                    }
                    self.alerts.append(alert)
                    logger.warning(f"资源警报: {metric} = {current}% (阈值: {self.thresholds[metric]}%)")

    def _collect_metrics(self):
        """收集资源指标"""
        try:
            with self.lock:
                # CPU使用率
                self.cpu_usage = psutil.cpu_percent()
                self.metrics['cpu_usage'].append(self.cpu_usage)
                
                # 内存使用率
                mem = psutil.virtual_memory()
                self._memory_usage = mem.percent
                self.metrics['memory_usage'].append(self._memory_usage)
                
                # 磁盘使用率
                disk = psutil.disk_usage('/')
                self.metrics['disk_usage'].append(disk.percent)
                
        except Exception as e:
            logger.error(f"收集资源指标时出错: {str(e)}")

    def get_memory_usage(self):
        """获取当前内存使用率"""
        return self._memory_usage

    def start(self):
        """启动资源监控"""
        if not self._running:
            self._running = True
            self._thread.start()
            logger.info("资源监控已启动")
    
    def stop(self):
        """停止资源监控"""
        self._running = False
        if self._thread.is_alive():
            self._thread.join(timeout=5)
        logger.info("资源监控已停止")

class PerformanceMonitor:
    """性能监控器 - 负责收集和分析性能指标"""
    def __init__(self, save_dir='logs/performance', window_size=1000):
        self.save_dir = save_dir
        self.window_size = window_size
        self.metrics = {
            'cpu_usage': deque(maxlen=window_size),
            'memory_usage': deque(maxlen=window_size),
            'gpu_usage': deque(maxlen=window_size),
            'loss': deque(maxlen=window_size),
            'accuracy': deque(maxlen=window_size)
        }
        self._running = False
        self._thread = threading.Thread(target=self._monitor_loop)
        
        # 添加内存监控器初始化
        from cell2_utils import memory_manager
        self.memory_monitor = memory_manager

    def _monitor_loop(self):
        """性能监控主循环"""
        while self._running:
            try:
                # 收集CPU使用率
                cpu_usage = psutil.cpu_percent()
                self.metrics['cpu_usage'].append(cpu_usage)
                
                # 收集内存使用率
                mem_usage = psutil.virtual_memory().percent
                self.metrics['memory_usage'].append(mem_usage)
                
                # 收集GPU使用率（需要安装GPU监控库）
                gpu_usage = 0  # 这里需要根据实际GPU监控库实现
                self.metrics['gpu_usage'].append(gpu_usage)
                
                # 保存指标到文件
                self._save_metrics()
                
                time.sleep(1)  # 每秒收集一次
                
            except Exception as e:
                logger.error(f"性能监控出错: {str(e)}")
                break

    def start(self):
        """启动监控线程"""
        if not self._running:
            self._running = True
            self._thread.start()
            logger.info("性能监控已启动")

    def stop(self):
        """停止监控线程"""
        self._running = False
        if self._thread.is_alive():
            self._thread.join(timeout=5)
        logger.info("性能监控已停止")

    def _save_metrics(self):
        """保存指标到文件"""
        try:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = os.path.join(self.save_dir, f"metrics_{timestamp}.json")
            
            metrics_to_save = {
                k: list(v) for k, v in self.metrics.items()
            }
            
            with open(filename, 'w') as f:
                json.dump(metrics_to_save, f)
                
        except Exception as e:
            logger.error(f"保存性能指标失败: {str(e)}")

class SystemManager:
    """系统管理器 - 单例模式"""
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not hasattr(self, 'initialized'):
            # 初始化各个监控器
            self.resource_monitor = ResourceMonitor()
            self.performance_monitor = PerformanceMonitor(save_dir='logs/performance')
            
            # 初始化其他组件
            self.memory_monitor = MemoryMonitor()
            self.system_cleaner = SystemCleaner()
            
            # 添加缺失的阈值初始化
            self.memory_warning_threshold = 0.85  # 85%内存使用率警告
            self.memory_critical_threshold = 0.95  # 95%内存使用率危险
            self.cpu_warning_threshold = 0.85  # 85% CPU使用率警告
            
            # 添加缺失的状态初始化
            self.last_cleanup_time = time.time()
            self.cleanup_interval = 300  # 5分钟执行一次清理
            self.compatibility_checked = False
            
            # 添加系统状态监控配置
            self.status_check_interval = 60  # 60秒检查一次系统状态
            self.last_status_check = time.time()
            self.status_history = deque(maxlen=1000)  # 存储最近1000次状态检查结果
            
            # 添加资源监控配置
            self.resource_warning_count = 0
            self.max_warning_threshold = 5  # 最大警告次数，超过后采取行动
            
            # 添加日志配置
            self.log_dir = 'logs/system'
            os.makedirs(self.log_dir, exist_ok=True)
            
            # 添加系统恢复机制
            self.recovery_attempts = 0
            self.max_recovery_attempts = 3
            
            # 启动基础监控
            self.start_basic_monitoring()
            
            self.initialized = True
            
            logger.info("系统管理器初始化完成")

    def check_system_compatibility(self):
        """检查系统兼容性"""
        try:
            compatibility = {
                'python_version': sys.version,
                'tensorflow_version': tf.__version__,
                'gpu_available': bool(tf.config.list_physical_devices('GPU')),
                'memory_sufficient': psutil.virtual_memory().total >= 8 * (1024 ** 3),  # 最少8GB内存
                'disk_sufficient': psutil.disk_usage('/').free >= 10 * (1024 ** 3)  # 最少10GB可用空间
            }
            
            self.compatibility_checked = True
            return compatibility
        except Exception as e:
            logger.error(f"检查系统兼容性时出错: {str(e)}")
            return None

    def get_system_status(self):
        """获取完整的系统状态报告"""
        try:
            status = {
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                'system_metrics': self.get_system_metrics(),
                'performance_metrics': self.performance_monitor.get_summary(),
                'memory_status': self.memory_monitor.check_memory(),
                'compatibility': self.check_system_compatibility() if not self.compatibility_checked else None
            }
            return status
        except Exception as e:
            logger.error(f"获取系统状态报告时出错: {str(e)}")
            return None

    def check_system_health(self):
        """检查系统健康状态"""
        try:
            # 检查内存使用
            memory_info = self.memory_monitor.check_memory()
            if not memory_info['healthy']:
                self.handle_memory_warning(memory_info)
            
            # 检查系统状态
            system_status = self.get_system_metrics()
            if any(status > self.cpu_warning_threshold for status in [
                system_status.get('cpu', 0),
                system_status.get('memory', 0),
                system_status.get('disk', 0)
            ]):
                self.handle_system_warning(system_status)
            
            # 定期清理
            self._perform_periodic_cleanup()
            
            return memory_info['healthy'] and all(
                status < self.cpu_warning_threshold for status in [
                    system_status.get('cpu', 0),
                    system_status.get('memory', 0),
                    system_status.get('disk', 0)
                ]
            )
            
        except Exception as e:
            logger.error(f"检查系统健康状态时出错: {str(e)}")
            return False

    def handle_memory_warning(self, memory_info):
        """处理内存警告"""
        try:
            if memory_info['usage_percent'] > self.memory_critical_threshold:
                logger.critical("内存使用率超过临界值，执行紧急清理")
                self.system_cleaner._emergency_cleanup()
            elif memory_info['usage_percent'] > self.memory_warning_threshold:
                logger.warning("内存使用率较高，执行常规清理")
                self.system_cleaner._regular_cleanup()
                
        except Exception as e:
            logger.error(f"处理内存警告时出错: {str(e)}")

    def handle_system_warning(self, status):
        """处理系统警告"""
        try:
            if status.get('cpu', 0) > self.cpu_warning_threshold:
                logger.warning(f"CPU使用率过高: {status['cpu']}%")
            if status.get('memory', 0) > self.memory_warning_threshold:
                logger.warning(f"内存使用率过高: {status['memory']}%")
            if status.get('disk', 0) > 90:  # 磁盘空间阈值固定为90%
                logger.warning(f"磁盘使用率过高: {status['disk']}%")
                
        except Exception as e:
            logger.error(f"处理系统警告时出错: {str(e)}")

    def check_dependencies(self):
        """系统启动时自动调用"""
        try:
            from ..notebooks.Untitled import check_requirements
            need_install = check_requirements(requirements)
            if need_install:
                self.install_dependencies(need_install)
        except Exception as e:
            logger.error(f"检查依赖时出错: {str(e)}")

    def install_dependencies(self, packages):
        """受控安装方法"""
        try:
            logger.info(f"自动安装依赖: {packages}")
            # 这里可以添加实际的包安装逻辑
            # pip.main(['install'] + packages)
        except Exception as e:
            logger.error(f"安装依赖时出错: {str(e)}")

    def get_system_metrics(self):
        """获取系统指标"""
        try:
            return {
                'memory': self.memory_monitor.check_memory(),
                'cpu': psutil.cpu_percent(),
                'disk': psutil.disk_usage('/').percent,
                'gpu': self._get_gpu_metrics()
            }
        except Exception as e:
            logger.error(f"获取系统指标时出错: {str(e)}")
            return {}
            
    def _get_gpu_metrics(self):
        """获取GPU指标"""
        try:
            if not tf.config.list_physical_devices('GPU'):
                return None
                
            result = subprocess.check_output(
                ['nvidia-smi', '--query-gpu=memory.used,memory.total,temperature.gpu', 
                 '--format=csv,nounits,noheader'],
                encoding='utf-8'
            )
            used, total, temp = map(int, result.strip().split(','))
            return {
                'memory_used': used,
                'memory_total': total,
                'temperature': temp,
                'utilization': used / total * 100
            }
        except Exception as e:
            logger.error(f"获取GPU指标时出错: {str(e)}")
            return None

    def _perform_periodic_cleanup(self):
        """执行定期清理"""
        current_time = time.time()
        if current_time - self.last_cleanup_time > self.cleanup_interval:
            try:
                self.system_cleaner.check_and_cleanup()
                self.last_cleanup_time = current_time
            except Exception as e:
                logger.error(f"执行定期清理时出错: {str(e)}")

    def analyze_system_performance(self):
        """分析系统整体性能"""
        try:
            perf_summary = self.performance_monitor.get_summary()
            sys_metrics = self.get_system_metrics()
            
            analysis = {
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                'performance_status': perf_summary,
                'system_status': sys_metrics,
                'health_check': self.check_system_health(),
                'recommendations': self._generate_recommendations(perf_summary, sys_metrics)
            }
            
            return analysis
        except Exception as e:
            logger.error(f"分析系统性能时出错: {str(e)}")
            return None

    def _generate_recommendations(self, perf_summary, sys_metrics):
        """生成系统优化建议"""
        recommendations = []
        
        # 检查内存使用
        if sys_metrics.get('memory', 0) > self.memory_warning_threshold:
            recommendations.append("建议清理内存或增加内存容量")
            
        # 检查CPU使用
        if sys_metrics.get('cpu', 0) > self.cpu_warning_threshold:
            recommendations.append("建议优化计算密集型任务或增加CPU资源")
            
        # 检查GPU使用
        gpu_metrics = sys_metrics.get('gpu')
        if gpu_metrics and gpu_metrics.get('utilization', 0) > 90:
            recommendations.append("建议优化GPU使用效率或考虑增加GPU资源")
            
        return recommendations

    def _handle_resource_warnings(self):
        """处理资源警告的升级机制"""
        self.resource_warning_count += 1
        if self.resource_warning_count >= self.max_warning_threshold:
            logger.critical("资源警告次数过多，执行紧急清理")
            self.system_cleaner._emergency_cleanup()
            self.resource_warning_count = 0

    def _perform_system_recovery(self):
        """系统恢复机制"""
        try:
            if self.recovery_attempts >= self.max_recovery_attempts:
                logger.critical("系统恢复次数超过限制，需要人工干预")
                return False
                
            logger.warning(f"尝试系统恢复，第{self.recovery_attempts + 1}次")
            
            # 执行恢复步骤
            self.system_cleaner._emergency_cleanup()
            tf.keras.backend.clear_session()
            gc.collect()
            
            self.recovery_attempts += 1
            return True
            
        except Exception as e:
            logger.error(f"执行系统恢复时出错: {str(e)}")
            return False
            
    def reset_recovery_count(self):
        """重置恢复计数"""
        self.recovery_attempts = 0

    def get_system_summary(self):
        """获取系统综合报告"""
        try:
            return {
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                'system_health': self.check_system_health(),
                'performance_metrics': self.performance_monitor.get_summary(),
                'resource_metrics': self.resource_monitor.metrics,
                'memory_status': self.memory_monitor.check_memory(),
                'recovery_attempts': self.recovery_attempts,
                'last_cleanup': datetime.fromtimestamp(self.last_cleanup_time).strftime('%Y-%m-%d %H:%M:%S'),
                'warnings_count': self.resource_warning_count
            }
        except Exception as e:
            logger.error(f"获取系统综合报告时出错: {str(e)}")
            return {}

    def _test_config_system(self):
        """测试配置系统完整性"""
        try:
            required_keys = ['host', 'port', 'user', 'password'] 
            assert all(key in self.DB_CONFIG for key in required_keys), "数据库配置缺失必要参数"
            logger.info("配置系统测试通过")
        except AssertionError as e:
            logger.error(f"配置系统测试失败: {str(e)}")
            raise

    def start_basic_monitoring(self):
        """启动基础监控"""
        try:
            # 启动资源监控
            if hasattr(self.resource_monitor, 'start'):
                self.resource_monitor.start()
            
            # 启动性能监控
            if hasattr(self.performance_monitor, 'start'):
                self.performance_monitor.start()
            
            # 启动内存监控
            if hasattr(self.memory_monitor, 'start'):
                self.memory_monitor.start()
            
            logger.info("基础监控已启动")
            
        except Exception as e:
            logger.error(f"启动基础监控失败: {str(e)}")
            
    def stop_basic_monitoring(self):
        """停止基础监控"""
        try:
            # 停止资源监控
            if hasattr(self.resource_monitor, 'stop'):
                self.resource_monitor.stop()
            
            # 停止性能监控
            if hasattr(self.performance_monitor, 'stop'):
                self.performance_monitor.stop()
            
            # 停止内存监控
            if hasattr(self.memory_monitor, 'stop'):
                self.memory_monitor.stop()
            
            logger.info("基础监控已停止")
            
        except Exception as e:
            logger.error(f"停止基础监控失败: {str(e)}")

# 从cell1移入的LogDisplayManager
class LogDisplayManager:
    """日志显示管理器"""
    def __init__(self, max_lines=10):
        self.max_lines = max_lines
        self.log_buffer = []
        self.progress_bars = {i: 0.0 for i in range(6)}
        self._clear_output()
    
    def _clear_output(self):
        """清空输出"""
        clear_output(wait=True)
    
    def _display_logs(self):
        """显示日志"""
        self._clear_output()
        
        start_idx = max(0, len(self.log_buffer) - self.max_lines)
        for log in self.log_buffer[start_idx:]:
            if log.strip():
                print(log)
        
        print('-' * 80)
        
        for model_idx, progress in self.progress_bars.items():
            bar_length = 50
            filled = int(progress * bar_length)
            bar = f"Model {model_idx + 1}: [{'='*filled}{' '*(bar_length-filled)}] {progress*100:.1f}%"
            print(bar)

# 辅助类定义
class MemoryMonitor:
    """内存监控类"""
    def check_memory(self):
        """检查内存状态"""
        try:
            memory = psutil.virtual_memory()
            memory_info = {
                'total': memory.total,
                'available': memory.available,
                'used': memory.used,
                'usage_percent': memory.percent,
                'healthy': memory.percent < 90,
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            }
            return memory_info
        except Exception as e:
            logger.error(f"检查内存状态时出错: {str(e)}")
            return {'healthy': False, 'error': str(e)}

    def get_memory_usage(self):
        """获取当前内存使用情况"""
        try:
            return psutil.virtual_memory().percent
        except Exception as e:
            logger.error(f"获取内存使用情况时出错: {str(e)}")
            return None

class SystemCleaner:
    """系统清理类"""
    def __init__(self):
        self.temp_dirs = ['/tmp', os.path.expanduser('~/.cache')]
        self.last_cleanup = time.time()
        self.cleanup_interval = 3600  # 1小时

    def check_and_cleanup(self):
        """检查并执行清理"""
        try:
            current_time = time.time()
            if current_time - self.last_cleanup >= self.cleanup_interval:
                self._regular_cleanup()
                self.last_cleanup = current_time
        except Exception as e:
            logger.error(f"执行定期清理时出错: {str(e)}")

    def _regular_cleanup(self):
        """常规清理"""
        try:
            # 清理Python缓存
            gc.collect()
            
            # 清理TensorFlow会话
            tf.keras.backend.clear_session()
            
            # 清理临时文件
            self._cleanup_temp_files()
            
            logger.info("完成常规清理")
        except Exception as e:
            logger.error(f"执行常规清理时出错: {str(e)}")

    def _emergency_cleanup(self):
        """紧急清理"""
        try:
            # 强制垃圾回收
            gc.collect()
            
            # 清理TensorFlow会话
            tf.keras.backend.clear_session()
            
            # 清理临时文件
            self._cleanup_temp_files(emergency=True)
            
            logger.warning("完成紧急清理")
        except Exception as e:
            logger.error(f"执行紧急清理时出错: {str(e)}")

    def _cleanup_temp_files(self, emergency=False):
        """清理临时文件
        Args:
            emergency: 是否为紧急清理
        """
        for temp_dir in self.temp_dirs:
            if os.path.exists(temp_dir):
                try:
                    if emergency:
                        shutil.rmtree(temp_dir)
                        os.makedirs(temp_dir)
                    else:
                        # 只删除超过1天的文件
                        for root, dirs, files in os.walk(temp_dir):
                            for f in files:
                                path = os.path.join(root, f)
                                if time.time() - os.path.getmtime(path) > 86400:
                                    os.remove(path)
                except Exception as e:
                    logger.error(f"清理临时文件时出错: {str(e)}")

# 创建资源监控器实例（确保在类定义之后）
resource_monitor = ResourceMonitor()
# 创建性能监控器实例
performance_monitor = PerformanceMonitor()

# 保持原有全局实例
monitor_system = SystemManager()

def init_monitoring():
    """初始化监控系统"""
    monitor_system.resource_monitor.start()
    monitor_system.performance_monitor.start()
    logger.info("监控系统已启动")

def stop_monitoring():
    """停止监控系统"""
    monitor_system.resource_monitor.stop()
    monitor_system.performance_monitor.stop()
    logger.info("监控系统已停止")

class TrainingMonitor(PerformanceMonitor):
    """训练专用监控"""
    def __init__(self):
        super().__init__()
        self.metrics.update({
            'model_loss': [deque(maxlen=1000) for _ in range(6)],
            'param_values': deque(maxlen=1000)
        })
        
    def log_training_metrics(self, model_idx, loss, params):
        """记录训练指标"""
        self.metrics['model_loss'][model_idx].append(loss)
        self.metrics['param_values'].append(params)


In [4]:
#4 Data Management System / 数据管理系统
import os
import numpy as np
import pandas as pd
import logging
import pymysql
import threading
from datetime import datetime, timedelta
from collections import deque, OrderedDict
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
from sqlalchemy.pool import QueuePool
from pymysql.cursors import DictCursor
from cell1_core import core_manager
import pickle
import time

# 获取logger实例
logger = logging.getLogger(__name__)

class DataManager:
    """统一数据管理器 - 整合数据库管理和数据管道功能"""
    _instance = None
    _lock = threading.Lock()
    
    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not core_manager.initialized:
            raise RuntimeError("核心配置未初始化！请先运行cell1_core.py")
        
        # 获取配置
        try:
            # 直接从SYSTEM_CONFIG中获取DATA_CONFIG
            data_config = core_manager.SYSTEM_CONFIG['DATA_CONFIG']
        except (KeyError, AttributeError):
            logger.error("数据配置缺失，使用默认值")
            data_config = {
                'cache_size': 10000,
                'min_sequence_length': 14400,
                'normalize_range': (-1, 1)
            }
        
        # 设置数据管理器属性
        self.cache_size = data_config['cache_size']
        self.normalize_range = data_config['normalize_range']
        
        # 初始化数据库配置
        self.DB_CONFIG = core_manager.DB_CONFIG
        
        # 初始化连接池
        self.pool = self._create_pool()
        
        # 初始化数据处理组件
        self.data_pool = DataPool()
        self.data_processor = DataProcessor()
        self.data_validator = DataValidator()
        self.time_feature_extractor = TimeFeatureExtractor()
        
        # 初始化数据缓存
        self.query_cache = {}
        self.cache_timeout = 300  # 5分钟缓存超时
        
        # 初始化数据配置
        self.sequence_length = data_config['min_sequence_length']
        
        # 初始化目录
        self._init_directories()
        
        # 修改数据库连接初始化
        self.connection = None
        self._init_connection()
        
        logger.info("数据管理器初始化完成")

    def _init_db_config(self):
        """初始化数据库配置"""
        db_config = config_instance.get_db_config()
        db_config.update({
            'database': 'admin_data',
            'charset': 'utf8mb4'
        })
        return db_config

    def _init_directories(self):
        """初始化目录结构"""
        self.comparison_dir = os.path.join(core_manager.BASE_DIR, 'comparison')
        os.makedirs(self.comparison_dir, exist_ok=True)
        self.issue_file = os.path.join(self.comparison_dir, 'issue_number.txt')

    def _create_pool(self):
        """创建数据库连接池"""
        try:
            pool = QueuePool(
                creator=lambda: pymysql.connect(**self.DB_CONFIG),
                pool_size=10,
                max_overflow=20,
                timeout=30
            )
            logger.info("数据库连接池创建成功")
            return pool
        except Exception as e:
            logger.error(f"创建数据库连接池失败: {str(e)}")
            raise

    def execute_query(self, query, params=None, retry=3, use_cache=False):
        """执行SQL查询（增加数据量检查）"""
        for attempt in range(retry):
            try:
                # 检查缓存
                if use_cache:
                    cache_key = f"{query}_{str(params)}"
                    cached_result = self._get_from_cache(cache_key)
                    if cached_result is not None:
                        logger.info(f"成功获取数据 {len(cached_result)} 条")
                        return cached_result
                
                # 获取连接和游标
                connection = self.pool.connect()
                try:
                    cursor = connection.cursor(DictCursor)
                    cursor.execute(query, params)
                    result = cursor.fetchall()
                    
                    if use_cache:
                        self._update_cache(cache_key, result)
                    logger.info(f"成功获取数据 {len(result)} 条")
                    return result
                finally:
                    cursor.close()
                    connection.close()
                
            except pymysql.OperationalError as e:
                if attempt < retry - 1:
                    logger.warning(f"数据库连接失败，尝试重连({attempt+1}/{retry})")
                    self.pool = self._create_pool()  # 重建连接池
                    time.sleep(2 ** attempt)
                    continue
                raise
        return None

    def execute_batch(self, query, params_list):
        """批量执行查询"""
        try:
            connection = self.pool.connect()
            cursor = connection.cursor()
            
            try:
                cursor.executemany(query, params_list)
                connection.commit()
                return True
            finally:
                cursor.close()
                connection.close()
                
        except Exception as e:
            logger.error(f"批量执行查询失败: {str(e)}")
            return False

    def get_records_by_issue(self, start_issue, limit):
        """按期号范围获取记录"""
        query = f"""
            SELECT * FROM admin_tab 
            WHERE date_period >= %s
            ORDER BY date_period ASC
            LIMIT %s
        """
        return self.execute_query(query, (start_issue, limit))

    def close_all(self):
        """关闭所有数据库连接"""
        if self.pool:
            self.pool.dispose()
            logger.info("已关闭所有数据库连接")

    def get_data_stats(self):
        """获取数据统计信息"""
        try:
            with self.lock:
                return {
                    'total_samples': len(self.data_pool.data),
                    'cache_size': self.data_pool.get_cache_size(),
                    'last_update': self.data_pool.last_update_time,
                    'memory_usage': self._get_memory_usage(),
                    'database_connections': self.pool.size()
                }
        except Exception as e:
            logger.error(f"获取数据统计信息时出错: {str(e)}")
            return None

    def _get_memory_usage(self):
        """获取内存使用情况"""
        try:
            import psutil
            process = psutil.Process()
            return process.memory_info().rss / (1024 * 1024)  # 转换为MB
        except:
            return None

    def get_training_batch(self, batch_size=None):
        """获取训练批次"""
        try:
            batch_size = batch_size or self.batch_size
            with self._lock:
                # 获取原始数据
                data = self.data_pool.get_latest_data()
                # 处理数据
                processed = self.data_processor.process_records(data)
                # 验证数据
                if not self.data_validator.validate(processed):
                    raise ValueError("数据验证失败")
                return processed
        except Exception as e:
            logger.error(f"获取训练批次时出错: {str(e)}")
            return None

    def update_data(self):
        """更新数据"""
        try:
            new_data = self._fetch_new_data()
            if new_data is not None:
                with self._lock:
                    self.data_pool.update_data(new_data)
                logger.info(f"数据更新成功，当前数据量: {len(self.data_pool.data)}")
                return True
            return False
        except Exception as e:
            logger.error(f"更新数据时出错: {str(e)}")
            return False

    def _fetch_new_data(self):
        """获取新数据"""
        try:
            with self.issue_lock:
                # 读取最后一期期号
                with open(self.issue_file, 'r+') as f:
                    last_issue = f.read().strip()
                    
                    # 获取总数据量
                    total = config_instance.SYSTEM_CONFIG['SAMPLE_CONFIG']['total_fetch']()
                    
                    # 构建查询
                    query = f"""
                        SELECT date_period, number 
                        FROM admin_tab 
                        WHERE date_period {'>' if last_issue else ''}= '{last_issue}'
                        ORDER BY date_period 
                        LIMIT {total}
                    """
                    
                    records = self.execute_query(query)
                    
                    # 验证数据连续性
                    if not self._validate_sequence(records):
                        raise ValueError("数据存在断层")
                    
                    # 更新期号文件
                    if records:
                        new_last = records[-1]['date_period']
                        f.seek(0)
                        f.write(new_last)
                        f.truncate()
                    
                    return self._process_numbers(records)
                    
        except Exception as e:
            logger.error(f"获取新数据失败: {str(e)}")
            return None

    def _validate_sequence(self, records):
        """验证数据序列连续性"""
        if not records or len(records) < 2:
            return True
            
        max_gap = config_instance.SYSTEM_CONFIG['max_sequence_gap']
        
        for i in range(1, len(records)):
            current = records[i]['date_period']
            previous = records[i-1]['date_period']
            if not self._is_consecutive_periods(previous, current, max_gap):
                return False
        return True

    def _is_consecutive_periods(self, prev_period, curr_period, max_gap):
        """检查两个期号是否连续"""
        try:
            prev_date, prev_num = prev_period.split('-')
            curr_date, curr_num = curr_period.split('-')
            
            prev_dt = datetime.strptime(prev_date, '%Y%m%d')
            curr_dt = datetime.strptime(curr_date, '%Y%m%d')
            
            if prev_dt == curr_dt:
                return int(curr_num) - int(prev_num) <= max_gap
            
            if curr_dt - prev_dt == timedelta(days=1):
                return int(prev_num) == 1440 and int(curr_num) == 1
                
            return False
            
        except Exception as e:
            logger.error(f"检查期号连续性时出错: {str(e)}")
            return False

    def _process_numbers(self, records):
        """处理数字号码"""
        try:
            processed = []
            for r in records:
                numbers = [int(d) for d in r['number'].zfill(5)]
                processed.append({
                    'date_period': r['date_period'],
                    'numbers': numbers,
                    'time_features': self.time_feature_extractor.extract_features(r['date_period'])
                })
            return processed
        except Exception as e:
            logger.error(f"处理号码时出错: {str(e)}")
            return None

    def _get_from_cache(self, key):
        """从缓存获取数据"""
        if key in self.query_cache:
            timestamp, data = self.query_cache[key]
            if datetime.now() - timestamp < timedelta(seconds(self.cache_timeout)):
                return data
            del self.query_cache[key]
        return None

    def _update_cache(self, key, data):
        """更新缓存"""
        self.query_cache[key] = (datetime.now(), data)

    def clear_cache(self):
        """清空所有数据缓存"""
        with self.lock:  # 使用锁保证线程安全
            self.query_cache.clear()
            self.data_pool.data.clear()
            logger.info("已清空所有数据缓存")

    def get_data_by_date_range(self, start_date, end_date):
        """按日期范围获取数据"""
        try:
            query = """
                SELECT * FROM admin_tab 
                WHERE DATE(SUBSTRING_INDEX(date_period, '-', 1)) 
                BETWEEN %s AND %s
                ORDER BY date_period
            """
            return self.execute_query(query, (start_date, end_date))
        except Exception as e:
            logger.error(f"获取日期范围数据失败: {str(e)}")
            return None

    def check_data_continuity(self, data):
        """检查数据连续性"""
        try:
            if not data or len(data) < 2:
                return True
                
            periods = [d['date_period'] for d in data]
            for i in range(1, len(periods)):
                curr_period = periods[i]
                prev_period = periods[i-1]
                
                if not self._is_consecutive_periods(prev_period, curr_period):
                    logger.warning(f"数据不连续: {prev_period} -> {curr_period}")
                    return False
            return True
            
        except Exception as e:
            logger.error(f"检查数据连续性失败: {str(e)}")
            return False

    def generate_test_batch(self, size=1000):
        """生成测试批次"""
        try:
            with self._lock:
                latest_data = self.data_pool.get_latest_data(size)
                if not latest_data:
                    return None
                    
                test_data = self.data_processor.process_records(latest_data)
                if not self.data_validator.validate(test_data):
                    raise ValueError("测试数据验证失败")
                    
                return test_data
                
        except Exception as e:
            logger.error(f"生成测试批次失败: {str(e)}")
            return None

    def save_cache_to_disk(self, cache_path=None):
        """保存缓存到磁盘"""
        try:
            if cache_path is None:
                cache_path = os.path.join(config_instance.BASE_DIR, 'cache', 'data_cache.pkl')
                
            os.makedirs(os.path.dirname(cache_path), exist_ok=True)
            
            cache_data = {
                'pool_data': self.data_pool.data,
                'query_cache': self.query_cache,
                'last_update': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            }
            
            with open(cache_path, 'wb') as f:
                pickle.dump(cache_data, f)
                
            logger.info(f"缓存已保存到: {cache_path}")
            return True
            
        except Exception as e:
            logger.error(f"保存缓存失败: {str(e)}")
            return False

    def load_training_samples(self, limit=1000):
        """加载训练样本"""
        try:
            # 修正字段名
            query = "SELECT number, date_period FROM admin_tab ORDER BY date_period DESC LIMIT %s"
            raw_data = self.execute_query(query, (limit,))
            
            if not raw_data:
                logger.error("未查询到任何数据")
                return None
                
            processed = self.data_processor.process_training_data(raw_data)
            if processed is None:
                return None
                
            sequences = self._create_sequences(processed, self.sequence_length)
            if sequences is None or len(sequences) == 0:
                logger.error("无法创建有效序列，可能原因：\n"
                             "1. 输入数据长度不足\n"
                             f"2. 序列长度设置过大（当前设置：{self.sequence_length}）\n"
                             "3. 数据预处理失败")
                return None
                
            logger.info(f"成功加载 {len(sequences)} 个训练样本")
            return sequences
            
        except Exception as e:
            logger.error(f"加载训练样本失败: {str(e)}")
            return None

    def _create_sequences(self, data, seq_length):
        """创建时间序列数据"""
        try:
            sequences = []
            for i in range(len(data) - seq_length + 1):  # 修正循环范围
                seq = data[i:i+seq_length]
                sequences.append(seq)
            return np.array(sequences)
        except Exception as e:
            logger.error(f"创建序列失败: {str(e)}")
            return None

    def process_training_data(self, raw_data):
        """处理原始训练数据"""
        try:
            if not raw_data:
                logger.error("原始数据为空")
                return None
                
            processed = []
            for idx, record in enumerate(raw_data):
                try:
                    # 验证必要字段存在
                    if 'number' not in record or 'date_period' not in record:
                        logger.warning(f"记录{idx}缺失必要字段，已跳过")
                        continue
                        
                    # 处理数字格式：将字符串拆分为单个字符
                    number_str = str(record['number']).strip()
                    if len(number_str) != 5:
                        logger.warning(f"记录{idx}号码长度错误: {number_str}")
                        continue
                        
                    numbers = [int(c) for c in number_str if c.isdigit()]
                    if len(numbers) != 5:
                        logger.warning(f"记录{idx}包含非数字字符: {number_str}")
                        continue
                        
                    time_feat = self._parse_time_features(str(record['date_period']))
                    processed.append(numbers + time_feat)
                    
                except Exception as e:
                    logger.warning(f"处理记录{idx}时出错: {str(e)}，已跳过该记录")

            if not processed:
                logger.error("无有效数据可处理")
                return None

            scaler = MinMaxScaler()
            return scaler.fit_transform(processed)
            
        except Exception as e:
            logger.error(f"数据处理失败: {str(e)}")
            return None

    def _parse_time_features(self, date_period):
        """解析时间特征"""
        date_str, period = date_period.split('-')
        dt = datetime.strptime(date_str, "%Y%m%d")
        return [
            dt.hour/24.0,  # 小时归一化
            dt.weekday()/7.0,  # 星期归一化
            int(period)/1440.0  # 期号归一化
        ]

    def check_connection(self):
        """检查数据库连接状态"""
        try:
            result = self.execute_query("SELECT 1", retry=1)
            return bool(result)
        except Exception as e:
            logger.error(f"数据库连接检查失败: {str(e)}")
            return False

    def get_sequence(self, start_issue, end_issue):
        """获取指定期号范围的序列数据"""
        try:
            # 确保连接有效
            if not self.connection or not self.connection.open:
                self._init_connection()
            
            # 使用新的连接执行查询
            with self.connection.cursor() as cursor:
                query = """
                    SELECT number, date_period 
                    FROM admin_tab 
                    WHERE date_period >= %s AND date_period <= %s
                    ORDER BY date_period ASC
                """
                cursor.execute(query, (start_issue, end_issue))
                results = cursor.fetchall()
                
                if not results:
                    logger.warning(f"未找到期号范围 {start_issue} 到 {end_issue} 的数据")
                    return None
                    
                # 转换数据格式
                sequence = []
                for row in results:
                    numbers = [int(n) for n in str(row['number']).zfill(5)]
                    normalized = [(n - 4.5) / 4.5 for n in numbers]
                    sequence.append(normalized)
                    
                sequence = np.array(sequence, dtype=np.float32)
                logger.info(f"获取到序列数据 {len(sequence)} 条")
                return sequence
                
        except pymysql.Error as e:
            logger.error(f"数据库操作失败: {str(e)}")
            self._init_connection()  # 尝试重新连接
            return None
        except Exception as e:
            logger.error(f"获取序列数据失败: {str(e)}")
            return None

    def _init_connection(self):
        """初始化数据库连接"""
        try:
            if self.connection and self.connection.open:
                self.connection.close()
                
            # 重新创建连接
            self.connection = pymysql.connect(
                host=self.DB_CONFIG['host'],
                port=self.DB_CONFIG['port'],
                user=self.DB_CONFIG['user'],
                password=self.DB_CONFIG['password'],
                database=self.DB_CONFIG['database'],
                charset=self.DB_CONFIG['charset'],
                cursorclass=pymysql.cursors.DictCursor,
                autocommit=True,
                connect_timeout=60
            )
            logger.info("数据库连接初始化成功")
        except Exception as e:
            logger.error(f"数据库连接初始化失败: {str(e)}")
            self.connection = None
            raise

class DataPool:
    """数据池 - 负责数据缓存和管理"""
    def __init__(self, max_size=10000):
        self.data = []
        self.cache = OrderedDict()
        self.max_size = max_size
        self.lock = threading.Lock()
        self.last_update_time = None
        
        # 初始化数据缩放器
        self.scaler = MinMaxScaler()
        self.is_scaler_fitted = False
    
    def update_data(self, new_data):
        """更新数据"""
        with self.lock:
            self.data.extend(new_data)
            self.last_update_time = datetime.now()
            
            # 如果还没有拟合scaler，进行拟合
            if not self.is_scaler_fitted and len(self.data) > 0:
                self.scaler.fit(np.array([d['numbers'] for d in self.data]))
                self.is_scaler_fitted = True
    
    def get_latest_data(self, n=1000):
        """获取最新的n条数据"""
        with self.lock:
            return self.data[-n:] if self.data else []
    
    def get_cache_size(self):
        """获取缓存大小"""
        return len(self.cache)

    def clear_cache(self):
        """清理数据缓存"""
        with self.lock:
            self.cache.clear()

    def add_batch(self, batch):
        """添加数据批次"""
        # 添加对齐检查
        aligned = self._align_sequences(batch)
        with self.lock:
            self.data.extend(aligned)
        
    def _align_sequences(self, batch):
        """对齐序列长度"""
        max_len = max(len(item['input']) for item in batch)
        aligned = []
        for item in batch:
            aligned_item = item.copy()
            aligned_item['input'] = np.pad(
                item['input'], 
                (0, max_len - len(item['input'])),
                'constant'
            )
            aligned.append(aligned_item)
        return aligned

    def get_training_data(self, sequence_length):
        """获取训练数据"""
        with self.lock:
            if len(self.data) < sequence_length:
                return None
            data = np.array([d['numbers'] for d in self.data])
            if self.is_scaler_fitted:
                data = self.scaler.transform(data)
            return data

    def get_data_window(self, start_idx, window_size):
        """获取指定窗口的数据"""
        with self.lock:
            if start_idx + window_size > len(self.data):
                return None
            return self.data[start_idx:start_idx + window_size]

    def get_latest_periods(self, n_periods):
        """获取最近n期数据"""
        with self.lock:
            return self.data[-n_periods:] if len(self.data) >= n_periods else None

    def preload_data(self, start_date, end_date):
        """预加载指定日期范围的数据"""
        try:
            query = """
                SELECT * FROM admin_tab 
                WHERE DATE(SUBSTRING_INDEX(date_period, '-', 1)) 
                BETWEEN %s AND %s
                ORDER BY date_period
            """
            records = data_manager.execute_query(query, (start_date, end_date))
            
            if records:
                self.update_data(records)
                return True
            return False
            
        except Exception as e:
            logger.error(f"预加载数据失败: {str(e)}")
            return False

class DataProcessor:
    """数据处理器 - 负责数据预处理和批次生成"""
    def __init__(self):
        self.scaler = MinMaxScaler(feature_range=(-1, 1))
        self.lock = threading.Lock()
        # 添加时间特征提取器
        self.time_feature_extractor = TimeFeatureExtractor()
    
    def process_records(self, records):
        """处理数据记录"""
        try:
            # 1. 清理数据
            cleaned = self._remove_invalid_data(records)
            
            # 2. 特征工程
            features = self._extract_features(cleaned)
            
            # 3. 数据标准化
            normalized = self._normalize_features(features)
            
            # 4. 序列化处理
            sequences = self._create_sequences(normalized)
            
            return sequences
            
        except Exception as e:
            logger.error(f"处理数据记录时出错: {str(e)}")
            return None
    
    def _remove_invalid_data(self, records):
        """移除无效数据"""
        return [r for r in records if self._is_valid_record(r)]
    
    def _is_valid_record(self, record):
        """检查记录是否有效"""
        try:
            numbers = record['numbers']
            return (
                isinstance(numbers, list) and
                len(numbers) == 5 and
                all(isinstance(n, int) and 0 <= n <= 9 for n in numbers)
            )
        except:
            return False
    
    def _extract_features(self, records):
        """提取特征"""
        features = []
        for record in records:
            # 基础数字特征
            number_features = np.array(record['numbers'])
            # 时间特征
            time_features = self.time_feature_extractor.extract_features(record['date_period'])
            # 合并特征
            combined = np.concatenate([number_features, time_features])
            features.append(combined)
        return np.array(features)
    
    def _normalize_features(self, features):
        """标准化特征"""
        with self.lock:
            return self.scaler.fit_transform(features)
    
    def _create_sequences(self, data):
        """创建序列数据"""
        sequences = []
        for i in range(len(data) - config_instance.SYSTEM_CONFIG['SAMPLE_CONFIG']['total_fetch']()):
            seq = {
                'input': data[i:i+config_instance.SYSTEM_CONFIG['SAMPLE_CONFIG']['input_length']],
                'target': data[i+config_instance.SYSTEM_CONFIG['SAMPLE_CONFIG']['input_length']:
                              i+config_instance.SYSTEM_CONFIG['SAMPLE_CONFIG']['total_fetch']()]
            }
            sequences.append(seq)
        return sequences

    def process_with_time_features(self, data):
        """处理数据并添加时间特征"""
        try:
            processed = []
            for record in data:
                # 基础特征
                features = self._extract_base_features(record)
                # 时间特征
                time_features = self.time_feature_extractor.extract_features(record['date_period'])
                # 组合特征
                combined = np.concatenate([features, time_features])
                processed.append(combined)
            return np.array(processed)
        except Exception as e:
            logger.error(f"处理时间特征时出错: {str(e)}")
            return None

    def _extract_base_features(self, record):
        """提取基础特征"""
        try:
            numbers = np.array(record['numbers'])
            # 添加统计特征
            stats = [
                np.mean(numbers),
                np.std(numbers),
                np.max(numbers),
                np.min(numbers)
            ]
            return np.concatenate([numbers, stats])
        except Exception as e:
            logger.error(f"提取基础特征时出错: {str(e)}")
            return None

    def apply_feature_scaling(self, data, feature_range=(-1, 1)):
        """应用特征缩放"""
        try:
            self.scaler = MinMaxScaler(feature_range=feature_range)
            return self.scaler.fit_transform(data)
        except Exception as e:
            logger.error(f"特征缩放失败: {str(e)}")
            return None

    def create_sliding_windows(self, data, window_size, stride=1):
        """创建滑动窗口数据"""
        try:
            windows = []
            for i in range(0, len(data) - window_size + 1, stride):
                windows.append(data[i:i + window_size])
            return np.array(windows)
        except Exception as e:
            logger.error(f"创建滑动窗口失败: {str(e)}")
            return None

    def batch_normalize(self, batches):
        """批量数据标准化"""
        try:
            normalized_batches = []
            for batch in batches:
                normalized = self._normalize_features(batch)
                normalized_batches.append(normalized)
            return normalized_batches
            
        except Exception as e:
            logger.error(f"批量标准化失败: {str(e)}")
            return None

    def process_training_data(self, raw_data):
        """处理原始训练数据"""
        try:
            if not raw_data:
                logger.error("原始数据为空")
                return None
                
            processed = []
            for idx, record in enumerate(raw_data):
                try:
                    # 验证必要字段存在
                    if 'number' not in record or 'date_period' not in record:
                        logger.warning(f"记录{idx}缺失必要字段，已跳过")
                        continue
                        
                    # 处理数字格式：将字符串拆分为单个字符
                    number_str = str(record['number']).strip()
                    if len(number_str) != 5:
                        logger.warning(f"记录{idx}号码长度错误: {number_str}")
                        continue
                        
                    numbers = [int(c) for c in number_str if c.isdigit()]
                    if len(numbers) != 5:
                        logger.warning(f"记录{idx}包含非数字字符: {number_str}")
                        continue
                        
                    time_feat = self._parse_time_features(str(record['date_period']))
                    processed.append(numbers + time_feat)
                    
                except Exception as e:
                    logger.warning(f"处理记录{idx}时出错: {str(e)}，已跳过该记录")

            if not processed:
                logger.error("无有效数据可处理")
                return None

            scaler = MinMaxScaler()
            return scaler.fit_transform(processed)
            
        except Exception as e:
            logger.error(f"数据处理失败: {str(e)}")
            return None

    def _parse_time_features(self, date_period):
        """解析时间特征"""
        date_str, period = date_period.split('-')
        dt = datetime.strptime(date_str, "%Y%m%d")
        return [
            dt.hour/24.0,  # 小时归一化
            dt.weekday()/7.0,  # 星期归一化
            int(period)/1440.0  # 期号归一化
        ]

class DataValidator:
    """数据验证器 - 负责数据有效性检查"""
    def __init__(self):
        self.lock = threading.Lock()
        self.validation_rules = {
            'sequence_length': self._check_sequence_length,
            'number_range': self._check_number_range,
            'time_continuity': self._check_time_continuity,
            'feature_completeness': self._check_feature_completeness
        }
    
    def validate(self, data):
        """验证数据有效性"""
        try:
            with self.lock:
                return all(
                    rule(data) for rule in self.validation_rules.values()
                )
        except Exception as e:
            logger.error(f"数据验证时出错: {str(e)}")
            return False
    
    def _check_sequence_length(self, data):
        """检查序列长度"""
        required_length = config_instance.SYSTEM_CONFIG['SAMPLE_CONFIG']['total_fetch']()
        return len(data) >= required_length
    
    def _check_number_range(self, data):
        """检查数字范围"""
        try:
            numbers = np.array([d['numbers'] for d in data])
            return np.all((numbers >= 0) & (numbers <= 9))
        except:
            return False
    
    def _check_time_continuity(self, data):
        """检查时间连续性"""
        try:
            periods = [d['date_period'] for d in data]
            for i in range(1, len(periods)):
                if not self._is_consecutive_periods(periods[i-1], periods[i]):
                    return False
            return True
        except:
            return False
    
    def _check_feature_completeness(self, data):
        """检查特征完整性"""
        try:
            return all(
                'numbers' in d and 'time_features' in d 
                for d in data
            )
        except:
            return False
    
    def _is_consecutive_periods(self, prev, curr):
        """检查期号是否连续"""
        try:
            p_date, p_num = prev.split('-')
            c_date, c_num = curr.split('-')
            
            if p_date == c_date:
                return int(c_num) - int(p_num) == 1
            
            p_dt = datetime.strptime(p_date, '%Y%m%d')
            c_dt = datetime.strptime(c_date, '%Y%m%d')
            
            return (c_dt - p_dt).days == 1 and int(p_num) == 1440 and int(c_num) == 1
            
        except:
            return False

    def validate_data_completeness(self, data):
        """验证数据完整性"""
        try:
            # 检查数据结构
            if not isinstance(data, (list, np.ndarray)):
                return False
                
            # 检查数据量
            if len(data) < config_instance.SYSTEM_CONFIG['SAMPLE_CONFIG']['input_length']:
                return False
                
            # 检查每条记录的完整性
            for record in data:
                if not self._check_record_completeness(record):
                    return False
                    
            return True
            
        except Exception as e:
            logger.error(f"验证数据完整性时出错: {str(e)}")
            return False
    
    def _check_record_completeness(self, record):
        """检查单条记录的完整性"""
        required_fields = ['date_period', 'numbers', 'time_features']
        return all(field in record for field in required_fields)

    def validate_batch_structure(self, batch):
        """验证批次数据结构"""
        try:
            if not isinstance(batch, dict):
                return False
                
            required_keys = ['input', 'target']
            if not all(key in batch for key in required_keys):
                return False
                
            input_shape = batch['input'].shape
            target_shape = batch['target'].shape
            
            if len(input_shape) != 3 or len(target_shape) != 3:
                return False
                
            return True
            
        except Exception as e:
            logger.error(f"验证批次结构失败: {str(e)}")
            return False

    def _check_number_format(self, number_str):
        """验证号码格式是否正确"""
        try:
            parts = number_str.split(',')
            if len(parts) != 5:
                return False
            return all(n.strip().isdigit() and 0 <= int(n) <= 9 for n in parts)
        except:
            return False

    def process_training_data(self, raw_data):
        """处理原始训练数据"""
        try:
            if not raw_data:
                logger.error("原始数据为空")
                return None
                
            processed = []
            for idx, record in enumerate(raw_data):
                try:
                    # 验证必要字段存在
                    if 'number' not in record or 'date_period' not in record:
                        logger.warning(f"记录{idx}缺失必要字段，已跳过")
                        continue
                        
                    # 处理数字格式：将字符串拆分为单个字符
                    number_str = str(record['number']).strip()
                    if len(number_str) != 5:
                        logger.warning(f"记录{idx}号码长度错误: {number_str}")
                        continue
                        
                    numbers = [int(c) for c in number_str if c.isdigit()]
                    if len(numbers) != 5:
                        logger.warning(f"记录{idx}包含非数字字符: {number_str}")
                        continue
                        
                    time_feat = self._parse_time_features(str(record['date_period']))
                    processed.append(numbers + time_feat)
                    
                    if not self._check_number_format(str(record['number'])):
                        logger.warning(f"记录{idx}号码格式错误: {record['number']}")
                        continue
                    
                except Exception as e:
                    logger.warning(f"处理记录{idx}时出错: {str(e)}，已跳过该记录")

            if not processed:
                logger.error("无有效数据可处理")
                return None

            scaler = MinMaxScaler()
            return scaler.fit_transform(processed)
            
        except Exception as e:
            logger.error(f"数据处理失败: {str(e)}")
            return None

    def _parse_time_features(self, date_period):
        """解析时间特征"""
        date_str, period = date_period.split('-')
        dt = datetime.strptime(date_str, "%Y%m%d")
        return [
            dt.hour/24.0,  # 小时归一化
            dt.weekday()/7.0,  # 星期归一化
            int(period)/1440.0  # 期号归一化
        ]

class TimeFeatureExtractor:
    """时间特征提取器"""
    def __init__(self):
        self.periodic_features = {
            'hour_of_day': (24, lambda dt: dt.hour),
            'minute_of_hour': (60, lambda dt: dt.minute),
            'day_of_week': (7, lambda dt: dt.weekday()),
            'day_of_month': (31, lambda dt: dt.day - 1),
            'month_of_year': (12, lambda dt: dt.month - 1)
        }
    
    def extract_features(self, date_period):
        """提取时间特征"""
        try:
            # 解析日期和期号
            date_str, period = date_period.split('-')
            date = datetime.strptime(date_str, '%Y%m%d')
            period_num = int(period)
            
            features = []
            
            # 添加周期性特征
            for period, func in self.periodic_features.values():
                value = func(date)
                # 转换为sin和cos特征以保持周期性
                sin_value = np.sin(2 * np.pi * value / period)
                cos_value = np.cos(2 * np.pi * value / period)
                features.extend([sin_value, cos_value])
            
            # 添加期号特征
            period_feature = (period_num - 1) / 1440  # 归一化到0-1
            features.append(period_feature)
            
            return np.array(features)
            
        except Exception as e:
            logger.error(f"提取时间特征时出错: {str(e)}")
            return np.zeros(len(self.periodic_features) * 2 + 1)  # 返回全零特征

class EnhancedDataManager(DataManager):
    """增强型数据管理器，支持流式数据加载"""
    def __init__(self):
        super().__init__()
        self.data_buffer = deque(maxlen=14400*2)  # 双倍缓冲
        self.last_processed = None
        
    def stream_training_samples(self):
        """实时数据流生成器"""
        while True:
            # 获取最新14400+2880期数据
            latest = self.execute_query(
                "SELECT number, date_period FROM admin_tab "
                "ORDER BY date_period DESC LIMIT 17280"
            )
            if latest and latest != self.last_processed:
                processed = self.data_processor.process_streaming_data(latest)
                if processed is not None:
                    sequences = self._create_sequences(processed, 14400)
                    if sequences:
                        yield sequences[0]  # 取最新序列
                        self.last_processed = latest
                time.sleep(58)  # 每58秒检查一次

# 创建全局实例
data_manager = DataManager()


[2025-02-18 15:17:57,267] INFO - 配置系统初始化完成
[2025-02-18 15:17:57,267] INFO - 配置系统初始化完成
[2025-02-18 15:17:57,269] INFO - 配置系统测试通过
[2025-02-18 15:17:57,269] INFO - 配置系统测试通过
[2025-02-18 15:17:57,271] INFO - 核心管理器初始化完成
[2025-02-18 15:17:57,271] INFO - 核心管理器初始化完成


In [5]:
#5 Feature Engineering System / 特征工程系统
import tensorflow as tf
import numpy as np
import logging
from tensorflow.keras.layers import Conv1D, Dense, Lambda
from sklearn.preprocessing import MinMaxScaler
from typing import Optional

# 获取logger实例
logger = logging.getLogger(__name__)

class FeatureEngineering:
    """特征工程类"""
    
    def __init__(self) -> None:
        """初始化特征工程组件"""
        self.logger = logging.getLogger(__name__)
        self.scaler: Optional[MinMaxScaler] = None
    
    def build_all_features(self, x):
        """构建所有特征"""
        try:
            # 1. 基础特征
            basic_features = self._build_basic_features(x)
            
            # 2. 高级特征
            advanced_features = self._build_advanced_features(x)
            
            # 3. 数字特征
            digit_features = self._build_advanced_digit_features(x)
            
            # 4. 形态特征
            pattern_features = self._build_pattern_features(x)
            
            # 5. 特征融合
            all_features = tf.keras.layers.Concatenate()([
                basic_features,
                advanced_features,
                digit_features,
                pattern_features
            ])
            
            return all_features
            
        except Exception as e:
            self.logger.error(f"构建特征时出错: {str(e)}")
            return x

    def _build_advanced_features(self, x):
        """构建高级特征分析"""
        features = []
        
        # 1. 冷热号分析
        hot_cold = self._analyze_hot_cold_numbers(x)
        features.append(hot_cold)
        
        # 2. 号码频率统计
        freq = self._analyze_frequency(x)
        features.append(freq)
        
        # 3. 和值分析
        sum_features = self._analyze_sum_value(x)
        features.append(sum_features)
        
        # 4. 数字特征分析
        digit_features = self._analyze_digit_patterns(x)
        features.append(digit_features)
        
        # 5. 形态分析
        pattern_features = self._analyze_number_patterns(x)
        features.append(pattern_features)
        
        # 6. 012路分析
        route_features = self._analyze_012_routes(x)
        features.append(route_features)
        
        return tf.keras.layers.Concatenate()(features)

    def _analyze_hot_cold_numbers(self, x, window_sizes=[100, 500, 1000]):
        """分析冷热号"""
        features = []
        
        for window in window_sizes:
            # 最近window期的频率
            recent = x[:, -window:]
            freq = tf.reduce_sum(tf.one_hot(tf.cast(recent, tf.int32), 10), axis=1)
            features.append(freq)
        
        return tf.keras.layers.Concatenate()(features)

    def _analyze_digit_patterns(self, x):
        """分析数字特征"""
        # 1. 奇偶比
        odd_even = tf.reduce_sum(tf.cast(x % 2 == 1, tf.float32), axis=-1, keepdims=True)
        
        # 2. 大小比 (5-9为大)
        big_small = tf.reduce_sum(tf.cast(x >= 5, tf.float32), axis=-1, keepdims=True)
        
        # 3. 质合比
        prime_numbers = tf.constant([2, 3, 5, 7])
        is_prime = tf.reduce_sum(tf.cast(
            tf.equal(x[..., None], prime_numbers), tf.float32
        ), axis=-1)
        prime_composite = tf.reduce_sum(is_prime, axis=-1, keepdims=True)
        
        # 4. 跨度
        span = tf.reduce_max(x, axis=-1) - tf.reduce_min(x, axis=-1)
        
        return tf.concat([odd_even, big_small, prime_composite, span[..., None]], axis=-1)

    def _analyze_sum_value(self, x):
        """分析和值特征"""
        # 1. 计算和值
        sum_value = tf.reduce_sum(x, axis=-1, keepdims=True)
        
        # 2. 和值分布区间
        sum_ranges = [
            (0, 10), (11, 20), (21, 30), (31, 40), (41, 45)
        ]
        sum_dist = []
        for low, high in sum_ranges:
            in_range = tf.logical_and(
                sum_value >= low,
                sum_value <= high
            )
            sum_dist.append(tf.cast(in_range, tf.float32))
        
        # 3. 和值特征
        sum_features = tf.concat([sum_value, tf.concat(sum_dist, axis=-1)], axis=-1)
        
        return sum_features

    def _analyze_012_routes(self, x):
        """分析012路特征"""
        # 1. 计算每个数字的路数
        routes = tf.math.floormod(x, 3)  # 对3取余
        
        # 2. 统计每个位置的路数分布
        route_distributions = []
        for i in range(5):  # 五个位置
            digit_routes = routes[..., i:i+1]
            # 统计0,1,2路的数量
            route_counts = []
            for r in range(3):
                count = tf.reduce_sum(
                    tf.cast(digit_routes == r, tf.float32),
                    axis=-1, keepdims=True
                )
                route_counts.append(count)
            route_distributions.append(tf.concat(route_counts, axis=-1))
        
        # 3. 计算整体012路比例
        total_route_dist = tf.reduce_mean(tf.concat(route_distributions, axis=-1), axis=-1, keepdims=True)
        
        # 4. 分析路数组合特征
        route_patterns = self._analyze_route_patterns(routes)
        
        # 5. 计算相邻位置的路数关系
        route_transitions = []
        for i in range(4):
            transition = tf.cast(
                routes[..., i:i+1] == routes[..., i+1:i+2],
                tf.float32
            )
            route_transitions.append(transition)
        
        # 6. 特征组合
        features = [
            *route_distributions,  # 每位路数分布
            total_route_dist,     # 整体路数比例
            route_patterns,       # 路数组合特征
            *route_transitions    # 相邻位置路数关系
        ]
        
        return tf.concat(features, axis=-1)

    def _analyze_route_patterns(self, routes):
        """分析路数组合模式"""
        # 1. 全0路
        all_zero = tf.reduce_all(routes == 0, axis=-1, keepdims=True)
        
        # 2. 全1路
        all_one = tf.reduce_all(routes == 1, axis=-1, keepdims=True)
        
        # 3. 全2路
        all_two = tf.reduce_all(routes == 2, axis=-1, keepdims=True)
        
        # 4. 012路是否均匀分布(各有至少一个)
        has_zero = tf.reduce_any(routes == 0, axis=-1, keepdims=True)
        has_one = tf.reduce_any(routes == 1, axis=-1, keepdims=True)
        has_two = tf.reduce_any(routes == 2, axis=-1, keepdims=True)
        balanced = tf.logical_and(
            tf.logical_and(has_zero, has_one),
            has_two
        )
        
        # 5. 主路特征(出现最多的路数)
        route_counts = []
        for r in range(3):
            count = tf.reduce_sum(
                tf.cast(routes == r, tf.float32),
                axis=-1, keepdims=True
            )
            route_counts.append(count)
        main_route = tf.argmax(tf.concat(route_counts, axis=-1), axis=-1)
        
        return tf.concat([
            tf.cast(all_zero, tf.float32),
            tf.cast(all_one, tf.float32),
            tf.cast(all_two, tf.float32),
            tf.cast(balanced, tf.float32),
            tf.cast(main_route, tf.float32)[..., tf.newaxis]
        ], axis=-1)

    def _build_advanced_digit_features(self, x):
        """构建高级数字特征"""
        features = []
        
        # 1. 每位数字的独立特征
        for i in range(5):
            digit = x[..., i:i+1]  # 提取第i位数字
            
            # 数字频率统计
            freq = tf.keras.layers.Lambda(
                lambda x: tf.cast(tf.histogram_fixed_width(x, [0, 9], nbins=10), tf.float32)
            )(digit)
            
            # 数字转换模式
            transitions = self._build_digit_transitions(digit)
            
            features.extend([freq, transitions])
        
        # 2. 数字组合特征
        for i in range(5):
            for j in range(i+1, 5):
                # 两位数字组合
                pair = tf.stack([x[..., i], x[..., j]], axis=-1)
                pair_features = self._build_pair_features(pair)
                features.append(pair_features)
        
        # 3. 完整号码特征
        full_number = tf.reshape(x, (-1, x.shape[1]))  # 将5位数字合并为一个完整号码
        number_features = self._build_number_features(full_number)
        features.append(number_features)
        
        return tf.keras.layers.Concatenate()(features)

    def _build_digit_transitions(self, digit):
        """构建数字转换特征"""
        # 计算相邻数字之间的转换
        transitions = digit[:, 1:] - digit[:, :-1]
        # 转换为one-hot编码
        transitions_one_hot = tf.one_hot(tf.cast(transitions + 9, tf.int32), 19)  # -9到9共19种可能
        return tf.reduce_mean(transitions_one_hot, axis=1)

    def _build_pair_features(self, pair):
        """构建数字对特征"""
        # 计算数字对的差值
        diff = tf.abs(pair[..., 0] - pair[..., 1])
        # 计算数字对的和
        sum_pair = pair[..., 0] + pair[..., 1]
        # 计算数字对的乘积
        prod = pair[..., 0] * pair[..., 1]
        
        return tf.stack([diff, sum_pair, prod], axis=-1)

    def _build_number_features(self, numbers):
        """构建完整号码特征"""
        # 1. 计算整体统计特征
        mean = tf.reduce_mean(numbers, axis=-1, keepdims=True)
        std = tf.math.reduce_std(numbers, axis=-1, keepdims=True)
        
        # 2. 计算号码的数字频率分布
        freq_dist = tf.keras.layers.Lambda(
            lambda x: tf.cast(tf.histogram_fixed_width(x, [0, 9], nbins=10), tf.float32)
        )(numbers)
        
        return tf.concat([mean, std, freq_dist], axis=-1)

    def _analyze_frequency(self, x):
        """分析号码频率"""
        # 转换为整数类型
        x = tf.cast(x, tf.int32)
        
        # 计算每个数字的出现频率
        freq = tf.zeros_like(x, dtype=tf.float32)
        for i in range(10):
            mask = tf.cast(x == i, tf.float32)
            freq += mask * tf.reduce_mean(mask, axis=1, keepdims=True)
        
        return freq

    def _analyze_number_patterns(self, x):
        """分析号码形态"""
        # 1. 连号分析
        consecutive = tf.reduce_sum(tf.cast(
            x[:, 1:] == x[:, :-1] + 1, tf.float32
        ), axis=-1, keepdims=True)
        
        # 2. 重复号分析
        unique_counts = tf.reduce_sum(tf.one_hot(tf.cast(x, tf.int32), 10), axis=-2)
        repeats = tf.reduce_sum(tf.cast(unique_counts > 1, tf.float32), axis=-1, keepdims=True)
        
        # 3. 形态识别
        patterns = self._identify_patterns(x)
        
        return tf.concat([consecutive, repeats, patterns], axis=-1)

    def _identify_patterns(self, x):
        """识别号码形态"""
        patterns = []
        
        # 1. 豹子号(AAAAA)
        baozi = tf.reduce_all(x == x[..., :1], axis=-1, keepdims=True)
        patterns.append(tf.cast(baozi, tf.float32))
        
        # 2. 组5(AAAAB)
        sorted_x = tf.sort(x, axis=-1)
        zu5 = tf.logical_and(
            tf.reduce_sum(tf.cast(sorted_x[..., :4] == sorted_x[..., :1], tf.float32), axis=-1) == 4,
            sorted_x[..., 4] != sorted_x[..., 0]
        )
        patterns.append(tf.cast(zu5, tf.float32)[..., tf.newaxis])
        
        # 3. 组10(AAABB)
        zu10 = tf.logical_and(
            tf.reduce_sum(tf.cast(sorted_x[..., :3] == sorted_x[..., :1], tf.float32), axis=-1) == 3,
            tf.reduce_sum(tf.cast(sorted_x[..., 3:] == sorted_x[..., 3:4], tf.float32), axis=-1) == 2
        )
        patterns.append(tf.cast(zu10, tf.float32)[..., tf.newaxis])
        
        # 4. 组20(AAABC)
        zu20 = tf.logical_and(
            tf.reduce_sum(tf.cast(sorted_x[..., :3] == sorted_x[..., :1], tf.float32), axis=-1) == 3,
            sorted_x[..., 3] != sorted_x[..., 4]
        )
        patterns.append(tf.cast(zu20, tf.float32)[..., tf.newaxis])
        
        # 5. 组30(AABBC)
        zu30 = tf.logical_and(
            tf.reduce_sum(tf.cast(sorted_x[..., :2] == sorted_x[..., :1], tf.float32), axis=-1) == 2,
            tf.reduce_sum(tf.cast(sorted_x[..., 2:4] == sorted_x[..., 2:3], tf.float32), axis=-1) == 2
        )
        patterns.append(tf.cast(zu30, tf.float32)[..., tf.newaxis])
        
        # 6. 组60(AABCD)
        zu60 = tf.logical_and(
            tf.reduce_sum(tf.cast(sorted_x[..., :2] == sorted_x[..., :1], tf.float32), axis=-1) == 2,
            tf.reduce_all(sorted_x[..., 2:] != sorted_x[..., 1:4], axis=-1)
        )
        patterns.append(tf.cast(zu60, tf.float32)[..., tf.newaxis])
        
        # 7. 组120(ABCDE)
        zu120 = tf.reduce_all(sorted_x[..., 1:] > sorted_x[..., :-1], axis=-1, keepdims=True)
        patterns.append(tf.cast(zu120, tf.float32))
        
        return tf.concat(patterns, axis=-1)

    def _build_periodic_features(self, x, periods=[60, 120, 360, 720, 1440]):
        """构建周期性特征"""
        features = []
        
        for period in periods:
            # 1. 提取周期模式
            pattern = self._extract_periodic_pattern(x, period)
            
            # 2. 周期性偏差
            deviation = x - pattern
            
            # 3. 周期强度
            strength = tf.reduce_mean(tf.abs(pattern), axis=-1, keepdims=True)
            
            features.extend([pattern, deviation, strength])
            
        return tf.keras.layers.Concatenate()(features)

    def _extract_periodic_pattern(self, x, period):
        """提取周期性模式"""
        # 重塑以匹配周期
        batch_size = tf.shape(x)[0]
        length = tf.shape(x)[1]
        n_periods = length // period
        
        # 重塑为(batch, n_periods, period, features)
        x_reshaped = tf.reshape(x[:, :n_periods*period], 
                               (batch_size, n_periods, period, -1))
        
        # 计算周期内模式
        pattern = tf.reduce_mean(x_reshaped, axis=1)  # 平均周期模式
        variance = tf.math.reduce_variance(x_reshaped, axis=1)  # 周期变异性
        
        return tf.concat([pattern, variance], axis=-1)

    def _build_probability_features(self, x):
        """构建概率分布特征"""
        # 1. 条件概率矩阵
        cond_matrix = tf.zeros((10, 10, 5))  # 每位数字的转移概率
        
        # 2. 计算历史转移概率
        for i in range(5):
            current = tf.cast(x[..., i], tf.int32)
            next_digit = tf.roll(current, shift=-1, axis=0)
            
            # 更新转移矩阵
            for j in range(10):
                for k in range(10):
                    mask_current = tf.cast(current == j, tf.float32)
                    mask_next = tf.cast(next_digit == k, tf.float32)
                    prob = tf.reduce_mean(mask_current * mask_next)
                    cond_matrix = tf.tensor_scatter_nd_update(
                        cond_matrix,
                        [[j, k, i]],
                        [prob]
                    )
        
        return cond_matrix

    def _build_statistical_features(self, x):
        """构建统计特征"""
        # 1. 移动统计
        windows = [60, 360, 720]  # 1小时、6小时、12小时
        stats = []
        
        for window in windows:
            # 移动平均
            ma = tf.keras.layers.AveragePooling1D(
                pool_size=window, strides=1, padding='same')(x)
            # 移动标准差
            std = tf.math.reduce_std(
                tf.stack([x, ma], axis=-1), axis=-1)
            # 移动极差
            pooled_max = tf.keras.layers.MaxPooling1D(
                pool_size=window, strides=1, padding='same')(x)
            pooled_min = -tf.keras.layers.MaxPooling1D(
                pool_size=window, strides=1, padding='same')(-x)
            range_stat = pooled_max - pooled_min
            
            stats.extend([ma, std, range_stat])
        
        # 2. 概率分布特征
        probs = self._build_probability_features(x)
        
        return tf.concat(stats + [probs], axis=-1)

    def _build_trend_features(self, x):
        """构建趋势特征"""
        # 1. 短期趋势
        short_ma = tf.keras.layers.AveragePooling1D(
            pool_size=12, strides=1, padding='same')(x)
        short_trend = tf.sign(x - short_ma)
        
        # 2. 中期趋势
        medium_ma = tf.keras.layers.AveragePooling1D(
            pool_size=60, strides=1, padding='same')(x)
        medium_trend = tf.sign(x - medium_ma)
        
        # 3. 长期趋势
        long_ma = tf.keras.layers.AveragePooling1D(
            pool_size=360, strides=1, padding='same')(x)
        long_trend = tf.sign(x - long_ma)
        
        # 4. 趋势一致性
        trend_consistency = tf.reduce_mean(
            tf.cast(short_trend == medium_trend, tf.float32) * 
            tf.cast(medium_trend == long_trend, tf.float32),
            axis=-1, keepdims=True
        )
        
        return tf.concat([short_trend, medium_trend, long_trend, trend_consistency], axis=-1)

    def _build_correlation_features(self, x):
        """构建相关性特征"""
        # 1. 位置间相关性
        correlations = []
        for i in range(5):
            for j in range(i+1, 5):
                corr = self._compute_correlation(x[..., i], x[..., j])
                correlations.append(corr)
        
        # 2. 滞后相关性
        lag_correlations = []
        for lag in [1, 2, 3, 5, 10]:
            lagged_corr = self._compute_lag_correlation(x, lag)
            lag_correlations.append(lagged_corr)
        
        return tf.concat([*correlations, *lag_correlations], axis=-1)

    def _compute_lag_correlation(self, x, lag):
        """计算滞后相关性"""
        x_current = x[:, lag:]
        x_lagged = x[:, :-lag]
        
        return self._compute_correlation(x_current, x_lagged)

    def _build_pattern_features(self, x):
        """构建形态特征分析"""
        # 1. 当前形态识别
        current_patterns = self._identify_patterns(x)
        
        # 2. 形态遗漏值分析
        pattern_gaps = self._analyze_pattern_gaps(x)
        
        # 3. 形态转换规律
        pattern_transitions = self._analyze_pattern_transitions(x)
        
        # 4. 形态组合特征
        pattern_combinations = self._analyze_pattern_combinations(x)
        
        # 5. 形态周期性分析
        pattern_periodicity = self._analyze_pattern_periodicity(x)
        
        return tf.keras.layers.Concatenate()([
            current_patterns,
            pattern_gaps,
            pattern_transitions,
            pattern_combinations,
            pattern_periodicity
        ])

    def _analyze_pattern_combinations(self, x):
        """分析形态组合特征"""
        # 1. 计算所有可能的形态组合
        patterns = []
        for i in range(5):
            for j in range(i+1, 5):
                pair = tf.stack([x[..., i], x[..., j]], axis=-1)
                patterns.append(self._analyze_digit_pair(pair))
                
        # 2. 组合形态间的关联性
        pattern_corr = tf.stack([
            self._compute_pattern_correlation(p1, p2)
            for i, p1 in enumerate(patterns)
            for j, p2 in enumerate(patterns) if i < j
        ], axis=-1)
        
        return tf.concat([*patterns, pattern_corr], axis=-1)

    def _analyze_pattern_periodicity(self, x):
        """分析形态周期性"""
        periods = [12, 24, 60, 120, 360]
        periodicity = []
        
        for period in periods:
            # 1. 周期性模式提取
            pattern = self._extract_pattern_cycle(x, period)
            
            # 2. 周期强度计算
            strength = self._compute_cycle_strength(pattern)
            
            # 3. 周期稳定性分析
            stability = self._analyze_cycle_stability(pattern)
            
            periodicity.extend([pattern, strength, stability])
            
        return tf.concat(periodicity, axis=-1)

    def _extract_pattern_cycle(self, x, period):
        """提取形态周期模式"""
        # 1. 重塑数据以匹配周期
        batch_size = tf.shape(x)[0]
        n_cycles = tf.shape(x)[1] // period
        x_cycles = tf.reshape(x[:, :n_cycles*period], 
                            [batch_size, n_cycles, period, -1])
        
        # 2. 计算周期内的形态分布
        cycle_patterns = tf.reduce_mean(x_cycles, axis=1)
        
        # 3. 计算周期间的变异性
        cycle_variance = tf.math.reduce_variance(x_cycles, axis=1)
        
        return tf.concat([cycle_patterns, cycle_variance], axis=-1)

    def _compute_cycle_strength(self, pattern):
        """计算周期强度"""
        # 1. 自相关分析
        autocorr = tf.keras.layers.Conv1D(
            filters=1, kernel_size=pattern.shape[1],
            padding='same'
        )(pattern)
        
        # 2. 周期性强度评分
        strength = tf.reduce_mean(tf.abs(autocorr), axis=1, keepdims=True)
        
        return strength

    def _analyze_cycle_stability(self, pattern):
        """分析周期稳定性"""
        # 1. 计算相邻周期的差异
        diffs = pattern[:, 1:] - pattern[:, :-1]
        
        # 2. 计算稳定性指标
        stability = tf.reduce_mean(tf.abs(diffs), axis=1, keepdims=True)
        stability = tf.exp(-stability)  # 转换到0-1范围
        
        return stability

    def _analyze_digit_pair(self, pair):
        """分析数字对特征"""
        # 1. 计算数字对基本特征
        diff = tf.abs(pair[..., 0] - pair[..., 1])
        sum_pair = pair[..., 0] + pair[..., 1]
        prod = pair[..., 0] * pair[..., 1]
        
        # 2. 计算数字对的位置关系
        is_adjacent = tf.cast(diff == 1, tf.float32)
        is_complementary = tf.cast(sum_pair == 9, tf.float32)
        
        # 3. 计算组合特征
        features = tf.stack([
            diff, sum_pair, prod,
            is_adjacent, is_complementary
        ], axis=-1)
        
        return features

    def _compute_pattern_correlation(self, p1, p2):
        """计算形态相关性"""
        # 标准化
        p1_norm = (p1 - tf.reduce_mean(p1)) / (tf.math.reduce_std(p1) + 1e-6)
        p2_norm = (p2 - tf.reduce_mean(p2)) / (tf.math.reduce_std(p2) + 1e-6)
        
        # 计算相关系数
        corr = tf.reduce_mean(p1_norm * p2_norm, axis=-1, keepdims=True)
        
        return corr

    def _analyze_pattern_gaps(self, x):
        """分析形态遗漏值"""
        try:
            # 1. 获取形态
            patterns = self._identify_patterns(x)
            
            # 2. 初始化遗漏值计数器
            gap_counters = tf.zeros_like(patterns)
            
            # 3. 计算每种形态的遗漏值
            def update_gaps(sequence):
                gaps = []
                for i in range(7):  # 7种形态
                    last_pos = -1
                    current_gap = 0
                    pattern_positions = tf.where(sequence[:, i])
                    
                    if tf.size(pattern_positions) > 0:
                        last_pos = tf.reduce_max(pattern_positions)
                        current_gap = tf.shape(sequence)[0] - 1 - last_pos
                    
                    gaps.append(current_gap)
                return tf.stack(gaps)
            
            # 4. 应用遗漏值计算
            gaps = tf.keras.layers.Lambda(update_gaps)(patterns)
            
            # 5. 构建遗漏值特征
            gap_features = []
            
            # 当前遗漏值
            gap_features.append(gaps)
            
            # 历史最大遗漏值
            max_gaps = tf.reduce_max(gaps, axis=0, keepdims=True)
            gap_features.append(max_gaps)
            
            # 历史平均遗漏值
            mean_gaps = tf.reduce_mean(gaps, axis=0, keepdims=True)
            gap_features.append(mean_gaps)
            
            # 遗漏值分布
            gap_dist = tf.keras.layers.Lambda(
                lambda x: tf.cast(tf.histogram_fixed_width(x, [0, 1000], nbins=50), tf.float32)
            )(gaps)
            gap_features.append(gap_dist)
            
            return tf.concat(gap_features, axis=-1)
            
        except Exception as e:
            logger.error(f"分析形态遗漏值时出错: {str(e)}")
            return tf.zeros_like(x)

    def _analyze_pattern_transitions(self, x):
        """分析形态转换规律"""
        try:
            # 1. 识别所有形态
            patterns = self._identify_patterns(x)
            
            # 2. 计算形态转换矩阵
            def get_transition_matrix(sequence):
                matrix = tf.zeros((7, 7))  # 7x7转换矩阵
                for i in range(len(sequence)-1):
                    current = tf.argmax(sequence[i])
                    next_pattern = tf.argmax(sequence[i+1])
                    matrix = tf.tensor_scatter_nd_update(
                        matrix,
                        [[current, next_pattern]],
                        [1.0]
                    )
                return matrix
            
            transition_matrix = tf.keras.layers.Lambda(get_transition_matrix)(patterns)
            
            # 3. 提取转换特征
            transitions = []
            
            # 转换概率
            prob_matrix = transition_matrix / (tf.reduce_sum(transition_matrix, axis=-1, keepdims=True) + 1e-7)
            transitions.append(tf.reshape(prob_matrix, [-1]))
            
            # 最常见转换路径
            common_transitions = tf.reduce_max(prob_matrix, axis=-1)
            transitions.append(common_transitions)
            
            # 形态稳定性(自我转换概率)
            stability = tf.linalg.diag_part(prob_matrix)
            transitions.append(stability)
            
            return tf.concat(transitions, axis=-1)
            
        except Exception as e:
            logger.error(f"分析形态转换规律时出错: {str(e)}")
            return tf.zeros_like(x)

    def _build_basic_features(self, x):
        """构建基础特征"""
        try:
            # 1. 时间特征
            time_features = self._add_positional_encoding(x)
            
            # 2. 多尺度特征
            multi_scale = self._build_multi_scale_features(x)
            
            # 3. 统计特征
            statistical = self._build_statistical_features(x)
            
            # 合并所有基础特征
            x = tf.concat([time_features, multi_scale, statistical], axis=-1)
            return x
            
        except Exception as e:
            logger.error(f"构建基础特征时出错: {str(e)}")
            return x

    def _add_positional_encoding(self, x):
        """添加位置编码"""
        seq_len = tf.shape(x)[1]
        d_model = tf.shape(x)[-1]
        
        position = tf.range(seq_len, dtype=tf.float32)[:, tf.newaxis]
        div_term = tf.exp(tf.range(0, d_model, 2, dtype=tf.float32) * -(tf.math.log(10000.0) / d_model))
        
        pe = tf.zeros((seq_len, d_model))
        pe = tf.tensor_scatter_nd_update(
            pe,
            tf.stack([tf.range(seq_len), tf.range(0, d_model, 2)], axis=1),
            tf.sin(position * div_term)
        )
        pe = tf.tensor_scatter_nd_update(
            pe,
            tf.stack([tf.range(seq_len), tf.range(1, d_model, 2)], axis=1),
            tf.cos(position * div_term)
        )
        
        return x + pe[tf.newaxis, :, :]

    def _build_model_features(self, x):
        """构建模型特征"""
        # 1. 基础特征
        base_features = self._build_basic_features(x)
        
        # 2. 时序特征
        temporal_features = self._build_temporal_features(x)
        
        # 3. 模式特征
        pattern_features = self._build_pattern_features(x)
        
        # 4. 高阶特征组合
        combined_features = self._build_combined_features([
            base_features,
            temporal_features, 
            pattern_features
        ])
        
        return combined_features

    def _build_combined_features(self, feature_list):
        """构建高阶特征组合"""
        try:
            # 1. 特征连接
            x = tf.keras.layers.Concatenate()(feature_list)
            
            # 2. 非线性变换
            x = tf.keras.layers.Dense(256, activation='relu')(x)
            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Dropout(0.2)(x)
            
            # 3. 特征交互
            x = self._build_feature_interactions(x)
            
            return x
            
        except Exception as e:
            logger.error(f"构建组合特征时出错: {str(e)}")
            return tf.zeros_like(x)

    def _build_feature_interactions(self, x):
        """构建特征交互"""
        try:
            # 1. 自注意力交互
            att = tf.keras.layers.MultiHeadAttention(
                num_heads=4,
                key_dim=32
            )(x, x)
            x = tf.keras.layers.Add()([x, att])
            x = tf.keras.layers.LayerNormalization()(x)
            
            # 2. 非线性特征组合
            x = tf.keras.layers.Dense(128, activation='relu')(x)
            x = tf.keras.layers.Dense(64, activation='relu')(x)
            
            return x
            
        except Exception as e:
            logger.error(f"构建特征交互时出错: {str(e)}")
            return tf.zeros_like(x)

    def _build_temporal_features(self, x):
        """构建时序特征"""
        try:
            # 1. 时间编码
            time_encoding = self._add_temporal_encoding(x)
            
            # 2. 周期特征
            periodic = self._build_periodic_features(x)
            
            # 3. 趋势特征
            trend = self._build_trend_features(x)
            
            # 4. 相关性特征
            correlation = self._build_correlation_features(x)
            
            return tf.concat([time_encoding, periodic, trend, correlation], axis=-1)
            
        except Exception as e:
            logger.error(f"构建时序特征时出错: {str(e)}")
            return x

    def _add_temporal_encoding(self, x):
        """添加时间编码"""
        try:
            seq_len = tf.shape(x)[1]
            d_model = tf.shape(x)[-1]
            
            # 1. 位置编码
            position = tf.range(seq_len, dtype=tf.float32)[:, tf.newaxis]
            div_term = tf.exp(tf.range(0, d_model, 2, dtype=tf.float32) * -(tf.math.log(10000.0) / d_model))
            
            pe = tf.zeros((seq_len, d_model))
            pe = tf.tensor_scatter_nd_update(
                pe,
                tf.stack([tf.range(seq_len), tf.range(0, d_model, 2)], axis=1),
                tf.sin(position * div_term)
            )
            pe = tf.tensor_scatter_nd_update(
                pe,
                tf.stack([tf.range(seq_len), tf.range(1, d_model, 2)], axis=1),
                tf.cos(position * div_term)
            )
            
            # 2. 周期性时间特征
            day_in_week = tf.cast(tf.math.floormod(position, 7), tf.float32) / 7.0
            hour_in_day = tf.cast(tf.math.floormod(position, 24), tf.float32) / 24.0
            minute_in_hour = tf.cast(tf.math.floormod(position, 60), tf.float32) / 60.0
            
            time_features = tf.concat([
                pe,
                day_in_week,
                hour_in_day,
                minute_in_hour
            ], axis=-1)
            
            return time_features[tf.newaxis, :, :]
            
        except Exception as e:
            logger.error(f"添加时间编码时出错: {str(e)}")
            return x

    def evaluate_model_feature_importance(self, model, X):
        """评估模型特征重要性"""
        try:
            # 使用模型的权重评估特征重要性
            feature_weights = model.get_layer('feature_layer').get_weights()[0]
            importance = np.abs(feature_weights).mean(axis=1)
            return importance
            
        except Exception as e:
            logger.error(f"评估特征重要性时出错: {str(e)}")
            return None
    
    def select_top_features(self, importance, top_k=10):
        """选择最重要的特征"""
        try:
            top_indices = np.argsort(importance)[-top_k:]
            return top_indices
            
        except Exception as e:
            logger.error(f"选择重要特征时出错: {str(e)}")
            return None

    def _build_interaction_features(self, features):
        """构建交互特征"""
        interactions = []
        for i in range(len(features)):
            for j in range(i + 1, len(features)):
                interaction = features[i] * features[j]
                interactions.append(interaction)
        return np.array(interactions)

    def _extract_derived_features(self, base_features):
        """提取衍生特征"""
        derived = []
        for feature in base_features:
            derived.extend([
                np.square(feature),
                np.sqrt(np.abs(feature)),
                np.log1p(np.abs(feature))
            ])
        return np.array(derived)

    def _compute_feature_crosses(self, features):
        """计算特征交叉"""
        crosses = []
        for i in range(len(features)):
            for j in range(i + 1, len(features)):
                cross = np.outer(features[i], features[j]).ravel()
                crosses.append(cross)
        return np.array(crosses)

# 创建全局实例
feature_engineering = FeatureEngineering()


In [6]:
#6 Model Core System / 模型核心系统
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import logging
import os
import json
import time
from collections import deque
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, LSTM, Conv1D, MultiHeadAttention, LayerNormalization, Bidirectional, Add
from typing import Dict, Any, Optional
from cell1_core import core_manager
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import math

# 获取logger实例 
logger = logging.getLogger(__name__)

class ModelCore:
    """整合后的模型核心类"""
    
    # 1. 配置管理 (from model_config.py)
    def __init__(self, config_path: Optional[str] = None):
        """初始化模型核心类
        Args:
            config_path: 配置文件路径,如果为None则使用默认配置
        """
        # 从base_model.py继承的属性
        self.sequence_length = 14400
        self.feature_dim = 5
        self.prediction_range = 2880
        self.performance_history = []
        
        # 从model_config.py继承的属性
        self.config_path = config_path
        self.config = self._load_config()
        self.input_shape = core_manager.SYSTEM_CONFIG['SAMPLE_CONFIG']['input_length']
        
        # 添加预测相关的属性
        self.prediction_history = deque(maxlen=1000)  # 预测历史记录
        self.prediction_cache = {}  # 预测结果缓存
        self.cache_timeout = 300   # 缓存超时时间(秒)
        self.confidence_threshold = 0.8  # 置信度阈值

        # 确保在初始化时创建新图
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.default_config = {
                'filters': 64,
                'units': 128,
                'dense_units': 64
            }
            self.models = [self._build_model(self.default_config) for _ in range(6)]
        self.session = tf.compat.v1.Session(graph=self.graph)
        tf.compat.v1.keras.backend.set_session(self.session)
        
        # 初始化优化器
        self.optimizer = self._create_optimizer()
        
        logger.info("模型核心类初始化完成")

    def _load_config(self) -> Dict[str, Any]:
        """从model_config.py继承的配置加载方法"""
        try:
            if self.config_path and os.path.exists(self.config_path):
                with open(self.config_path, 'r', encoding='utf-8') as f:
                    config = json.load(f)
                logger.info(f"从{self.config_path}加载配置")
                return config
            
            # 使用默认配置
            return self._get_default_config()
            
        except Exception as e:
            logger.error(f"加载配置失败: {str(e)}")
            return self._get_default_config()

    def _get_default_config(self) -> Dict[str, Any]:
        """获取默认配置"""
        return {
            'optimizer_config': {
                'learning_rate': 0.001,
                'beta_1': 0.9,
                'beta_2': 0.999
            },
            'architecture_config': {
                'model_1': {'filters': 64, 'units': 128, 'dense_units': 64}
            }
        }

    def get_model_config(self, model_index: int) -> Dict[str, Any]:
        """获取指定模型配置 (from model_config.py)"""
        try:
            return self.config['architecture_config'][f'model_{model_index}']
        except KeyError:
            logger.error(f"未找到模型{model_index}的配置")
            return {}

    def validate_config(self) -> bool:
        """验证配置有效性 (from model_config.py)"""
        try:
            # 验证基础配置
            base_config = self.config['base_config']
            assert base_config['sequence_length'] > 0
            assert base_config['feature_dim'] > 0
            assert base_config['batch_size'] > 0
            
            # 验证优化器配置
            optimizer_config = self.config['optimizer_config']
            assert optimizer_config['learning_rate'] > 0
            assert 0 < optimizer_config['beta_1'] < 1
            assert 0 < optimizer_config['beta_2'] < 1
            
            # 验证模型架构配置
            for model_name, model_config in self.config['architecture_config'].items():
                assert all(value > 0 for value in model_config.values() if isinstance(value, (int, float)))
            
            return True
        except Exception as e:
            logger.error(f"配置验证失败: {str(e)}")
            return False

    def update_config(self, new_config: Dict[str, Any]) -> None:
        """更新配置 (from model_config.py)"""
        try:
            self.config.update(new_config)
            self._save_config()
            logger.info("配置更新成功")
        except Exception as e:
            logger.error(f"更新配置失败: {str(e)}")

    def _save_config(self) -> None:
        """保存配置到文件 (from model_config.py)"""
        try:
            if self.config_path:
                with open(self.config_path, 'w', encoding='utf-8') as f:
                    json.dump(self.config, f, indent=4)
                logger.info(f"配置已保存到{self.config_path}")
        except Exception as e:
            logger.error(f"保存配置失败: {str(e)}")

    def get_optimizer_config(self) -> Dict[str, Any]:
        """获取优化器配置"""
        return self.config['optimizer_config']
    
    def get_loss_config(self) -> Dict[str, Any]:
        """获取损失函数配置"""
        return self.config['loss_config']
    
    def get_training_config(self) -> Dict[str, Any]:
        """获取训练配置"""
        return self.config['training_config']
    
    def get_ensemble_config(self) -> Dict[str, Any]:
        """获取集成配置"""
        return self.config['ensemble_config']
    
    def get_monitor_config(self) -> Dict[str, Any]:
        """获取监控配置"""
        return self.config['monitor_config']

    def _validate_model_architecture(self, architecture: Dict[str, Any]) -> bool:
        """验证模型架构配置"""
        try:
            # 检查必需的层配置
            required_layers = ['lstm_units', 'attention_heads', 'dense_units']
            for layer in required_layers:
                assert any(layer in model for model in architecture.values())
            
            # 检查层参数范围
            for model_config in architecture.values():
                assert 32 <= model_config.get('lstm_units', 64) <= 512
                assert 2 <= model_config.get('attention_heads', 4) <= 16
                assert 16 <= model_config.get('dense_units', 32) <= 256
                
            return True
        except Exception as e:
            logger.error(f"模型架构验证失败: {str(e)}")
            return False

    def _validate_training_strategy(self, strategy: Dict[str, Any]) -> bool:
        """验证训练策略配置"""
        try:
            assert 0 < strategy['learning_rate'] < 1
            assert 16 <= strategy['batch_size'] <= 512
            assert 0 < strategy['dropout_rate'] < 1
            assert strategy['early_stopping_patience'] > 0
            return True
        except Exception as e:
            logger.error(f"训练策略验证失败: {str(e)}")
            return False

    def _validate_preprocessing_config(self) -> bool:
        """验证预处理配置"""
        try:
            preprocess_cfg = self.config['preprocessing_config']
            assert 'sequence_length' in preprocess_cfg
            assert 'sliding_window' in preprocess_cfg
            assert 'normalization' in preprocess_cfg
            return True
        except Exception as e:
            logger.error(f"预处理配置验证失败: {str(e)}")
            return False

    def _validate_ensemble_config(self) -> bool:
        """验证集成配置"""
        try:
            ensemble_cfg = self.config['ensemble_config']
            assert 'voting_method' in ensemble_cfg
            assert ensemble_cfg['voting_method'] in ['majority', 'weighted', 'average']
            assert 0 < ensemble_cfg['min_weight'] < ensemble_cfg['max_weight'] < 1
            return True
        except Exception as e:
            logger.error(f"集成配置验证失败: {str(e)}")
            return False

    # 2. 特征工程 (from base_model.py)
    def _build_basic_features(self, x):
        """构建基础特征 (from base_model.py)"""
        try:
            # 1. 时间特征
            time_features = self._add_positional_encoding(x)
            
            # 2. 多尺度特征
            multi_scale = self._build_multi_scale_features(x)
            
            # 3. 统计特征
            statistical = self._build_statistical_features(x)
            
            # 合并所有基础特征
            x = tf.concat([time_features, multi_scale, statistical], axis=-1)
            return x
            
        except Exception as e:
            logger.error(f"构建基础特征时出错: {str(e)}")
            return x

    def _add_positional_encoding(self, x):
        """时序位置编码 (修正序列长度计算)"""
        seq_len = tf.shape(x)[1]
        d_model = tf.shape(x)[-1]
        
        position = tf.range(seq_len, dtype=tf.float32)[:, tf.newaxis]
        
        # 修正：先计算常量部分
        log_value = tf.cast(-math.log(10000.0), tf.float32)
        d_model_float = tf.cast(d_model, tf.float32)
        angle_factor = log_value / d_model_float
        
        # 修正：计算正确的序列长度
        d_model_half = tf.cast(tf.math.floor(d_model_float / 2), tf.int32)
        total_length = seq_len * d_model_half
        
        # 生成角度基数
        angle_rads = tf.range(0, d_model_float, 2.0) * angle_factor
        div_term = tf.exp(angle_rads)
        
        pe = tf.zeros((seq_len, d_model))
        
        # 正弦部分 - 确保长度匹配
        indices_sin = tf.stack([
            tf.repeat(tf.range(seq_len), d_model_half),
            tf.tile(tf.range(0, d_model, 2)[:d_model_half], [seq_len])
        ], axis=1)
        
        updates_sin = tf.sin(tf.reshape(position, [-1, 1]) * div_term[:d_model_half])
        pe = tf.tensor_scatter_nd_update(pe, indices_sin, tf.reshape(updates_sin, [-1]))
        
        # 余弦部分 - 确保长度匹配
        indices_cos = tf.stack([
            tf.repeat(tf.range(seq_len), d_model_half),
            tf.tile(tf.range(1, d_model, 2)[:d_model_half], [seq_len])
        ], axis=1)
        
        updates_cos = tf.cos(tf.reshape(position, [-1, 1]) * div_term[:d_model_half])
        pe = tf.tensor_scatter_nd_update(pe, indices_cos, tf.reshape(updates_cos, [-1]))
        
        return x + pe[tf.newaxis, :, :]

    def _build_multi_scale_features(self, x):
        """多尺度特征提取 (from base_model.py)"""
        try:
            conv1 = Conv1D(32, kernel_size=3, padding='same')(x)
            conv2 = Conv1D(32, kernel_size=5, padding='same')(x)
            conv3 = Conv1D(32, kernel_size=7, padding='same')(x)
            dconv1 = Conv1D(32, kernel_size=3, dilation_rate=2, padding='same')(x)
            dconv2 = Conv1D(32, kernel_size=5, dilation_rate=2, padding='same')(x)
            return tf.keras.layers.Concatenate()([conv1, conv2, conv3, dconv1, dconv2])
        except Exception as e:
            logger.error(f"构建多尺度特征时出错: {str(e)}")
            return x

    def _build_statistical_features(self, x):
        """统计特征提取 (from base_model.py)"""
        try:
            mean = tf.reduce_mean(x, axis=1, keepdims=True)
            std = tf.math.reduce_std(x, axis=1, keepdims=True)
            kurtosis = tf.reduce_mean(tf.pow(x - mean, 4), axis=1, keepdims=True) / tf.pow(std, 4)
            skewness = tf.reduce_mean(tf.pow(x - mean, 3), axis=1, keepdims=True) / tf.pow(std, 3)
            return tf.keras.layers.Concatenate()([mean, std, kurtosis, skewness])
        except Exception as e:
            logger.error(f"构建统计特征时出错: {str(e)}")
            return x

    def _build_temporal_features(self, x):
        """构建时序特征"""
        try:
            # 1. 时间编码
            time_encoding = self._add_temporal_encoding(x)
            
            # 2. 周期特征
            periodic = self._build_periodic_features(x)
            
            # 3. 趋势特征
            trend = self._build_trend_features(x)
            
            # 4. 相关性特征
            correlation = self._build_correlation_features(x)
            
            return tf.concat([time_encoding, periodic, trend, correlation], axis=-1)
            
        except Exception as e:
            logger.error(f"构建时序特征时出错: {str(e)}")
            return x

    def _build_periodic_features(self, x, periods=[60, 120, 360, 720, 1440]):
        """构建周期性特征"""
        features = []
        for period in periods:
            pattern = self._extract_periodic_pattern(x, period)
            deviation = x - pattern
            strength = tf.reduce_mean(tf.abs(pattern), axis=-1, keepdims=True)
            features.extend([pattern, deviation, strength])
        return tf.keras.layers.Concatenate()(features)

    def _build_correlation_features(self, x):
        """构建相关性特征"""
        # 1. 位置间相关性
        correlations = []
        for i in range(5):
            for j in range(i+1, 5):
                corr = self._compute_correlation(x[..., i], x[..., j])
                correlations.append(corr)
        
        # 2. 滞后相关性
        lag_correlations = []
        for lag in [1, 2, 3, 5, 10]:
            lagged_corr = self._compute_lag_correlation(x, lag)
            lag_correlations.append(lagged_corr)
        
        return tf.concat([*correlations, *lag_correlations], axis=-1)

    def _compute_correlation(self, x1, x2):
        """计算相关系数"""
        x1_norm = (x1 - tf.reduce_mean(x1)) / tf.math.reduce_std(x1)
        x2_norm = (x2 - tf.reduce_mean(x2)) / tf.math.reduce_std(x2)
        return tf.reduce_mean(x1_norm * x2_norm, axis=-1, keepdims=True)

    def _compute_lag_correlation(self, x, lag):
        """计算滞后相关性"""
        x_current = x[:, lag:]
        x_lagged = x[:, :-lag]
        return self._compute_correlation(x_current, x_lagged)

    def _extract_periodic_pattern(self, x, period):
        """提取周期性模式"""
        try:
            batch_size = tf.shape(x)[0]
            length = tf.shape(x)[1]
            n_periods = length // period
            
            x_reshaped = tf.reshape(x[:, :n_periods*period], 
                               (batch_size, n_periods, period, -1))
            pattern = tf.reduce_mean(x_reshaped, axis=1)
            return pattern
        except Exception as e:
            logger.error(f"提取周期模式时出错: {str(e)}")
            return x

    def _add_temporal_encoding(self, x):
        """添加时间编码"""
        try:
            seq_len = tf.shape(x)[1]
            d_model = tf.shape(x)[-1]
            
            # 1. 位置编码
            position = tf.range(seq_len, dtype=tf.float32)[:, tf.newaxis]
            div_term = tf.exp(
                tf.range(0, d_model, 2, dtype=tf.float32) * 
                (-math.log(10000.0) / d_model)
            )
            
            pe = tf.zeros((seq_len, d_model))
            pe = tf.tensor_scatter_nd_update(
                pe,
                tf.stack([tf.range(seq_len), tf.range(0, d_model, 2)], axis=1),
                tf.sin(position * div_term)
            )
            pe = tf.tensor_scatter_nd_update(
                pe,
                tf.stack([tf.range(seq_len), tf.range(1, d_model, 2)], axis=1),
                tf.cos(position * div_term)
            )
            
            # 2. 添加时间周期特征
            day_in_week = tf.cast(tf.math.floormod(position, 7), tf.float32) / 7.0
            hour_in_day = tf.cast(tf.math.floormod(position, 24), tf.float32) / 24.0
            minute_in_hour = tf.cast(tf.math.floormod(position, 60), tf.float32) / 60.0
            
            time_features = tf.concat([
                pe,
                day_in_week,
                hour_in_day,
                minute_in_hour
            ], axis=-1)
            
            return time_features[tf.newaxis, :, :]
            
        except Exception as e:
            logger.error(f"添加时间编码时出错: {str(e)}")
            return x

    def _adjust_sequence_length(self, x):
        """调整序列长度"""
        current_length = x.shape[1]
        if current_length > self.prediction_range:
            x = tf.keras.layers.AveragePooling1D(
                pool_size=current_length // self.prediction_range)(x)
        elif current_length < self.prediction_range:
            x = tf.keras.layers.UpSampling1D(
                size=self.prediction_range // current_length)(x)
        return x

    def _build_pattern_features(self, x):
        """构建形态特征"""
        try:
            # 1. 连号分析
            consecutive = tf.reduce_sum(tf.cast(
                x[:, 1:] == x[:, :-1] + 1, tf.float32
            ), axis=-1, keepdims=True)
            
            # 2. 重复号分析
            unique_counts = tf.reduce_sum(
                tf.one_hot(tf.cast(x, tf.int32), 10),
                axis=-2
            )
            repeats = tf.reduce_sum(
                tf.cast(unique_counts > 1, tf.float32),
                axis=-1, keepdims=True
            )
            
            # 3. 号码分布特征
            distribution = self._analyze_number_distribution(x)
            
            # 4. 形态组合特征
            combinations = self._analyze_pattern_combinations(x)
            
            # 5. 形态周期规律
            periodicity = self._analyze_pattern_periodicity(x)
            
            return tf.concat([
                consecutive, 
                repeats, 
                distribution,
                combinations,
                periodicity
            ], axis=-1)
            
        except Exception as e:
            logger.error(f"构建形态特征时出错: {str(e)}")
            return x
    
    def _analyze_number_distribution(self, x):
        """分析号码分布特征"""
        try:
            # 1. 大小比例
            big_nums = tf.reduce_mean(tf.cast(x >= 5, tf.float32), axis=-1, keepdims=True)
            
            # 2. 奇偶比例
            odd_nums = tf.reduce_mean(tf.cast(x % 2 == 1, tf.float32), axis=-1, keepdims=True)
            
            # 3. 012路数分析
            mod_3 = tf.cast(x % 3, tf.float32)
            route_0 = tf.reduce_mean(tf.cast(mod_3 == 0, tf.float32), axis=-1, keepdims=True)
            route_1 = tf.reduce_mean(tf.cast(mod_3 == 1, tf.float32), axis=-1, keepdims=True)
            route_2 = tf.reduce_mean(tf.cast(mod_3 == 2, tf.float32), axis=-1, keepdims=True)
            
            # 4. 和值分析
            sum_value = tf.reduce_sum(x, axis=-1, keepdims=True)
            
            return tf.concat([
                big_nums, odd_nums,
                route_0, route_1, route_2,
                sum_value
            ], axis=-1)
        except Exception as e:
            logger.error(f"分析号码分布特征时出错: {str(e)}")
            return tf.zeros_like(x[..., :1])

    # 3. 模型构建 (from model_builder.py)
    def build_model(self, model_num=None, params=None):
        """整合后的模型构建方法"""
        try:
            inputs = tf.keras.Input(shape=(self.sequence_length, self.feature_dim))
            
            # 1. 基础特征提取 (from base_model.py)
            x = self._build_basic_features(inputs)
            
            # 2. 根据模型编号选择不同的特征提取方法 (from model_builder.py)
            if model_num is not None:
                x = self._build_model_specific_features(x, model_num, params or {})
            
            # 3. 预测输出头 (from base_model.py)
            outputs = self._build_prediction_head(x)
            
            model = Model(inputs=inputs, outputs=outputs)
            self.compile_model(model)
            return model
        except Exception as e:
            logger.error(f"构建模型时出错: {str(e)}")
            raise

    def _build_model_specific_features(self, x, model_num, params):
        """不同模型架构的特征构建 (from model_builder.py)"""
        if model_num == 1:
            return self._build_lstm_gru_attention(x, params)
        elif model_num == 2:
            return self._build_bilstm_residual(x, params)
        elif model_num == 3:
            return self._build_temporal_conv_lstm(x, params)
        elif model_num == 4:
            return self._build_transformer(x, params)
        elif model_num == 5:
            return self._build_gru_attention_skip(x, params)
        elif model_num == 6:
            return self._build_digit_correlation_model(x, params)
        elif model_num == 7:
            return self._build_probability_model(x, params)
        else:
            return self._build_lstm_cnn(x, params)

    def _build_lstm_gru_attention(self, x, params):
        """LSTM+GRU+注意力模型 (from model_builder.py)"""
        x = LSTM(params.get('lstm_units', 128), return_sequences=True)(x)
        x = GRU(params.get('gru_units', 64), return_sequences=True)(x)
        x = MultiHeadAttention(
            num_heads=params.get('attention_heads', 4),
            key_dim=params.get('key_dim', 16),
            value_dim=params.get('value_dim', 16)
        )(x, x)
        return x

    def _build_bilstm_residual(self, x, params):
        """BiLSTM残差网络 (from model_builder.py)"""
        try:
            main = Bidirectional(LSTM(params['lstm_units'], return_sequences=True))(x)
            residual = Conv1D(params['lstm_units']*2, kernel_size=1)(x)
            x = Add()([main, residual])
            return LayerNormalization()(x)
        except Exception as e:
            logger.error(f"构建BiLSTM残差网络时出错: {str(e)}")
            return x

    def _build_temporal_conv_lstm(self, x, params):
        """时空卷积LSTM (from model_builder.py)"""
        try:
            x = Conv1D(64, kernel_size=3, padding='same')(x)
            x = Conv1D(64, kernel_size=3, padding='causal', dilation_rate=2)(x)
            x = tf.keras.layers.PReLU()(x)
            x = LSTM(64, return_sequences=True)(x)
            return x
        except Exception as e:
            logger.error(f"构建时空卷积LSTM时出错: {str(e)}")
            return x

    def _build_transformer(self, x, params):
        """构建Transformer模型 (from model_builder.py)"""
        try:
            x = MultiHeadAttention(
                num_heads=params['num_heads'],
                key_dim=params['key_dim']
            )(x, x)
            return LayerNormalization()(x)
        except Exception as e:
            logger.error(f"构建Transformer时出错: {str(e)}")
            return x

    def _build_gru_attention_skip(self, x, params):
        """GRU + 自注意力 + 跳跃连接 (from model_builder.py)"""
        try:
            gru_out = GRU(params['gru_units'], return_sequences=True)(x)
            att = MultiHeadAttention(num_heads=4, key_dim=16)(gru_out, gru_out)
            x = Add()([att, x])
            return LayerNormalization()(x)
        except Exception as e:
            logger.error(f"构建GRU注意力网络时出错: {str(e)}")
            return x

    def _build_lstm_cnn(self, x, params):
        """LSTM + CNN模型 (from model_builder.py)"""
        try:
            x = LSTM(params['lstm_units'], return_sequences=True)(x)
            x = Conv1D(filters=16, kernel_size=3, padding='same')(x)
            return tf.keras.layers.BatchNormalization()(x)
        except Exception as e:
            logger.error(f"构建LSTM-CNN时出错: {str(e)}")
            return x

    def _build_digit_correlation_model(self, x, params):
        """数字关联分析模型"""
        try:
            # 1. 相邻数字关系
            adjacent_patterns = Conv1D(64, kernel_size=2, strides=1, padding='same')(x)
            
            # 2. 数字组合模式
            combination_patterns = []
            for window in [3, 5, 7]:
                pattern = Conv1D(32, kernel_size=window, padding='same')(x)
                combination_patterns.append(pattern)
            
            # 3. 合并所有模式
            x = tf.keras.layers.Concatenate()([adjacent_patterns, *combination_patterns])
            x = LSTM(128, return_sequences=True)(x)
            return x
        except Exception as e:
            logger.error(f"构建数字关联分析模型时出错: {str(e)}")
            return x

    def _build_probability_model(self, x, params):
        """概率分布学习模型"""
        try:
            # 1. 历史概率分布
            hist_probs = self._build_historical_probabilities(x)
            
            # 2. 条件概率特征
            cond_probs = self._build_conditional_probabilities(x)
            
            # 3. 组合概率模型
            x = tf.keras.layers.Concatenate()([hist_probs, cond_probs])
            x = Dense(256, activation='relu')(x)
            x = Dense(128, activation='relu')(x)
            return x
        except Exception as e:
            logger.error(f"构建概率模型时出错: {str(e)}")
            return x

    def _build_historical_probabilities(self, x):
        """构建历史概率分布特征"""
        try:
            # 1. 计算历史频率分布
            freqs = tf.zeros((10, 5))  # 10个数字在5个位置的频率
            for i in range(5):
                digit_freqs = tf.keras.layers.Lambda(
                    lambda x: tf.cast(
                        tf.histogram_fixed_width(x[..., i], [0, 9], nbins=10),
                        tf.float32
                    )
                )(x)
                freqs = tf.tensor_scatter_nd_update(
                    freqs,
                    [[j, i] for j in range(10)],
                    digit_freqs
                )
            
            # 2. 计算条件概率
            cond_probs = self._calculate_conditional_probs(x)
            
            return tf.concat([freqs, cond_probs], axis=-1)
        except Exception as e:
            logger.error(f"构建历史概率分布特征时出错: {str(e)}")
            return x

    def _build_conditional_probabilities(self, x):
        """构建条件概率特征"""
        try:
            # 1. 计算相邻位置条件概率
            adjacent_probs = []
            for i in range(4):
                curr = tf.cast(x[..., i], tf.int32)
                next_digit = tf.cast(x[..., i+1], tf.int32)
                probs = self._compute_transition_probs(curr, next_digit)
                adjacent_probs.append(probs)
            
            # 2. 计算跳跃位置条件概率
            skip_probs = []
            for i in range(3):
                curr = tf.cast(x[..., i], tf.int32)
                next_digit = tf.cast(x[..., i+2], tf.int32)
                probs = self._compute_transition_probs(curr, next_digit)
                skip_probs.append(probs)
            
            return tf.concat([*adjacent_probs, *skip_probs], axis=-1)
        except Exception as e:
            logger.error(f"构建条件概率特征时出错: {str(e)}")
            return x

    def _compute_transition_probs(self, curr_digits, next_digits):
        """计算转移概率矩阵"""
        try:
            # 创建10x10的转移矩阵
            transition_matrix = tf.zeros((10, 10))
            
            # 统计转移次数
            for i in range(10):
                for j in range(10):
                    mask_curr = tf.cast(curr_digits == i, tf.float32)
                    mask_next = tf.cast(next_digits == j, tf.float32)
                    count = tf.reduce_sum(mask_curr * mask_next)
                    transition_matrix = tf.tensor_scatter_nd_update(
                        transition_matrix,
                        [[i, j]],
                        [count]
                    )
            
            # 计算概率
            row_sums = tf.reduce_sum(transition_matrix, axis=1, keepdims=True)
            probs = transition_matrix / (row_sums + 1e-7)
            
            return probs
        except Exception as e:
            logger.error(f"计算转移概率时出错: {str(e)}")
            return tf.zeros((10, 10))

    def _build_combined_features(self, feature_list):
        """构建组合特征 (from model_builder.py)"""
        try:
            # 1. 特征连接
            x = tf.keras.layers.Concatenate()(feature_list)
            
            # 2. 非线性变换
            x = Dense(256, activation='relu')(x)
            x = tf.keras.layers.BatchNormalization()(x)
            x = tf.keras.layers.Dropout(0.2)(x)
            
            # 3. 特征交互
            x = self._build_feature_interactions(x)
            
            return x
            
        except Exception as e:
            logger.error(f"构建组合特征时出错: {str(e)}")
            return tf.zeros_like(x)

    def _build_feature_interactions(self, x):
        """构建特征交互 (from model_builder.py)"""
        try:
            # 1. 自注意力交互
            att = MultiHeadAttention(
                num_heads=4,
                key_dim=32
            )(x, x)
            x = Add()([x, att])
            x = LayerNormalization()(x)
            
            # 2. 非线性特征组合
            x = Dense(128, activation='relu')(x)
            x = Dense(64, activation='relu')(x)
            
            return x
            
        except Exception as e:
            logger.error(f"构建特征交互时出错: {str(e)}")
            return tf.zeros_like(x)

    def _build_ensemble_predictor(self, x, params):
        """集成预测器"""
        # 1. 多模型特征
        features = []
        
        # 组合模式特征
        comb_features = self._build_combination_predictor(x, params)
        # 概率特征
        prob_features = self._build_probability_predictor(x, params)
        # 周期特征  
        period_features = self._build_periodic_predictor(x, params)
        # 趋势特征
        trend_features = self._build_trend_predictor(x, params)
        # 统计特征
        stat_features = self._build_statistical_predictor(x, params)
        
        features.extend([
            comb_features, prob_features, period_features,
            trend_features, stat_features
        ])
        
        # 2. 特征融合
        x = tf.keras.layers.Concatenate()(features)
        
        return x

    def _build_attention_residual_block(self, x, params):
        """注意力残差块"""
        # 多头自注意力
        att = MultiHeadAttention(
            num_heads=params.get('num_heads', 4),
            key_dim=params.get('key_dim', 32)
        )(x, x)
        
        # 残差连接
        x = Add()([x, att])
        x = LayerNormalization()(x)
        
        # FFN
        ffn = Dense(params.get('ffn_dim', 256), activation='relu')(x)
        ffn = Dense(x.shape[-1])(ffn)
        
        # 残差连接
        x = Add()([x, ffn])
        x = LayerNormalization()(x)
        
        return x

    def _build_combination_predictor(self, x, params):
        """组合数字预测器"""
        try:
            # 1. 局部组合模式
            local_patterns = []
            for window in [2, 3, 4]:
                pattern = Conv1D(32, kernel_size=window, padding='same')(x)
                local_patterns.append(pattern)
            
            # 2. 全局组合模式
            global_pattern = self._build_attention_residual_block(
                tf.concat(local_patterns, axis=-1),
                params
            )
            
            return global_pattern
        except Exception as e:
            logger.error(f"构建组合预测器出错: {str(e)}")
            return x

    def _build_trend_predictor(self, x, params):
        """趋势预测器"""
        try:
            # 1. 多尺度趋势分析 
            trends = []
            windows = [60, 360, 720, 1440]
            
            for window in windows:
                ma = tf.keras.layers.AveragePooling1D(
                    pool_size=window, strides=1, padding='same')(x)
                trend = tf.sign(x - ma)
                trends.append(trend)
            
            # 2. 趋势特征融合
            x = tf.keras.layers.Concatenate()(trends)
            x = Bidirectional(LSTM(64, return_sequences=True))(x)
            return x
        except Exception as e:
            logger.error(f"构建趋势预测器出错: {str(e)}")
            return x

    def _build_statistical_predictor(self, x, params):
        """统计模式预测器"""
        try:
            stats = []
            
            # 均值特征
            mean = tf.reduce_mean(x, axis=1, keepdims=True)
            # 标准差特征
            std = tf.math.reduce_std(x, axis=1, keepdims=True)
            # 峰度
            kurtosis = tf.reduce_mean(tf.pow(x - mean, 4), axis=1, keepdims=True) / tf.pow(std, 4)
            # 偏度
            skewness = tf.reduce_mean(tf.pow(x - mean, 3), axis=1, keepdims=True) / tf.pow(std, 3)
            
            stats.extend([mean, std, kurtosis, skewness])
            
            x = tf.keras.layers.Concatenate()(stats)
            x = Dense(128, activation='relu')(x)
            return x
        except Exception as e:
            logger.error(f"构建统计预测器出错: {str(e)}")
            return x

    # 4. 训练评估 (from base_model.py)
    def _build_prediction_head(self, x):
        """预测输出头 (from base_model.py)"""
        # 1. 每位数字的概率分布
        batch_size = tf.shape(x)[0]
        digit_predictions = []
        for i in range(5):
            digit_pred = Dense(10, activation='softmax', name=f'digit_{i}')(x)
            digit_predictions.append(digit_pred)

        # 2. 完整号码匹配概率
        match_prob = Dense(self.prediction_range, activation='sigmoid', name='match_prob')(x)

        # 3. 预测置信度
        confidence = Dense(1, activation='sigmoid', name='confidence')(x)
        
        return {
            'digits': digit_predictions,
            'match_prob': match_prob,
            'confidence': confidence
        }

    def enhanced_match_loss(self, y_true, y_pred):
        """增强型匹配损失函数 (from base_model.py)"""
        try:
            # 1. 预处理
            y_pred_expanded = tf.expand_dims(y_pred, axis=1)
            y_pred_rounded = tf.round(y_pred_expanded)
            
            # 2. 计算匹配情况
            matches = tf.cast(tf.equal(y_true, y_pred_rounded), tf.float32)
            match_counts = tf.reduce_sum(matches, axis=-1)
            best_match_indices = tf.argmax(match_counts, axis=1)
            best_targets = tf.gather(y_true, best_match_indices, batch_dims=1)
            best_match_counts = tf.reduce_max(match_counts, axis=1)
            
            # 3. 计算基础匹配损失
            base_loss = tf.reduce_mean(tf.abs(y_pred - best_targets), axis=1)
            
            # 4. 计算方向性损失
            direction_loss = self._calculate_direction_loss(y_pred, best_targets)
            
            # 5. 完全匹配时损失为0
            perfect_match = tf.cast(tf.equal(best_match_counts, 5.0), tf.float32)
            
            # 6. 组合损失(动态权重)
            direction_weight = tf.exp(-best_match_counts / 5.0) * 0.5
            total_loss = base_loss * (1.0 - perfect_match) + direction_weight * direction_loss
            
            return total_loss
        except Exception as e:
            logger.error(f"计算损失时出错: {str(e)}")
            return 5.0 * tf.ones_like(y_pred[:, 0])

    def compile_model(self, model):
        """编译模型 (from base_model.py)"""
        model.compile(
            optimizer=self.optimizer,
            loss=self.enhanced_match_loss,
            metrics=['accuracy']
        )

    def train_step(self, batch_data):
        """训练步骤 (from base_model.py)"""
        try:
            with tf.GradientTape() as tape:
                predictions = self.model(batch_data['input'], training=True)
                loss = self.enhanced_match_loss(batch_data['target'], predictions)
            gradients = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
            return loss
        except Exception as e:
            logger.error(f"训练步骤执行出错: {str(e)}")
            return None

    def validate_step(self, batch_data):
        """验证步骤 (from base_model.py)"""
        try:
            predictions = self.model(batch_data['input'], training=False)
            loss = self.enhanced_match_loss(batch_data['target'], predictions)
            matches = self._calculate_matches(predictions, batch_data['target'])
            return {
                'loss': loss,
                'accuracy': tf.reduce_mean(matches)
            }
        except Exception as e:
            logger.error(f"验证步骤执行出错: {str(e)}")
            return None

    def _calculate_matches(self, predictions, targets):
        """计算匹配程度 (from base_model.py)"""
        try:
            rounded_preds = tf.round(predictions)
            matches = tf.cast(tf.equal(rounded_preds, targets), tf.float32)
            full_matches = tf.reduce_all(matches, axis=-1)
            return full_matches
        except Exception as e:
            logger.error(f"计算匹配程度时出错: {str(e)}")
            return tf.zeros_like(predictions[..., 0])

    def _calculate_direction_loss(self, y_pred, best_targets):
        """计算方向性损失 (from base_model.py)"""
        try:
            value_diff = best_targets - y_pred
            direction_mask = tf.cast(
                tf.not_equal(tf.round(y_pred), best_targets),
                tf.float32
            )
            direction_factor = tf.sigmoid(value_diff * 2.0) * 2.0 - 1.0
            return tf.reduce_mean(
                direction_mask * direction_factor * tf.abs(value_diff),
                axis=-1
            )
        except Exception as e:
            logger.error(f"计算方向性损失时出错: {str(e)}")
            return tf.zeros_like(y_pred[:, 0])

    # 5. 模型保存加载 (from base_model.py)
    def save_model(self, path: str):
        """保存模型 (from base_model.py)"""
        try:
            self.model.save(path)
            logger.info(f"模型已保存到: {path}")
        except Exception as e:
            logger.error(f"保存模型时出错: {str(e)}")

    def load_model(self, path: str):
        """加载模型 (from base_model.py)"""
        try:
            self.model = tf.keras.models.load_model(
                path,
                custom_objects={'enhanced_match_loss': self.enhanced_match_loss}
            )
            logger.info(f"已加载模型: {path}")
        except Exception as e:
            logger.error(f"加载模型时出错: {str(e)}")

    def predict(self, input_data):
        """执行预测"""
        try:
            # 添加输入验证
            if input_data.shape[-1] != 8:  # 5个号码+3个时间特征
                logger.error(f"输入特征维度错误，预期8维，实际收到{input_data.shape[-1]}维")
                return None
            
            # 获取集成预测结果
            predictions = []
            confidences = []
            
            for i, model in enumerate(self.models):
                pred = model.predict(input_data)
                pred_value = pred['digits']
                confidence = pred['confidence']
                
                predictions.append(pred_value * self.weights[i])
                confidences.append(confidence)
            
            # 集成预测结果
            ensemble_pred = np.sum(predictions, axis=0)
            mean_confidence = np.mean(confidences)
            
            return {
                'prediction': ensemble_pred,
                'confidence': mean_confidence
            }
            
        except Exception as e:
            logger.error(f"预测失败: {str(e)}")
            return None

    def predict_with_cache(self, X):
        """带缓存的预测"""
        cache_key = hash(str(X))
        current_time = time.time()
        
        # 检查缓存
        if cache_key in self.prediction_cache:
            cached_result, cache_time = self.prediction_cache[cache_key]
            if current_time - cache_time < self.cache_timeout:
                return cached_result
        
        # 执行预测
        result = self.predict(X)
        
        # 更新缓存
        self.prediction_cache[cache_key] = (result, current_time)
        
        return result

    def get_prediction_history(self, start_time=None, end_time=None):
        """获取预测历史"""
        history = list(self.prediction_history)
        
        if start_time:
            history = [h for h in history if h['timestamp'] >= start_time]
        if end_time:
            history = [h for h in history if h['timestamp'] <= end_time]
            
        return history

    def analyze_prediction_accuracy(self):
        """分析预测准确率"""
        if not self.prediction_history:
            return None
            
        correct = 0
        total = 0
        
        for pred in self.prediction_history:
            if pred.get('actual') is not None:
                correct += int(np.array_equal(
                    pred['prediction'],
                    pred['actual']
                ))
                total += 1
        
        accuracy = correct / total if total > 0 else 0
        
        return {
            'accuracy': accuracy,
            'total_predictions': total,
            'correct_predictions': correct
        }

    def _record_prediction(self, prediction, confidence):
        """记录预测结果"""
        try:
            self.prediction_history.append({
                'prediction': prediction,
                'confidence': confidence,
                'timestamp': datetime.now()
            })
            logger.info(f"记录预测结果: {prediction}, 置信度: {confidence:.2f}")
        except Exception as e:
            logger.error(f"记录预测结果失败: {str(e)}")

    def _clean_prediction_cache(self):
        """清理过期的预测缓存"""
        try:
            current_time = time.time()
            expired_keys = [
                k for k, (_, cache_time) in self.prediction_cache.items()
                if current_time - cache_time > self.cache_timeout
            ]
            for k in expired_keys:
                del self.prediction_cache[k]
            if expired_keys:
                logger.info(f"清理了{len(expired_keys)}条过期预测缓存")
        except Exception as e:
            logger.error(f"清理预测缓存失败: {str(e)}")

    def get_prediction_stats(self) -> Dict[str, Any]:
        """获取预测统计信息"""
        try:
            if not self.prediction_history:
                return {}
                
            recent_predictions = list(self.prediction_history)[-100:]
            return {
                'total_predictions': len(self.prediction_history),
                'recent_avg_confidence': np.mean([p['confidence'] for p in recent_predictions]),
                'cache_hit_rate': self._calculate_cache_hit_rate(),
                'last_prediction_time': self.prediction_history[-1]['timestamp'],
                'accuracy_stats': self.analyze_prediction_accuracy()
            }
        except Exception as e:
            logger.error(f"获取预测统计失败: {str(e)}")
            return {}

    def _calculate_cache_hit_rate(self) -> float:
        """计算缓存命中率"""
        try:
            if not hasattr(self, '_cache_stats'):
                self._cache_stats = {'hits': 0, 'misses': 0}
            total = self._cache_stats['hits'] + self._cache_stats['misses']
            return self._cache_stats['hits'] / total if total > 0 else 0
        except Exception as e:
            logger.error(f"计算缓存命中率失败: {str(e)}")
            return 0.0

    def _build_model(self, config):
        """构建单个模型（添加编译步骤）"""
        try:
            with self.graph.as_default():
                inputs = tf.keras.Input(shape=(self.input_shape, 5))
                x = Conv1D(config['filters'], 3, activation='relu')(inputs)
                x = LSTM(config['units'], return_sequences=True)(x)
                x = Dense(config['dense_units'], activation='relu')(x)
                outputs = Dense(5, activation='softmax')(x)
                model = tf.keras.Model(inputs=inputs, outputs=outputs)
                
                # 编译模型
                model.compile(
                    optimizer=tf.keras.optimizers.Adam(
                        learning_rate=self.config['optimizer_config']['learning_rate']
                    ),
                    loss='mse',
                    metrics=['mae']
                )
                return model
            
        except KeyError as e:
            logger.error(f"配置缺失关键参数: {str(e)}")
            raise

    def _create_optimizer(self):
        """创建优化器"""
        try:
            opt_cfg = self.config['optimizer_config']
            return tf.keras.optimizers.Adam(
                learning_rate=opt_cfg['learning_rate'],
                beta_1=opt_cfg['beta_1'],
                beta_2=opt_cfg['beta_2']
            )
        except KeyError as e:
            logger.error(f"配置缺失关键参数: {str(e)}，使用默认优化器")
            return tf.keras.optimizers.Adam(learning_rate=0.001)

    def reset_models(self):
        """安全重置模型"""
        try:
            # 清理现有会话
            if self.session:
                self.session.close()
            # 创建新图和新会话
            self.graph = tf.Graph()
            with self.graph.as_default():
                self.models = [self._build_model(self.default_config) for _ in range(6)]
            self.session = tf.compat.v1.Session(graph=self.graph)
            tf.compat.v1.keras.backend.set_session(self.session)
            logger.info("模型重置成功")
        except Exception as e:
            logger.error(f"模型重置失败: {str(e)}")
            raise

# 创建全局模型核心实例
model_core = ModelCore()

# 生成测试数据
test_input = np.random.rand(1, 14400, 5)  # 批次大小1，序列长度14400，特征维度5

with model_core.session.as_default():
    with model_core.graph.as_default():
        # 创建输入占位符
        input_tensor = tf.placeholder(tf.float32, shape=(1, 14400, 5))
        # 获取编码输出
        encoded_output = model_core._add_positional_encoding(input_tensor)
        # 运行会话
        result = model_core.session.run(encoded_output, feed_dict={input_tensor: test_input})

print("输入形状:", test_input.shape)
print("编码后形状:", result.shape)
print("编码示例(前3个时间步):\n", result[0, :3, :5])



Instructions for updating:
non-resource variables are not supported in the long term


输入形状: (1, 14400, 5)
编码后形状: (1, 14400, 5)
编码示例(前3个时间步):
 [[ 0.8466895   1.3766761   0.15630183  1.8638542   0.31040782]
 [ 0.84151846  1.2634137   0.82372415  1.9615151   0.03049649]
 [ 0.9386499  -0.27177608  0.28308994  1.6084507   0.5598241 ]]


In [7]:
#7 Data Optimization Module / 数据优化模块
import numpy as np
from scipy import stats
from statsmodels.tsa.stattools import adfuller
import logging

logger = logging.getLogger(__name__)

class DataOptimizer:
    def _evaluate_distribution(self, X):
        """评估数据分布情况"""
        try:
            # 1. 检查数据偏度
            skewness = np.abs(np.mean([np.abs(stats.skew(X[:, i])) for i in range(X.shape[1])]))
            skewness_score = 1 / (1 + skewness)  # 转换为0-1分数
            
            # 2. 检查数据峰度
            kurtosis = np.abs(np.mean([np.abs(stats.kurtosis(X[:, i])) for i in range(X.shape[1])]))
            kurtosis_score = 1 / (1 + kurtosis)  # 转换为0-1分数
            
            # 3. 检查异常值比例
            z_scores = np.abs(stats.zscore(X))
            outlier_ratio = np.mean(z_scores > 3)  # 3个标准差以外视为异常值
            outlier_score = 1 - outlier_ratio
            
            # 计算加权平均分数
            distribution_score = (
                0.4 * skewness_score +
                0.3 * kurtosis_score +
                0.3 * outlier_score
            )
            
            return distribution_score
            
        except Exception as e:
            logger.error(f"评估数据分布时出错: {str(e)}")
            return 0.0

    def _evaluate_correlation(self, X):
        """评估特征相关性"""
        try:
            # 1. 计算特征间相关系数矩阵
            corr_matrix = np.corrcoef(X.T)
            
            # 2. 计算特征间的平均相关性
            # 去除对角线上的1
            mask = ~np.eye(corr_matrix.shape[0], dtype=bool)
            avg_correlation = np.mean(np.abs(corr_matrix[mask]))
            
            # 3. 计算相关性得分
            correlation_score = 1 - avg_correlation  # 越小越好
            
            return correlation_score
            
        except Exception as e:
            logger.error(f"评估特征相关性时出错: {str(e)}")
            return 0.0

    def _evaluate_time_series(self, X):
        """评估时间序列特性"""
        try:
            # 1. 检查平稳性
            stationarity_scores = []
            for i in range(X.shape[1]):
                # 使用ADF测试检查平稳性
                adf_result = adfuller(X[:, i])[1]  # 获取p值
                stationarity_scores.append(1 - min(adf_result, 1))  # 转换为0-1分数
            
            stationarity_score = np.mean(stationarity_scores)
            
            # 2. 检查自相关性
            autocorr_scores = []
            for i in range(X.shape[1]):
                # 计算滞后1期的自相关系数
                autocorr = np.corrcoef(X[1:, i], X[:-1, i])[0, 1]
                autocorr_scores.append(abs(autocorr))
            
            autocorr_score = np.mean(autocorr_scores)
            
            # 3. 检查趋势性
            trend_scores = []
            for i in range(X.shape[1]):
                # 使用简单线性回归检测趋势
                slope = np.polyfit(np.arange(len(X)), X[:, i], 1)[0]
                trend_scores.append(abs(slope))
            
            trend_score = 1 / (1 + np.mean(trend_scores))  # 转换为0-1分数
            
            # 计算加权平均分数
            time_series_score = (
                0.4 * stationarity_score +
                0.3 * autocorr_score +
                0.3 * trend_score
            )
            
            return time_series_score
            
        except Exception as e:
            logger.error(f"评估时间序列特性时出错: {str(e)}")
            return 0.0

    def _calculate_trend(self, values):
        """计算趋势"""
        if len(values) < 2:
            return "INSUFFICIENT_DATA"
            
        # 使用简单线性回归
        x = np.arange(len(values))
        slope = np.polyfit(x, values, 1)[0]
        
        if slope < -0.01:
            return "IMPROVING"
        elif slope > 0.01:
            return "DEGRADING"
        else:
            return "STABLE"

    def _compute_correlation(self, x1, x2):
        """计算相关系数"""
        try:
            # 标准化
            x1_norm = (x1 - np.mean(x1)) / np.std(x1)
            x2_norm = (x2 - np.mean(x2)) / np.std(x2)
            
            # 计算相关系数
            corr = np.mean(x1_norm * x2_norm)
            return corr
            
        except Exception as e:
            logger.error(f"计算相关系数时出错: {str(e)}")
            return 0.0

    def optimize_feature_selection(self, X, y, n_features=10):
        """优化特征选择"""
        try:
            # 1. 计算特征重要性
            importance = self._calculate_feature_importance(X, y)
            
            # 2. 计算特征冗余度
            redundancy = self._calculate_feature_redundancy(X)
            
            # 3. 综合评分
            final_score = importance * (1 - redundancy)
            
            # 4. 选择最优特征
            selected = np.argsort(final_score)[-n_features:]
            
            return selected, final_score[selected]
            
        except Exception as e:
            logger.error(f"优化特征选择时出错: {str(e)}")
            return None, None
    
    def _calculate_feature_importance(self, X, y):
        """计算特征重要性分数"""
        try:
            correlations = []
            for i in range(X.shape[1]):
                corr = np.abs(np.corrcoef(X[:, i], y)[0, 1])
                correlations.append(corr)
            return np.array(correlations)
        except Exception as e:
            logger.error(f"计算特征重要性时出错: {str(e)}")
            return None

    def _calculate_feature_redundancy(self, X):
        """计算特征冗余度"""
        try:
            n_features = X.shape[1]
            redundancy = np.zeros(n_features)
            
            for i in range(n_features):
                correlations = []
                for j in range(n_features):
                    if i != j:
                        corr = np.abs(np.corrcoef(X[:, i], X[:, j])[0, 1])
                        correlations.append(corr)
                redundancy[i] = np.mean(correlations)
            
            return redundancy
            
        except Exception as e:
            logger.error(f"计算特征冗余度时出错: {str(e)}")
            return None

    def evaluate_feature_correlation(self, features):
        """评估特征相关性矩阵"""
        try:
            corr_matrix = np.corrcoef(features.T)
            return corr_matrix
        except Exception as e:
            logger.error(f"评估特征相关性时出错: {str(e)}")
            return None

    def optimize_feature_combination(self, features, target):
        """优化特征组合"""
        try:
            # 1. 计算相关性
            correlations = self.evaluate_feature_correlation(features)
            
            # 2. 计算稳定性
            stability = self._analyze_feature_stability(features)
            
            # 3. 计算目标相关性
            target_corr = np.array([
                abs(np.corrcoef(features[:, i], target)[0, 1])
                for i in range(features.shape[1])
            ])
            
            # 4. 优化组合
            scores = target_corr * stability
            return scores
            
        except Exception as e:
            logger.error(f"优化特征组合时出错: {str(e)}")
            return None

    def _analyze_feature_stability(self, features):
        """分析特征稳定性"""
        try:
            stability_scores = []
            for i in range(features.shape[1]):
                # 计算特征的变异系数
                cv = np.std(features[:, i]) / np.mean(np.abs(features[:, i]))
                # 转换为稳定性分数
                stability = 1 / (1 + cv)
                stability_scores.append(stability)
            return np.array(stability_scores)
        except Exception as e:
            logger.error(f"分析特征稳定性时出错: {str(e)}")
            return None

In [8]:
#8 Parameter Tuning System / 参数调优系统
import numpy as np
import tensorflow as tf
import logging
import os
import json
from datetime import datetime
from bayes_opt import BayesianOptimization
from bayes_opt.logger import Events
import optuna
from sklearn.metrics import mutual_info_score
from typing import Dict, Any, Optional
from collections import deque
import tensorflow.keras.backend as K

# 获取logger实例
logger = logging.getLogger(__name__)

class OptimizerManager:
    """优化器管理类 - 整合模型、动态和集成优化"""
    
    def __init__(self, model_ensemble, data_processor, performance_monitor):
        """
        初始化优化器管理器
        Args:
            model_ensemble: 模型集成实例
            data_processor: 数据处理器实例
            performance_monitor: 性能监控器实例
        """
        self.model_ensemble = model_ensemble
        self.data_processor = data_processor
        self.performance_monitor = performance_monitor
        
        # 优化历史记录
        self.optimization_history = []
        
        # 初始化参数范围
        self._init_param_ranges()
        
        # 初始化调整阈值
        self._init_thresholds()
        
        logger.info("优化器管理器初始化完成")

    def _init_param_ranges(self):
        """初始化所有参数范围"""
        self.param_ranges = {
            # 模型架构参数
            'model_params': {
                'lstm_units': (64, 256),
                'lstm_layers': (1, 4),
                'cnn_filters': (32, 256),
                'transformer_heads': (4, 16),
                'dense_units': (32, 128)
            },
            
            # 动态调整参数
            'dynamic_params': {
                'learning_rate': (0.0001, 0.01),
                'batch_size': (16, 128),
                'dropout_rate': (0.1, 0.5)
            },
            
            # 集成参数
            'ensemble_params': {
                'initial_weights': (0.1, 0.3),
                'diversity_weight': (0.1, 0.5),
                'adaptation_rate': (0.1, 0.5)
            },
            
            # 添加训练优化参数范围
            'training_params': {
                'optimizer_params': {
                    'learning_rate': (1e-5, 1e-2),
                    'beta_1': (0.8, 0.999),
                    'beta_2': (0.8, 0.999),
                    'epsilon': (1e-8, 1e-6)
                },
                'lr_schedule_params': {
                    'decay_rate': (0.9, 0.99),
                    'decay_steps': (100, 1000),
                    'warmup_steps': (0, 100),
                    'min_lr': (1e-6, 1e-4)
                },
                'training_control': {
                    'batch_size': (16, 128),
                    'epochs_per_iteration': (1, 10),
                    'validation_frequency': (1, 10),
                    'early_stopping_patience': (10, 50)
                }
            }
        }
        
        # 添加离散参数选项
        self.discrete_params = {
            'optimizer_type': ['adam', 'adamw', 'radam'],
            'scheduler_type': ['exponential', 'cosine', 'step']
        }

    def _init_thresholds(self):
        """初始化调整阈值"""
        self.thresholds = {
            'performance_drop': 0.1,    # 性能下降阈值
            'loss_spike': 0.5,         # 损失突增阈值
            'diversity_min': 0.3,      # 最小多样性要求
            'weight_change': 0.2       # 权重调整阈值
        }

    def optimize_all(self, n_iter=50):
        """执行全面优化"""
        try:
            # 1. 模型架构优化
            model_params = self._optimize_model_architecture(n_iter)
            
            # 2. 动态参数优化
            dynamic_params = self._optimize_dynamic_params(n_iter)
            
            # 3. 集成策略优化
            ensemble_params = self._optimize_ensemble_strategy(n_iter)
            
            # 整合优化结果
            optimized_params = {
                'model_params': model_params,
                'dynamic_params': dynamic_params,
                'ensemble_params': ensemble_params
            }
            
            # 保存优化结果
            self._save_optimization_results(optimized_params)
            
            return optimized_params
            
        except Exception as e:
            logger.error(f"执行全面优化时出错: {str(e)}")
            return None

    def _optimize_model_architecture(self, n_iter):
        """模型架构优化"""
        try:
            optimizer = BayesianOptimization(
                f=self._model_objective,
                pbounds=self.param_ranges['model_params'],
                random_state=42
            )
            
            optimizer.maximize(init_points=5, n_iter=n_iter)
            return optimizer.max['params']
            
        except Exception as e:
            logger.error(f"模型架构优化失败: {str(e)}")
            return None

    def _optimize_dynamic_params(self, n_iter):
        """动态参数优化"""
        try:
            optimizer = BayesianOptimization(
                f=self._dynamic_objective,
                pbounds=self.param_ranges['dynamic_params'],
                random_state=42
            )
            
            optimizer.maximize(init_points=5, n_iter=n_iter)
            return optimizer.max['params']
            
        except Exception as e:
            logger.error(f"动态参数优化失败: {str(e)}")
            return None

    def _optimize_ensemble_strategy(self, n_iter):
        """集成策略优化"""
        try:
            optimizer = BayesianOptimization(
                f=self._ensemble_objective,
                pbounds=self.param_ranges['ensemble_params'],
                random_state=42
            )
            
            optimizer.maximize(init_points=5, n_iter=n_iter)
            return optimizer.max['params']
            
        except Exception as e:
            logger.error(f"集成策略优化失败: {str(e)}")
            return None

    def _model_objective(self, **params):
        """模型优化目标函数"""
        try:
            self.model_ensemble.update_architecture(params)
            return self._evaluate_performance()
        except Exception as e:
            logger.error(f"模型目标函数评估失败: {str(e)}")
            return float('-inf')

    def _dynamic_objective(self, **params):
        """动态优化目标函数"""
        try:
            self.model_ensemble.update_dynamic_params(params)
            return self._evaluate_performance()
        except Exception as e:
            logger.error(f"动态目标函数评估失败: {str(e)}")
            return float('-inf')

    def _ensemble_objective(self, **params):
        """集成优化目标函数"""
        try:
            self.model_ensemble.update_ensemble_params(params)
            performance = self._evaluate_performance()
            diversity = self._calculate_diversity()
            return 0.7 * performance + 0.3 * diversity
        except Exception as e:
            logger.error(f"集成目标函数评估失败: {str(e)}")
            return float('-inf')

    def _evaluate_performance(self):
        """评估性能"""
        try:
            X_val, y_val = self.data_processor.get_validation_data()
            predictions = self.model_ensemble.predict(X_val)
            matches = np.any(np.round(predictions) == y_val, axis=1)
            return np.mean(matches)
        except Exception as e:
            logger.error(f"性能评估失败: {str(e)}")
            return 0.0

    def _calculate_diversity(self):
        """计算模型多样性"""
        try:
            predictions = []
            X_val, _ = self.data_processor.get_validation_data()
            
            for model in self.model_ensemble.models:
                pred = model.predict(X_val)
                predictions.append(pred)
            
            diversity_scores = []
            n_models = len(predictions)
            
            for i in range(n_models):
                for j in range(i + 1, n_models):
                    mi_score = mutual_info_score(
                        predictions[i].ravel(),
                        predictions[j].ravel()
                    )
                    diversity_scores.append(1 - mi_score)
            
            return np.mean(diversity_scores)
            
        except Exception as e:
            logger.error(f"多样性计算失败: {str(e)}")
            return 0.0

    def _save_optimization_results(self, results):
        """保存优化结果"""
        try:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"optimization_results_{timestamp}.json"
            
            save_path = os.path.join(os.getcwd(), 'optimization_results', filename)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            
            with open(save_path, 'w') as f:
                json.dump(results, f, indent=4)
                
            logger.info(f"优化结果已保存到: {save_path}")
            
        except Exception as e:
            logger.error(f"保存优化结果失败: {str(e)}")

    def dynamic_adjust(self, metrics):
        """动态参数调整"""
        try:
            if self._needs_adjustment(metrics):
                suggestions = self._get_adjustment_suggestions(metrics)
                return self._apply_adjustments(suggestions)
            return None
        except Exception as e:
            logger.error(f"动态调整失败: {str(e)}")
            return None

    def _needs_adjustment(self, metrics):
        """检查是否需要调整"""
        return any([
            metrics['performance_change'] < -self.thresholds['performance_drop'],
            metrics['loss_change'] > self.thresholds['loss_spike'],
            metrics['diversity'] < self.thresholds['diversity_min']
        ])

    def _get_adjustment_suggestions(self, metrics):
        """获取参数调整建议"""
        try:
            suggestions = {}
            
            # 基于性能变化的学习率调整
            if metrics['performance_change'] < 0:
                current_lr = metrics['current_lr']
                suggestions['learning_rate'] = self._adjust_learning_rate(
                    current_lr, 
                    metrics['performance_change']
                )
            
            # 基于内存使用的批次大小调整    
            if metrics['memory_usage'] > 0.9:
                current_batch = metrics['current_batch_size']
                suggestions['batch_size'] = self._adjust_batch_size(current_batch)
            
            # 基于过拟合风险的正则化调整
            if metrics['validation_loss'] > metrics['training_loss'] * 1.2:
                suggestions['dropout_rate'] = self._adjust_dropout_rate(
                    metrics['current_dropout']
                )
                
            return suggestions
            
        except Exception as e:
            logger.error(f"生成调整建议失败: {str(e)}")
            return {}

    def _adjust_learning_rate(self, current_lr, performance_change):
        """调整学习率"""
        try:
            if performance_change < -0.2:  # 性能显著下降
                return current_lr * 0.5
            elif performance_change < -0.1:  # 性能轻微下降
                return current_lr * 0.8
            return current_lr
        except Exception as e:
            logger.error(f"调整学习率失败: {str(e)}")
            return current_lr

    def _adjust_batch_size(self, current_batch_size):
        """调整批次大小"""
        try:
            return max(16, current_batch_size // 2)
        except Exception as e:
            logger.error(f"调整批次大小失败: {str(e)}")
            return current_batch_size

    def _adjust_dropout_rate(self, current_dropout):
        """调整dropout率"""
        try:
            return min(0.5, current_dropout + 0.1)
        except Exception as e:
            logger.error(f"调整dropout率失败: {str(e)}")
            return current_dropout

    def _apply_adjustments(self, suggestions):
        """应用参数调整"""
        try:
            new_params = {}
            
            for param_name, new_value in suggestions.items():
                # 验证参数范围
                if param_name in self.param_ranges['dynamic_params']:
                    min_val, max_val = self.param_ranges['dynamic_params'][param_name]
                    new_value = np.clip(new_value, min_val, max_val)
                    new_params[param_name] = new_value
                    
            # 更新模型参数
            if new_params:
                self.model_ensemble.update_dynamic_params(new_params)
                logger.info(f"应用参数调整: {new_params}")
                
            return new_params
            
        except Exception as e:
            logger.error(f"应用参数调整失败: {str(e)}")
            return {}

    def adjust_ensemble_weights(self, performance_metrics):
        """调整集成权重"""
        try:
            weights = []
            for model_idx, metrics in enumerate(performance_metrics):
                # 基于性能计算新权重
                performance_score = 1.0 - metrics['loss']
                diversity_score = self._calculate_model_diversity(model_idx)
                weight = 0.7 * performance_score + 0.3 * diversity_score
                weights.append(weight)
            
            # 归一化权重
            weights = np.array(weights)
            weights = weights / np.sum(weights)
            
            # 更新模型集成权重
            self.model_ensemble.update_weights(weights)
            logger.info(f"更新集成权重: {weights}")
            
            return weights
            
        except Exception as e:
            logger.error(f"调整集成权重失败: {str(e)}")
            return None

    def _calculate_model_diversity(self, model_idx):
        """计算单个模型的多样性得分"""
        try:
            predictions = []
            X_val, _ = self.data_processor.get_validation_data()
            
            # 获取当前模型和其他模型的预测
            current_pred = self.model_ensemble.models[model_idx].predict(X_val)
            other_preds = []
            for i, model in enumerate(self.model_ensemble.models):
                if i != model_idx:
                    other_preds.append(model.predict(X_val))
            
            # 计算与其他模型的平均互信息分数
            diversity_scores = []
            for other_pred in other_preds:
                mi_score = mutual_info_score(
                    current_pred.ravel(),
                    other_pred.ravel()
                )
                diversity_scores.append(1 - mi_score)
            
            return np.mean(diversity_scores)
            
        except Exception as e:
            logger.error(f"计算模型多样性失败: {str(e)}")
            return 0.0

    def get_optimization_summary(self):
        """获取优化过程摘要"""
        try:
            if not self.optimization_history:
                return None
                
            latest_results = self.optimization_history[-1]
            best_results = max(
                self.optimization_history,
                key=lambda x: x.get('final_score', 0)
            )
            
            return {
                'latest': {
                    'params': latest_results['params'],
                    'performance': latest_results.get('final_score', 0)
                },
                'best': {
                    'params': best_results['params'],
                    'performance': best_results.get('final_score', 0)
                },
                'total_iterations': len(self.optimization_history)
            }
            
        except Exception as e:
            logger.error(f"获取优化摘要失败: {str(e)}")
            return None

    def reset_optimization(self):
        """重置优化状态"""
        try:
            self.optimization_history.clear()
            self._init_param_ranges()
            self._init_thresholds()
            logger.info("优化状态已重置")
            return True
        except Exception as e:
            logger.error(f"重置优化状态失败: {str(e)}")
            return False

    def optimize_model_params(self, training_direction):
        """根据训练方向优化模型参数 (from cell13)"""
        try:
            if isinstance(training_direction, dict):
                # 1. 学习率调整
                if training_direction['learning_rate'] == 'INCREASE':
                    self.current_lr *= 1.5
                elif training_direction['learning_rate'] == 'DECREASE':
                    self.current_lr *= 0.7
                
                # 2. 批次大小调整
                if training_direction['batch_size'] == 'DECREASE':
                    self.batch_size = max(16, self.batch_size // 2)
                
                # 3. 模型复杂度调整
                if training_direction['model_complexity'] == 'INCREASE':
                    self._increase_model_complexity()
                
                # 4. 正则化调整
                if training_direction.get('regularization') == 'INCREASE':
                    self._increase_regularization()
                
            return True
            
        except Exception as e:
            logger.error(f"优化模型参数时出错: {str(e)}")
            return False

    def _increase_model_complexity(self):
        """增加模型复杂度 (from cell13)"""
        try:
            current_params = self.model_ensemble.get_current_params()
            new_params = {
                'lstm_units': int(current_params['lstm_units'] * 1.5),
                'transformer_heads': current_params['transformer_heads'] + 2,
                'cnn_filters': int(current_params['cnn_filters'] * 1.3)
            }
            self.model_ensemble.update_architecture(new_params)
        except Exception as e:
            logger.error(f"增加模型复杂度失败: {str(e)}")

    def _increase_regularization(self):
        """增加正则化强度 (from cell13)"""
        try:
            current_params = self.model_ensemble.get_current_params()
            new_params = {
                'dropout_rate': min(0.5, current_params['dropout_rate'] + 0.1),
                'weight_decay': current_params['weight_decay'] * 2
            }
            self.model_ensemble.update_dynamic_params(new_params)
        except Exception as e:
            logger.error(f"增加正则化强度失败: {str(e)}")

    def adjust_ensemble_strategy(self, match_distribution):
        """调整集成策略 (from cell15)"""
        try:
            total_samples = sum(match_distribution.values())
            
            # 1. 分析集成效果
            high_match_ratio = (match_distribution[4] + match_distribution[5]) / total_samples
            low_match_ratio = (match_distribution[0] + match_distribution[1]) / total_samples
            
            # 2. 根据分布调整集成策略
            if high_match_ratio < 0.1:  # 高匹配率太低
                # 增加模型多样性
                self._increase_model_diversity()
                # 调整模型权重
                self._adjust_model_weights()
                
            elif low_match_ratio > 0.5:  # 低匹配率太高
                # 强化表现好的模型
                self._strengthen_best_models()
                # 重新训练表现差的模型
                self._retrain_weak_models()
                
            return True
            
        except Exception as e:
            logger.error(f"调整集成策略时出错: {str(e)}")
            return False

    def _increase_model_diversity(self):
        """增加模型多样性 (from cell15)"""
        try:
            # 1. 计算当前多样性矩阵
            diversity_matrix = self._calculate_diversity_matrix()
            
            # 2. 找出相似度最高的模型对
            similar_pairs = self._find_similar_model_pairs(diversity_matrix)
            
            # 3. 对相似模型进行差异化训练
            for model_i, model_j in similar_pairs:
                self._differentiate_models(model_i, model_j)
                
        except Exception as e:
            logger.error(f"增加模型多样性失败: {str(e)}")

    def _calculate_diversity_matrix(self):
        """计算模型间多样性矩阵 (from cell15)"""
        try:
            n_models = len(self.model_ensemble.models)
            diversity_matrix = np.zeros((n_models, n_models))
            
            X_val, _ = self.data_processor.get_validation_data()
            predictions = [model.predict(X_val) for model in self.model_ensemble.models]
            
            for i in range(n_models):
                for j in range(i+1, n_models):
                    mi_score = mutual_info_score(
                        predictions[i].ravel(),
                        predictions[j].ravel()
                    )
                    diversity_matrix[i, j] = mi_score
                    diversity_matrix[j, i] = mi_score
                    
            return diversity_matrix
            
        except Exception as e:
            logger.error(f"计算多样性矩阵失败: {str(e)}")
            return None

    def _find_similar_model_pairs(self, diversity_matrix):
        """找出相似度高的模型对 (from cell15)"""
        try:
            n_models = len(self.model_ensemble.models)
            similar_pairs = []
            
            for i in range(n_models):
                for j in range(i+1, n_models):
                    if diversity_matrix[i, j] > 0.8:  # 相似度阈值
                        similar_pairs.append((i, j))
                        
            return similar_pairs
            
        except Exception as e:
            logger.error(f"寻找相似模型对失败: {str(e)}")
            return []

    def _differentiate_models(self, model_i, model_j):
        """对相似模型进行差异化训练 (from cell15)"""
        try:
            # 1. 调整模型架构
            self.model_ensemble.update_model_architecture({
                model_i: {'dropout_rate': 0.3},
                model_j: {'dropout_rate': 0.5}
            })
            
            # 2. 使用不同的优化器
            self.model_ensemble.update_optimizer_settings({
                model_i: {'learning_rate': 0.001},
                model_j: {'learning_rate': 0.0005}
            })
            
        except Exception as e:
            logger.error(f"模型差异化失败: {str(e)}")

    def _strengthen_best_models(self):
        """强化表现好的模型 (from cell15)"""
        try:
            performance_metrics = self.performance_monitor.get_model_metrics()
            best_models = self._identify_best_models(performance_metrics)
            
            for model_idx in best_models:
                # 增加模型权重
                self.model_ensemble.increase_model_weight(model_idx)
                # 微调学习率
                self.model_ensemble.fine_tune_model(model_idx)
                
        except Exception as e:
            logger.error(f"强化最佳模型失败: {str(e)}")

    def _retrain_weak_models(self):
        """重新训练表现差的模型 (from cell15)"""
        try:
            performance_metrics = self.performance_monitor.get_model_metrics()
            weak_models = self._identify_weak_models(performance_metrics)
            
            for model_idx in weak_models:
                # 重置模型参数
                self.model_ensemble.reset_model(model_idx)
                # 使用新的训练策略
                self.model_ensemble.retrain_model(
                    model_idx, 
                    strategy='adaptive'
                )
                
        except Exception as e:
            logger.error(f"重训弱模型失败: {str(e)}")

    def _identify_best_models(self, metrics):
        """识别最佳模型 (from cell15)"""
        try:
            scores = [m['performance'] for m in metrics]
            threshold = np.percentile(scores, 75)  # 上四分位数
            return [i for i, score in enumerate(scores) if score >= threshold]
        except Exception as e:
            logger.error(f"识别最佳模型失败: {str(e)}")
            return []

    def _identify_weak_models(self, metrics):
        """识别弱模型 (from cell15)"""
        try:
            scores = [m['performance'] for m in metrics]
            threshold = np.percentile(scores, 25)  # 下四分位数
            return [i for i, score in enumerate(scores) if score <= threshold]
        except Exception as e:
            logger.error(f"识别弱模型失败: {str(e)}")
            return []

    def adjust_after_sample(self, model, sample, current_params):
        """基于样本梯度调整参数 (from cell13)"""
        try:
            with tf.GradientTape() as tape:
                predictions = model(sample['input'])
                loss = tf.keras.losses.MSE(sample['target'], predictions)
            grads = tape.gradient(loss, model.trainable_variables)
            
            # 生成参数调整建议
            adjusted_params = {
                'learning_rate': self._adjust_lr_from_grads(grads, current_params),
                'batch_size': current_params['batch_size']
            }
            return adjusted_params
            
        except Exception as e:
            logger.error(f"样本级参数调整失败: {str(e)}")
            return current_params

    def _adjust_lr_from_grads(self, grads, current_params):
        """根据梯度调整学习率"""
        try:
            grad_norm = tf.linalg.global_norm(grads)
            if grad_norm > 10.0:  # 梯度爆炸
                return current_params['learning_rate'] * 0.5
            elif grad_norm < 0.1:  # 梯度消失
                return current_params['learning_rate'] * 1.5
            return current_params['learning_rate']
        except Exception as e:
            logger.error(f"学习率梯度调整失败: {str(e)}")
            return current_params['learning_rate']

    def on_train_end(self):
        """训练结束时的优化操作 (from cell13)"""
        try:
            # 获取新的参数建议
            new_params = self.suggest_next_params()
            # 更新集成模型参数
            self.model_ensemble.update_params(new_params)
            # 保存优化记录
            self._save_optimization_record()
            return True
        except Exception as e:
            logger.error(f"训练结束优化操作失败: {str(e)}")
            return False

    def suggest_next_params(self):
        """使用Optuna生成下一组参数"""
        try:
            study = optuna.create_study(
                study_name="model_optim_v1",
                storage="sqlite:///optuna.db",
                load_if_exists=True
            )
            
            trial = study.ask()
            params = {
                'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1e-2),
                'batch_size': trial.suggest_int('batch_size', 16, 128),
                'dropout_rate': trial.suggest_uniform('dropout_rate', 0.1, 0.5)
            }
            return params
        except Exception as e:
            logger.error(f"生成参数建议失败: {str(e)}")
            return None

    def optimize_parameters(self):
        """执行参数优化 (from cell14)"""
        try:
            # 贝叶斯优化初始化基础参数
            initial_params = self.bayesian_optimization()
            # 使用Optuna进行细粒度优化
            final_params = self._optuna_optimization(initial_params)
            
            # 更新并保存最佳参数
            self.best_params = final_params
            self.save_best_params()
            
            logger.info(f"参数优化完成: {final_params}")
            return final_params
        except Exception as e:
            logger.error(f"参数优化失败: {str(e)}")
            return None

    def save_best_params(self):
        """保存最佳参数配置 (from cell14)"""
        try:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"best_params_{timestamp}.json"
            save_path = os.path.join(os.getcwd(), 'optimization_params', filename)
            
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            with open(save_path, 'w') as f:
                json.dump(self.best_params, f, indent=4)
            
            logger.info(f"最佳参数已保存到: {save_path}")
            return True
        except Exception as e:
            logger.error(f"保存最佳参数失败: {str(e)}")
            return False

    def _save_optimization_record(self):
        """保存优化记录"""
        try:
            record = {
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                'history': self.optimization_history,
                'best_params': self.best_params,
                'performance_summary': self.get_optimization_summary()
            }
            
            filename = f"optimization_record_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
            save_path = os.path.join(os.getcwd(), 'optimization_records', filename)
            
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            with open(save_path, 'w') as f:
                json.dump(record, f, indent=4)
                
            logger.info(f"优化记录已保存到: {save_path}")
            return True
        except Exception as e:
            logger.error(f"保存优化记录失败: {str(e)}")
            return False

    def analyze_training_direction(self, match_counts, current_params):
        """分析训练方向"""
        try:
            # 1. 初始化/更新匹配分布
            if not hasattr(self, 'match_distribution'):
                self.match_distribution = {i: 0 for i in range(6)}
                
            # 2. 更新分布
            for count in match_counts:
                self.match_distribution[count] += 1
            
            # 3. 判断当前状态
            if self.match_distribution[5] > 0:
                return "OPTIMAL"
                
            avg_match = sum(k * v for k, v in self.match_distribution.items()) / sum(self.match_distribution.values())
            
            # 4. 根据匹配分布给出调整建议
            if avg_match < 2:
                return {
                    'learning_rate': 'INCREASE',
                    'batch_size': 'DECREASE',
                    'model_complexity': 'INCREASE'
                }
            elif avg_match > 3:
                return {
                    'learning_rate': 'DECREASE',
                    'regularization': 'INCREASE',
                    'ensemble_diversity': 'INCREASE'
                }
            
            return "CONTINUE"
            
        except Exception as e:
            logger.error(f"分析训练方向时出错: {str(e)}")
            return None

    def optimize_training_flow(self):
        """优化训练流程"""
        try:
            self._dynamic_resource_adjust()
            self._dynamic_batch_adjust()
            self._enable_mixed_precision()
            return True
        except Exception as e:
            logger.error(f"优化训练流程时出错: {str(e)}")
            return False

    def _dynamic_resource_adjust(self):
        """根据硬件资源动态调整参数"""
        try:
            # 获取资源信息
            mem_info = memory_manager.get_memory_info()
            cpu_usage = psutil.cpu_percent()
            
            # 内存调整策略
            if mem_info['percent'] > 75:
                new_batch = max(4, self.batch_size // 2)
                logger.info(f"内存使用{mem_info['percent']}% → 批次从{self.batch_size}调整为{new_batch}")
                self.batch_size = new_batch
            
            # CPU线程调整策略
            if hasattr(self, 'threads'):
                if cpu_usage < 60:
                    self.threads = min(12, self.threads + 2)
                else:
                    self.threads = max(4, self.threads - 2)
            
            # GPU显存优化
            if tf.config.list_physical_devices('GPU'):
                gpu_mem = tf.config.experimental.get_memory_info('GPU:0')
                used_percent = gpu_mem['current'] / gpu_mem['total']
                if used_percent > 0.8:
                    tf.config.experimental.set_memory_growth(True)
                    
            return True
        except Exception as e:
            logger.error(f"资源调整失败: {str(e)}")
            return False

    def _dynamic_batch_adjust(self):
        """动态调整批次大小"""
        try:
            if hasattr(self, 'batch_size'):
                mem_usage = memory_manager.get_memory_info()
                if mem_usage['percent'] > 80:
                    new_size = max(8, self.batch_size // 2)
                    logger.info(f"批次大小从{self.batch_size}调整为{new_size}")
                    self.batch_size = new_size
            return True
        except Exception as e:
            logger.error(f"批次调整失败: {str(e)}")
            return False

    def _enable_mixed_precision(self):
        """启用混合精度训练"""
        try:
            if tf.config.list_physical_devices('GPU'):
                policy = tf.keras.mixed_precision.Policy('mixed_float16')
                tf.keras.mixed_precision.set_global_policy(policy)
                logger.info("已启用混合精度训练")
            return True
        except Exception as e:
            logger.error(f"启用混合精度失败: {str(e)}")
            return False

    def setup_mixed_precision(self):
        """配置混合精度训练"""
        try:
            if tf.config.list_physical_devices('GPU'):
                # 启用mixed precision policy
                policy = tf.keras.mixed_precision.Policy('mixed_float16')
                tf.keras.mixed_precision.set_global_policy(policy)
                logger.info("已启用混合精度训练")
                
                # 配置优化器
                self.model_ensemble.update_optimizer_settings({
                    'mixed_precision': True,
                    'loss_scale': 'dynamic'
                })
                return True
        except Exception as e:
            logger.error(f"配置混合精度训练失败: {str(e)}")
            return False

    def setup_checkpoints(self):
        """配置检查点"""
        try:
            checkpoint_config = {
                'save_freq': 100,  # 每100步保存一次
                'max_to_keep': 5,  # 保留最新的5个检查点
                'include_optimizer': True,
                'save_best_only': True,
                'monitor': 'val_accuracy'
            }
            self.model_ensemble.setup_model_checkpoint(checkpoint_config)
            logger.info("检查点配置完成")
            return True
        except Exception as e:
            logger.error(f"配置检查点失败: {str(e)}")
            return False

    def adjust_learning_rate(self, metrics):
        """智能调整学习率"""
        try:
            accuracy = metrics.get('accuracy', 0)
            loss_change = metrics.get('loss_change', 0)
            
            # 基于性能调整学习率
            if accuracy < 0.5 and loss_change > 0:
                # 性能差且损失在增加，大幅降低学习率
                return self.current_lr * 0.5
            elif accuracy < 0.7 and loss_change > 0:
                # 性能一般且损失增加，小幅降低学习率
                return self.current_lr * 0.8
            elif accuracy > 0.9 and loss_change < 0:
                # 性能好且损失在下降，小幅提高学习率
                return self.current_lr * 1.1
            
            return self.current_lr
        except Exception as e:
            logger.error(f"调整学习率失败: {str(e)}")
            return self.current_lr

    def _update_training_params(self, params):
        """更新训练参数"""
        try:
            nested_params = self._process_params(params)
            
            # 1. 更新优化器参数
            optimizer_params = nested_params['optimizer_params']
            for model in self.model_ensemble.models:
                model.optimizer.learning_rate = optimizer_params['learning_rate']
                if hasattr(model.optimizer, 'beta_1'):
                    model.optimizer.beta_1 = optimizer_params['beta_1']
                if hasattr(model.optimizer, 'beta_2'):
                    model.optimizer.beta_2 = optimizer_params['beta_2']
            
            # 2. 更新学习率调度
            lr_params = nested_params['lr_schedule_params']
            self.model_ensemble.update_learning_rate_schedule(
                decay_rate=lr_params['decay_rate'],
                decay_steps=int(lr_params['decay_steps']),
                warmup_steps=int(lr_params['warmup_steps']),
                min_lr=lr_params['min_lr']
            )
            
            # 3. 更新训练控制参数
            training_params = nested_params['training_control']
            self.model_ensemble.batch_size = int(training_params['batch_size'])
            self.model_ensemble.epochs_per_iteration = int(training_params['epochs_per_iteration'])
            self.model_ensemble.validation_frequency = int(training_params['validation_frequency'])
            
            logger.info("训练参数已更新")
            
        except Exception as e:
            logger.error(f"更新训练参数时出错: {str(e)}")
            raise

class DynamicTuner:
    """动态参数调优器"""
    def __init__(self):
        self.history = deque(maxlen=100)
        self.param_ranges = {
            'learning_rate': (1e-5, 1e-3),
            'batch_size': (16, 128),
            'dropout': (0.1, 0.5)
        }
        
    def adapt_parameters(self, model, recent_loss):
        """自适应调整参数"""
        if len(self.history) < 10:
            return
        
        # 计算损失变化趋势
        loss_diff = np.diff(self.history)
        trend = np.mean(loss_diff[-3:])
        
        # 动态调整学习率
        if trend > 0:  # 损失上升
            K.set_value(
                model.optimizer.learning_rate, 
                max(self.param_ranges['learning_rate'][0], 
                    K.get_value(model.optimizer.learning_rate) * 0.9)
            )
        else:  # 损失下降
            K.set_value(
                model.optimizer.learning_rate,
                min(self.param_ranges['learning_rate'][1],
                    K.get_value(model.optimizer.learning_rate) * 1.1)
            )

# 创建全局实例
optimizer_manager = OptimizerManager(
    model_ensemble=None,
    data_processor=None,
    performance_monitor=None
)


  from .autonotebook import tqdm as notebook_tqdm


In [9]:
#9 State Management System / 状态管理系统
import os
import signal
import logging
import threading
from typing import Optional, Any, Dict
from collections import deque
from datetime import datetime
import torch
from bayes_opt.logger import Events  # 拆分导入
from bayes_opt.util import load_logs
import tensorflow as tf

# 获取logger实例
logger = logging.getLogger(__name__)

class StateManager:
    """全局状态管理器 - 单例模式"""
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not hasattr(self, 'initialized'):
            # 训练相关状态
            self.trainer = None  # 全局trainer实例
            self.training_state = 'idle'  # 训练状态: idle/training/paused/stopped
            self.current_epoch = 0
            self.current_batch = 0
            self.best_performance = float('inf')
            
            # 显示相关状态
            self.display_running = True  # 显示线程运行标志
            self.display_thread = None  # 显示线程实例
            self.log_buffer = deque(maxlen=100)  # 日志缓冲区
            self.show_print = False  # 控制是否显示打印信息
            
            # 性能监控状态
            self.performance_metrics = {
                'loss': deque(maxlen=1000),
                'accuracy': deque(maxlen=1000),
                'learning_rate': 0.001
            }
            
            # 资源监控状态
            self.resource_metrics = {
                'memory_usage': 0,
                'cpu_usage': 0,
                'gpu_usage': 0
            }
            
            # 注册信号处理器
            self._register_signal_handlers()
            
            self.initialized = True
            logger.info("状态管理器初始化完成")
    
    def _register_signal_handlers(self):
        """注册信号处理器"""
        signal.signal(signal.SIGINT, self._signal_handler)   # Ctrl+C
        signal.signal(signal.SIGTERM, self._signal_handler)  # 终止信号
    
    def _signal_handler(self, signum, frame):
        """信号处理器"""
        logger.info(f"接收到信号: {signum}, 开始保存进度...")
        self.save_all_progress()
        import sys
        sys.exit(0)
    
    def save_all_progress(self):
        """保存所有进度和参数"""
        if self.trainer:
            try:
                # 保存训练进度
                self.trainer.save_training_progress()
                # 保存模型参数
                self.trainer.save_model_weights()
                # 保存性能指标
                self.save_performance_metrics()
                logger.info("所有进度和参数已保存")
            except Exception as e:
                logger.error(f"保存进度时出错: {str(e)}")
    
    def update_training_state(self, new_state: str):
        """更新训练状态"""
        valid_states = {'idle', 'training', 'paused', 'stopped'}
        if new_state not in valid_states:
            logger.error(f"无效的训练状态: {new_state}")
            return
            
        old_state = self.training_state
        self.training_state = new_state
        logger.info(f"训练状态从 {old_state} 变更为 {new_state}")
    
    def update_performance_metrics(self, metrics: Dict[str, float]):
        """更新性能指标"""
        try:
            for key, value in metrics.items():
                if key in self.performance_metrics:
                    self.performance_metrics[key].append(value)
                    
            # 更新最佳性能
            if 'loss' in metrics and metrics['loss'] < self.best_performance:
                self.best_performance = metrics['loss']
                logger.info(f"更新最佳性能: {self.best_performance:.4f}")
        except Exception as e:
            logger.error(f"更新性能指标时出错: {str(e)}")
    
    def update_resource_metrics(self, metrics: Dict[str, float]):
        """更新资源使用指标"""
        try:
            self.resource_metrics.update(metrics)
            # 检查资源使用是否超过警戒线
            if metrics.get('memory_usage', 0) > 90:
                logger.warning("内存使用率超过90%!")
            if metrics.get('gpu_usage', 0) > 90:
                logger.warning("GPU使用率超过90%!")
        except Exception as e:
            logger.error(f"更新资源指标时出错: {str(e)}")
    
    def set_trainer(self, trainer: Any):
        """设置trainer实例"""
        self.trainer = trainer
    
    def set_display_thread(self, thread: threading.Thread):
        """设置显示线程"""
        self.display_thread = thread
    
    def stop_display(self):
        """停止显示线程"""
        self.display_running = False
        if self.display_thread and self.display_thread.is_alive():
            self.display_thread.join()
            logger.info("显示线程已停止")
    
    def save_performance_metrics(self):
        """保存性能指标到文件"""
        try:
            metrics_file = os.path.join('logs', f'metrics_{datetime.now():%Y%m%d_%H%M%S}.json')
            import json
            with open(metrics_file, 'w') as f:
                # 将deque转换为list后保存
                metrics_to_save = {
                    k: list(v) if isinstance(v, deque) else v 
                    for k, v in self.performance_metrics.items()
                }
                json.dump(metrics_to_save, f, indent=4)
            logger.info(f"性能指标已保存到: {metrics_file}")
        except Exception as e:
            logger.error(f"保存性能指标时出错: {str(e)}")

    def save_training_state(self):
        """保存训练状态"""
        try:
            tf.keras.models.save_model(self.model, 'training_state.h5')
            self.logger.info("训练状态已保存")
        except Exception as e:
            self.logger.error(f"保存训练状态失败: {str(e)}")

    def restore_training_state(self):
        """恢复训练状态"""
        if os.path.exists('training_state.h5'):
            self.model = tf.keras.models.load_model('training_state.h5')

# 创建全局实例
state_manager = StateManager()

In [10]:
#10 Training Management System / 训练管理系统
import traceback
import os
import tensorflow as tf
import numpy as np
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from cell1_core import core_manager
from cell3_monitor import resource_monitor, performance_monitor
from datetime import datetime, timedelta
from collections import deque
from cell6_model import model_core
import gc
from cell4_data import data_manager

# 获取logger实例
logger = logging.getLogger(__name__)

class TrainingManager:
    """训练管理器类"""
    
    def __init__(self, model_ensemble):
        self.model_ensemble = model_ensemble
        
        # 获取训练配置
        self.config = core_manager.SYSTEM_CONFIG['TRAINING_CONFIG']
        self.batch_size = self.config['batch_size']
        self.epochs = self.config['max_epochs']
        
        # 添加数据管理器
        self.data_manager = data_manager
        
        # 训练状态
        self.is_training = False
        self.pause_training = False
        self._training_loop_running = False
        self._training_thread = None
        
        # 训练进度
        self.current_epoch = 0
        self.current_batch = 0
        self.total_batches = 0
        
        # 训练性能监控
        self.performance_history = deque(maxlen=1000)
        self.resource_monitor = resource_monitor
        self.performance_monitor = performance_monitor
        
        # 训练同步机制
        self.training_lock = threading.Lock()
        self.finished_models = 0
        self.total_models = 6
        
        logger.info("训练管理器初始化完成")
        
        self.issue_file = "D:\\JupyterWork\\notebooks\\period_number\\issue.txt"
        self.current_target_period = None
    
    def start_training(self, training_data):
        """启动训练（增加数据量检查）"""
        try:
            # 修改后的数据量检查
            if training_data.shape[0] != 14400 + 2880:  # 检查数组第一维长度
                logger.error(f"数据量不匹配，预期{14400+2880}条，实际{training_data.shape[0]}条")
                return False
            
            # 添加数据校验
            if training_data is None:
                logger.error("训练数据为空")
                return False
            if len(training_data) == 0:
                logger.error("训练数据长度为零")
                return False
            
            if self.is_training:
                logger.warning("训练已在进行中")
                return False
            
            self.is_training = True
            self.pause_training = False
            
            # 初始化训练状态
            self._init_training(training_data)
            
            # 启动训练线程
            self._training_thread = threading.Thread(
                target=self._training_loop,
                args=(training_data,)
            )
            self._training_thread.start()
            
            logger.info("训练已启动")
            return True
            
        except Exception as e:
            logger.error(f"启动训练失败: {str(e)}")
            self.is_training = False
            return False
    
    def stop_training(self):
        """停止训练"""
        try:
            self.is_training = False
            if self._training_thread and self._training_thread.is_alive():
                self._training_thread.join()
            logger.info("训练已停止")
            return True
        except Exception as e:
            logger.error(f"停止训练失败: {str(e)}")
            return False
    
    def pause_resume_training(self):
        """暂停/恢复训练"""
        try:
            self.pause_training = not self.pause_training
            status = "暂停" if self.pause_training else "恢复"
            logger.info(f"训练已{status}")
            return True
        except Exception as e:
            logger.error(f"训练暂停/恢复失败: {str(e)}")
            return False
    
    def _init_training(self, training_data):
        """初始化训练"""
        self.current_epoch = 0
        self.current_batch = 0
        self.total_batches = len(training_data) // self.batch_size
        self.performance_history.clear()
        
        # 初始化资源监控
        self.resource_monitor.start()
        
        # 初始化性能监控
        self.performance_monitor.reset()
    
    def _training_loop(self, training_data):
        """训练主循环"""
        try:
            while self.is_training and self.current_epoch < self.epochs:
                
                # 检查暂停状态
                if self.pause_training:
                    time.sleep(1)
                    continue
                
                # 检查系统资源
                if not self._check_resources():
                    time.sleep(5)
                    continue
                
                # 获取训练批次
                batch_data = self._get_next_batch(training_data)
                if batch_data is None:
                    continue
                
                # 并行训练模型
                self._parallel_train_models(batch_data)
                
                # 更新训练状态
                self._update_training_status()
                
                # 记录训练性能
                self._record_performance()
                
            logger.info("训练完成")
            
        except Exception as e:
            logger.error(f"训练循环出错: {str(e)}")
            self.is_training = False
    
    def _check_resources(self):
        """检查系统资源"""
        try:
            # 检查CPU使用率
            if self.resource_monitor.cpu_usage > 90:
                logger.warning("CPU使用率过高,暂停训练")
                return False
            
            # 检查内存使用率    
            if self.resource_monitor.memory_usage > 90:
                logger.warning("内存使用率过高,暂停训练")
                return False
            
            # 检查GPU使用率
            if self.resource_monitor.gpu_usage > 90:
                logger.warning("GPU使用率过高,暂停训练")
                return False
                
            return True
            
        except Exception as e:
            logger.error(f"资源检查失败: {str(e)}")
            return False
    
    def _get_next_batch(self, training_data):
        """获取下一个训练批次"""
        try:
            start_idx = self.current_batch * self.batch_size
            end_idx = start_idx + self.batch_size
            
            if end_idx > len(training_data):
                self.current_epoch += 1
                self.current_batch = 0
                return None
                
            batch = training_data[start_idx:end_idx]
            self.current_batch += 1
            
            return self._preprocess_batch(batch)
            
        except Exception as e:
            logger.error(f"获取训练批次失败: {str(e)}")
            return None
    
    def _parallel_train_models(self, batch_data):
        """并行训练模型"""
        try:
            strategy = tf.distribute.MirroredStrategy()
            with strategy.scope():
                with ThreadPoolExecutor(max_workers=self.total_models) as executor:
                    futures = []
                    for i, model in enumerate(self.model_ensemble.models):
                        future = executor.submit(
                            self._train_single_model,
                            model,
                            batch_data,
                            i
                        )
                        futures.append(future)
                    
                    # 等待所有模型训练完成
                    for future in futures:
                        future.result()
                    
        except Exception as e:
            logger.error(f"并行训练失败: {str(e)}")
    
    def _train_single_model(self, model, batch_data, model_idx):
        """训练单个模型"""
        try:
            # 1. 前向传播
            with tf.GradientTape() as tape:
                predictions = model(batch_data['input'])
                loss = self._calculate_loss(predictions, batch_data['target'])
            
            # 2. 反向传播
            gradients = tape.gradient(loss, model.trainable_variables)
            self.model_ensemble.optimizer.apply_gradients(
                zip(gradients, model.trainable_variables)
            )
            
            # 3. 更新模型权重
            self._update_model_weights(model_idx, loss.numpy())
            
            # 4. 记录训练进度
            with self.training_lock:
                self.finished_models += 1
                if self.finished_models == self.total_models:
                    self.finished_models = 0
                    self._on_batch_complete()
                    
        except Exception as e:
            logger.error(f"训练模型 {model_idx} 失败: {str(e)}")
    
    def _calculate_loss(self, predictions, targets):
        """计算训练损失"""
        try:
            # 使用增强型匹配损失函数
            return self.model_ensemble.enhanced_match_loss(targets, predictions)
        except Exception as e:
            logger.error(f"计算损失失败: {str(e)}")
            return tf.constant(0.0)
    
    def _update_model_weights(self, model_idx, loss):
        """更新模型权重"""
        try:
            # 记录性能
            self.performance_history.append({
                'model_idx': model_idx,
                'loss': loss,
                'timestamp': datetime.now()
            })
            
            # 更新权重
            performance = np.exp(-loss)  # 损失越小,性能越好
            self.model_ensemble.weights[model_idx] = performance
            
            # 归一化权重
            total = np.sum(self.model_ensemble.weights)
            self.model_ensemble.weights /= total
            
        except Exception as e:
            logger.error(f"更新模型权重失败: {str(e)}")
    
    def _update_training_status(self):
        """更新训练状态"""
        try:
            # 计算训练进度
            total_steps = self.epochs * self.total_batches
            current_steps = self.current_epoch * self.total_batches + self.current_batch
            progress = current_steps / total_steps
            
            # 更新监控指标
            self.performance_monitor.update_metrics({
                'progress': progress,
                'current_epoch': self.current_epoch,
                'current_batch': self.current_batch,
                'loss': np.mean([p['loss'] for p in self.performance_history])
            })
            
        except Exception as e:
            logger.error(f"更新训练状态失败: {str(e)}")
    
    def _record_performance(self):
        """记录训练性能"""
        try:
            metrics = {
                'timestamp': datetime.now(),
                'epoch': self.current_epoch,
                'batch': self.current_batch,
                'loss': np.mean([p['loss'] for p in self.performance_history]),
                'resource_usage': {
                    'cpu': self.resource_monitor.cpu_usage,
                    'memory': self.resource_monitor.memory_usage,
                    'gpu': self.resource_monitor.gpu_usage
                }
            }
            
            self.performance_monitor.update_metrics(metrics)
            
        except Exception as e:
            logger.error(f"记录性能失败: {str(e)}")
    
    def _preprocess_batch(self, batch):
        """预处理训练批次"""
        try:
            return {
                'input': tf.convert_to_tensor(batch['input']),
                'target': tf.convert_to_tensor(batch['target'])
            }
        except Exception as e:
            logger.error(f"预处理批次失败: {str(e)}")
            return None
    
    def _on_batch_complete(self):
        """批次训练完成回调"""
        try:
            # 保存检查点
            if self.current_batch % self.config['save_frequency'] == 0:
                self._save_checkpoint()
            
            # 评估性能
            if self.current_batch % self.config['eval_frequency'] == 0:
                self._evaluate_performance()
                
            # 调整学习率
            if self.current_batch % self.config['lr_update_frequency'] == 0:
                self._adjust_learning_rate()
                
        except Exception as e:
            logger.error(f"批次完成处理失败: {str(e)}")
    
    def _save_checkpoint(self):
        """保存训练检查点"""
        try:
            checkpoint = {
                'epoch': self.current_epoch,
                'batch': self.current_batch,
                'model_states': [model.get_weights() for model in self.model_ensemble.models],
                'optimizer_state': self.model_ensemble.optimizer.get_weights(),
                'performance_history': list(self.performance_history)
            }
            
            save_path = f"checkpoints/checkpoint_e{self.current_epoch}_b{self.current_batch}.h5"
            tf.keras.models.save_model(checkpoint, save_path)
            logger.info(f"保存检查点: {save_path}")
            
        except Exception as e:
            logger.error(f"保存检查点失败: {str(e)}")
    
    def _evaluate_performance(self):
        """评估训练性能"""
        try:
            # 计算平均损失
            avg_loss = np.mean([p['loss'] for p in self.performance_history])
            
            # 计算性能改进
            if len(self.performance_history) > 1:
                prev_loss = self.performance_history[-2]['loss']
                improvement = (prev_loss - avg_loss) / prev_loss
                
                if improvement < self.config['min_improvement']:
                    logger.warning("性能改进不足")
                    
            logger.info(f"当前平均损失: {avg_loss:.4f}")
            
        except Exception as e:
            logger.error(f"评估性能失败: {str(e)}")
    
    def _adjust_learning_rate(self):
        """调整学习率"""
        try:
            if len(self.performance_history) < 2:
                return
                
            # 计算最近的性能变化
            recent_loss = np.mean([p['loss'] for p in list(self.performance_history)[-10:]])
            previous_loss = np.mean([p['loss'] for p in list(self.performance_history)[-20:-10]])
            
            # 根据性能变化调整学习率
            if recent_loss > previous_loss:
                new_lr = self.model_ensemble.optimizer.learning_rate * 0.8
                self.model_ensemble.optimizer.learning_rate.assign(new_lr)
                logger.info(f"降低学习率至: {new_lr:.6f}")
            
        except Exception as e:
            logger.error(f"调整学习率失败: {str(e)}")

    def get_training_status(self):
        """获取训练状态"""
        return {
            'is_training': self.is_training,
            'is_paused': self.pause_training,
            'current_epoch': self.current_epoch,
            'current_batch': self.current_batch,
            'total_epochs': self.epochs,
            'total_batches': self.total_batches,
            'progress': (self.current_epoch * self.total_batches + self.current_batch) / 
                       (self.epochs * self.total_batches)
        }

    def get_performance_metrics(self):
        """获取性能指标"""
        if not self.performance_history:
            return None
            
        recent_records = list(self.performance_history)[-100:]
        return {
            'average_loss': np.mean([r['loss'] for r in recent_records]),
            'min_loss': np.min([r['loss'] for r in recent_records]),
            'max_loss': np.max([r['loss'] for r in recent_records]),
            'loss_trend': self._calculate_trend([r['loss'] for r in recent_records])
        }

    def _calculate_trend(self, values):
        """计算趋势"""
        if len(values) < 2:
            return "INSUFFICIENT_DATA"
            
        # 使用简单线性回归
        x = np.arange(len(values))
        slope = np.polyfit(x, values, 1)[0]
        
        if slope < -0.01:
            return "IMPROVING"
        elif slope > 0.01:
            return "DEGRADING"
        else:
            return "STABLE"

    def _calculate_next_period(self, current_period):
        """计算下一期号"""
        try:
            date_part, num_part = current_period.split('-')
            current_date = datetime.strptime(date_part, "%Y%m%d")
            current_num = int(num_part)
            
            if current_num < 1440:
                return f"{date_part}-{current_num+1:04d}"
            else:
                next_date = current_date + timedelta(days=1)
                return f"{next_date.strftime('%Y%m%d')}-0001"
        except Exception as e:
            logger.error(f"计算下一期号失败: {str(e)}")
            return None

    def _get_last_processed_period(self):
        """获取最后处理的期号"""
        try:
            if os.path.exists(self.issue_file):
                with open(self.issue_file, 'r') as f:
                    return f.read().strip()
            return None
        except Exception as e:
            logger.error(f"读取期号文件失败: {str(e)}")
            return None

    def _save_processed_period(self, period):
        """保存处理完成的期号"""
        try:
            os.makedirs(os.path.dirname(self.issue_file), exist_ok=True)
            with open(self.issue_file, 'w') as f:
                f.write(period)
        except Exception as e:
            logger.error(f"保存期号文件失败: {str(e)}")

    def _generate_sequence_periods(self, start_period, length):
        """生成连续的期号序列"""
        periods = [start_period]
        current = start_period
        for _ in range(length-1):
            current = self._calculate_next_period(current)
            if not current:
                return None
            periods.append(current)
        return periods

    def _fetch_training_data(self, start_period):
        """获取训练数据（增加数据验证）"""
        try:
            # 生成需要获取的期号范围
            total_periods = self._generate_sequence_periods(start_period, 14400+2880)
            if not total_periods or len(total_periods) != 14400+2880:
                logger.error("期号生成不完整")
                return None
                
            # 获取数据
            query = """
                SELECT number, date_period FROM admin_tab
                WHERE date_period IN %s
                ORDER BY date_period ASC
            """
            data = data_manager.execute_query(query, (tuple(total_periods),))
            
            # 验证数据完整性
            if len(data) != 14400+2880:
                logger.error(f"数据不完整，预期{14400+2880}条，实际获取{len(data)}条")
                return None
                
            # 新增数据量打印
            logger.info(f"获取到训练数据 {len(data)} 条，时间范围: {data[0]['date_period']} 至 {data[-1]['date_period']}")
            
            return data
        except Exception as e:
            logger.error(f"获取训练数据失败: {str(e)}")
            return None

    def _validate_sequence(self, sequence):
        """验证训练序列有效性"""
        try:
            if sequence is None:
                return False
                
            # 检查是否为numpy数组
            if not isinstance(sequence, np.ndarray):
                logger.warning(f"序列类型错误: {type(sequence)}")
                return False
                
            # 检查数据形状
            if sequence.shape != (14400+2880, 5):
                logger.warning(f"无效序列形状: {sequence.shape}")
                return False
                
            # 检查数值范围
            if np.min(sequence) < -1 or np.max(sequence) > 1:
                logger.warning(f"数值范围异常: [{np.min(sequence):.2f}, {np.max(sequence):.2f}]")
                return False
                
            return True
        except Exception as e:
            logger.error(f"序列验证失败: {str(e)}")
            return False

    def training_loop(self):
        """训练主循环"""
        logger.info("开始训练循环")
        issue_file = "D:/JupyterWork/notebooks/period_number/issue.txt"
        
        # 添加损失值历史记录
        loss_history = []
        
        while True:
            try:
                # 1. 读取上一次训练的最后期号
                try:
                    with open(issue_file, 'r') as f:
                        last_issue = f.read().strip()
                    logger.info(f"读取到上次期号: {last_issue}")
                except Exception as e:
                    logger.error(f"读取期号文件失败: {str(e)}")
                    time.sleep(60)
                    continue
                
                # 2. 计算下一个目标期号
                next_target_issue = self._calculate_next_issue(last_issue)
                logger.info(f"下一目标期号: {next_target_issue}")
                
                # 3. 构建训练序列
                try:
                    sequence = self.data_manager.get_sequence(
                        start_issue=self._get_start_issue(next_target_issue),
                        end_issue=next_target_issue
                    )
                    if sequence is None:
                        logger.info(f"等待期号 {next_target_issue} 的数据...")
                        time.sleep(60)
                        continue
                        
                    logger.info(f"获取到序列数据，形状: {sequence.shape}")
                    
                except Exception as e:
                    logger.error(f"获取训练序列失败: {str(e)}")
                    time.sleep(60)
                    continue
                
                # 4. 训练所有模型
                input_data = sequence[:-2880]
                target_data = sequence[-2880:]
                
                # 添加batch维度
                input_data = np.expand_dims(input_data, axis=0)
                target_data = np.expand_dims(target_data, axis=0)
                
                # 记录每个模型的训练前损失值
                with self.model_ensemble.session.as_default():
                    with self.model_ensemble.graph.as_default():
                        for i, model in enumerate(self.model_ensemble.models):
                            try:
                                # 训练前评估
                                initial_loss = model.evaluate(input_data, target_data, verbose=0)
                                
                                # 训练
                                loss = model.train_on_batch(input_data, target_data)
                                
                                # 训练后评估
                                final_loss = model.evaluate(input_data, target_data, verbose=0)
                                
                                logger.info(f"模型 {i+1} 训练: 初始损失={initial_loss:.4f}, "
                                          f"最终损失={final_loss:.4f}, 变化={initial_loss-final_loss:.4f}")
                                
                                loss_history.append({
                                    'model': i+1,
                                    'issue': next_target_issue,
                                    'initial_loss': initial_loss,
                                    'final_loss': final_loss
                                })
                                
                            except Exception as e:
                                logger.error(f"模型 {i+1} 训练失败: {str(e)}")
                                continue
                
                # 5. 保存新的期号
                try:
                    with open(issue_file, 'w') as f:
                        f.write(next_target_issue)
                    logger.info(f"已更新期号为: {next_target_issue}")
                except Exception as e:
                    logger.error(f"更新期号文件失败: {str(e)}")
                
                # 6. 等待一段时间再继续
                time.sleep(10)
                
            except Exception as e:
                logger.error(f"训练循环出错: {str(e)}")
                time.sleep(60)
                continue
            
    def _get_start_issue(self, end_issue):
        """计算起始期号（往前推12天）"""
        date_str, _ = end_issue.split('-')
        end_date = datetime.strptime(date_str, "%Y%m%d")
        start_date = end_date - timedelta(days=12)
        return f"{start_date.strftime('%Y%m%d')}-0001"

    def _calculate_next_issue(self, current_issue):
        """计算下一个期号"""
        # 解析当前期号 (格式: YYYYMMDD-XXXX)
        date_str, period = current_issue.split('-')
        year = int(date_str[:4])
        month = int(date_str[4:6])
        day = int(date_str[6:8])
        period_num = int(period)
        
        # 计算下一期
        if period_num < 1440:
            # 同一天的下一期
            next_period = f"{period_num + 1:04d}"
            next_date = date_str
        else:
            # 下一天的第一期
            next_period = "0001"
            next_date = datetime(year, month, day) + timedelta(days=1)
            next_date = next_date.strftime("%Y%m%d")
            
        return f"{next_date}-{next_period}"
        
    def _get_training_sequence(self, target_issue):
        """获取训练序列"""
        try:
            # 计算起始期号（往前推12天）
            start_issue = self._get_start_issue(target_issue)
            
            # 从数据库获取序列
            sequence = self.data_manager.get_sequence(start_issue, target_issue)
            
            # 验证序列长度
            if sequence is not None and sequence.shape[0] != 14400 + 2880:
                logger.error(f"序列长度不正确: {sequence.shape[0]}, 应为 {14400 + 2880}")
                return None
                
            return sequence
            
        except Exception as e:
            logger.error(f"获取训练序列失败: {str(e)}")
            return None

    def ensure_training_thread(self):
        """确保训练线程运行"""
        if not self._training_loop_running:
            self._training_loop_running = True
            self._training_thread = threading.Thread(
                target=self.training_loop,
                name="TrainingThread",
                daemon=True
            )
            self._training_thread.start()
            logger.info("训练线程已启动")

    def _validate_training_effect(self, model, input_data, target_data, model_index):
        """验证训练效果"""
        try:
            # 添加batch维度
            input_data = np.expand_dims(input_data, axis=0)  # 转换为 (1, sequence_length, features)
            target_data = np.expand_dims(target_data, axis=0)  # 转换为 (1, target_length, features)
            
            # 训练前预测
            pred_before = model.predict(input_data)
            
            # 训练
            loss = model.train_on_batch(input_data, target_data)
            
            # 训练后预测
            pred_after = model.predict(input_data)
            
            # 计算预测变化
            pred_change = np.mean(np.abs(pred_after - pred_before))
            
            # 计算准确度变化
            acc_before = np.mean(np.abs(pred_before - target_data))
            acc_after = np.mean(np.abs(pred_after - target_data))
            
            logger.info(f"模型 {model_index} 训练效果:"
                       f"\n - 预测变化: {pred_change:.4f}"
                       f"\n - 准确度提升: {acc_before-acc_after:.4f}")
            
        except Exception as e:
            logger.error(f"验证训练效果失败: {str(e)}")

    def check_training_status(self):
        """检查训练状态"""
        try:
            # 初始化状态字典
            status = {
                'is_training': self._training_loop_running,
                'thread_alive': False,
                'models_compiled': False,
                'last_losses': [],
                'weights_initialized': []
            }
            
            # 检查训练线程状态
            if self._training_thread:
                status['thread_alive'] = self._training_thread.is_alive()
            
            # 检查模型编译状态
            if self.model_ensemble and hasattr(self.model_ensemble, 'models'):
                with self.model_ensemble.session.as_default():
                    with self.model_ensemble.graph.as_default():
                        models_compiled = []
                        weights_initialized = []
                        last_losses = []
                        
                        for model in self.model_ensemble.models:
                            # 检查模型是否已编译
                            has_optimizer = hasattr(model, 'optimizer')
                            models_compiled.append(has_optimizer)
                            
                            # 检查权重是否已初始化
                            if hasattr(model, 'get_weights'):
                                weights = model.get_weights()
                                weights_initialized.append(
                                    len(weights) > 0 and any(np.any(w != 0) for w in weights)
                                )
                            else:
                                weights_initialized.append(False)
                            
                            # 尝试获取最后的损失值
                            if hasattr(model, 'history') and model.history:
                                if model.history.history and 'loss' in model.history.history:
                                    last_losses.append(model.history.history['loss'][-1])
                                else:
                                    last_losses.append(None)
                            else:
                                last_losses.append(None)
                        
                        status['models_compiled'] = all(models_compiled)
                        status['weights_initialized'] = weights_initialized
                        status['last_losses'] = last_losses
            
            logger.info(f"训练状态检查完成: {status}")
            return status
            
        except Exception as e:
            logger.error(f"检查训练状态失败: {str(e)}")
            # 返回基本状态信息
            return {
                'is_training': False,
                'thread_alive': False,
                'models_compiled': False,
                'last_losses': [],
                'weights_initialized': [],
                'error': str(e)
            }

# 然后创建训练管理器
training_manager = TrainingManager(model_ensemble=model_core)

# 最后启动训练线程
if not hasattr(training_manager, '_training_loop_running'):
    training_manager.ensure_training_thread()


输入形状: (1, 14400, 5)
编码后形状: (1, 14400, 5)
编码示例(前3个时间步):
 [[0.4515651  1.7493052  0.00732224 1.769947   0.44020802]
 [1.4587454  0.9580512  0.7141876  1.2719253  0.6423419 ]
 [1.6264706  0.3802063  0.7045929  1.995584   0.6268969 ]]


In [11]:
# 重置训练管理器
training_manager = TrainingManager(model_ensemble=model_core)

# 检查训练状态
status = training_manager.check_training_status()
print("\n当前训练状态:")
print(f"训练循环运行: {status.get('is_training', False)}")
print(f"训练线程活跃: {status.get('thread_alive', False)}")
print(f"模型已编译: {status.get('models_compiled', False)}")

if 'error' in status:
    print(f"\n检查状态时出现错误: {status['error']}")

# 检查每个模型的状态
if status.get('weights_initialized'):
    print("\n模型权重状态:")
    for i, initialized in enumerate(status['weights_initialized']):
        print(f"模型 {i+1}: {'已初始化' if initialized else '未初始化'}")

# 检查损失值
if status.get('last_losses'):
    print("\n最近损失值:")
    for i, loss in enumerate(status['last_losses']):
        if loss is not None:
            print(f"模型 {i+1}: {loss:.4f}")
        else:
            print(f"模型 {i+1}: 无损失记录")

# 启动训练
print("\n启动训练线程...")
training_manager.ensure_training_thread()


当前训练状态:
训练循环运行: False
训练线程活跃: False
模型已编译: True

模型权重状态:
模型 1: 已初始化
模型 2: 已初始化
模型 3: 已初始化
模型 4: 已初始化
模型 5: 已初始化
模型 6: 已初始化

最近损失值:
模型 1: 无损失记录
模型 2: 无损失记录
模型 3: 无损失记录
模型 4: 无损失记录
模型 5: 无损失记录
模型 6: 无损失记录

启动训练线程...




  updates = self.state_updates
模型 1 训练失败: Incompatible shapes: [1,14402,5] vs. [1,2880,5]
	 [[{{node metrics/mean_absolute_error/sub}}]]
模型 2 训练失败: Incompatible shapes: [1,14402,5] vs. [1,2880,5]
	 [[{{node metrics_2/mean_absolute_error/sub}}]]
模型 3 训练失败: Incompatible shapes: [1,14402,5] vs. [1,2880,5]
	 [[{{node metrics_4/mean_absolute_error/sub}}]]
模型 4 训练失败: Incompatible shapes: [1,14402,5] vs. [1,2880,5]
	 [[{{node metrics_6/mean_absolute_error/sub}}]]
模型 5 训练失败: Incompatible shapes: [1,14402,5] vs. [1,2880,5]
	 [[{{node loss_4/dense_9_loss/SquaredDifference}}]]
模型 6 训练失败: Incompatible shapes: [1,14402,5] vs. [1,2880,5]
	 [[{{node metrics_10/mean_absolute_error/sub}}]]
