In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from models import EEGNetHybridNorm

class BCI2DCursorEnv(gym.Env):
    """
    环境功能：
    - 输入：EEG 段（每段 6 个 trial）
    - 模式：
        'sl-only': 仅基于 SL 概率；
        'sl-rl':   概率 + 位置；
        'hybrid':  EEG特征 + 概率 + 位置
    - 每个 segment 对应一个 episode（6 步）
    """

    def __init__(self, eeg_segments, labels, sl_model_path,
                 feature_npz_path=None, method='hybrid', success_radius=2.0):
        super().__init__()

        # ====== 数据 ======
        self.eeg_segments = eeg_segments  # (N_segments, 6, 22, 1000)
        self.labels = labels              # (N_segments, 6)
        self.num_segments = len(eeg_segments)
        self.segment_len = eeg_segments.shape[1]
        self.success_radius = float(success_radius)

        self.method = method
        self.grid_size = 20
        self.max_steps = 50

        # ====== 模型加载 ======
        self.sl_decoder = self._load_sl_decoder(sl_model_path)
        if method == 'hybrid' and feature_npz_path is not None:
            self.feature_bank = np.load(feature_npz_path)['features']
            print(f" Hybrid 特征载入完成: {self.feature_bank.shape}")
        else:
            self.feature_bank = None

        # ====== 环境空间 ======
        self.action_space = spaces.Discrete(4)
        if method == 'sl-only':
            self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)
        elif method == 'sl-rl':
            self.observation_space = spaces.Box(low=-1, high=1, shape=(8,), dtype=np.float32)
        else:
            # hybrid 模式：EEG特征(160) + 概率(4) + 位置(4)
            self.observation_space = spaces.Box(low=-1, high=1, shape=(168,), dtype=np.float32)

        # ====== 状态变量 ======
        self.current_segment_idx = None
        self.current_segment = None
        self.segment_labels = None
        self.trial_ptr = 0
        self.cursor_pos = None
        self.target_pos = None
        self._last_distance = None
        self.steps = 0


    # ====================== 模型加载 ======================
    def _load_sl_decoder(self, path):
        
        model = EEGNetHybridNorm(num_classes=4, num_channels=22, sample_length=1000)
        model.load_state_dict(torch.load(path, map_location='cpu'))
        model.eval()
        print(f" SL 模型加载完成: {path}")
        return model


    # ====================== 特征提取 ======================
    def _get_sl_probabilities(self, eeg_signal):
        with torch.no_grad():
            eeg_tensor = torch.FloatTensor(eeg_signal).unsqueeze(0).unsqueeze(0)
            probs = self.sl_decoder(eeg_tensor)
            probs = F.softmax(probs, dim=1).cpu().numpy().flatten()
            return self._apply_noise(probs, sigma=0.05).astype(np.float32)

    def _get_eeg_features(self):
        """在 hybrid 模式下从已保存特征中索引对应 trial"""
        if self.feature_bank is not None:
            global_idx = self.current_segment_idx * self.segment_len + self.trial_ptr
            feats = self.feature_bank[global_idx]
            if feats.shape[0] > 160:
                feats = feats[:160]
            return feats.astype(np.float32)
        else:
            return np.zeros(160, dtype=np.float32)


    # ====================== 状态构建 ======================
    def _build_state(self, eeg_signal):
        sl_probs = self._get_sl_probabilities(eeg_signal)
        if self.method == 'sl-only':
            return sl_probs
        elif self.method == 'sl-rl':
            return np.concatenate([sl_probs, self.cursor_pos, self.target_pos]).astype(np.float32)
        else:
            eeg_features = self._get_eeg_features()
            return np.concatenate([eeg_features, sl_probs, self.cursor_pos, self.target_pos]).astype(np.float32)


    # ====================== 环境重置 ======================
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        # 随机选择一个 segment
        self.current_segment_idx = np.random.randint(0, self.num_segments)
        self.current_segment = self.eeg_segments[self.current_segment_idx]
        self.segment_labels = self.labels[self.current_segment_idx]
        self.trial_ptr = 0
        self.steps = 0

        # 初始化光标与目标
        self.cursor_pos = np.array([self.grid_size // 2, self.grid_size // 2], dtype=np.float32)
        target_offsets = {1: [-5, 0], 2: [5, 0], 3: [0, -5], 4: [0, 5]}
        label = int(self.segment_labels[-1])  # 最后一帧定义终点动作
        offset = np.array(target_offsets.get(label, [0, 0]))
        self.target_pos = np.clip(self.cursor_pos + offset, 0, self.grid_size - 1)
        self._last_distance = np.linalg.norm(self.cursor_pos - self.target_pos)

        eeg_signal = self.current_segment[self.trial_ptr]
        state = self._build_state(eeg_signal)
        return state, {}


    # ====================== 步进逻辑 ======================
    def step(self, action):
        self.steps += 1

        # 动作映射（上、下、左、右）
        move_map = {
            0: np.array([0, 1]),
            1: np.array([0, -1]),
            2: np.array([-1, 0]),
            3: np.array([1, 0])
        }
        move = move_map[action] 
        self.cursor_pos = np.clip(self.cursor_pos + move, 0, self.grid_size - 1)

        # 奖励
        reward, done = self._calculate_reward()

        # 推进 EEG 段
        if self.trial_ptr < self.segment_len - 1:
            self.trial_ptr += 1
        eeg_signal = self.current_segment[self.trial_ptr]

        next_state = self._build_state(eeg_signal)
        if self.trial_ptr == self.segment_len - 1 or self.steps >= self.max_steps:
            done = True

        return next_state, reward, done, False, {}


    # ====================== 奖励函数 ======================
    def _calculate_reward(self):
        new_dist = np.linalg.norm(self.cursor_pos - self.target_pos)
        old_dist = self._last_distance
        reward = -1
        reward += (old_dist - new_dist) * 10

        done = False
        if new_dist < 1.5:
            reward += 100
            done = True

        self._last_distance = new_dist
        return reward, done


    # ====================== 噪声扰动 ======================
    def _apply_noise(self, probs, sigma=0.0):
        """评估 SL-only 时关闭噪声"""
        if sigma == 0:
            return probs / np.sum(probs)
        noise = np.random.normal(0, sigma, size=probs.shape)
        noisy = np.clip(probs + noise, 0, 1)
        return noisy / np.sum(noisy)



    # ====================== 可视化 ======================
    def render(self):
        plt.figure(figsize=(5, 5))
        plt.scatter(*self.target_pos, c='r', label='Target')
        plt.scatter(*self.cursor_pos, c='b', label='Cursor')
        plt.xlim(0, self.grid_size)
        plt.ylim(0, self.grid_size)
        plt.legend()
        plt.title(f"Step: {self.steps}")
        plt.show()
