In [1]:
!pip install "stable-baselines3[extra]"
!pip install wandb
!wget http://www.aiotlab.org/teaching/oop/tetris/TetrisTCPserver_v0.6.jar
import os

if os.path.exists("TetrisTCPserver_v0.6.jar"):
    print("✅ 檔案複製成功")
else:
    print("❌ 檔案複製失敗")

Collecting shimmy~=1.1.0 (from shimmy[atari]~=1.1.0; extra == "extra"->stable-baselines3[extra])
  Downloading Shimmy-1.1.0-py3-none-any.whl.metadata (3.3 kB)
Collecting autorom~=0.6.1 (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM-0.6.1-py3-none-any.whl.metadata (2.4 kB)
Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.7/434.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ale-py~=0.8.1 (from shimmy[atari]~=1.1.0; extra == "extra"->stable-baselines3[extra])
  Downloading ale_py-0.8.1-cp311-cp311-manylinux_2_17_x86_64.man

In [None]:
import numpy as np
import socket
import cv2
import matplotlib.pyplot as plt
import subprocess
import os
import shutil
import glob
import imageio
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack
from IPython.display import FileLink, display, Image
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
import torch
# 使用 wandb 記錄訓練日誌
import os
import wandb
from kaggle_secrets import UserSecretsClient
from stable_baselines3.common.vec_env import DummyVecEnv

# 從 Kaggle Secrets 讀取 API Token
user_secrets = UserSecretsClient()
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

# 設定環境變數，模擬 login
os.environ["WANDB_API_KEY"] = WANDB_API_KEY

# login & init
wandb.login()
wandb.init(project="tetris-training", entity="t113598065-ntut-edu-tw")

log_path = "/kaggle/working/tetris_train_log.txt"

def write_log(message):
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(message + "\n")
    print(message)

import time

def wait_for_tetris_server(ip="127.0.0.1", port=10612, timeout=30):
    write_log("⏳ 等待 Tetris TCP server 啟動中...")
    start_time = time.time()
    while True:
        try:
            test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            test_sock.settimeout(1.0)
            test_sock.connect((ip, port))
            test_sock.close()
            write_log("✅ Java TCP server 準備完成，連線成功")
            break
        except socket.error:
            if time.time() - start_time > timeout:
                raise TimeoutError("❌ 等待 Java TCP server 超時")
            time.sleep(0.5)

# 啟動 Java Tetris server
print("Java started")
subprocess.Popen(["java", "-jar", "TetrisTCPserver_v0.6.jar"])
write_log("✅ Java server started")
wait_for_tetris_server()

if torch.cuda.is_available():
    print("✅ PyTorch is using GPU:", torch.cuda.get_device_name(0))
else:
    print("❌ PyTorch is using CPU")
# ----------------------------
# 定義 Tetris 環境 (採用老師的格式)
class RewardLoggerCallback(BaseCallback):
    def __init__(self, log_path="./reward_log.txt", verbose=0):
        super().__init__(verbose)
        self.log_path = log_path
        self.episode_rewards = []

    def _on_step(self) -> bool:
        # 每次 env step 呼叫一次，但只有在 episode 結束時才記錄
        infos = self.locals.get("infos", [])
        for info in infos:
            if "episode" in info:  # VecEnv 會回傳 episode reward
                ep_reward = info["episode"]["r"]
                self.episode_rewards.append(ep_reward)
                with open(self.log_path, "a") as f:
                    f.write(f"{len(self.episode_rewards)},{ep_reward}\n")
                print(f"📈 Episode {len(self.episode_rewards)} Reward: {ep_reward}")
        return True

class TetrisEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 20}
    N_DISCRETE_ACTIONS = 5
    IMG_HEIGHT = 200
    IMG_WIDTH = 100
    IMG_CHANNELS = 3

    def __init__(self, host_ip="127.0.0.1", host_port=10612):
        super().__init__()
        self.action_space = spaces.Discrete(self.N_DISCRETE_ACTIONS)
        # self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)
        self.observation_space = spaces.Box(low=0, high=255, shape=(1, 84, 84), dtype=np.uint8)
        self.server_ip = host_ip
        self.server_port = host_port

        self.client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.client_sock.connect((self.server_ip, self.server_port))

        # 初始化 reward shaping 與統計用變數
        self.lines_removed = 0
        self.height = 0
        self.holes = 0
        self.lifetime = 0

    def step(self, action):
        if action == 0:
            self.client_sock.sendall(b"move -1\n")
        elif action == 1:
            self.client_sock.sendall(b"move 1\n")
        elif action == 2:
            self.client_sock.sendall(b"rotate 0\n")
        elif action == 3:
            self.client_sock.sendall(b"rotate 1\n")
        elif action == 4:
            self.client_sock.sendall(b"drop\n")
    
        terminated, lines, height, holes, observation = self.get_tetris_server_response(self.client_sock)
    
        reward = 0
        if action == 4:
            reward += 5
    
        if height > self.height:
            reward -= (height - self.height) * 5
    
        if holes < self.holes:
            reward += (self.holes - holes) * 10
    
        if lines > self.lines_removed:
            reward += (lines - self.lines_removed) * 1000
            self.lines_removed = lines
    
        self.height = height
        self.holes = holes
        self.lifetime += 1
    
        info = {'removed_lines': self.lines_removed, 'lifetime': self.lifetime}
    
        truncated = False
    
        # 關鍵！處理終止觀察值
        if terminated:
            info['terminal_observation']  = observation.copy()  
    
        return observation, reward, terminated, truncated, info


    def reset(self, seed=None, options=None):
        self.client_sock.sendall(b"start\n")
        terminated, lines, height, holes, observation = self.get_tetris_server_response(self.client_sock)
        # 重置統計變數
        self.lines_removed = 0
        self.height = 0
        self.holes = 0
        self.lifetime = 0
        return observation, {}

    def render(self):
        cv2.imshow("Tetris", self.last_observation)
        cv2.waitKey(1)

    def close(self):
        self.client_sock.close()
        cv2.destroyAllWindows()

    def get_tetris_server_response(self, sock):
        is_game_over = (sock.recv(1) == b'\x01')
        removed_lines = int.from_bytes(sock.recv(4), 'big')
        height = int.from_bytes(sock.recv(4), 'big')
        holes = int.from_bytes(sock.recv(4), 'big')
        img_size = int.from_bytes(sock.recv(4), 'big')
        img_png = sock.recv(img_size)
        nparr = np.frombuffer(img_png, np.uint8)
        np_image = cv2.imdecode(nparr, -1)
        resized = cv2.resize(np_image, (84, 84))
        gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
        gray = np.expand_dims(gray, axis=0)  # <- 關鍵！channel-first
        self.last_observation = gray.copy()
        return is_game_over, removed_lines, height, holes, gray
    

        
