In [1]:
!pip install "stable-baselines3[extra]"
!pip install wandb
!wget http://www.aiotlab.org/teaching/oop/tetris/TetrisTCPserver_v0.6.jar
import os

if os.path.exists("TetrisTCPserver_v0.6.jar"):
    print("✅ 檔案複製成功")
else:
    print("❌ 檔案複製失敗")

Collecting shimmy~=1.1.0 (from shimmy[atari]~=1.1.0; extra == "extra"->stable-baselines3[extra])
  Downloading Shimmy-1.1.0-py3-none-any.whl.metadata (3.3 kB)
Collecting autorom~=0.6.1 (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM-0.6.1-py3-none-any.whl.metadata (2.4 kB)
Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.7/434.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ale-py~=0.8.1 (from shimmy[atari]~=1.1.0; extra == "extra"->stable-baselines3[extra])
  Downloading ale_py-0.8.1-cp311-cp311-manylinux_2_17_x86_64.man

In [None]:
import numpy as np
import socket
import cv2
import matplotlib.pyplot as plt
import subprocess
import os
import shutil
import glob
import imageio
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack
from IPython.display import FileLink, display, Image
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
import torch
# 使用 wandb 記錄訓練日誌
import os
import wandb
from kaggle_secrets import UserSecretsClient
from stable_baselines3.common.vec_env import DummyVecEnv

# 從 Kaggle Secrets 讀取 API Token
user_secrets = UserSecretsClient()
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

# 設定環境變數，模擬 login
os.environ["WANDB_API_KEY"] = WANDB_API_KEY

# login & init
wandb.login()
wandb.init(project="tetris-training", entity="t113598065-ntut-edu-tw")

log_path = "/kaggle/working/tetris_train_log.txt"

def write_log(message):
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(message + "\n")
    print(message)

import time

def wait_for_tetris_server(ip="127.0.0.1", port=10612, timeout=30):
    write_log("⏳ 等待 Tetris TCP server 啟動中...")
    start_time = time.time()
    while True:
        try:
            test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            test_sock.settimeout(1.0)
            test_sock.connect((ip, port))
            test_sock.close()
            write_log("✅ Java TCP server 準備完成，連線成功")
            break
        except socket.error:
            if time.time() - start_time > timeout:
                raise TimeoutError("❌ 等待 Java TCP server 超時")
            time.sleep(0.5)

# 啟動 Java Tetris server
print("Java started")
subprocess.Popen(["java", "-jar", "TetrisTCPserver_v0.6.jar"])
write_log("✅ Java server started")
wait_for_tetris_server()

if torch.cuda.is_available():
    print("✅ PyTorch is using GPU:", torch.cuda.get_device_name(0))
else:
    print("❌ PyTorch is using CPU")
# ----------------------------
# 定義 Tetris 環境 (採用老師的格式)
class RewardLoggerCallback(BaseCallback):
    def __init__(self, log_path="./reward_log.txt", verbose=0):
        super().__init__(verbose)
        self.log_path = log_path
        self.episode_rewards = []

    def _on_step(self) -> bool:
        # 每次 env step 呼叫一次，但只有在 episode 結束時才記錄
        infos = self.locals.get("infos", [])
        for info in infos:
            if "episode" in info:  # VecEnv 會回傳 episode reward
                ep_reward = info["episode"]["r"]
                self.episode_rewards.append(ep_reward)
                with open(self.log_path, "a") as f:
                    f.write(f"{len(self.episode_rewards)},{ep_reward}\n")
                print(f"📈 Episode {len(self.episode_rewards)} Reward: {ep_reward}")
        return True

class TetrisEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 20}
    N_DISCRETE_ACTIONS = 5
    IMG_HEIGHT = 200
    IMG_WIDTH = 100
    IMG_CHANNELS = 3

    def __init__(self, host_ip="127.0.0.1", host_port=10612):
        super().__init__()
        self.action_space = spaces.Discrete(self.N_DISCRETE_ACTIONS)
        # self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)
        self.observation_space = spaces.Box(low=0, high=255, shape=(1, 84, 84), dtype=np.uint8)
        self.server_ip = host_ip
        self.server_port = host_port

        self.client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.client_sock.connect((self.server_ip, self.server_port))

        # 初始化 reward shaping 與統計用變數
        self.lines_removed = 0
        self.height = 0
        self.holes = 0
        self.lifetime = 0

    def step(self, action):
        if action == 0:
            self.client_sock.sendall(b"move -1\n")
        elif action == 1:
            self.client_sock.sendall(b"move 1\n")
        elif action == 2:
            self.client_sock.sendall(b"rotate 0\n")
        elif action == 3:
            self.client_sock.sendall(b"rotate 1\n")
        elif action == 4:
            self.client_sock.sendall(b"drop\n")
    
        terminated, lines, height, holes, observation = self.get_tetris_server_response(self.client_sock)
    
        reward = 0
        if action == 4:
            reward += 5
    
        if height > self.height:
            reward -= (height - self.height) * 5
    
        if holes < self.holes:
            reward += (self.holes - holes) * 10
    
        if lines > self.lines_removed:
            reward += (lines - self.lines_removed) * 1000
            self.lines_removed = lines
    
        self.height = height
        self.holes = holes
        self.lifetime += 1
    
        info = {'removed_lines': self.lines_removed, 'lifetime': self.lifetime}
    
        truncated = False
    
        # 關鍵！處理終止觀察值
        if terminated:
            info['terminal_observation']  = observation.copy()  
    
        return observation, reward, terminated, truncated, info


    def reset(self, seed=None, options=None):
        self.client_sock.sendall(b"start\n")
        terminated, lines, height, holes, observation = self.get_tetris_server_response(self.client_sock)
        # 重置統計變數
        self.lines_removed = 0
        self.height = 0
        self.holes = 0
        self.lifetime = 0
        return observation, {}

    def render(self):
        cv2.imshow("Tetris", self.last_observation)
        cv2.waitKey(1)

    def close(self):
        self.client_sock.close()
        cv2.destroyAllWindows()

    def get_tetris_server_response(self, sock):
        is_game_over = (sock.recv(1) == b'\x01')
        removed_lines = int.from_bytes(sock.recv(4), 'big')
        height = int.from_bytes(sock.recv(4), 'big')
        holes = int.from_bytes(sock.recv(4), 'big')
        img_size = int.from_bytes(sock.recv(4), 'big')
        img_png = sock.recv(img_size)
        nparr = np.frombuffer(img_png, np.uint8)
        np_image = cv2.imdecode(nparr, -1)
        resized = cv2.resize(np_image, (84, 84))
        gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
        gray = np.expand_dims(gray, axis=0)  # <- 關鍵！channel-first
        self.last_observation = gray.copy()
        return is_game_over, removed_lines, height, holes, gray
    

        
