In [1]:
!pip install "stable-baselines3[extra]"
!pip install wandb
!wget http://www.aiotlab.org/teaching/oop/tetris/TetrisTCPserver_v0.6.jar
import os

if os.path.exists("TetrisTCPserver_v0.6.jar"):
    print("✅ 檔案複製成功")
else:
    print("❌ 檔案複製失敗")

Collecting shimmy~=1.1.0 (from shimmy[atari]~=1.1.0; extra == "extra"->stable-baselines3[extra])
  Downloading Shimmy-1.1.0-py3-none-any.whl.metadata (3.3 kB)
Collecting autorom~=0.6.1 (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM-0.6.1-py3-none-any.whl.metadata (2.4 kB)
Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.7/434.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ale-py~=0.8.1 (from shimmy[atari]~=1.1.0; extra == "extra"->stable-baselines3[extra])
  Downloading ale_py-0.8.1-cp311-cp311-manylinux_2_17_x86_64.man

In [None]:
import numpy as np
import socket
import cv2
import matplotlib.pyplot as plt
import subprocess
import os
import shutil
import glob
import imageio
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack
from IPython.display import FileLink, display, Image
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
import torch
# 使用 wandb 記錄訓練日誌
import os
import wandb
from kaggle_secrets import UserSecretsClient
from stable_baselines3.common.vec_env import DummyVecEnv

# 從 Kaggle Secrets 讀取 API Token
user_secrets = UserSecretsClient()
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

# 設定環境變數，模擬 login
os.environ["WANDB_API_KEY"] = WANDB_API_KEY

# login & init
wandb.login()
wandb.init(project="tetris-training", entity="t113598065-ntut-edu-tw")

log_path = "/kaggle/working/tetris_train_log.txt"

def write_log(message):
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(message + "\n")
    print(message)

import time

def wait_for_tetris_server(ip="127.0.0.1", port=10612, timeout=30):
    write_log("⏳ 等待 Tetris TCP server 啟動中...")
    start_time = time.time()
    while True:
        try:
            test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            test_sock.settimeout(1.0)
            test_sock.connect((ip, port))
            test_sock.close()
            write_log("✅ Java TCP server 準備完成，連線成功")
            break
        except socket.error:
            if time.time() - start_time > timeout:
                raise TimeoutError("❌ 等待 Java TCP server 超時")
            time.sleep(0.5)

# 啟動 Java Tetris server
print("Java started")
subprocess.Popen(["java", "-jar", "TetrisTCPserver_v0.6.jar"])
write_log("✅ Java server started")
wait_for_tetris_server()

if torch.cuda.is_available():
    print("✅ PyTorch is using GPU:", torch.cuda.get_device_name(0))
else:
    print("❌ PyTorch is using CPU")
# ----------------------------
# 定義 Tetris 環境 (採用老師的格式)
class RewardLoggerCallback(BaseCallback):
    def __init__(self, log_path="./reward_log.txt", verbose=0):
        super().__init__(verbose)
        self.log_path = log_path
        self.episode_rewards = []

    def _on_step(self) -> bool:
        # 每次 env step 呼叫一次，但只有在 episode 結束時才記錄
        infos = self.locals.get("infos", [])
        for info in infos:
            if "episode" in info:  # VecEnv 會回傳 episode reward
                ep_reward = info["episode"]["r"]
                self.episode_rewards.append(ep_reward)
                with open(self.log_path, "a") as f:
                    f.write(f"{len(self.episode_rewards)},{ep_reward}\n")
                print(f"📈 Episode {len(self.episode_rewards)} Reward: {ep_reward}")
        return True

class TetrisEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 20}
    N_DISCRETE_ACTIONS = 5
    IMG_HEIGHT = 200
    IMG_WIDTH = 100
    IMG_CHANNELS = 3

    def __init__(self, host_ip="127.0.0.1", host_port=10612):
        super().__init__()
        self.action_space = spaces.Discrete(self.N_DISCRETE_ACTIONS)
        # self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)
        self.observation_space = spaces.Box(low=0, high=255, shape=(1, 84, 84), dtype=np.uint8)
        self.server_ip = host_ip
        self.server_port = host_port

        self.client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.client_sock.connect((self.server_ip, self.server_port))

        # 初始化 reward shaping 與統計用變數
        self.lines_removed = 0
        self.height = 0
        self.holes = 0
        self.lifetime = 0

    def step(self, action):
        if action == 0:
            self.client_sock.sendall(b"move -1\n")
        elif action == 1:
            self.client_sock.sendall(b"move 1\n")
        elif action == 2:
            self.client_sock.sendall(b"rotate 0\n")
        elif action == 3:
            self.client_sock.sendall(b"rotate 1\n")
        elif action == 4:
            self.client_sock.sendall(b"drop\n")
    
        terminated, lines, height, holes, observation = self.get_tetris_server_response(self.client_sock)
    
        reward = 0
        if action == 4:
            reward += 5
    
        if height > self.height:
            reward -= (height - self.height) * 5
    
        if holes < self.holes:
            reward += (self.holes - holes) * 10
    
        if lines > self.lines_removed:
            reward += (lines - self.lines_removed) * 1000
            self.lines_removed = lines
    
        self.height = height
        self.holes = holes
        self.lifetime += 1
    
        info = {'removed_lines': self.lines_removed, 'lifetime': self.lifetime}
    
        truncated = False
    
        # 關鍵！處理終止觀察值
        if terminated:
            info['terminal_observation']  = observation.copy()  
    
        return observation, reward, terminated, truncated, info


    def reset(self, seed=None, options=None):
        self.client_sock.sendall(b"start\n")
        terminated, lines, height, holes, observation = self.get_tetris_server_response(self.client_sock)
        # 重置統計變數
        self.lines_removed = 0
        self.height = 0
        self.holes = 0
        self.lifetime = 0
        return observation, {}

    def render(self):
        cv2.imshow("Tetris", self.last_observation)
        cv2.waitKey(1)

    def close(self):
        self.client_sock.close()
        cv2.destroyAllWindows()

    def get_tetris_server_response(self, sock):
        is_game_over = (sock.recv(1) == b'\x01')
        removed_lines = int.from_bytes(sock.recv(4), 'big')
        height = int.from_bytes(sock.recv(4), 'big')
        holes = int.from_bytes(sock.recv(4), 'big')
        img_size = int.from_bytes(sock.recv(4), 'big')
        img_png = sock.recv(img_size)
        nparr = np.frombuffer(img_png, np.uint8)
        np_image = cv2.imdecode(nparr, -1)
        resized = cv2.resize(np_image, (84, 84))
        gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
        gray = np.expand_dims(gray, axis=0)  # <- 關鍵！channel-first
        self.last_observation = gray.copy()
        return is_game_over, removed_lines, height, holes, gray
    

        
# 檢查環境
print("✅ 建立環境開始")
env = TetrisEnv()
check_env(env)

# ----------------------------
# 建立訓練環境（使用向量化、多個 env）並加入正規化與 frame stacking
# 這部分主要用於加速並穩定訓練
# train_env = make_vec_env(TetrisEnv, n_envs=3)
# train_env = VecNormalize(train_env, norm_obs=False, norm_reward=True)
# train_env = VecFrameStack(train_env, n_stack=4, channels_order='first')
train_env = DummyVecEnv([lambda: TetrisEnv()])
train_env = VecFrameStack(train_env, n_stack=4, channels_order="first")

# ----------------------------
# 使用 DQN 進行訓練，調整超參數以提升效能：
# 這裡設定 buffer_size、learning_starts、target_update_interval 等參數
model = DQN("CnnPolicy", train_env, verbose=1, tensorboard_log="./sb3_log/",
            gamma=0.95,
            learning_rate=1e-4,         # 較低的學習率有助於穩定收斂
            buffer_size=20000,         # 經驗回放緩衝區大小
            learning_starts=1000,       # 多少步後開始學習
            policy_kwargs=dict(normalize_images=False),
            target_update_interval=1000 # 目標網路更新頻率
           )
write_log("Model device: " + str(model.device))
# model.learn(total_timesteps=100000)  # 可根據需要延長 timesteps1000000
reward_logger = RewardLoggerCallback(log_path="./reward_log.txt")
model.learn(total_timesteps=5000000, callback=reward_logger)

# 儲存訓練後的模型（訓練完畢後可先暫停 train_env 的歸一化更新）
train_env.training = False

# ----------------------------
# 包裝測試環境，但僅用來符合 predict 格式，取影像還是從原生環境拿
# wrapped_test_env = make_vec_env(TetrisEnv, n_envs=1)
# wrapped_test_env = VecNormalize(wrapped_test_env, norm_obs=False, norm_reward=False, training=False)
# wrapped_test_env = VecFrameStack(wrapped_test_env, n_stack=4, channels_order='first')
wrapped_test_env = DummyVecEnv([lambda: TetrisEnv()])
wrapped_test_env = VecFrameStack(wrapped_test_env, n_stack=4, channels_order="first")