# 檢查環境
print("✅ 建立環境開始")
env = TetrisEnv()
check_env(env)

# ----------------------------
# 建立訓練環境（使用向量化、多個 env）並加入正規化與 frame stacking
# 這部分主要用於加速並穩定訓練
# train_env = make_vec_env(TetrisEnv, n_envs=3)
# train_env = VecNormalize(train_env, norm_obs=False, norm_reward=True)
# train_env = VecFrameStack(train_env, n_stack=4, channels_order='first')
train_env = DummyVecEnv([lambda: TetrisEnv()])
train_env = VecFrameStack(train_env, n_stack=4, channels_order="first")

# ----------------------------
# 使用 DQN 進行訓練，調整超參數以提升效能：
# 這裡設定 buffer_size、learning_starts、target_update_interval 等參數
model = DQN("CnnPolicy", train_env, verbose=1, tensorboard_log="./sb3_log/",
            gamma=0.95,
            learning_rate=1e-4,         # 較低的學習率有助於穩定收斂
            buffer_size=20000,         # 經驗回放緩衝區大小
            learning_starts=1000,       # 多少步後開始學習
            policy_kwargs=dict(normalize_images=False),
            target_update_interval=1000 # 目標網路更新頻率
           )
write_log("Model device: " + str(model.device))
# model.learn(total_timesteps=100000)  # 可根據需要延長 timesteps1000000
reward_logger = RewardLoggerCallback(log_path="./reward_log.txt")
model.learn(total_timesteps=5000000, callback=reward_logger)

# 儲存訓練後的模型（訓練完畢後可先暫停 train_env 的歸一化更新）
train_env.training = False

# ----------------------------
# 包裝測試環境，但僅用來符合 predict 格式，取影像還是從原生環境拿
# wrapped_test_env = make_vec_env(TetrisEnv, n_envs=1)
# wrapped_test_env = VecNormalize(wrapped_test_env, norm_obs=False, norm_reward=False, training=False)
# wrapped_test_env = VecFrameStack(wrapped_test_env, n_stack=4, channels_order='first')
wrapped_test_env = DummyVecEnv([lambda: TetrisEnv()])
wrapped_test_env = VecFrameStack(wrapped_test_env, n_stack=4, channels_order="first")