# 檢查環境
print("✅ 建立環境開始")
env = TetrisEnv()
check_env(env)

# ----------------------------
# 建立訓練環境（使用向量化、多個 env）並加入正規化與 frame stacking
# 這部分主要用於加速並穩定訓練
# train_env = make_vec_env(TetrisEnv, n_envs=3)
# train_env = VecNormalize(train_env, norm_obs=False, norm_reward=True)
# train_env = VecFrameStack(train_env, n_stack=4, channels_order='first')
train_env = DummyVecEnv([lambda: TetrisEnv()])
train_env = VecFrameStack(train_env, n_stack=4, channels_order="first")

# ----------------------------
# 使用 DQN 進行訓練，調整超參數以提升效能：
# 這裡設定 buffer_size、learning_starts、target_update_interval 等參數
model = DQN("CnnPolicy", train_env, verbose=1, tensorboard_log="./sb3_log/",
            gamma=0.95,
            learning_rate=1e-4,         # 較低的學習率有助於穩定收斂
            buffer_size=20000,         # 經驗回放緩衝區大小
            learning_starts=1000,       # 多少步後開始學習
            policy_kwargs=dict(normalize_images=False),
            target_update_interval=1000 # 目標網路更新頻率
           )
write_log("Model device: " + str(model.device))
# model.learn(total_timesteps=100000)  # 可根據需要延長 timesteps1000000
reward_logger = RewardLoggerCallback(log_path="./reward_log.txt")
model.learn(total_timesteps=5000000, callback=reward_logger)

# 儲存訓練後的模型（訓練完畢後可先暫停 train_env 的歸一化更新）
train_env.training = False

# ----------------------------
# 包裝測試環境，但僅用來符合 predict 格式，取影像還是從原生環境拿
# wrapped_test_env = make_vec_env(TetrisEnv, n_envs=1)
# wrapped_test_env = VecNormalize(wrapped_test_env, norm_obs=False, norm_reward=False, training=False)
# wrapped_test_env = VecFrameStack(wrapped_test_env, n_stack=4, channels_order='first')
wrapped_test_env = DummyVecEnv([lambda: TetrisEnv()])
wrapped_test_env = VecFrameStack(wrapped_test_env, n_stack=4, channels_order="first")

# 原始環境保留用來取影像
raw_test_env = TetrisEnv()

# 初始化狀態
wrapped_obs = wrapped_test_env.reset()
raw_obs, _ = raw_test_env.reset()

frames = []
total_test_reward = 0
test_steps = 1000

for step in range(test_steps):
    action, _ = model.predict(wrapped_obs, deterministic=True)

    # 執行動作
    next_raw_obs, reward, done, truncated, info = raw_test_env.step(action)
    wrapped_obs, _, _, _ = wrapped_test_env.step(action)

    total_test_reward += reward
    frames.append(np.expand_dims(raw_obs.copy(), axis=0))
    # frames.append(raw_obs.copy())  # 儲存原始影像
    raw_obs = next_raw_obs

    if done:
        break

write_log("Test completed: Total reward = " + str(total_test_reward))