# 原始環境保留用來取影像
raw_test_env = TetrisEnv()

# 初始化狀態
wrapped_obs = wrapped_test_env.reset()
raw_obs, _ = raw_test_env.reset()

frames = []
total_test_reward = 0
test_steps = 1000

for step in range(test_steps):
    action, _ = model.predict(wrapped_obs, deterministic=True)

    # 執行動作
    next_raw_obs, reward, done, truncated, info = raw_test_env.step(action)
    wrapped_obs, _, _, _ = wrapped_test_env.step(action)

    total_test_reward += reward
    frames.append(np.expand_dims(raw_obs.copy(), axis=0))
    # frames.append(raw_obs.copy())  # 儲存原始影像
    raw_obs = next_raw_obs

    if done:
        break

write_log("Test completed: Total reward = " + str(total_test_reward))

# 將回放影像存入資料夾（依老師格式）
replay_folder = './replay'
if os.path.exists(replay_folder):
    shutil.rmtree(replay_folder)
os.makedirs(replay_folder, exist_ok=True)
episode_folder = os.path.join(replay_folder, "0", "0")
os.makedirs(episode_folder, exist_ok=True)
for i, frame in enumerate(frames):
    fname = os.path.join(episode_folder, '{:06d}.png'.format(i))
    cv2.imwrite(fname,frame[0].squeeze())

# 產生 replay GIF（最佳回放）
filenames = sorted(glob.glob(episode_folder + '/*.png'))
gif_images = []
for filename in filenames:
    gif_images.append(imageio.imread(filename))
imageio.mimsave('replay.gif', gif_images, loop=0)
print("Replay GIF saved: replay.gif")
display(FileLink('replay.gif'))

# 將測試結果寫入 CSV（格式與老師版本一致）
with open('tetris_best_score_test2.csv', 'w') as fs:
    fs.write('id,removed_lines,played_steps\n')
    fs.write(f'0,{info["removed_lines"]},{info["lifetime"]}\n')
    fs.write(f'1,{info["removed_lines"]},{info["lifetime"]}\n')
print("CSV file saved: tetris_best_score_test2.csv")
display(FileLink('tetris_best_score_test2.csv'))
wandb.save('tetris_best_score_test2.csv')

# ----------------------------
# 儲存最終模型（請確認將 '113598065' 替換成你的學號）
model.save('113598065_dqn_30env_1M.zip')
print("Model saved: 113598065_dqn_30env_1M.zip")
display(FileLink('113598065_dqn_30env_1M.zip'))
wandb.save('113598065_dqn_30env_1M.zip')

# 關閉環境
wrapped_test_env.close()
raw_test_env.close()
train_env.close()



Java started
✅ Java server started
⏳ 等待 Tetris TCP server 啟動中...
✅ Java TCP server 準備完成，連線成功
✅ PyTorch is using GPU:Client has joined the game Tesla T4
Client has exited the game

✅ 建立環境開始
Client has exited the game
Client has joined the game
Address already in use (Bind failed)
Tetris TCP server is listening at 10612
Client has joined the game
Using cuda device
Model device: cuda
Logging to ./sb3_log/DQN_11
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.856    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 23       |
|    time_elapsed     | 6        |
|    total_timesteps  | 152      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.683    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 23       |
|    time_elapsed     | 14       |
|    total_timesteps  | 334 

In [None]:
import numpy as np
import socket
import cv2
import matplotlib.pyplot as plt
import subprocess
import os
import shutil
import glob
import imageio
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack, DummyVecEnv
from IPython.display import FileLink, display, Image
from stable_baselines3.common.callbacks import BaseCallback
import torch
import time

# --- Wandb Setup ---
import os
import wandb
from kaggle_secrets import UserSecretsClient
# Import WandbCallback for SB3 integration
from wandb.integration.sb3 import WandbCallback

# 從 Kaggle Secrets 讀取 API Token
user_secrets = UserSecretsClient()
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

# 設定環境變數，模擬 login
os.environ["WANDB_API_KEY"] = WANDB_API_KEY

# login & init
wandb.login()
# Start a wandb run
run = wandb.init(
    project="tetris-training-improved", # Changed project name slightly
    entity="t113598065-ntut-edu-tw",
    sync_tensorboard=True,  # auto-upload sb3 logs
    monitor_gym=True,       # auto-upload videos and plots
    save_code=True,         # save script to wandb
    config={ # Log hyperparameters
        "policy_type": "CnnPolicy",
        "total_timesteps": 2000000, # Example: increased timesteps
        "env_id": "TetrisEnv-v1",
        "gamma": 0.99, # Increased gamma
        "learning_rate": 1e-4,
        "buffer_size": 300000, # Increased buffer size
        "learning_starts": 10000, # Increased learning starts
        "target_update_interval": 10000, # Increased target update interval
        "exploration_fraction": 0.6, # Explore for 60% of training
        "exploration_final_eps": 0.05, # Lower final epsilon
        "batch_size": 32, # Default for DQN, can be tuned
        "n_stack": 4, # Frame stacking
    }
)


log_path = "/kaggle/working/tetris_train_log.txt"

def write_log(message):
    """Appends a message to the log file and prints it."""
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {message}\n")
    print(message)

def wait_for_tetris_server(ip="127.0.0.1", port=10612, timeout=60):
    """Waits for the Tetris TCP server to become available."""
    write_log(f"⏳ 等待 Tetris TCP server 啟動中 ({ip}:{port})...")
    start_time = time.time()
    while True:
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_sock:
                test_sock.settimeout(1.0)
                test_sock.connect((ip, port))
            write_log("✅ Java TCP server 準備完成，連線成功")
            break
        except socket.error as e:
            if time.time() - start_time > timeout:
                raise TimeoutError(f"❌ 等待 Java TCP server 超時 ({timeout}s)")
            # write_log(f"   連接失敗 ({e}), 重試中...") # Optional: more verbose logging
            time.sleep(1.0) # Wait a bit longer before retrying