# 原始環境保留用來取影像
raw_test_env = TetrisEnv()

# 初始化狀態
wrapped_obs = wrapped_test_env.reset()
raw_obs, _ = raw_test_env.reset()

frames = []
total_test_reward = 0
test_steps = 1000

for step in range(test_steps):
    action, _ = model.predict(wrapped_obs, deterministic=True)

    # 執行動作
    next_raw_obs, reward, done, truncated, info = raw_test_env.step(action)
    wrapped_obs, _, _, _ = wrapped_test_env.step(action)

    total_test_reward += reward
    frames.append(np.expand_dims(raw_obs.copy(), axis=0))
    # frames.append(raw_obs.copy())  # 儲存原始影像
    raw_obs = next_raw_obs

    if done:
        break

write_log("Test completed: Total reward = " + str(total_test_reward))

# 將回放影像存入資料夾（依老師格式）
replay_folder = './replay'
if os.path.exists(replay_folder):
    shutil.rmtree(replay_folder)
os.makedirs(replay_folder, exist_ok=True)
episode_folder = os.path.join(replay_folder, "0", "0")
os.makedirs(episode_folder, exist_ok=True)
for i, frame in enumerate(frames):
    fname = os.path.join(episode_folder, '{:06d}.png'.format(i))
    cv2.imwrite(fname,frame[0].squeeze())

# 產生 replay GIF（最佳回放）
filenames = sorted(glob.glob(episode_folder + '/*.png'))
gif_images = []
for filename in filenames:
    gif_images.append(imageio.imread(filename))
imageio.mimsave('replay.gif', gif_images, loop=0)
print("Replay GIF saved: replay.gif")
display(FileLink('replay.gif'))

# 將測試結果寫入 CSV（格式與老師版本一致）
with open('tetris_best_score_test2.csv', 'w') as fs:
    fs.write('id,removed_lines,played_steps\n')
    fs.write(f'0,{info["removed_lines"]},{info["lifetime"]}\n')
    fs.write(f'1,{info["removed_lines"]},{info["lifetime"]}\n')
print("CSV file saved: tetris_best_score_test2.csv")
display(FileLink('tetris_best_score_test2.csv'))
wandb.save('tetris_best_score_test2.csv')

# ----------------------------
# 儲存最終模型（請確認將 '113598065' 替換成你的學號）
model.save('113598065_dqn_30env_1M.zip')
print("Model saved: 113598065_dqn_30env_1M.zip")
display(FileLink('113598065_dqn_30env_1M.zip'))
wandb.save('113598065_dqn_30env_1M.zip')

# 關閉環境
wrapped_test_env.close()
raw_test_env.close()
train_env.close()



Java started
✅ Java server started
⏳ 等待 Tetris TCP server 啟動中...
✅ Java TCP server 準備完成，連線成功
✅ PyTorch is using GPU:Client has joined the game Tesla T4
Client has exited the game

✅ 建立環境開始
Client has exited the game
Client has joined the game
Address already in use (Bind failed)
Tetris TCP server is listening at 10612
Client has joined the game
Using cuda device
Model device: cuda
Logging to ./sb3_log/DQN_11
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.856    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 23       |
|    time_elapsed     | 6        |
|    total_timesteps  | 152      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.683    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 23       |
|    time_elapsed     | 14       |
|    total_timesteps  | 334 

In [None]:
import numpy as np
import socket
import cv2
import matplotlib.pyplot as plt
import subprocess
import os
import shutil
import glob
import imageio
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack, DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
import torch
import time
import wandb

# --- Kaggle/Jupyter 特有導入 ---
try:
    from kaggle_secrets import UserSecretsClient
    # 如果在 Kaggle 環境，嘗試讀取 WandB API Key
    try:
        user_secrets = UserSecretsClient()
        WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
        os.environ["WANDB_API_KEY"] = WANDB_API_KEY
        print("Trying to login WandB using Kaggle secrets...")
        wandb.login()
        wandb.init(project="tetris-training", entity="t113598065-ntut-edu-tw", sync_tensorboard=True)
        print("✅ WandB login successful via Kaggle secrets.")
    except Exception as e:
        print(f"❌ WandB setup via Kaggle secrets failed: {e}. Trying anonymous login.")
        # 如果 Kaggle Secrets 失敗，嘗試匿名
        wandb.login(anonymous="allow")
        wandb.init(project="tetris-training", anonymous="allow", sync_tensorboard=True)
        print("✅ WandB login successful anonymously.")