# 將回放影像存入資料夾（依老師格式）
replay_folder = './replay'
if os.path.exists(replay_folder):
    shutil.rmtree(replay_folder)
os.makedirs(replay_folder, exist_ok=True)
episode_folder = os.path.join(replay_folder, "0", "0")
os.makedirs(episode_folder, exist_ok=True)
for i, frame in enumerate(frames):
    fname = os.path.join(episode_folder, '{:06d}.png'.format(i))
    cv2.imwrite(fname,frame[0].squeeze())

# 產生 replay GIF（最佳回放）
filenames = sorted(glob.glob(episode_folder + '/*.png'))
gif_images = []
for filename in filenames:
    gif_images.append(imageio.imread(filename))
imageio.mimsave('replay.gif', gif_images, loop=0)
print("Replay GIF saved: replay.gif")
display(FileLink('replay.gif'))

# 將測試結果寫入 CSV（格式與老師版本一致）
with open('tetris_best_score_test2.csv', 'w') as fs:
    fs.write('id,removed_lines,played_steps\n')
    fs.write(f'0,{info["removed_lines"]},{info["lifetime"]}\n')
    fs.write(f'1,{info["removed_lines"]},{info["lifetime"]}\n')
print("CSV file saved: tetris_best_score_test2.csv")
display(FileLink('tetris_best_score_test2.csv'))
wandb.save('tetris_best_score_test2.csv')

# ----------------------------
# 儲存最終模型（請確認將 '113598065' 替換成你的學號）
model.save('113598065_dqn_30env_1M.zip')
print("Model saved: 113598065_dqn_30env_1M.zip")
display(FileLink('113598065_dqn_30env_1M.zip'))
wandb.save('113598065_dqn_30env_1M.zip')

# 關閉環境
wrapped_test_env.close()
raw_test_env.close()
train_env.close()



Java started
✅ Java server started
⏳ 等待 Tetris TCP server 啟動中...
✅ Java TCP server 準備完成，連線成功
✅ PyTorch is using GPU:Client has joined the game Tesla T4
Client has exited the game

✅ 建立環境開始
Client has exited the game
Client has joined the game
Address already in use (Bind failed)
Tetris TCP server is listening at 10612
Client has joined the game
Using cuda device
Model device: cuda
Logging to ./sb3_log/DQN_11
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.856    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 23       |
|    time_elapsed     | 6        |
|    total_timesteps  | 152      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.683    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 23       |
|    time_elapsed     | 14       |
|    total_timesteps  | 334 

In [None]:
import numpy as np
import socket
import cv2
import matplotlib.pyplot as plt
import subprocess
import os
import shutil
import glob
import imageio
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack, DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
import torch
import time
import wandb
import io
import struct # For unpacking bytes
from PIL import Image # For image processing

# --- Kaggle/Jupyter 特有導入 ---
try:
    from kaggle_secrets import UserSecretsClient
    # 如果在 Kaggle 環境，嘗試讀取 WandB API Key
    try:
        user_secrets = UserSecretsClient()
        WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
        os.environ["WANDB_API_KEY"] = WANDB_API_KEY
        print("Trying to login WandB using Kaggle secrets...")
        wandb.login()
        wandb.init(project="tetris-training", entity="t113598065-ntut-edu-tw", sync_tensorboard=True)
        print("✅ WandB login successful via Kaggle secrets.")
    except Exception as e:
        print(f"❌ WandB setup via Kaggle secrets failed: {e}. Trying anonymous login.")
        wandb.login(anonymous="allow")
        wandb.init(project="tetris-training", anonymous="allow", sync_tensorboard=True)
        print("✅ WandB login successful anonymously.")
except ImportError:
    # 如果不在 Kaggle，嘗試直接登入
    print("Kaggle secrets not found. Trying standard WandB login...")
    try:
        if "WANDB_API_KEY" in os.environ:
             wandb.login()
             wandb.init(project="tetris-training", entity="t113598065-ntut-edu-tw", sync_tensorboard=True)
             print("✅ WandB login successful via environment variable.")
        else:
             wandb.login(anonymous="allow")
             wandb.init(project="tetris-training", anonymous="allow", sync_tensorboard=True)
             print("✅ WandB login successful anonymously (no API key found).")
    except Exception as e:
        print(f"❌ Standard WandB login failed: {e}. WandB logging disabled.")
        wandb.init(mode="disabled") # Disable WandB if login fails

# --- IPython Display Fallback ---
try:
    from IPython.display import display, FileLink
except ImportError:
    def display(dummy): pass
    def FileLink(path): return f"File available at: {path}"
    print("⚠️ IPython.display not found. Download links may not work as expected.")

# --- Logging Setup ---
log_path = "/kaggle/working/tetris_train_log.txt" # Kaggle 路徑
if not os.path.exists("/kaggle/working/"):
    log_path = "./tetris_train_log.txt" # 本地路徑

def write_log(message):
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}\n")
    print(message)