# --- Start Java Server ---
try:
    write_log("🚀 嘗試啟動 Java Tetris server...")
    # Ensure the JAR file exists
    jar_file = "TetrisTCPserver_v0.6.jar"
    if not os.path.exists(jar_file):
         write_log(f"❌ 錯誤: 找不到 JAR 檔案 '{jar_file}'。請確保它在工作目錄中。")
         # Handle error appropriately, maybe exit
         raise FileNotFoundError(f"JAR file '{jar_file}' not found.")

    process = subprocess.Popen(["java", "-jar", jar_file], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    write_log(f"✅ Java server process 啟動 (PID: {process.pid})")
    wait_for_tetris_server()
except Exception as e:
    write_log(f"❌ 啟動或等待 Java server 時發生錯誤: {e}")
    # Optionally log process output if it failed
    if 'process' in locals() and process.poll() is not None:
        stdout, stderr = process.communicate()
        write_log(f"   Java Server STDOUT: {stdout.decode('utf-8', errors='ignore')}")
        write_log(f"   Java Server STDERR: {stderr.decode('utf-8', errors='ignore')}")
    raise # Re-raise the exception to stop the script

# --- Check GPU ---
if torch.cuda.is_available():
    write_log(f"✅ PyTorch is using GPU: {torch.cuda.get_device_name(0)}")
else:
    write_log("⚠️ PyTorch is using CPU. Training will be significantly slower.")

# ----------------------------
# 定義 Tetris 環境 (採用老師的格式, 結合獎勵機制概念)
# ----------------------------
class TetrisEnv(gym.Env):
    """Custom Environment for Tetris that interacts with a Java TCP server."""
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} # Added rgb_array for potential recording
    N_DISCRETE_ACTIONS = 5  # 0: left, 1: right, 2: rot_left, 3: rot_right, 4: drop
    IMG_HEIGHT = 200 # Original image height from server (used for decoding)
    IMG_WIDTH = 100  # Original image width from server (used for decoding)
    IMG_CHANNELS = 3 # Original image channels
    RESIZED_DIM = 84 # Dimension for resized observation

    def __init__(self, host_ip="127.0.0.1", host_port=10612, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        self.action_space = spaces.Discrete(self.N_DISCRETE_ACTIONS)
        # Observation space: Grayscale image, channel-first (needed for CnnPolicy)
        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(1, self.RESIZED_DIM, self.RESIZED_DIM), # (Channels, Height, Width)
            dtype=np.uint8
        )
        self.server_ip = host_ip
        self.server_port = host_port
        self.client_sock = None # Initialize socket to None
        self._connect_socket() # Connect in init

        # Reward shaping & statistics variables
        self.current_score = 0 # Keep track of raw game score if needed
        self.lines_removed = 0
        self.current_height = 0
        self.current_holes = 0
        self.lifetime = 0
        self.last_observation = np.zeros(self.observation_space.shape, dtype=np.uint8) # Store last obs

        # --- Reward Shaping Coefficients (Tuning is crucial!) ---
        self.reward_line_clear_coeff = 100.0 # Base reward per line squared
        self.penalty_height_increase_coeff = 15.0 # Penalty for increasing max height
        self.penalty_hole_increase_coeff = 25.0 # Penalty for creating new holes
        self.penalty_step_coeff = 0.1 # Small penalty per step to encourage speed
        self.penalty_game_over_coeff = 500.0 # Large penalty for losing

        # For rendering
        self.window = None
        self.clock = None


    def _connect_socket(self):
        """Establishes connection to the game server."""
        try:
            # Close existing socket if any before reconnecting
            if self.client_sock:
                self.client_sock.close()
            self.client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            # Set a timeout for socket operations to prevent indefinite blocking
            self.client_sock.settimeout(10.0) # 10 seconds timeout
            self.client_sock.connect((self.server_ip, self.server_port))
            write_log(f"🔌 Socket connected to {self.server_ip}:{self.server_port}")
        except socket.error as e:
            write_log(f"❌ Socket connection error during connect: {e}")
            raise # Re-raise to indicate failure

    def _send_command(self, command: bytes):
        """Sends a command to the server, handles potential errors."""
        try:
            self.client_sock.sendall(command)
        except socket.timeout:
            write_log("❌ Socket timeout during send.")
            raise ConnectionAbortedError("Socket timeout during send")
        except socket.error as e:
            write_log(f"❌ Socket error during send: {e}")
            # Attempt to reconnect on send error? Could be risky. Better to fail.
            raise ConnectionAbortedError(f"Socket error during send: {e}")

    def _receive_data(self, size):
        """Receives exactly size bytes from the server."""
        data = b""
        while len(data) < size:
            try:
                chunk = self.client_sock.recv(size - len(data))
                if not chunk:
                    write_log("❌ Socket connection broken during receive (received empty chunk).")
                    raise ConnectionAbortedError("Socket connection broken")
                data += chunk
            except socket.timeout:
                 write_log(f"❌ Socket timeout during receive (expected {size}, got {len(data)}).")
                 raise ConnectionAbortedError("Socket timeout during receive")
            except socket.error as e:
                write_log(f"❌ Socket error during receive: {e}")
                raise ConnectionAbortedError(f"Socket error during receive: {e}")
        return data

    def get_tetris_server_response(self):
        """Gets state update from the Tetris server via socket."""
        try:
            is_game_over_byte = self._receive_data(1)
            is_game_over = (is_game_over_byte == b'\x01')

            removed_lines_bytes = self._receive_data(4)
            removed_lines = int.from_bytes(removed_lines_bytes, 'big')

            height_bytes = self._receive_data(4)
            height = int.from_bytes(height_bytes, 'big')

            holes_bytes = self._receive_data(4)
            holes = int.from_bytes(holes_bytes, 'big')

            img_size_bytes = self._receive_data(4)
            img_size = int.from_bytes(img_size_bytes, 'big')

            # Ensure image size is reasonable to prevent memory issues
            if img_size <= 0 or img_size > 500000: # Set a reasonable max size (e.g., 500KB)
                 write_log(f"❌ Received invalid image size: {img_size}. Aborting receive.")
                 raise ValueError(f"Invalid image size received: {img_size}")

            img_png = self._receive_data(img_size)

            # Decode and preprocess image
            nparr = np.frombuffer(img_png, np.uint8)
            # Use IMREAD_COLOR to ensure 3 channels even if image is grayscale upstream
            np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            if np_image is None:
                 write_log("❌ Failed to decode image from server response.")
                 # Return a default observation or raise error
                 # Using last observation as fallback might be problematic
                 # raise ValueError("Failed to decode image")
                 # Let's return the last known good observation and signal game over potentially
                 return True, self.lines_removed, self.current_height, self.current_holes, self.last_observation.copy()


            resized = cv2.resize(np_image, (self.RESIZED_DIM, self.RESIZED_DIM), interpolation=cv2.INTER_AREA)
            gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
            # Add channel dimension: (H, W) -> (1, H, W) for PyTorch Conv2D (channel-first)
            observation = np.expand_dims(gray, axis=0)
            # Ensure dtype is uint8 for the observation space
            observation = observation.astype(np.uint8)

            # Store the raw BGR resized image for rendering if needed
            self.last_raw_render_frame = resized.copy()
            # Store the processed observation
            self.last_observation = observation.copy()

            return is_game_over, removed_lines, height, holes, observation

        except ConnectionAbortedError as e:
             write_log(f"❌ Connection aborted while getting server response: {e}")
             # Attempt to reconnect? Or just end the episode. Let's end it.
             # Return a state indicating termination and use the last valid observation
             return True, self.lines_removed, self.current_height, self.current_holes, self.last_observation.copy()
        except Exception as e:
            write_log(f"❌ Unexpected error getting server response: {e}")
             # End the episode on unexpected errors
            return True, self.lines_removed, self.current_height, self.current_holes, self.last_observation.copy()


    def step(self, action):
        # --- Send Action ---
        if action == 0:
            command = b"move -1\n"
        elif action == 1:
            command = b"move 1\n"
        elif action == 2:
            command = b"rotate 0\n" # Assuming 0 is one direction (e.g., left/ccw)
        elif action == 3:
            command = b"rotate 1\n" # Assuming 1 is other direction (e.g., right/cw)
        elif action == 4:
            command = b"drop\n"
        else:
            write_log(f"⚠️ Invalid action received: {action}. Sending 'drop'.")
            command = b"drop\n" # Default safe action? Or raise error?

        try:
            self._send_command(command)
        except ConnectionAbortedError:
            # If sending fails, the episode must end.
            write_log("❌ Ending episode due to send failure in step.")
            terminated = True
            # Use last known state for observation, provide zero reward.
            observation = self.last_observation.copy()
            reward = self.penalty_game_over_coeff * -1 # Penalize heavily
            info = {'removed_lines': self.lines_removed, 'lifetime': self.lifetime, 'final_status': 'send_error'}
            # SB3 expects terminal_observation in info if terminated
            info['terminal_observation'] = observation
            return observation, reward, terminated, False, info # terminated=True, truncated=False

        # --- Get State Update ---
        terminated, new_lines_removed, new_height, new_holes, observation = self.get_tetris_server_response()

        # --- Calculate Reward ---
        reward = 0.0

        # 1. Line Clear Reward (Quadratic)
        lines_cleared_this_step = new_lines_removed - self.lines_removed
        if lines_cleared_this_step > 0:
            # Quadratic reward emphasizes clearing multiple lines
            reward += (lines_cleared_this_step ** 2) * self.reward_line_clear_coeff
            # Bonus for Tetris (4 lines)? Optional.
            # if lines_cleared_this_step == 4:
            #    reward += 1000 # Large bonus

        # 2. Height Increase Penalty
        height_increase = new_height - self.current_height
        if height_increase > 0:
            reward -= height_increase * self.penalty_height_increase_coeff

        # 3. Hole Increase Penalty
        hole_increase = new_holes - self.current_holes
        if hole_increase > 0:
            reward -= hole_increase * self.penalty_hole_increase_coeff
        # Optional: Small reward for *filling* holes?
        # elif hole_increase < 0:
        #    reward += abs(hole_increase) * (self.penalty_hole_increase_coeff / 2)


        # 4. Step Penalty (encourages efficiency)
        reward -= self.penalty_step_coeff

        # 5. Game Over Penalty
        if terminated:
            reward -= self.penalty_game_over_coeff
            write_log(f"💔 Game Over! Final Lines: {new_lines_removed}, Lifetime: {self.lifetime + 1}")


        # --- Update Internal State ---
        self.lines_removed = new_lines_removed
        self.current_height = new_height
        self.current_holes = new_holes
        self.lifetime += 1

        # --- Prepare Return Values ---
        info = {'removed_lines': self.lines_removed, 'lifetime': self.lifetime}
        truncated = False # We use terminated based on the game's state

        # IMPORTANT: Provide terminal observation in info when terminated
        # This is required by SB3 for correct value estimation at episode end
        if terminated:
            info['terminal_observation'] = observation.copy()
            # Log final stats to wandb if needed (can also be done in callback)
            # wandb.log({"final/lines_removed": self.lines_removed, "final/lifetime": self.lifetime})

        return observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed) # Important for Gym compatibility

        # Ensure connection is alive, reconnect if needed
        try:
            # Simple check: send 'start' and see if we get a response without error
             self._send_command(b"start\n")
             # Initial state read
             terminated, lines, height, holes, observation = self.get_tetris_server_response()
             if terminated: # Should not be terminated on reset, indicates server issue
                 write_log("⚠️ Server reported game over immediately after reset. Attempting reconnect and reset again.")
                 self._connect_socket() # Reconnect
                 self._send_command(b"start\n") # Try starting again
                 terminated, lines, height, holes, observation = self.get_tetris_server_response()
                 if terminated:
                      write_log("❌ Server still terminated after reset/reconnect. Cannot proceed.")
                      raise RuntimeError("Tetris server failed to reset properly.")

        except (ConnectionAbortedError, socket.error, TimeoutError) as e:
             write_log(f"🔌 Connection issue during reset ({e}). Attempting reconnect...")
             self._connect_socket() # Re-establish connection
             self._send_command(b"start\n") # Send start command again
             terminated, lines, height, holes, observation = self.get_tetris_server_response() # Get initial state
             # Check again if terminated immediately
             if terminated:
                 write_log("❌ Server terminated immediately after reset/reconnect. Cannot proceed.")
                 raise RuntimeError("Tetris server failed to reset properly after reconnect.")


        # Reset internal statistics
        self.lines_removed = 0
        self.current_height = height # Initial height from server
        self.current_holes = holes   # Initial holes from server
        self.lifetime = 0
        self.last_observation = observation.copy() # Store initial observation

        write_log(f"🔄 Environment Reset. Initial state: H={height}, O={holes}")

        info = {} # No extra info needed on reset by default
        return observation, info

    def render(self):
         # Render using the stored raw frame suitable for display
         if self.render_mode == "human":
             if self.window is None:
                 pygame.init()
                 pygame.display.init()
                 self.window = pygame.display.set_mode((self.RESIZED_DIM * 3, self.RESIZED_DIM * 3)) # Upscale a bit
                 pygame.display.set_caption("Tetris Env")
             if self.clock is None:
                 self.clock = pygame.time.Clock()

             # Need a surface to display the BGR image correctly
             # self.last_raw_render_frame should be (H, W, C) BGR
             if hasattr(self, 'last_raw_render_frame'):
                 # Pygame uses RGB, OpenCV uses BGR. Need conversion.
                 render_frame_rgb = cv2.cvtColor(self.last_raw_render_frame, cv2.COLOR_BGR2RGB)
                 # Pygame surface requires (width, height)
                 surf = pygame.Surface((self.RESIZED_DIM, self.RESIZED_DIM))
                 # Transpose is needed if axes are wrong, check render_frame_rgb shape
                 pygame.surfarray.blit_array(surf, np.transpose(render_frame_rgb, (1, 0, 2)))
                 surf = pygame.transform.scale(surf, (self.RESIZED_DIM * 3, self.RESIZED_DIM * 3)) # Scale up
                 self.window.blit(surf, (0, 0))
                 pygame.event.pump()
                 self.clock.tick(self.metadata["render_fps"])
                 pygame.display.flip()
             else:
                 # Draw a blank screen if no frame available yet
                  self.window.fill((0,0,0))
                  pygame.display.flip()


         elif self.render_mode == "rgb_array":
              # Return the last processed observation (channel first) or raw render frame
              # Returning the processed observation might be more useful if logging video
              # return self.last_observation # Shape (1, H, W)
              # Or return the displayable frame
              if hasattr(self, 'last_raw_render_frame'):
                  return cv2.cvtColor(self.last_raw_render_frame, cv2.COLOR_BGR2RGB) # Return RGB (H, W, C)
              else:
                  return np.zeros((self.RESIZED_DIM, self.RESIZED_DIM, 3), dtype=np.uint8) # Return black frame


    def close(self):
        write_log("🔌 Closing environment connection.")
        if self.client_sock:
            try:
                # Send a final command? Like 'quit' if the server supports it?
                # self.client_sock.sendall(b"quit\n")
                self.client_sock.close()
                write_log("   Socket closed.")
            except socket.error as e:
                 write_log(f"   Error closing socket: {e}")
            self.client_sock = None
        # Close pygame window
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()
            self.window = None
            write_log("   Pygame window closed.")