except ImportError:
    # 如果不在 Kaggle，嘗試直接登入 (可能需要手動 `wandb login` 或設置環境變數)
    print("Kaggle secrets not found. Trying standard WandB login...")
    try:
        # 檢查環境變數
        if "WANDB_API_KEY" in os.environ:
             wandb.login()
             wandb.init(project="tetris-training", entity="t113598065-ntut-edu-tw", sync_tensorboard=True)
             print("✅ WandB login successful via environment variable.")
        else:
            # 嘗試匿名登入作為後備
             wandb.login(anonymous="allow")
             wandb.init(project="tetris-training", anonymous="allow", sync_tensorboard=True)
             print("✅ WandB login successful anonymously (no API key found).")
    except Exception as e:
        print(f"❌ Standard WandB login failed: {e}. WandB logging disabled.")
        # 禁用 WandB，使腳本仍能運行
        wandb.init(mode="disabled")


# --- IPython Display Fallback ---
try:
    from IPython.display import display, FileLink
except ImportError:
    def display(dummy): pass
    def FileLink(path): return f"File available at: {path}"
    print("⚠️ IPython.display not found. Download links may not work as expected.")

# --- Logging Setup ---
log_path = "/kaggle/working/tetris_train_log.txt" # Kaggle 路徑
if not os.path.exists("/kaggle/working/"):
    log_path = "./tetris_train_log.txt" # 本地路徑

def write_log(message):
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}\n")
    print(message)

# --- Wait for Server ---
def wait_for_tetris_server(ip="127.0.0.1", port=10612, timeout=30):
    write_log("⏳ 等待 Tetris TCP server 啟動中...")
    start_time = time.time()
    while True:
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_sock:
                test_sock.settimeout(1.0)
                test_sock.connect((ip, port))
            write_log(f"✅ Java TCP server ({ip}:{port}) 準備完成，連線成功")
            break
        except socket.error as e:
            if time.time() - start_time > timeout:
                write_log(f"❌ 等待 Java TCP server 超時 ({timeout}s)")
                raise TimeoutError("❌ 等待 Java TCP server 超時")
            time.sleep(0.5)

# --- Reward Logger Callback (TXT format) ---
class RewardLoggerCallback(BaseCallback):
    def __init__(self, log_path: str, verbose: int = 0):
        super().__init__(verbose)
        self.log_path = log_path
        self.episode_rewards = []
        # Optional: Clear log file at the beginning of a new run
        # with open(self.log_path, "w", encoding="utf-8") as f:
        #     pass

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])
        for i, info in enumerate(infos):
             if isinstance(self.env, DummyVecEnv) or isinstance(self.env, VecFrameStack) or isinstance(self.env, VecNormalize):
                if self.locals["dones"][i]:
                   if "episode" in info:
                       ep_reward = info["episode"]["r"]
                       ep_len = info["episode"]["l"]
                       ep_num = len(self.episode_rewards) + 1
                       self.episode_rewards.append(ep_reward)
                       with open(self.log_path, "a", encoding="utf-8") as f:
                           f.write(f"{ep_num},{ep_reward}\n")
                       # print(f"📈 Episode {ep_num} finished. Reward: {ep_reward:.2f}, Length: {ep_len}") # Optional print
                       if wandb.run and wandb.run.mode != "disabled":
                           wandb.log({"train/episode_reward": ep_reward, "train/episode_length": ep_len}, step=self.num_timesteps)

             elif self.locals["dones"]:
                 if "episode" in info:
                       ep_reward = info["episode"]["r"]
                       ep_len = info["episode"]["l"]
                       ep_num = len(self.episode_rewards) + 1
                       self.episode_rewards.append(ep_reward)
                       with open(self.log_path, "a", encoding="utf-8") as f:
                            f.write(f"{ep_num},{ep_reward}\n")
                       # print(f"📈 Episode {ep_num} finished. Reward: {ep_reward:.2f}, Length: {ep_len}") # Optional print
                       if wandb.run and wandb.run.mode != "disabled":
                           wandb.log({"train/episode_reward": ep_reward, "train/episode_length": ep_len}, step=self.num_timesteps)
        return True