# --- Wait for Server ---
def wait_for_tetris_server(ip="127.0.0.1", port=10612, timeout=30):
    write_log("⏳ 等待 Tetris TCP server 啟動中...")
    start_time = time.time()
    while True:
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_sock:
                test_sock.settimeout(1.0)
                test_sock.connect((ip, port))
            write_log(f"✅ Java TCP server ({ip}:{port}) 準備完成，連線成功")
            break
        except socket.error as e:
            if time.time() - start_time > timeout:
                write_log(f"❌ 等待 Java TCP server 超時 ({timeout}s)")
                raise TimeoutError("❌ 等待 Java TCP server 超時")
            time.sleep(0.5)

# --- Reward Logger Callback (TXT format) ---
class RewardLoggerCallback(BaseCallback):
    def __init__(self, log_path: str, verbose: int = 0):
        super().__init__(verbose)
        self.log_path = log_path
        self.episode_rewards = []

    def _on_step(self) -> bool:
        # Check if the environment is vectorized
        is_vec_env = isinstance(self.training_env, (DummyVecEnv, VecFrameStack, VecNormalize))

        if is_vec_env:
            for i, done in enumerate(self.locals["dones"]):
                if done:
                    info = self.locals["infos"][i]
                    if "episode" in info:
                        ep_reward = info["episode"]["r"]
                        ep_len = info["episode"]["l"]
                        ep_num = len(self.episode_rewards) + 1
                        self.episode_rewards.append(ep_reward)
                        with open(self.log_path, "a", encoding="utf-8") as f:
                            f.write(f"{ep_num},{ep_reward}\n")
                        if wandb.run and wandb.run.mode != "disabled":
                             wandb.log({"train/episode_reward": ep_reward, "train/episode_length": ep_len}, step=self.num_timesteps)
        # Handling for non-vectorized envs (less common with SB3 but good practice)
        elif self.locals["dones"]:
             info = self.locals["infos"] # Should be a single dict
             if "episode" in info:
                  ep_reward = info["episode"]["r"]
                  ep_len = info["episode"]["l"]
                  ep_num = len(self.episode_rewards) + 1
                  self.episode_rewards.append(ep_reward)
                  with open(self.log_path, "a", encoding="utf-8") as f:
                       f.write(f"{ep_num},{ep_reward}\n")
                  if wandb.run and wandb.run.mode != "disabled":
                       wandb.log({"train/episode_reward": ep_reward, "train/episode_length": ep_len}, step=self.num_timesteps)
        return True