# --- Environment Setup ---
write_log("✅ 建立環境開始")

# Create a function to instantiate the environment
def make_env():
    env = TetrisEnv()
    return env

# Use DummyVecEnv for single environment interaction with the Java server
# If you could run multiple servers on different ports, you could use SubprocVecEnv
train_env = DummyVecEnv([make_env])

# Wrap with VecFrameStack (channel-first order is important for CnnPolicy)
train_env = VecFrameStack(train_env, n_stack=run.config["n_stack"], channels_order="first")

# Wrap with VecNormalize, NORMALIZING REWARDS ONLY. Observation normalization
# should ideally be handled by the policy or done carefully if needed.
# Since policy_kwargs has normalize_images=False, we definitely don't normalize obs here.
train_env = VecNormalize(train_env, norm_obs=False, norm_reward=True, gamma=run.config["gamma"]) # Pass gamma

write_log("   環境建立完成並已包裝 (DummyVecEnv -> VecFrameStack -> VecNormalize)")

# Check environment (optional but recommended)
# Note: check_env doesn't work directly on VecEnv, check the base env if needed
# check_env(make_env())
# write_log("   基礎環境檢查通過")

# ----------------------------
# DQN Model Setup and Training
# ----------------------------
write_log("🧠 設定 DQN 模型...")

# Define DQN model with tuned hyperparameters and Wandb logging
model = DQN(
    run.config["policy_type"], # "CnnPolicy"
    train_env,
    verbose=1,
    tensorboard_log=f"/kaggle/working/runs/{run.id}", # Log TensorBoard data for Wandb
    gamma=run.config["gamma"],
    learning_rate=run.config["learning_rate"],
    buffer_size=run.config["buffer_size"],
    learning_starts=run.config["learning_starts"],
    batch_size=run.config["batch_size"],
    train_freq=(1, "step"), # Train every step
    gradient_steps=1,       # Perform 1 gradient update per training step
    target_update_interval=run.config["target_update_interval"],
    exploration_fraction=run.config["exploration_fraction"],
    exploration_final_eps=run.config["exploration_final_eps"],
    policy_kwargs=dict(normalize_images=False), # As per original code
    seed=42 # Set seed for reproducibility
    # device="cuda" if torch.cuda.is_available() else "cpu" # Explicitly set device if needed
)
write_log(f"   模型建立完成. Device: {model.device}")
write_log(f"   超參數: {run.config}")


# Setup Wandb callback for logging SB3 metrics, gradients, etc.
wandb_callback = WandbCallback(
    gradient_save_freq=10000, # Save gradients every N steps
    model_save_path=f"/kaggle/working/models/{run.id}", # Save model checkpoints locally
    model_save_freq=50000, # Save model every N steps
    log="all", # Log histograms, gradients, etc.
    verbose=2
)


# --- Training ---
write_log("🚀 開始訓練...")
try:
    model.learn(
        total_timesteps=run.config["total_timesteps"],
        callback=wandb_callback, # Use Wandb callback
        log_interval=10 # Log basic stats (like reward) every 10 episodes to console/wandb
    )
    write_log("✅ 訓練完成!")
except Exception as e:
     write_log(f"❌ 訓練過程中發生錯誤: {e}")
     # Save model before exiting if error occurs mid-training
     error_save_path = '/kaggle/working/113598065_dqn_error_save.zip'
     model.save(error_save_path)
     write_log(f"   模型已儲存至 {error_save_path}")
     wandb.save(error_save_path) # Upload to Wandb
     run.finish(exit_code=1, quiet=True) # Finish wandb run with error code
     raise # Re-raise exception


# --- Save Final Model ---
# Important: Save VecNormalize statistics before saving the agent
# These stats are needed to properly evaluate the trained agent later
stats_path = "/kaggle/working/vecnormalize_stats.pkl"
train_env.save(stats_path)
write_log(f"   VecNormalize 統計數據已儲存至 {stats_path}")
wandb.save(stats_path) # Save stats to wandb

# Save the final trained model
final_model_name = '113598065_dqn_final.zip' # Use a clear name
final_model_path = os.path.join("/kaggle/working", final_model_name)
model.save(final_model_path)
write_log(f"✅ 最終模型已儲存: {final_model_path}")
display(FileLink(final_model_path))
wandb.save(final_model_path) # Upload final model to Wandb


# ----------------------------
# Evaluation (Optional but recommended)
# ----------------------------
write_log("\n🧪 開始評估訓練後的模型...")

# Create a separate evaluation environment
eval_env = DummyVecEnv([make_env])
# IMPORTANT: Load the SAME VecNormalize statistics saved during training
# Use training=False to disable updates to the running stats
eval_env = VecNormalize.load(stats_path, eval_env)
eval_env.training = False # Do not update stats
eval_env.norm_reward = False # Do not normalize rewards during evaluation for true score

# Wrap with FrameStack, same as training
eval_env = VecFrameStack(eval_env, n_stack=run.config["n_stack"], channels_order="first")

# Load the trained model (optional, could just use 'model')
# loaded_model = DQN.load(final_model_path, env=eval_env) # Load model with eval env if needed

# --- Run Evaluation Episodes ---
num_eval_episodes = 5
total_rewards = []
total_lines = []
total_lifetimes = []
all_frames = [] # Store frames for one episode's GIF