# --- Tetris 環境 (基於假設實現) ---
class TetrisEnv(gym.Env):
    """
    Tetris 環境，通過 TCP 與外部 Java 伺服器通信。
    **警告:** 內部實現基於對伺服器協議的假設。
    """
    metadata = {'render_modes': ['human', 'ansi'], 'render_fps': 4}

    def __init__(self, ip="127.0.0.1", port=10612, render_mode=None):
        super().__init__()

        self.server_ip = ip
        self.server_port = port
        self.sock = None
        self.buffer = ""
        self.render_mode = render_mode
        self.current_observation = None

        # **假設:** 狀態是 14 維向量 (10列高 + 4特徵)
        self.state_dim = 14
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)

        # **假設:** 動作是 40 種 (4 旋轉 * 10 X座標)
        self.num_actions = 40
        self.action_space = spaces.Discrete(self.num_actions)

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        # write_log(f"TetrisEnv initialized. Observation space: {self.observation_space}, Action space: {self.action_space}") # Debug

    def _connect_to_server(self):
        if self.sock:
            self.close()
        try:
            self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.sock.settimeout(10.0)
            self.sock.connect((self.server_ip, self.server_port))
            # write_log(f"🔗 Connected to Tetris server at {self.server_ip}:{self.server_port}") # Debug
        except socket.error as e:
            write_log(f"❌ Failed to connect to Tetris server: {e}")
            self.sock = None
            raise ConnectionError(f"Failed to connect to Tetris server: {e}")

    def _send_command(self, command):
        if not self.sock:
            raise ConnectionError("Not connected to server.")
        try:
            full_command = command + "\n"
            self.sock.sendall(full_command.encode('utf-8'))
        except socket.error as e:
            write_log(f"❌ Error sending command '{command}': {e}")
            self.close()
            raise ConnectionError(f"Error sending command: {e}")

    def _receive_data(self, buffer_size=4096):
        if not self.sock:
             raise ConnectionError("Not connected to server.")
        while "\n" not in self.buffer:
            try:
                chunk = self.sock.recv(buffer_size)
                if not chunk:
                    self.close()
                    raise ConnectionAbortedError("Connection closed by server.")
                self.buffer += chunk.decode('utf-8')
            except socket.timeout:
                self.close()
                raise TimeoutError("Socket timeout.")
            except socket.error as e:
                self.close()
                raise ConnectionError(f"Socket error: {e}")
            except UnicodeDecodeError as e:
                 write_log(f"❌ Error decoding server message: {e}. Received bytes: {chunk}")
                 raise ValueError(f"Error decoding server message: {e}")
        line_end = self.buffer.find("\n")
        data_line = self.buffer[:line_end]
        self.buffer = self.buffer[line_end + 1:]
        return data_line

    def _parse_server_response(self, data_line):
        """
        **假設:** 解析格式為 "f1,f2,...,f14,reward,terminated_flag" 的字符串。
        **假設:** 返回的 info 包含評估所需的 'lines_cleared_this_step', 'removed_lines', 'lifetime'。
                  (這裡的實現只是一個基本框架，實際 info 需要從伺服器或 Env 邏輯填充)
        """
        try:
            parts = data_line.strip().split(',')
            expected_parts = self.state_dim + 2 # state + reward + terminated_flag
            if len(parts) != expected_parts:
                write_log(f"❌ Unexpected server response format. Expected {expected_parts} parts, got {len(parts)}: '{data_line}'")
                # 返回錯誤狀態
                obs = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
                return obs, 0.0, True, False, {"error": "Invalid server response format"}

            observation = np.array(parts[:self.state_dim], dtype=np.float32)
            reward = float(parts[self.state_dim])
            terminated_flag = int(parts[self.state_dim + 1])
            terminated = (terminated_flag == 1)
            truncated = False # TCP 環境通常不因時間截斷

            # **關鍵假設**: info 字典包含評估所需的鍵。
            # 這裡只是返回一個空的或預設的 info，實際的填充需要在 step/reset 或從伺服器獲取。
            # 在評估循環中，我們需要這些鍵被正確填充。
            info = {
                "lines_cleared_this_step": 0, # 需要伺服器或 Env 邏輯提供
                "removed_lines": 0,       # 需要伺服器或 Env 邏輯提供 (回合結束時)
                "lifetime": 0             # 需要伺服器或 Env 邏輯提供 (回合結束時)
            }

            # 基本檢查觀察值是否在空間內 (可選，但有助於調試)
            if not self.observation_space.contains(observation):
                 write_log(f"⚠️ Warning: Observation {observation} out of defined space {self.observation_space}. Clipping or check server data.")
                 # observation = np.clip(observation, self.observation_space.low, self.observation_space.high) # 可選的裁剪

            return observation, reward, terminated, truncated, info

        except Exception as e:
            write_log(f"❌ Error parsing server response '{data_line}': {e}")
            obs = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
            return obs, 0.0, True, False, {"error": f"Parsing error: {e}"}

    def _map_action_to_command(self, action):
        """
        **假設:** 將動作索引 (0-39) 映射到 "PLACE <rot> <x>" 指令。
        """
        if not 0 <= action < self.num_actions:
            write_log(f"❌ Invalid action received: {action}. Must be between 0 and {self.num_actions - 1}.")
            action = 0 # 使用一個安全的默認動作

        num_x_positions = 10 # Assumed based on action space size
        rotation_index = action // num_x_positions
        x_position = action % num_x_positions
        command = f"PLACE {rotation_index} {x_position}"
        return command

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        try:
            self._connect_to_server()
            self._send_command("RESET")
            response = self._receive_data()
            observation, _, _, _, info = self._parse_server_response(response) # 解析初始狀態
            self.current_observation = observation
            # write_log("🔄 Environment reset.") # Debug
            if self.render_mode == "human": self.render()
            # 確保返回 info 字典
            if info is None: info = {}
            return observation, info
        except (ConnectionError, TimeoutError, ValueError, ConnectionAbortedError) as e:
             write_log(f"❌ Error during reset: {e}. Returning zero observation.")
             obs = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
             info = {"error": f"Reset error: {e}"}
             self.current_observation = obs
             return obs, info

    def step(self, action):
        if self.sock is None:
            write_log("❌ Attempting to step with no connection. Reset required.")
            obs = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
            # 返回 terminated=True 讓 SB3 結束回合
            return obs, 0.0, True, False, {"error": "Connection lost during step"}
        try:
            command = self._map_action_to_command(action)
            self._send_command(command)
            response = self._receive_data()
            observation, reward, terminated, truncated, info = self._parse_server_response(response)
            self.current_observation = observation
            if self.render_mode == "human": self.render()
            # 確保 info 是字典
            if info is None: info = {}

            # **關鍵假設實現**: 如果回合結束，填充最終評估指標 (這部分邏輯可能需要在伺服器端或更複雜的 Env 邏輯中完成)
            if terminated:
                # 這裡我們只是示例性地設置，實際值需要從伺服器獲取或由 Env 計算
                info['removed_lines'] = info.get('final_removed_lines', 0) # 假設伺服器返回了 'final_removed_lines'
                info['lifetime'] = info.get('final_lifetime', 0) # 假設伺服器返回了 'final_lifetime'

            return observation, float(reward), bool(terminated), bool(truncated), info
        except (ConnectionError, TimeoutError, ValueError, ConnectionAbortedError) as e:
             write_log(f"❌ Error during step: {e}. Terminating episode.")
             obs = self.current_observation if self.current_observation is not None else np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
             return obs, 0.0, True, False, {"error": f"Step error: {e}"}

    def render(self):
        if self.render_mode == "ansi":
            print(f"State: {self.current_observation}")
        elif self.render_mode == "human":
             # 這裡可以添加基於 self.current_observation 的視覺化代碼 (例如用 cv2 或 matplotlib)
             # 由於不知道狀態向量的確切含義，這裡只打印
             print(f"State (render_mode=human): {self.current_observation}")
             pass

    def close(self):
        if self.sock:
            try:
                self.sock.shutdown(socket.SHUT_RDWR)
                self.sock.close()
                # write_log("🔌 Socket connection closed.") # Debug
            except socket.error:
                pass # Ignore errors during close
            finally:
                self.sock = None
        if self.render_mode == "human":
             try:
                 cv2.destroyAllWindows()
             except NameError: # cv2 可能未導入或未使用
                 pass