# --- Tetris 環境 (修改為接收圖像) ---
class TetrisEnv(gym.Env):
    """
    Tetris 環境，通過 TCP 與外部 Java 伺服器通信。
    **警告:** 內部實現基於對伺服器協議的假設 (發送圖像 + 文本狀態)。
    """
    metadata = {'render_modes': ['human', 'ansi'], 'render_fps': 4}

    def __init__(self, ip="127.0.0.1", port=10612, render_mode=None):
        super().__init__()

        self.server_ip = ip
        self.server_port = port
        self.sock = None
        self.text_buffer = "" # Buffer for text messages
        self.render_mode = render_mode
        self.current_observation_image = None

        # **假設:** 狀態是 200x100 RGB 圖像
        self.image_height = 200
        self.image_width = 100
        self.image_channels = 3
        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(self.image_height, self.image_width, self.image_channels),
            dtype=np.uint8
        )

        # **假設:** 動作是 40 種 (4 旋轉 * 10 X座標)
        self.num_actions = 40
        self.action_space = spaces.Discrete(self.num_actions)

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        write_log(f"TetrisEnv (Image Mode) initialized. Observation space: {self.observation_space}, Action space: {self.action_space}")

    def _connect_to_server(self):
        if self.sock:
            self.close()
        try:
            self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.sock.settimeout(20.0) # Increased timeout for potentially larger data
            self.sock.connect((self.server_ip, self.server_port))
        except socket.error as e:
            write_log(f"❌ Failed to connect to Tetris server: {e}")
            self.sock = None
            raise ConnectionError(f"Failed to connect to Tetris server: {e}")

    def _send_command(self, command):
        if not self.sock:
            raise ConnectionError("Not connected to server.")
        try:
            full_command = command + "\n"
            self.sock.sendall(full_command.encode('utf-8'))
        except socket.error as e:
            write_log(f"❌ Error sending command '{command}': {e}")
            self.close()
            raise ConnectionError(f"Error sending command: {e}")

    def _receive_exact_bytes(self, n_bytes):
        """Receives exactly n_bytes from the socket."""
        if not self.sock:
            raise ConnectionError("Not connected to server.")
        chunks = []
        bytes_recd = 0
        while bytes_recd < n_bytes:
            try:
                chunk = self.sock.recv(min(n_bytes - bytes_recd, 4096))
                if not chunk:
                    self.close()
                    raise ConnectionAbortedError(f"Connection closed by server while waiting for {n_bytes} bytes.")
                chunks.append(chunk)
                bytes_recd += len(chunk)
            except socket.timeout:
                self.close()
                raise TimeoutError(f"Socket timeout while waiting for {n_bytes} bytes.")
            except socket.error as e:
                self.close()
                raise ConnectionError(f"Socket error while receiving data: {e}")
        return b''.join(chunks)

    def _receive_text_line(self):
        """Receives a newline-terminated text message."""
        if not self.sock:
             raise ConnectionError("Not connected to server.")
        while "\n" not in self.text_buffer:
            try:
                # Use recv with a smaller size for text likely after image
                chunk = self.sock.recv(1024)
                if not chunk:
                    self.close()
                    raise ConnectionAbortedError("Connection closed by server while waiting for text line.")
                # Attempt to decode immediately to catch errors early, but buffer original bytes if needed
                try:
                    self.text_buffer += chunk.decode('utf-8')
                except UnicodeDecodeError as e:
                     # Handle case where text might be mixed with unexpected binary data
                     write_log(f"❌ Error decoding text chunk: {e}. Received bytes: {chunk}")
                     # Decide recovery strategy: skip chunk, raise error, etc.
                     # For now, we'll raise, assuming text should be clean UTF-8
                     raise ValueError(f"Invalid UTF-8 received in text message: {e}")

            except socket.timeout:
                self.close()
                raise TimeoutError("Socket timeout while waiting for text line.")
            except socket.error as e:
                self.close()
                raise ConnectionError(f"Socket error while receiving text line: {e}")

        line_end = self.text_buffer.find("\n")
        data_line = self.text_buffer[:line_end]
        self.text_buffer = self.text_buffer[line_end + 1:]
        return data_line

    def _receive_message(self):
        """
        Receives a message based on the **assumed** protocol:
        1. 4-byte integer (big-endian) for image size.
        2. image_size bytes of PNG data.
        3. Newline-terminated text line for status (reward, terminated, info).
        """
        try:
            # 1. Receive image size header (4 bytes, big-endian)
            header_bytes = self._receive_exact_bytes(4)
            image_size = struct.unpack('>I', header_bytes)[0] # '>I' for big-endian unsigned int

            # 2. Receive image data
            if image_size == 0:
                 # Handle case where server might send 0 size for error or empty image
                 image_bytes = b''
                 write_log("⚠️ Received image size 0 from server.")
            elif image_size > 5_000_000: # Sanity check for abnormally large size
                 raise ValueError(f"Received excessively large image size: {image_size} bytes.")
            else:
                 image_bytes = self._receive_exact_bytes(image_size)

            # 3. Receive status text line
            status_line = self._receive_text_line()

            return image_bytes, status_line

        except (struct.error, ValueError) as e:
             write_log(f"❌ Error processing message header or size: {e}")
             raise ValueError(f"Invalid message structure received: {e}")


    def _parse_image_and_status(self, image_bytes, status_line):
        """
        Parses the received image bytes and status line.
        **Assumes** status_line format: "reward,term_flag[,lines_step][,final_lines][,final_lifetime]"
        """
        # Parse Image
        observation_image = None
        try:
            if image_bytes:
                 img = Image.open(io.BytesIO(image_bytes))
                 # Convert to RGB if it's not already (handles RGBA, Grayscale etc.)
                 img = img.convert("RGB")
                 observation_image = np.array(img, dtype=np.uint8)

                 # Verify shape (optional but recommended)
                 expected_shape = (self.image_height, self.image_width, self.image_channels)
                 if observation_image.shape != expected_shape:
                      write_log(f"⚠️ Warning: Received image shape {observation_image.shape} differs from expected {expected_shape}. Resizing or check Env definition.")
                      # Example: Resize (might distort aspect ratio)
                      # from skimage.transform import resize
                      # observation_image = (resize(observation_image, expected_shape, anti_aliasing=True) * 255).astype(np.uint8)
                      # Fallback: Return error state if resize is not desired
                      raise ValueError("Incorrect image dimensions received.")

            else: # Handle empty image bytes case
                 observation_image = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)

        except Exception as e:
            write_log(f"❌ Error processing received image data: {e}")
            observation_image = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
            # We still need to parse the status line if possible
            # return observation_image, 0.0, True, False, {"error": f"Image processing error: {e}"}

        # Parse Status Line
        try:
            parts = status_line.strip().split(',')
            if len(parts) < 2:
                raise ValueError("Status line has too few parts.")

            reward = float(parts[0])
            terminated_flag = int(parts[1])
            terminated = (terminated_flag == 1)
            truncated = False # Assume no truncation via time limit

            # Attempt to parse optional info fields
            info = {}
            if len(parts) > 2: info['lines_cleared_this_step'] = int(parts[2])
            if terminated and len(parts) > 3: info['removed_lines'] = int(parts[3]) # Only relevant if terminated
            if terminated and len(parts) > 4: info['lifetime'] = int(parts[4])      # Only relevant if terminated

            # Fill defaults if keys are missing after termination (needed for evaluation)
            if terminated:
                 info.setdefault('removed_lines', 0)
                 info.setdefault('lifetime', 0)


            return observation_image, reward, terminated, truncated, info

        except Exception as e:
            write_log(f"❌ Error parsing status line '{status_line}': {e}")
            # Return observation image if valid, but signal error/termination
            if observation_image is None: # If image parsing also failed
                 observation_image = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
            # Terminate the episode on parsing error
            return observation_image, 0.0, True, False, {"error": f"Status parsing error: {e}"}


    def _map_action_to_command(self, action):
        """
        Maps action index (0-39) to "PLACE <rot> <x>" command (Assumption).
        """
        if not 0 <= action < self.num_actions:
            write_log(f"❌ Invalid action: {action}. Using action 0.")
            action = 0
        num_x_positions = 10
        rotation_index = action // num_x_positions
        x_position = action % num_x_positions
        command = f"PLACE {rotation_index} {x_position}"
        return command

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        try:
            self._connect_to_server()
            self._send_command("RESET")
            image_bytes, status_line = self._receive_message()
            observation, _, _, _, info = self._parse_image_and_status(image_bytes, status_line)
            self.current_observation_image = observation
            if self.render_mode == "human": self.render()
            if info is None: info = {} # Ensure info is always a dict
            return observation, info
        except (ConnectionError, TimeoutError, ValueError, ConnectionAbortedError, struct.error) as e:
             write_log(f"❌ Error during reset: {e}. Returning zero observation.")
             obs = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
             info = {"error": f"Reset error: {e}"}
             self.current_observation_image = obs
             return obs, info

    def step(self, action):
        if self.sock is None:
            write_log("❌ Attempting to step with no connection.")
            obs = np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
            return obs, 0.0, True, False, {"error": "Connection lost"}
        try:
            command = self._map_action_to_command(action)
            self._send_command(command)
            image_bytes, status_line = self._receive_message()
            observation, reward, terminated, truncated, info = self._parse_image_and_status(image_bytes, status_line)
            self.current_observation_image = observation

            if self.render_mode == "human": self.render()
            if info is None: info = {} # Ensure info is always a dict

            # Add final evaluation keys if terminated (based on parsing assumption)
            if terminated:
                info.setdefault('removed_lines', 0)
                info.setdefault('lifetime', 0)

            return observation, float(reward), bool(terminated), bool(truncated), info

        except (ConnectionError, TimeoutError, ValueError, ConnectionAbortedError, struct.error) as e:
             write_log(f"❌ Error during step: {e}. Terminating episode.")
             obs = self.current_observation_image if self.current_observation_image is not None else np.zeros(self.observation_space.shape, dtype=self.observation_space.dtype)
             return obs, 0.0, True, False, {"error": f"Step error: {e}"}

    def render(self):
        if self.render_mode == "ansi":
            # Cannot render image as text easily
            print("State: (Image observation, use 'human' mode or check saved images)")
        elif self.render_mode == "human":
             if self.current_observation_image is not None:
                 try:
                     # Convert RGB (PIL/numpy) to BGR (OpenCV)
                     img_bgr = cv2.cvtColor(self.current_observation_image, cv2.COLOR_RGB2BGR)
                     cv2.imshow("Tetris (Python Render)", img_bgr)
                     cv2.waitKey(1) # Needed for imshow to refresh
                 except Exception as e:
                     print(f"Human rendering failed: {e}")
             else:
                 print("State: (No observation image available)")

    def close(self):
        if self.sock:
            try:
                self.sock.shutdown(socket.SHUT_RDWR)
                self.sock.close()
            except socket.error: pass
            finally: self.sock = None
        if self.render_mode == "human":
             try: cv2.destroyAllWindows()
             except: pass