for i in range(num_eval_episodes):
    obs = eval_env.reset()
    done = False
    episode_reward = 0
    episode_lines = 0
    episode_lifetime = 0
    frames = []
    last_info = {}

    while not done:
        # Get the raw observation for rendering/saving GIF if needed
        # Access the underlying environment's last raw frame
        # This assumes DummyVecEnv and accesses the first (and only) env
        # Be careful if using SubprocVecEnv
        raw_frame = eval_env.envs[0].render(mode="rgb_array")
        if i == 0: # Only save frames for the first evaluation episode
             frames.append(raw_frame)

        # Use deterministic=True for consistent evaluation actions
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, infos = eval_env.step(action)

        # Note: reward here is the unnormalized reward because eval_env.norm_reward = False
        episode_reward += reward[0] # VecEnv returns lists
        last_info = infos[0] # Get info from the single env
        episode_lines = last_info.get('removed_lines', 0) # Get final lines from info
        episode_lifetime = last_info.get('lifetime', 0)   # Get final lifetime from info

        # Handle VecEnv done signal (it's an array)
        done = done[0]


    write_log(f"   評估 Episode {i+1}: Reward={episode_reward:.2f}, Lines={episode_lines}, Steps={episode_lifetime}")
    total_rewards.append(episode_reward)
    total_lines.append(episode_lines)
    total_lifetimes.append(episode_lifetime)
    if i == 0:
        all_frames = frames # Keep frames from first episode

write_log(f"--- 評估結果 ({num_eval_episodes} episodes) ---")
write_log(f"   平均 Reward: {np.mean(total_rewards):.2f} +/- {np.std(total_rewards):.2f}")
write_log(f"   平均 Lines: {np.mean(total_lines):.2f} +/- {np.std(total_lines):.2f}")
write_log(f"   平均 Steps: {np.mean(total_lifetimes):.2f} +/- {np.std(total_lifetimes):.2f}")

# Log evaluation metrics to Wandb
wandb.log({
    "eval/mean_reward": np.mean(total_rewards),
    "eval/std_reward": np.std(total_rewards),
    "eval/mean_lines": np.mean(total_lines),
    "eval/std_lines": np.std(total_lines),
    "eval/mean_lifetime": np.mean(total_lifetimes),
    "eval/std_lifetime": np.std(total_lifetimes),
})


# --- Generate Replay GIF (from first evaluation episode) ---
if all_frames:
    gif_path = '/kaggle/working/replay_eval.gif'
    write_log(f"💾 正在儲存評估回放 GIF 至 {gif_path}...")
    try:
        # Ensure frames are uint8
        imageio.mimsave(gif_path, [np.array(frame).astype(np.uint8) for frame in all_frames], fps=15, loop=0)
        write_log("   GIF 儲存成功.")
        display(FileLink(gif_path))
        # Log the GIF to Wandb
        wandb.log({"eval/replay": wandb.Video(gif_path, fps=15, format="gif")})
    except Exception as e:
        write_log(f"   ❌ 儲存 GIF 時發生錯誤: {e}")
else:
     write_log("   ⚠️ 未能儲存 GIF (沒有收集到幀).")


# --- Save Evaluation Results CSV ---
csv_filename = 'tetris_evaluation_scores.csv'
csv_path = os.path.join("/kaggle/working", csv_filename)
try:
    with open(csv_path, 'w') as fs:
        fs.write('episode_id,removed_lines,played_steps,reward\n')
        # Use stats from the first episode for the format you requested
        # Ideally, log all episodes or averages
        fs.write(f'eval_0,{total_lines[0]},{total_lifetimes[0]},{total_rewards[0]:.2f}\n')
        # Add more lines if needed, e.g., for averages or best episode
        fs.write(f'eval_avg,{np.mean(total_lines):.2f},{np.mean(total_lifetimes):.2f},{np.mean(total_rewards):.2f}\n')
    write_log(f"✅ 評估分數 CSV 已儲存: {csv_path}")
    display(FileLink(csv_path))
    wandb.save(csv_path) # Upload CSV to Wandb
except Exception as e:
    write_log(f"   ❌ 儲存 CSV 時發生錯誤: {e}")


# --- Cleanup ---
write_log("🧹 清理環境...")
eval_env.close()
train_env.close() # Close training env as well
# Close the Java server process if needed (might require PID management or specific server command)
if 'process' in locals() and process.poll() is None:
     write_log("   正在終止 Java server process...")
     process.terminate() # Ask nicely first
     try:
         process.wait(timeout=5) # Wait for termination
         write_log("   Java server process 已終止.")
     except subprocess.TimeoutExpired:
         write_log("   Java server 未能在 5 秒內終止, 強制結束...")
         process.kill() # Force kill
         write_log("   Java server process 已強制結束.")


# Finish the Wandb run
run.finish()
write_log("✨ Wandb run finished.")

In [None]:
import numpy as np
import socket
import cv2
# import matplotlib.pyplot as plt # Matplotlib not strictly needed for core logic
import subprocess
import os
import shutil
import glob
import imageio
import gymnasium as gym
from gymnasium import spaces
# from stable_baselines3.common.env_checker import check_env # Not used on VecEnv
from stable_baselines3 import DQN
# from stable_baselines3.common.env_util import make_vec_env # Not used
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack, DummyVecEnv
from IPython.display import FileLink, display # Image not used directly
# from stable_baselines3.common.callbacks import BaseCallback # Replaced by WandbCallback
import torch
import time
import pygame # Added for rendering in TetrisEnv

# --- Wandb Setup ---
import os
import wandb
from kaggle_secrets import UserSecretsClient
# Import WandbCallback for SB3 integration
from wandb.integration.sb3 import WandbCallback

# --- Configuration ---
# Set your student ID here for filenames
STUDENT_ID = "113598065"
# Set total training steps
#TOTAL_TIMESTEPS = 2000000 # Adjust as needed (e.g., 1M, 2M, 5M)
TOTAL_TIMESTEPS = 2000000 # Reduced for a potentially quicker test run, increase for full training


# --- Wandb Login and Initialization ---
try:
    user_secrets = UserSecretsClient()
    WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")
    os.environ["WANDB_API_KEY"] = WANDB_API_KEY
    wandb.login()
    wandb_enabled = True
except Exception as e:
    print(f"Wandb login failed (running without secrets?): {e}. Running without Wandb logging.")
    wandb_enabled = False
    WANDB_API_KEY = None # Ensure it's None if not available

# Start a wandb run if enabled
if wandb_enabled:
    run = wandb.init(
        project="tetris-training-improved",
        entity="t113598065-ntut-edu-tw", # Replace with your Wandb entity if different
        sync_tensorboard=True,
        monitor_gym=True,
        save_code=True,
        config={ # Log hyperparameters
            "policy_type": "CnnPolicy",
            "total_timesteps": TOTAL_TIMESTEPS,
            "env_id": "TetrisEnv-v1",
            "gamma": 0.99,
            "learning_rate": 1e-4,
            "buffer_size": 300000, # Increased buffer size
            "learning_starts": 10000, # Keep reasonable starts
            "target_update_interval": 10000, # Keep reasonable update interval
            "exploration_fraction": 0.3, # Explore for 60% of training
            "exploration_final_eps": 0.05, # Lower final epsilon
            "batch_size": 32, # Default for DQN, can be tuned
            "n_stack": 4, # Frame stacking
            "student_id": STUDENT_ID,
        }
    )
    run_id = run.id # Get run ID for saving paths
else:
    run = None # Set run to None if wandb is disabled
    run_id = f"local_{int(time.time())}" # Create a local ID for paths


log_path = f"/kaggle/working/tetris_train_log_{run_id}.txt"

def write_log(message):
    """Appends a message to the log file and prints it."""
    timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
    log_message = f"{timestamp} - {message}"
    try:
        with open(log_path, "a", encoding="utf-8") as f:
            f.write(log_message + "\n")
    except Exception as e:
        print(f"Error writing to log file {log_path}: {e}")
    print(log_message)

def wait_for_tetris_server(ip="127.0.0.1", port=10612, timeout=60):
    """Waits for the Tetris TCP server to become available."""
    write_log(f"⏳ 等待 Tetris TCP server 啟動中 ({ip}:{port})...")
    start_time = time.time()
    while True:
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_sock:
                test_sock.settimeout(1.0)
                test_sock.connect((ip, port))
            write_log("✅ Java TCP server 準備完成，連線成功")
            return True # Indicate success
        except socket.error as e:
            if time.time() - start_time > timeout:
                write_log(f"❌ 等待 Java TCP server 超時 ({timeout}s)")
                return False # Indicate failure
            time.sleep(1.0) # Wait a bit longer before retrying