# --- 主程式 ---
if __name__ == "__main__":
    # --- Constants ---
    SERVER_IP = "127.0.0.1"
    SERVER_PORT = 10612
    STUDENT_ID = "113598065" # 你的學號
    ENV_COUNT_STR = "30env" # 檔名中的環境數量標識 (即使實際只用1個)
    STEPS_STR = "1M"      # 檔名中的步數標識
    TOTAL_TRAIN_TIMESTEPS = 1_000_000 # 實際訓練步數

    # --- Working Directory (Kaggle vs Local) ---
    WORK_DIR = "/kaggle/working/"
    if not os.path.exists(WORK_DIR):
        WORK_DIR = "./" # Use current directory if /kaggle/working/ doesn't exist

    MODEL_FILENAME_BASE = f"{STUDENT_ID}_dqn_{ENV_COUNT_STR}_{STEPS_STR}"
    MODEL_SAVE_PATH_ZIP = os.path.join(WORK_DIR, f"{MODEL_FILENAME_BASE}.zip")
    CSV_FILENAME = 'tetris_best_score_test2.csv'
    CSV_SAVE_PATH = os.path.join(WORK_DIR, CSV_FILENAME)
    REWARD_LOG_FILE_PATH = os.path.join(WORK_DIR, "reward_log.txt")
    TENSORBOARD_LOG_PATH = os.path.join(WORK_DIR, "tensorboard_logs/")
    JAVA_SERVER_JAR = "TetrisTCPserver_v0.6.jar" # 確保此文件存在

    # --- 1. 啟動 Java Server ---
    write_log("--- Starting Java Server ---")
    java_process = None
    try:
        # 使用 Popen 在後台啟動，並保存進程對象以便後續可能需要終止
        java_process = subprocess.Popen(["java", "-jar", JAVA_SERVER_JAR], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        write_log(f"✅ Java server process started (PID: {java_process.pid}). Waiting for connection...")
        wait_for_tetris_server(SERVER_IP, SERVER_PORT)
    except FileNotFoundError:
        write_log(f"❌ Error: {JAVA_SERVER_JAR} not found. Please place it in the script's directory or provide the correct path.")
        exit()
    except Exception as e:
        write_log(f"❌ Error starting or waiting for Java server: {e}")
        if java_process: java_process.terminate() # 嘗試終止進程
        exit()

    # --- 2. 環境檢查與創建 ---
    write_log("--- Setting up Environment ---")
    env_instance_for_check = None
    train_env = None
    try:
        env_instance_for_check = TetrisEnv(ip=SERVER_IP, port=SERVER_PORT)
        check_env(env_instance_for_check, warn=True)
        write_log("✅ Environment check passed (or only warnings).")
    except Exception as e:
        write_log(f"❌ Environment check failed: {e}")
        if env_instance_for_check: env_instance_for_check.close()
        if java_process: java_process.terminate()
        exit()
    finally:
         if env_instance_for_check: env_instance_for_check.close() # 關閉檢查用的實例

    # 創建用於訓練的環境 (使用 DummyVecEnv)
    train_env = DummyVecEnv([lambda: TetrisEnv(ip=SERVER_IP, port=SERVER_PORT)])

    # --- 3. 定義模型 ---
    write_log("--- Defining DQN Model ---")
    policy_type = "MlpPolicy" # Based on assumed vector observation space
    os.makedirs(TENSORBOARD_LOG_PATH, exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    write_log(f"Using device: {device}")

    model = DQN(
        policy_type,
        train_env,
        verbose=1,
        tensorboard_log=TENSORBOARD_LOG_PATH,
        learning_rate=1e-4,
        buffer_size=50000,
        learning_starts=1000,
        batch_size=64,
        tau=1.0,
        gamma=0.99,
        train_freq=4,
        gradient_steps=1,
        target_update_interval=1000,
        exploration_fraction=0.1,
        exploration_final_eps=0.05,
        device=device
    )

    # --- 4. 定義回調函數 ---
    reward_logger = RewardLoggerCallback(log_path=REWARD_LOG_FILE_PATH)
    callbacks = [reward_logger]
    if wandb.run and wandb.run.mode != "disabled":
        from wandb.integration.sb3 import WandbCallback
        wandb_callback = WandbCallback(log="all", verbose=2) # 簡化 WandbCallback
        callbacks.append(wandb_callback)

    # --- 5. 訓練模型 ---
    write_log(f"--- Starting Training for {TOTAL_TRAIN_TIMESTEPS} timesteps ---")
    training_successful = False
    try:
        model.learn(
            total_timesteps=TOTAL_TRAIN_TIMESTEPS,
            log_interval=10, # Logs episodes count, time, fps to console/tensorboard
            callback=callbacks
        )
        training_successful = True
        write_log("✅ Training completed successfully.")
    except KeyboardInterrupt:
        write_log("🛑 Training interrupted by user.")
    except Exception as train_e:
        write_log(f"❌ Training failed: {train_e}")
        error_save_path = os.path.join(WORK_DIR, f"{MODEL_FILENAME_BASE}_error.zip")
        try: model.save(error_save_path); write_log(f"💾 Intermediate model saved to {error_save_path}")
        except: write_log("❌ Failed to save intermediate model after error.")
    finally:
        # --- 保存最終模型 ---
        if training_successful:
            try:
                model.save(MODEL_SAVE_PATH_ZIP)
                write_log(f"💾 Final model saved: {MODEL_SAVE_PATH_ZIP}")
                display(FileLink(MODEL_SAVE_PATH_ZIP))
                if wandb.run and wandb.run.mode != "disabled":
                    wandb.save(MODEL_SAVE_PATH_ZIP, base_path=WORK_DIR)
                    write_log("⬆️ Final model uploaded to WandB.")
            except Exception as final_save_e:
                write_log(f"❌ Failed to save final model: {final_save_e}")
        # 關閉訓練環境
        if train_env: train_env.close()
        write_log("✅ Training environment closed.")

    # --- 6. 評估模型並保存 CSV ---
    if training_successful:
        write_log("\n--- Starting Model Evaluation ---")
        eval_env = None
        try:
            write_log(f"🔄 Loading model from: {MODEL_SAVE_PATH_ZIP}")
            # 確保提供環境以加載觀察/動作空間信息
            eval_model = DQN.load(MODEL_SAVE_PATH_ZIP, device=device, env=None) # env=None 需要模型zip包含環境信息

            # 創建評估環境
            eval_env = TetrisEnv(ip=SERVER_IP, port=SERVER_PORT)
            obs, info = eval_env.reset()

            terminated = False
            truncated = False
            played_steps = 0
            removed_lines_total = 0
            # 假設運行一個完整的回合進行評估
            while not terminated and not truncated:
                action, _ = eval_model.predict(obs, deterministic=True)
                obs, reward, terminated, truncated, info = eval_env.step(action)
                played_steps += 1
                # **假設**: info 包含 'lines_cleared_this_step'
                removed_lines_total += info.get('lines_cleared_this_step', 0)

            # 回合結束後，**假設** info 包含最終統計
            final_removed_lines = info.get('removed_lines', removed_lines_total) # 優先使用 info 中的最終值
            final_lifetime = info.get('lifetime', played_steps) # 優先使用 info 中的最終值
            write_log(f"🏁 Evaluation Episode finished. Steps: {final_lifetime}, Removed Lines: {final_removed_lines}")

            # --- 寫入 CSV 文件 ---
            write_log(f"💾 Writing evaluation results to: {CSV_SAVE_PATH}")
            with open(CSV_SAVE_PATH, 'w') as fs:
                fs.write('id,removed_lines,played_steps\n') # 使用你指定的表頭
                fs.write(f'0,{final_removed_lines},{final_lifetime}\n') # id=0
                fs.write(f'1,{final_removed_lines},{final_lifetime}\n') # id=1 (數據相同)
            write_log("✅ CSV file saved.")
            display(FileLink(CSV_SAVE_PATH))
            if wandb.run and wandb.run.mode != "disabled":
                wandb.save(CSV_SAVE_PATH, base_path=WORK_DIR)
                write_log("⬆️ CSV results uploaded to WandB.")

        except FileNotFoundError:
             write_log(f"❌ Evaluation failed: Model file not found at {MODEL_SAVE_PATH_ZIP}")
        except KeyError as e:
             write_log(f"❌ Evaluation failed: Missing key {e} in environment's info dictionary during evaluation.")
             write_log("   Ensure TetrisEnv step() returns info with 'removed_lines' and 'lifetime' on termination.")
        except Exception as eval_e:
             write_log(f"❌ An error occurred during evaluation or CSV saving: {eval_e}")
        finally:
            if eval_env: eval_env.close() # 確保關閉評估環境
            write_log("✅ Evaluation environment closed.")

    # --- 7. 清理工作 ---
    write_log("--- Cleaning up ---")
    if java_process:
        try:
            java_process.terminate() # 嘗試終止 Java 進程
            java_process.wait(timeout=5) # 等待一小段時間
            write_log("☕ Java server process terminated.")
        except subprocess.TimeoutExpired:
            java_process.kill() # 如果無法終止，強制殺死
            write_log("🔪 Java server process killed.")
        except Exception as e:
            write_log(f"⚠️ Error terminating Java process: {e}")

    if wandb.run and wandb.run.mode != "disabled":
         wandb.finish() # 結束 WandB run
         write_log("📊 WandB run finished.")

    write_log("🏁 Full script finished.")