# --- 主程式 ---
if __name__ == "__main__":
    # --- Constants ---
    SERVER_IP = "127.0.0.1"
    SERVER_PORT = 10612
    STUDENT_ID = "113598065"
    # Using EXACT filenames as requested
    MODEL_SAVE_PATH_ZIP = "/kaggle/working/113598065_dqn_30env_1M.zip"
    CSV_SAVE_PATH = "/kaggle/working/tetris_best_score_test2.csv"
    # Other paths derived from WORK_DIR
    WORK_DIR = "/kaggle/working/"
    if not os.path.exists(WORK_DIR): WORK_DIR = "./"
    REWARD_LOG_FILE_PATH = os.path.join(WORK_DIR, "reward_log.txt")
    TENSORBOARD_LOG_PATH = os.path.join(WORK_DIR, "tensorboard_logs/")
    JAVA_SERVER_JAR = "TetrisTCPserver_v0.6.jar"
    TOTAL_TRAIN_TIMESTEPS = 1_000_000 # Adjust training steps if needed (doesn't affect filename now)
    N_STACK_FRAMES = 4 # Number of frames to stack for CNN

    # --- 1. 啟動 Java Server ---
    write_log("--- Starting Java Server ---")
    java_process = None
    # ...(Same Java startup code as before)...
    try:
        java_process = subprocess.Popen(["java", "-jar", JAVA_SERVER_JAR], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        write_log(f"✅ Java server process started (PID: {java_process.pid}). Waiting for connection...")
        # Add brief pause and check for immediate errors
        time.sleep(2)
        stderr_check = ""
        try: stderr_check = java_process.stderr.read()
        except: pass # Ignore read errors if process running
        if stderr_check:
             write_log("--- Java Server STDERR on Startup ---"); write_log(stderr_check); write_log("---------------------------------")
        wait_for_tetris_server(SERVER_IP, SERVER_PORT)
    except FileNotFoundError:
        write_log(f"❌ Error: {JAVA_SERVER_JAR} not found."); exit()
    except Exception as e:
        write_log(f"❌ Error starting/waiting for Java server: {e}")
        if java_process: java_process.terminate()
        exit()

    # --- 2. 環境檢查與創建 ---
    write_log("--- Setting up Environment ---")
    env_instance_for_check = None
    train_env_vec = None
    try:
        # Check the base environment first
        env_instance_for_check = TetrisEnv(ip=SERVER_IP, port=SERVER_PORT)
        check_env(env_instance_for_check, warn=True)
        write_log("✅ Base environment check passed.")
    except Exception as e:
        write_log(f"❌ Environment check failed: {e}")
        if env_instance_for_check: env_instance_for_check.close()
        if java_process: java_process.terminate()
        exit()
    finally:
         if env_instance_for_check: env_instance_for_check.close()

    # Create vectorized and stacked environment for training
    train_env = DummyVecEnv([lambda: TetrisEnv(ip=SERVER_IP, port=SERVER_PORT)])
    train_env_vec = VecFrameStack(train_env, n_stack=N_STACK_FRAMES)
    write_log(f"✅ Training environment created and wrapped with VecFrameStack (n_stack={N_STACK_FRAMES}).")
    write_log(f"   Observation space (stacked): {train_env_vec.observation_space}")


    # --- 3. 定義模型 ---
    write_log("--- Defining DQN Model (CNN Policy) ---")
    policy_type = "CnnPolicy" # Use CNN for image input
    os.makedirs(TENSORBOARD_LOG_PATH, exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    write_log(f"Using device: {device}")

    model = DQN(
        policy_type,
        train_env_vec, # Train on the stacked environment
        verbose=1,
        tensorboard_log=TENSORBOARD_LOG_PATH,
        learning_rate=1e-4, # May need tuning for CNN
        buffer_size=50000,  # May need larger buffer for images
        learning_starts=10000,# Let buffer fill more before learning
        batch_size=32,      # Often smaller batch size for CNNs due to memory
        tau=1.0,
        gamma=0.99,
        train_freq=4,
        gradient_steps=1,
        target_update_interval=1000, # May need longer interval
        exploration_fraction=0.1,
        exploration_final_eps=0.05,
        # SB3's DQN with CnnPolicy expects images in channel-first format (C, H, W)
        # VecFrameStack usually handles the dimension permutation if the env returns H, W, C
        policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=512)), # Example CNN feature dim
        device=device
    )

    # --- 4. 定義回調函數 ---
    reward_logger = RewardLoggerCallback(log_path=REWARD_LOG_FILE_PATH)
    callbacks = [reward_logger]
    if wandb.run and wandb.run.mode != "disabled":
        from wandb.integration.sb3 import WandbCallback
        wandb_callback = WandbCallback(log="all", verbose=2)
        callbacks.append(wandb_callback)

    # --- 5. 訓練模型 ---
    write_log(f"--- Starting Training for {TOTAL_TRAIN_TIMESTEPS} timesteps ---")
    training_successful = False
    try:
        model.learn(
            total_timesteps=TOTAL_TRAIN_TIMESTEPS,
            log_interval=10,
            callback=callbacks
        )
        training_successful = True
        write_log("✅ Training completed successfully.")
    except KeyboardInterrupt:
        write_log("🛑 Training interrupted by user.")
    except Exception as train_e:
        write_log(f"❌ Training failed: {train_e}")
        # ...(Error saving logic remains the same)...
        error_save_path = os.path.join(WORK_DIR, f"{STUDENT_ID}_dqn_error.zip") # Generic error name
        try: model.save(error_save_path); write_log(f"💾 Intermediate model saved to {error_save_path}")
        except: write_log("❌ Failed to save intermediate model after error.")
    finally:
        # --- 保存最終模型 (Using exact requested filename) ---
        if training_successful:
            try:
                model.save(MODEL_SAVE_PATH_ZIP) # Use the exact path defined earlier
                write_log(f"💾 Final model saved: {MODEL_SAVE_PATH_ZIP}")
                display(FileLink(MODEL_SAVE_PATH_ZIP))
                if wandb.run and wandb.run.mode != "disabled":
                    wandb.save(MODEL_SAVE_PATH_ZIP, base_path=WORK_DIR)
                    write_log("⬆️ Final model uploaded to WandB.")
            except Exception as final_save_e:
                write_log(f"❌ Failed to save final model: {final_save_e}")
        # 關閉訓練環境
        if train_env_vec: train_env_vec.close() # Close the VecFrameStack wrapper
        write_log("✅ Training environment closed.")

    # --- 6. 評估模型並保存 CSV ---
    if training_successful:
        write_log("\n--- Starting Model Evaluation ---")
        # Use a non-vectorized, non-stacked env for simplicity in eval loop control
        eval_env = None
        try:
            write_log(f"🔄 Loading model from: {MODEL_SAVE_PATH_ZIP}")
            # Load model, no env needed if saved correctly
            eval_model = DQN.load(MODEL_SAVE_PATH_ZIP, device=device)

            # Create a fresh evaluation environment (not stacked)
            eval_env = TetrisEnv(ip=SERVER_IP, port=SERVER_PORT)
            obs, info = eval_env.reset() # obs is single image (H, W, C)

            # For CNN policy, prediction needs channel-first and batch dim (N, C, H, W)
            # And needs stacking if trained with stacking. Manual stacking for eval:
            stacked_obs_deque = io.deque([np.zeros_like(obs)] * N_STACK_FRAMES, maxlen=N_STACK_FRAMES)
            stacked_obs_deque.append(obs)

            def get_stacked_observation(deque):
                # Stack along the channel dimension (axis=2 for HWC -> HWSC, then transpose) or a new axis
                stacked_frames = np.array(list(deque)) # Shape (N_STACK, H, W, C)
                # SB3 CnnPolicy usually expects (Batch, C*N_STACK, H, W) or similar based on extractor
                # Or VecFrameStack gives (Batch, H, W, C*N_STACK) -> let's assume this
                # This part is tricky and depends on how VecFrameStack observation space is structured.
                # Let's assume VecFrameStack concatenates channels: (H, W, C * N_STACK)
                # We need to replicate this manually.
                processed_obs = np.concatenate(list(deque), axis=-1) # Concatenate along channel axis -> (H, W, C*N_STACK)
                # Add batch dimension
                return np.expand_dims(processed_obs, axis=0)


            terminated = False
            truncated = False
            played_steps = 0
            removed_lines_total = 0
            final_removed_lines = 0
            final_lifetime = 0

            while not terminated and not truncated:
                current_stacked_obs_for_policy = get_stacked_observation(stacked_obs_deque)
                action, _ = eval_model.predict(current_stacked_obs_for_policy, deterministic=True)
                obs, reward, terminated, truncated, info = eval_env.step(action[0]) # action is usually wrapped in a list/array
                stacked_obs_deque.append(obs) # Add new observation
                played_steps += 1
                # Accumulate lines based on info from step (ASSUMPTION)
                removed_lines_total += info.get('lines_cleared_this_step', 0)
                # If terminated, try to get final stats from info (ASSUMPTION)
                if terminated:
                    final_removed_lines = info.get('removed_lines', removed_lines_total)
                    final_lifetime = info.get('lifetime', played_steps)

            write_log(f"🏁 Evaluation Episode finished. Steps: {final_lifetime}, Removed Lines: {final_removed_lines}")

            # --- 寫入 CSV 文件 (Using exact requested filename and format) ---
            write_log(f"💾 Writing evaluation results to: {CSV_SAVE_PATH}")
            with open(CSV_SAVE_PATH, 'w') as fs:
                fs.write('id,removed_lines,played_steps\n')
                fs.write(f'0,{final_removed_lines},{final_lifetime}\n')
                fs.write(f'1,{final_removed_lines},{final_lifetime}\n')
            write_log("✅ CSV file saved.")
            display(FileLink(CSV_SAVE_PATH))
            if wandb.run and wandb.run.mode != "disabled":
                wandb.save(CSV_SAVE_PATH, base_path=WORK_DIR)
                write_log("⬆️ CSV results uploaded to WandB.")

        except FileNotFoundError:
             write_log(f"❌ Evaluation failed: Model file not found at {MODEL_SAVE_PATH_ZIP}")
        except KeyError as e:
             write_log(f"❌ Evaluation failed: Missing key {e} in environment's info dictionary.")
             write_log("   Ensure TetrisEnv step() returns info with required keys on termination.")
        except Exception as eval_e:
             write_log(f"❌ An error occurred during evaluation or CSV saving: {eval_e}")
        finally:
            if eval_env: eval_env.close()
            write_log("✅ Evaluation environment closed.")


    # --- 7. 清理工作 ---
    # ...(Same cleanup code as before)...
    write_log("--- Cleaning up ---")
    if java_process:
        try:
            java_process.terminate()
            java_process.wait(timeout=5)
            write_log("☕ Java server process terminated.")
        except subprocess.TimeoutExpired:
            java_process.kill()
            write_log("🔪 Java server process killed.")
        except Exception as e:
            write_log(f"⚠️ Error terminating Java process: {e}")

    if wandb.run and wandb.run.mode != "disabled":
         wandb.finish()
         write_log("📊 WandB run finished.")

    write_log("🏁 Full script finished.")