# --- Start Java Server ---
java_process = None # Initialize to None
try:
    write_log("🚀 嘗試啟動 Java Tetris server...")
    jar_file = "TetrisTCPserver_v0.6.jar"
    if not os.path.exists(jar_file):
         write_log(f"❌ 錯誤: 找不到 JAR 檔案 '{jar_file}'。請確保它在工作目錄中。")
         raise FileNotFoundError(f"JAR file '{jar_file}' not found.")

    # Start process, redirect stdout/stderr to DEVNULL if desired to keep console clean
    java_process = subprocess.Popen(
        ["java", "-jar", jar_file],
        stdout=subprocess.DEVNULL, # Optional: hide server stdout
        stderr=subprocess.DEVNULL  # Optional: hide server stderr
    )
    write_log(f"✅ Java server process 啟動 (PID: {java_process.pid})")
    if not wait_for_tetris_server():
        raise TimeoutError("Java server did not become available.") # Raise specific error

except Exception as e:
    write_log(f"❌ 啟動或等待 Java server 時發生錯誤: {e}")
    # Attempt to terminate if process started but failed connection
    if java_process and java_process.poll() is None:
         write_log("   嘗試終止未成功連接的 Java server process...")
         java_process.terminate()
         try:
             java_process.wait(timeout=2)
         except subprocess.TimeoutExpired:
             java_process.kill()
    raise # Re-raise the exception to stop the script

# --- Check GPU ---
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    write_log(f"✅ PyTorch is using GPU: {device_name}")
    # Optional: Check compute capability if needed
    # cc = torch.cuda.get_device_capability(0)
    # write_log(f"   Compute Capability: {cc[0]}.{cc[1]}")
else:
    write_log("⚠️ PyTorch is using CPU. Training will be significantly slower.")

# ----------------------------
# 定義 Tetris 環境 (採用老師的格式, 結合獎勵機制概念)
# ----------------------------
class TetrisEnv(gym.Env):
    """Custom Environment for Tetris that interacts with a Java TCP server."""
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
    N_DISCRETE_ACTIONS = 5
    IMG_HEIGHT = 200
    IMG_WIDTH = 100
    IMG_CHANNELS = 3
    RESIZED_DIM = 84

    def __init__(self, host_ip="127.0.0.1", host_port=10612, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        self.action_space = spaces.Discrete(self.N_DISCRETE_ACTIONS)
        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(1, self.RESIZED_DIM, self.RESIZED_DIM), # (Channels, Height, Width)
            dtype=np.uint8
        )
        self.server_ip = host_ip
        self.server_port = host_port
        self.client_sock = None
        self._connect_socket() # Connect in init

        # Reward shaping & statistics variables
        self.lines_removed = 0
        self.current_height = 0
        self.current_holes = 0
        self.lifetime = 0
        self.last_observation = np.zeros(self.observation_space.shape, dtype=np.uint8)

        # --- Reward Shaping Coefficients (TUNING REQUIRED) ---
        self.reward_line_clear_coeff = 100.0
        self.penalty_height_increase_coeff = 15.0
        self.penalty_hole_increase_coeff = 25.0
        self.penalty_step_coeff = 0.1
        self.penalty_game_over_coeff = 500.0

        # For rendering
        self.window_surface = None
        self.clock = None
        self.is_pygame_initialized = False # Track Pygame init state

    def _initialize_pygame(self):
        """Initializes Pygame if not already done."""
        if not self.is_pygame_initialized and self.render_mode == "human":
            try:
                import pygame
                pygame.init()
                pygame.display.init()
                # Scale window for better visibility
                window_width = self.RESIZED_DIM * 4
                window_height = self.RESIZED_DIM * 4
                self.window_surface = pygame.display.set_mode((window_width, window_height))
                pygame.display.set_caption(f"Tetris Env ({self.server_ip}:{self.server_port})")
                self.clock = pygame.time.Clock()
                self.is_pygame_initialized = True
                write_log("   Pygame initialized for rendering.")
            except ImportError:
                write_log("⚠️ Pygame not installed, cannot use 'human' render mode.")
                self.render_mode = None # Disable human rendering
            except Exception as e:
                write_log(f"⚠️ Error initializing Pygame: {e}")
                self.render_mode = None


    def _connect_socket(self):
        """Establishes connection to the game server."""
        try:
            if self.client_sock:
                self.client_sock.close()
            self.client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.client_sock.settimeout(10.0)
            self.client_sock.connect((self.server_ip, self.server_port))
            # write_log(f"🔌 Socket connected to {self.server_ip}:{self.server_port}") # Less verbose
        except socket.error as e:
            write_log(f"❌ Socket connection error during connect: {e}")
            raise ConnectionError(f"Failed to connect to Tetris server at {self.server_ip}:{self.server_port}")

    def _send_command(self, command: bytes):
        """Sends a command to the server, handles potential errors."""
        if not self.client_sock:
             raise ConnectionError("Socket is not connected. Cannot send command.")
        try:
            self.client_sock.sendall(command)
        except socket.timeout:
            write_log("❌ Socket timeout during send.")
            raise ConnectionAbortedError("Socket timeout during send")
        except socket.error as e:
            write_log(f"❌ Socket error during send: {e}")
            raise ConnectionAbortedError(f"Socket error during send: {e}")

    def _receive_data(self, size):
        """Receives exactly size bytes from the server."""
        if not self.client_sock:
             raise ConnectionError("Socket is not connected. Cannot receive data.")
        data = b""
        try:
            self.client_sock.settimeout(10.0) # Set timeout for recv
            while len(data) < size:
                chunk = self.client_sock.recv(size - len(data))
                if not chunk:
                    write_log("❌ Socket connection broken during receive (received empty chunk).")
                    raise ConnectionAbortedError("Socket connection broken")
                data += chunk
        except socket.timeout:
             write_log(f"❌ Socket timeout during receive (expected {size}, got {len(data)}).")
             raise ConnectionAbortedError("Socket timeout during receive")
        except socket.error as e:
            write_log(f"❌ Socket error during receive: {e}")
            raise ConnectionAbortedError(f"Socket error during receive: {e}")
        return data

    def get_tetris_server_response(self):
        """Gets state update from the Tetris server via socket."""
        try:
            is_game_over_byte = self._receive_data(1)
            is_game_over = (is_game_over_byte == b'\x01')

            removed_lines_bytes = self._receive_data(4)
            removed_lines = int.from_bytes(removed_lines_bytes, 'big')

            height_bytes = self._receive_data(4)
            height = int.from_bytes(height_bytes, 'big')

            holes_bytes = self._receive_data(4)
            holes = int.from_bytes(holes_bytes, 'big')

            img_size_bytes = self._receive_data(4)
            img_size = int.from_bytes(img_size_bytes, 'big')

            if img_size <= 0 or img_size > 1000000: # Increased max size slightly
                 write_log(f"❌ Received invalid image size: {img_size}. Aborting receive.")
                 raise ValueError(f"Invalid image size received: {img_size}")

            img_png = self._receive_data(img_size)

            # Decode and preprocess image
            nparr = np.frombuffer(img_png, np.uint8)
            np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            if np_image is None:
                 write_log("❌ Failed to decode image from server response.")
                 # Return last known state and signal termination
                 return True, self.lines_removed, self.current_height, self.current_holes, self.last_observation.copy()

            resized = cv2.resize(np_image, (self.RESIZED_DIM, self.RESIZED_DIM), interpolation=cv2.INTER_AREA)
            gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
            observation = np.expand_dims(gray, axis=0).astype(np.uint8) # Combine steps

            # Store frames for rendering/observation
            self.last_raw_render_frame = resized.copy() # Store BGR for render
            self.last_observation = observation.copy() # Store processed obs

            return is_game_over, removed_lines, height, holes, observation

        except (ConnectionAbortedError, ConnectionRefusedError, ValueError) as e:
             write_log(f"❌ Connection/Value error getting server response: {e}. Ending episode.")
             # Return last known state and signal termination
             return True, self.lines_removed, self.current_height, self.current_holes, self.last_observation.copy()
        except Exception as e:
            write_log(f"❌ Unexpected error getting server response: {e}. Ending episode.")
            # Return last known state and signal termination
            return True, self.lines_removed, self.current_height, self.current_holes, self.last_observation.copy()


    def step(self, action):
        # --- Send Action ---
        command_map = {
            0: b"move -1\n", 1: b"move 1\n",
            2: b"rotate 0\n", 3: b"rotate 1\n",
            4: b"drop\n"
        }
        command = command_map.get(action)
        if command is None:
            write_log(f"⚠️ Invalid action received: {action}. Sending 'drop'.")
            command = b"drop\n"

        try:
            self._send_command(command)
        except (ConnectionAbortedError, ConnectionError) as e:
            write_log(f"❌ Ending episode due to send failure in step: {e}")
            terminated = True
            observation = self.last_observation.copy()
            reward = self.penalty_game_over_coeff * -1
            info = {'removed_lines': self.lines_removed, 'lifetime': self.lifetime, 'final_status': 'send_error'}
            info['terminal_observation'] = observation
            return observation, reward, terminated, False, info

        # --- Get State Update ---
        terminated, new_lines_removed, new_height, new_holes, observation = self.get_tetris_server_response()

        # --- Calculate Reward ---
        reward = 0.0
        lines_cleared_this_step = new_lines_removed - self.lines_removed
        if lines_cleared_this_step > 0:
            reward += (lines_cleared_this_step ** 2) * self.reward_line_clear_coeff

        height_increase = new_height - self.current_height
        if height_increase > 0:
            reward -= height_increase * self.penalty_height_increase_coeff

        hole_increase = new_holes - self.current_holes
        if hole_increase > 0:
            reward -= hole_increase * self.penalty_hole_increase_coeff

        reward -= self.penalty_step_coeff # Step penalty

        if terminated:
            reward -= self.penalty_game_over_coeff
            # Log only once per game over for clarity
            write_log(f"💔 Game Over! Final Lines: {new_lines_removed}, Lifetime: {self.lifetime + 1}, reward: {reward}")

        # --- Update Internal State ---
        self.lines_removed = new_lines_removed
        self.current_height = new_height
        self.current_holes = new_holes
        self.lifetime += 1

        # --- Prepare Return Values ---
        info = {'removed_lines': self.lines_removed, 'lifetime': self.lifetime}
        truncated = False

        if terminated:
            info['terminal_observation'] = observation.copy()
            # Log final stats here if needed, or use SB3 logger/callback
            # Example: print(f"Episode End: Lines={self.lines_removed}, Lifetime={self.lifetime}, Reward={reward}")


        # Optional: Render on step if requested
        if self.render_mode == "human":
             self.render()

        return observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        for attempt in range(3): # Allow a few attempts to reset/reconnect
            try:
                self._send_command(b"start\n")
                terminated, lines, height, holes, observation = self.get_tetris_server_response()
                if terminated:
                    write_log(f"⚠️ Server reported game over on reset attempt {attempt+1}. Retrying...")
                    if attempt < 2: # Reconnect if not last attempt
                         self._connect_socket()
                         time.sleep(0.5) # Small delay before retry
                         continue # Retry the loop
                    else:
                         write_log("❌ Server still terminated after multiple reset attempts. Cannot proceed.")
                         raise RuntimeError("Tetris server failed to reset properly.")
                # Reset successful
                self.lines_removed = 0
                self.current_height = height
                self.current_holes = holes
                self.lifetime = 0
                self.last_observation = observation.copy()
                # write_log(f"🔄 Environment Reset. Initial state: H={height}, O={holes}") # Less verbose logging
                info = {}
                return observation, info

            except (ConnectionAbortedError, ConnectionError, socket.error, TimeoutError) as e:
                 write_log(f"🔌 Connection issue during reset attempt {attempt+1} ({e}). Retrying...")
                 if attempt < 2:
                      try:
                          self._connect_socket() # Attempt reconnect
                          time.sleep(0.5)
                      except ConnectionError:
                           write_log("   Reconnect failed.")
                           if attempt == 1: # If second attempt also fails, raise
                               raise RuntimeError(f"Failed to reconnect and reset Tetris server after multiple attempts: {e}")
                 else: # Final attempt failed
                     raise RuntimeError(f"Failed to reset Tetris server after multiple attempts: {e}")

        # Should not be reached if logic is correct, but as fallback:
        raise RuntimeError("Failed to reset Tetris server.")


    def render(self):
        self._initialize_pygame() # Ensure pygame is ready if in human mode

        if self.render_mode == "human" and self.is_pygame_initialized:
            import pygame
            if self.window_surface is None:
                 # This should not happen if _initialize_pygame worked, but handle defensively
                 write_log("⚠️ Render called but Pygame window is not initialized.")
                 return

            if hasattr(self, 'last_raw_render_frame'):
                try:
                    # last_raw_render_frame is (H, W, C) BGR from OpenCV
                    render_frame_rgb = cv2.cvtColor(self.last_raw_render_frame, cv2.COLOR_BGR2RGB)
                    # Pygame surface requires (width, height)
                    surf = pygame.Surface((self.RESIZED_DIM, self.RESIZED_DIM))
                    # Transpose needed: (H, W, C) -> (W, H, C) for Pygame surfarray
                    pygame.surfarray.blit_array(surf, np.transpose(render_frame_rgb, (1, 0, 2)))
                    # Scale up to window size
                    surf = pygame.transform.scale(surf, self.window_surface.get_size())
                    self.window_surface.blit(surf, (0, 0))
                    pygame.event.pump() # Process internal Pygame events
                    pygame.display.flip() # Update the full screen surface
                    self.clock.tick(self.metadata["render_fps"]) # Control frame rate
                except Exception as e:
                    write_log(f"⚠️ Error during Pygame rendering: {e}")
                    # Attempt to close pygame gracefully on error?
                    # self.close()

            else:
                # Draw a black screen if no frame available yet
                 self.window_surface.fill((0, 0, 0))
                 pygame.display.flip()

        elif self.render_mode == "rgb_array":
             if hasattr(self, 'last_raw_render_frame'):
                 # Return RGB (H, W, C)
                 return cv2.cvtColor(self.last_raw_render_frame, cv2.COLOR_BGR2RGB)
             else:
                 # Return black frame if no observation yet
                 return np.zeros((self.RESIZED_DIM, self.RESIZED_DIM, 3), dtype=np.uint8)

    def close(self):
        # write_log("🔌 Closing environment connection.") # Less verbose
        if self.client_sock:
            try:
                self.client_sock.close()
            except socket.error as e:
                 write_log(f"   Error closing socket: {e}")
            self.client_sock = None

        if self.is_pygame_initialized:
            try:
                import pygame
                pygame.display.quit()
                pygame.quit()
                self.is_pygame_initialized = False
                # write_log("   Pygame window closed.") # Less verbose
            except Exception as e:
                 write_log(f"   Error closing Pygame: {e}")

