## Task Definition
- input:
- ouptut:

## RL

#### 子功能

In [None]:
import gym
import numpy as np
from gym import spaces

class AuditLogEnv(gym.Env):
    def __init__(self, session_logs, max_steps=50):
        super().__init__()

        self.logs = session_logs
        self.max_steps = max_steps

        # ===== Action Space =====
        self.action_space = spaces.Discrete(20)

        # ===== Observation Space =====
        # 你先固定 state 向量長度，例如 64
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(64,),
            dtype=np.float32
        )

        self.reset()

    def reset(self):
        self.cursor = 0
        self.window_size = 5
        self.steps = 0

        self.labeled_logs = {}
        self.seen_entities = set()
        self.done = False

        return self._get_state()

    def step(self, action):
        self.steps += 1
        reward = -0.01  # 每一步都有成本

        if action == 0:
            pass  # SKIP

        elif 1 <= action <= 10:
            reward += self._handle_label(action)

        elif 11 <= action <= 15:
            reward += self._handle_search(action)

        elif 16 <= action <= 18:
            self._handle_expand(action)

        elif action == 19:
            reward += self._final_reward()
            self.done = True

        if self.steps >= self.max_steps:
            self.done = True

        return self._get_state(), reward, self.done, {}

    # ===== 以下是環境邏輯 =====

    def _handle_label(self, action):
        idx = (action - 1) // 2
        label = (action - 1) % 2  # 0 benign, 1 malicious

        candidates = self._topk_logs()
        if idx >= len(candidates):
            return -0.1

        log_id = candidates[idx]["id"]
        self.labeled_logs[log_id] = label

        # reward shaping（示意）
        return 0.2 if label == 1 else 0.05

    def _handle_search(self, action):
        idx = action - 11
        entities = self._topk_entities()
        if idx >= len(entities):
            return -0.1

        entity = entities[idx]
        self.seen_entities.add(entity)

        return 0.1

    def _handle_expand(self, action):
        if action == 16:
            self.cursor = max(0, self.cursor - 1)
        elif action == 17:
            self.cursor = min(len(self.logs) - self.window_size, self.cursor + 1)
        elif action == 18:
            self.cursor = max(0, self.cursor - 1)
            self.cursor = min(len(self.logs) - self.window_size, self.cursor + 1)

    def _final_reward(self):
        # session-level reward（F1 / consistency / LLM judge）
        return len(self.labeled_logs) * 0.1

    def _get_state(self):
        obs = self._extract_observation_features()
        mem = self._extract_memory_features()
        prog = np.array([
            self.steps / self.max_steps,
            len(self.labeled_logs),
            len(self.seen_entities)
        ])

        state = np.concatenate([obs, mem, prog])
        return state.astype(np.float32)

    # ===== Feature Extractors =====

    def _extract_observation_features(self):
        # 回傳固定長度向量（例如 40 維）
        return np.zeros(40)

    def _extract_memory_features(self):
        # 例如 21 維
        return np.zeros(21)

    def _topk_logs(self):
        return self.logs[self.cursor:self.cursor + self.window_size]

    def _topk_entities(self):
        # 從目前 window 抽 entity
        return []


#### 主要流程 