# --- Environment Setup ---
write_log("✅ 建立基礎環境函數 make_env...")
def make_env():
    """Helper function to create an instance of the Tetris environment."""
    env = TetrisEnv()
    return env

write_log("✅ 建立向量化環境 (DummyVecEnv)...")
# Use DummyVecEnv for single environment interaction
train_env_base = DummyVecEnv([make_env])

write_log("✅ 包裝環境 (VecFrameStack)...")
# Wrap with VecFrameStack (channel-first is important)
# Use wandb config if available, otherwise use default
n_stack = run.config["n_stack"] if run else 4
train_env_stacked = VecFrameStack(train_env_base, n_stack=n_stack, channels_order="first")

write_log("✅ 包裝環境 (VecNormalize - Rewards Only)...")
# Wrap with VecNormalize, NORMALIZING REWARDS ONLY.
# Use wandb config if available, otherwise use default
gamma = run.config["gamma"] if run else 0.99
train_env = VecNormalize(train_env_stacked, norm_obs=False, norm_reward=True, gamma=gamma)

write_log("   環境建立完成並已包裝 (DummyVecEnv -> VecFrameStack -> VecNormalize)")


# ----------------------------
# DQN Model Setup and Training
# ----------------------------
write_log("🧠 設定 DQN 模型...")
# Use wandb config for hyperparameters if available, otherwise use defaults
policy_type = run.config["policy_type"] if run else "CnnPolicy"
learning_rate = run.config["learning_rate"] if run else 1e-4
buffer_size = run.config["buffer_size"] if run else 100000
learning_starts = run.config["learning_starts"] if run else 10000
batch_size = run.config["batch_size"] if run else 32
tau = 1.0 # Default for DQN
target_update_interval = run.config["target_update_interval"] if run else 10000
gradient_steps = 1 # Default for DQN
exploration_fraction = run.config["exploration_fraction"] if run else 0.1 # Default DQN explore fraction is smaller
exploration_final_eps = run.config["exploration_final_eps"] if run else 0.05

# Define DQN model
model = DQN(
    policy=policy_type,
    env=train_env,
    verbose=1,
    gamma=gamma,
    learning_rate=learning_rate,
    buffer_size=buffer_size,
    learning_starts=learning_starts,
    batch_size=batch_size,
    tau=tau,
    train_freq=(1, "step"), # Train every step
    gradient_steps=gradient_steps,
    target_update_interval=target_update_interval,
    exploration_fraction=exploration_fraction,
    exploration_final_eps=exploration_final_eps,
    policy_kwargs=dict(normalize_images=False), # As per original code
    seed=42, # Set seed for reproducibility
    device="cuda" if torch.cuda.is_available() else "cpu",
    tensorboard_log=f"/kaggle/working/runs/{run_id}" if wandb_enabled else None # Log TB only if wandb enabled
)
write_log(f"   模型建立完成. Device: {model.device}")
if run: write_log(f"   使用 Wandb 超參數: {run.config}")
else: write_log("   使用默認超參數 (Wandb 未啟用).")


# Setup Wandb callback if enabled
if wandb_enabled:
    wandb_callback = WandbCallback(
        gradient_save_freq=10000,
        model_save_path=f"/kaggle/working/models/{run_id}",
        model_save_freq=50000,
        log="all",
        verbose=2
    )
    callback_list = [wandb_callback]
else:
    callback_list = None # No callback if wandb is disabled

# --- Training ---
write_log(f"🚀 開始訓練 {TOTAL_TIMESTEPS} 步...")
training_successful = False
try:
    model.learn(
        total_timesteps=TOTAL_TIMESTEPS,
        callback=callback_list,
        log_interval=10 # Log basic stats every 10 episodes
    )
    write_log("✅ 訓練完成!")
    training_successful = True
except Exception as e:
     write_log(f"❌ 訓練過程中發生錯誤: {e}", exc_info=True) # Log exception info
     # Save model before exiting if error occurs mid-training
     error_save_path = f'/kaggle/working/{STUDENT_ID}_dqn_error_save.zip'
     try:
        model.save(error_save_path)
        write_log(f"   模型已嘗試儲存至 {error_save_path}")
        if wandb_enabled: wandb.save(error_save_path)
     except Exception as save_e:
         write_log(f"   ❌ 儲存錯誤模型時也發生錯誤: {save_e}")
     if run: run.finish(exit_code=1, quiet=True) # Finish wandb run with error code

# --- Save Final Model (only if training completed successfully) ---
if training_successful:
    stats_path = f"/kaggle/working/vecnormalize_stats_{run_id}.pkl"
    final_model_name = f'{STUDENT_ID}_dqn_final_{run_id}.zip'
    final_model_path = os.path.join("/kaggle/working", final_model_name)

    try:
        train_env.save(stats_path)
        write_log(f"   VecNormalize 統計數據已儲存至 {stats_path}")
        if wandb_enabled: wandb.save(stats_path)

        model.save(final_model_path)
        write_log(f"✅ 最終模型已儲存: {final_model_path}")
        display(FileLink(final_model_path))
        if wandb_enabled: wandb.save(final_model_path)

    except Exception as e:
        write_log(f"❌ 儲存最終模型或統計數據時出錯: {e}")
        training_successful = False # Mark as unsuccessful if saving fails


# ----------------------------
# Evaluation (only if training and saving were successful)
# ----------------------------
if training_successful:
    write_log("\n🧪 開始評估訓練後的模型...")

    # Create a separate evaluation environment
    try:
        eval_env_base = DummyVecEnv([make_env])

        # Wrap with FrameStack FIRST, same as training
        n_stack_eval = run.config["n_stack"] if run else 4
        eval_env_stacked = VecFrameStack(eval_env_base, n_stack=n_stack_eval, channels_order="first")

        # Load the SAME VecNormalize statistics
        eval_env = VecNormalize.load(stats_path, eval_env_stacked)
        eval_env.training = False
        eval_env.norm_reward = False # IMPORTANT: দেখতে আসল reward

        write_log("   評估環境建立成功.")

    except FileNotFoundError:
        write_log(f"❌ 錯誤: VecNormalize 統計文件未找到於 {stats_path}。跳過評估。")
        eval_env = None
    except Exception as e:
        write_log(f"❌ 建立評估環境時出錯: {e}")
        eval_env = None

    if eval_env is not None:
        # --- Run Evaluation Episodes ---
        num_eval_episodes = 5
        total_rewards = []
        total_lines = []
        total_lifetimes = []
        all_frames = []

        try:
            for i in range(num_eval_episodes):
                obs = eval_env.reset()
                done = False
                episode_reward = 0
                episode_lines = 0
                episode_lifetime = 0
                frames = []
                last_info = {}

                while not done:
                    # Render base env for GIF
                    try:
                         base_env = eval_env.get_attr("envs")[0].env
                         raw_frame = base_env.render(mode="rgb_array")
                         if i == 0: frames.append(raw_frame) # Only for first ep
                    except Exception as render_err:
                         write_log(f"⚠️ 評估時獲取渲染幀出錯: {render_err}")

                    # Predict and step
                    action, _ = model.predict(obs, deterministic=True)
                    obs, reward, done, infos = eval_env.step(action)

                    episode_reward += reward[0]
                    last_info = infos[0]
                    # Use .get() for safety, default to previous value if key missing
                    episode_lines = last_info.get('removed_lines', episode_lines)
                    episode_lifetime = last_info.get('lifetime', episode_lifetime)
                    done = done[0]

                write_log(f"   評估 Episode {i+1}: Reward={episode_reward:.2f}, Lines={episode_lines}, Steps={episode_lifetime}")
                total_rewards.append(episode_reward)
                total_lines.append(episode_lines)
                total_lifetimes.append(episode_lifetime)
                if i == 0: all_frames = frames

            write_log(f"--- 評估結果 ({num_eval_episodes} episodes) ---")
            mean_reward = np.mean(total_rewards)
            std_reward = np.std(total_rewards)
            mean_lines = np.mean(total_lines)
            std_lines = np.std(total_lines)
            mean_lifetime = np.mean(total_lifetimes)
            std_lifetime = np.std(total_lifetimes)

            write_log(f"   平均 Reward: {mean_reward:.2f} +/- {std_reward:.2f}")
            write_log(f"   平均 Lines: {mean_lines:.2f} +/- {std_lines:.2f}")
            write_log(f"   平均 Steps: {mean_lifetime:.2f} +/- {std_lifetime:.2f}")

            # Log evaluation metrics to Wandb
            if wandb_enabled:
                wandb.log({
                    "eval/mean_reward": mean_reward, "eval/std_reward": std_reward,
                    "eval/mean_lines": mean_lines, "eval/std_lines": std_lines,
                    "eval/mean_lifetime": mean_lifetime, "eval/std_lifetime": std_lifetime,
                })

            # --- Generate Replay GIF ---
            if all_frames:
                gif_path = f'/kaggle/working/replay_eval_{run_id}.gif'
                write_log(f"💾 正在儲存評估回放 GIF 至 {gif_path}...")
                try:
                    imageio.mimsave(gif_path, [np.array(frame).astype(np.uint8) for frame in all_frames], fps=15, loop=0)
                    write_log("   GIF 儲存成功.")
                    display(FileLink(gif_path))
                    if wandb_enabled: wandb.log({"eval/replay": wandb.Video(gif_path, fps=15, format="gif")})
                except Exception as e: write_log(f"   ❌ 儲存 GIF 時發生錯誤: {e}")
            else: write_log("   ⚠️ 未能儲存 GIF (沒有收集到幀).")

            # --- Save Evaluation Results CSV ---
            csv_filename = f'tetris_evaluation_scores_{run_id}.csv'
            csv_path = os.path.join("/kaggle/working", csv_filename)
            try:
                with open(csv_path, 'w') as fs:
                    fs.write('episode_id,removed_lines,played_steps,reward\n')
                    if total_lines: # Ensure lists are not empty
                        fs.write(f'eval_0,{total_lines[0]},{total_lifetimes[0]},{total_rewards[0]:.2f}\n')
                    fs.write(f'eval_avg,{mean_lines:.2f},{mean_lifetime:.2f},{mean_reward:.2f}\n')
                write_log(f"✅ 評估分數 CSV 已儲存: {csv_path}")
                display(FileLink(csv_path))
                if wandb_enabled: wandb.save(csv_path)
            except Exception as e: write_log(f"   ❌ 儲存 CSV 時發生錯誤: {e}")

        except Exception as eval_e:
            write_log(f"❌ 評估迴圈中發生錯誤: {eval_e}", exc_info=True)

        finally:
             # Ensure evaluation env is closed even if errors occur
             if eval_env:
                 eval_env.close()
                 write_log("   評估環境已關閉.")

# --- Cleanup ---
write_log("🧹 清理環境...")
if 'train_env' in locals() and train_env: # Check if train_env exists
    train_env.close()
    write_log("   訓練環境已關閉.")
# Close the Java server process
if java_process and java_process.poll() is None:
     write_log("   正在終止 Java server process...")
     java_process.terminate()
     try:
         java_process.wait(timeout=5)
         write_log("   Java server process 已終止.")
     except subprocess.TimeoutExpired:
         write_log("   Java server 未能在 5 秒內終止, 強制結束...")
         java_process.kill()
         write_log("   Java server process 已強制結束.")
elif java_process and java_process.poll() is not None:
     write_log("   Java server process 已自行結束.")
else:
     write_log("   Java server process 未啟動或已關閉.")


# Finish the Wandb run if it was initialized and training didn't crash early
if run:
    if training_successful:
         run.finish()
         write_log("✨ Wandb run finished.")
    else:
         # Run might have already been finished in the exception handler
         if run.is_running:
              run.finish(exit_code=1) # Ensure it's marked as failed
         write_log("✨ Wandb run finished (marked as failed due to error).")

write_log("🏁 腳本執行完畢.")