## 注意事項
要先調整 PROJECT_ROOT = Path('/Users/Jer_ry/Desktop/scripts') 

In [1]:
import os
import inspect
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from pathlib import Path
import csv
import random 
import pdb
import os
import pandas as pd
from queue import *
from scipy.stats import poisson
from collections import defaultdict

In [34]:
PROJECT_ROOT = Path('/Users/Jer_ry/Desktop/scripts')   # already defined elsewhere
def generate_social_score(fname='sim1', n_agents=3,
                          MIN=-30, MAX=0, MEAN=-10, SD=10):

    # where to save
    csv_dir  = SIM_ROOT / 'log' / 'social_score' 
    csv_dir.mkdir(parents=True, exist_ok=True)
    out_path = csv_dir / f'{fname}.csv'

    # make the matrix
    raw   = np.random.randint(MIN, MAX + 1, (n_agents, n_agents))
    np.fill_diagonal(raw, 0)

    z     = (raw - raw.mean()) / raw.std()
    score = z * SD + MEAN

    # header row
    labels = [str(i+1) for i in range(n_agents)]
    df = pd.DataFrame(score, columns=labels, index=labels)

    df.to_csv(out_path) 
    print(f'✓ social score saved -> {out_path}')
    return df

## Simulation data dynamic
- 如果要一次跑多個 trial 的話也要回去修改


### 主函式
Parameters:
- FNAME (str): 要讀取的 social score CSV 檔名 （包含 social scores (without extension).
- RANDOM_NUM_GOALS (bool): If True, the number of goals will vary across mazes.
- VERSION_NAME (str): 模擬版本名稱，決定輸出目錄名稱
- STEP_TOTAL (int): 每個迷宮模擬最多執行幾步
- AGENT_NAME (str): 產生的文字檔開頭名稱
- SET_UP_MAZE_TOTAL (int): 要產生幾個迷宮
- PRINT_FINAL_MAZE (bool):是否印出最後一個迷宮狀態的文字檔

Outputs:
- 初始迷宮的 Text files：<AGENT_NAME>_<maze_total>.txt
- 最終迷宮的 Text files: <AGENT_NAME>_<maze_total>_FINAL.txt
- 'summary.csv': 總結模擬結果
    - maze: 第幾次模擬。
    - steps: 最後穩定時的步數。
    - maze_final_binary: 10-bit binary string，表示最後 agent 的鄰近關係。
    - maze_final_decimal: 將上面 binary 轉為 0~1023 的數字。
    - center_agent: 中心 agent 是誰。
    - flag_1 / flag_2: 是否由這個條件終止
- All output files are saved in a directory structured as: current_working_directory/../../new_scripts/data/data_dynamic/<VERSION_NAME>

In [35]:
# ========= 24×24 內室設定（不穿牆） =========
INNER_MIN = 7
INNER_MAX = 30
COORD_OFFSET = 6
CENTER_X = CENTER_Y = 19  # 24x24 內室中心

def clamp_world(pos):
    pos = np.asarray(pos, dtype=float)
    pos[...,0] = np.clip(pos[...,0], INNER_MIN, INNER_MAX)
    pos[...,1] = np.clip(pos[...,1], INNER_MIN, INNER_MAX)
    return pos
def world_to_disk(pos):
    """世界座標(7..30) → 磁碟座標(1..24)"""
    pos = np.asarray(pos, dtype=float) - COORD_OFFSET
    pos[...,0] = np.clip(pos[...,0], 1, 24)
    pos[...,1] = np.clip(pos[...,1], 1, 24)
    return pos

def disk_to_world(pos):
    """磁碟座標(1..24) → 世界座標(7..30)"""
    return np.asarray(pos, dtype=float) + COORD_OFFSET

def assert_world(pos, msg=''):
    pos = np.asarray(pos)
    assert (INNER_MIN <= pos[...,0]).all() and (pos[...,0] <= INNER_MAX).all(), f'X OOR {msg}'
    assert (INNER_MIN <= pos[...,1]).all() and (pos[...,1] <= INNER_MAX).all(), f'Y OOR {msg}'

# def build_base_maze_layout():
    # return [
    #     '##########################',
    #     '#                        #',
    #     '#                        #',
    #     '#                        #',
    #     '#                        #',
    #     '#                        #',
    #     '#     ##############     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     #            #     #',
    #     '#     ##############     #',
    #     '#                        #',
    #     '#                        #',
    #     '#                        #',
    #     '#                        #',
    #     '#                        #',
    #     '##########################',
    # ]
def build_base_maze_layout_38():
    # 38x38 外框；中間是 24x24 內室
    return[
        '######################################',
        '#                                    #',
        '#                                    #',
        '#                                    #',
        '#                                    #',
        '#                                    #',
        '#     ##########################     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     #                        #     #',
        '#     ##########################     #',
        '#                                    #',
        '#                                    #',
        '#                                    #',
        '#                                    #',
        '#                                    #',
        '######################################',
    ]

def game_art(center_agent, n_move_agents, n_chosen_agents, Chosen_agents_label, Chosen_agents_index):
    GAME_ART = [build_base_maze_layout_38()]
    H = len(GAME_ART[0]); W = len(GAME_ART[0][0])
    assert H == 38 and W == 38

    available = [(x,y) for x in range(INNER_MIN, INNER_MAX+1)
                        for y in range(INNER_MIN, INNER_MAX+1)
                        if not (x==CENTER_X and y==CENTER_Y)]
    chosen = random.sample(available, n_move_agents)
    x_random = [p[0] for p in chosen]
    y_random = [p[1] for p in chosen]
    x_random.insert(center_agent, CENTER_X)
    y_random.insert(center_agent, CENTER_Y)

    # ★ 起點斷言：世界座標範圍
    starts = np.c_[x_random, y_random]
    assert_world(starts, 'init start')

    # 覆字母
    for i in range(n_chosen_agents):
        x = int(x_random[i]); y = int(y_random[i])
        row = GAME_ART[0][y]
        GAME_ART[0][y] = row[:x] + Chosen_agents_label[Chosen_agents_index[i]] + row[x+1:]

    return x_random, y_random, GAME_ART


def final_maze(all_agent_place, Chosen_agents_label, Chosen_agents_index, n_chosen_agents):
    """38×38 底圖上覆最後一刻的位置（注意：這裡吃『未做 -offset』的原座標）"""
    GAME_FINAL = [build_base_maze_layout_38()]
    H = len(GAME_FINAL[0]); W = len(GAME_FINAL[0][0])
    assert H == 38 and W == 38, f'maze layout is {H}x{W}, expected 38x38'

    final_agent_place = np.zeros((n_chosen_agents,2))
    for i in range(-n_chosen_agents,0):
        final_agent_place[i] = all_agent_place[i,]
    final_agent_place = final_agent_place.astype(int)

    for i in range(n_chosen_agents):
        y = final_agent_place[i,1]; x = final_agent_place[i,0]
        assert 0 <= y < H and 0 <= x < W, f"final pos out of map: {(x,y)} for HxW={H}x{W}"
        row = GAME_FINAL[0][y]
        GAME_FINAL[0][y] = row[:x] + Chosen_agents_label[Chosen_agents_index[i]] + row[x+1:]
    return GAME_FINAL


In [36]:
def Center_Agent(fix_central_index, maze_total, n_agents):
    if fix_central_index:                           
        return fix_central_index[maze_total % len(fix_central_index)]
    else:                                           
        return maze_total % n_agents
## 迷宮初始化 
# 25x25 的格子中間有一個 12x12 的空間區域
        

def inputs(agent_place, x_random, y_random, social_reward, n_chosen_agents):
    if np.all(agent_place) == 0:
        for i in range(n_chosen_agents):
            agent_place[i] = np.array([x_random[i], y_random[i]])

    distance = np.zeros((n_chosen_agents, n_chosen_agents))
    for i in range(n_chosen_agents):
        for j in range(n_chosen_agents):
            distance[i][j] = np.hypot(agent_place[i][0]-agent_place[j][0],
                                      agent_place[i][1]-agent_place[j][1])

    dist_sq = np.square(distance) + 1e-6        # 加 ε 避免除 0
    force   = social_reward / dist_sq
    return agent_place, distance, force


In [37]:
DIR8 = [(-1,-1), (0,-1), (1,-1),
        (-1, 0),          (1, 0),
        (-1, 1), (0, 1), (1, 1)]
def move(agent_place, force, n_chosen_agents):
    """
    1) 計算每個 agent 的合力向量 force_vector_sum
    2) 根據合力方向與閾值 (>0.5) 決定擬定移動方向 (-1/0/1)
    3) 擬定所有新位置（同時夾在內室 [7,18] 之內）
    4) 如果一個位置只有一個 agent 想去，直接移動；
       如果多個想去，挑合力最大者移動，剩下的人都留原地。
    """
    # （1）計算 force_vector_sum
    force_vector = np.zeros((n_chosen_agents, 2 * n_chosen_agents))
    for i in range(n_chosen_agents):
        for j in range(n_chosen_agents):
            if i == j: continue
            dx = agent_place[j,0] - agent_place[i,0]
            dy = agent_place[j,1] - agent_place[i,1]
            dist = np.hypot(dx, dy)
            if dist > 0:
                f = force[i,j] / (dist**2 + 1e-6)
                force_vector[i,2*j  ] = (dx/dist) * f
                force_vector[i,2*j+1] = (dy/dist) * f

    force_vector_sum = force_vector.reshape(n_chosen_agents, n_chosen_agents, 2).sum(axis=1)

    # （2）從合力向量決定離散移動向量 move_vector
    move_vector = np.zeros_like(force_vector_sum, dtype=int)
    norms = np.linalg.norm(force_vector_sum, axis=1)
    for i in range(n_chosen_agents):
        ux, uy = force_vector_sum[i] / (norms[i] if norms[i]>0 else 1)
        move_vector[i,0] =  1 if ux >=  0.5 else (-1 if ux <= -0.5 else 0)
        move_vector[i,1] =  1 if uy >=  0.5 else (-1 if uy <= -0.5 else 0)

    # （3）擬定所有新位置，並夾在內室 [7,18]
    proposed = []
    for i in range(n_chosen_agents):
        x0, y0 = agent_place[i]
        x1 = min(INNER_MAX, max(INNER_MIN, x0 + move_vector[i,0]))
        y1 = min(INNER_MAX, max(INNER_MIN, y0 + move_vector[i,1]))
        proposed.append((x1, y1))

    # （4）衝突解決：同一個 cell 出現多於 1 次，留在原地；但選一個「力道最強者」破障
    buckets = defaultdict(list)
    for i, pos in enumerate(proposed):
        buckets[pos].append(i)

    new_place = agent_place.copy()
    for pos, idxs in buckets.items():
        if len(idxs) == 1:
            # 無人衝突，直接移動
            i = idxs[0]
            new_place[i] = pos
        else:
            # 衝突：挑合力最大者移動，其他人留原地
            strengths = [norms[i] for i in idxs]
            winner = idxs[np.argmax(strengths)]
            new_place[winner] = pos
            # 其餘的 idxs 不變（留在 agent_place）

    return clamp_world(new_place)

def move_random(agent_place, n_chosen_agents, n_sim_agents_for_poisson_mu):
    agent_place_current_step_start = np.array(agent_place, dtype=float)
    agent_place_proposed_updates   = np.array(agent_place, dtype=float)

    mu_val = n_sim_agents_for_poisson_mu / 2
    num_moving = 0
    if n_chosen_agents > 0:
        num_moving = min(max(1, poisson.rvs(mu=mu_val)), n_chosen_agents)
    if num_moving == 0:
        return agent_place_current_step_start

    moving_agent_indices = np.random.choice(n_chosen_agents, num_moving, replace=False)

    for agent_idx in moving_agent_indices:
        direction = np.random.randint(0, 4)
        current_x, current_y = agent_place_current_step_start[agent_idx]

        proposed_x, proposed_y = current_x, current_y
        if   direction == 0: proposed_y -= 1  # 上
        elif direction == 1: proposed_x += 1  # 右
        elif direction == 2: proposed_y += 1  # 下
        else:                proposed_x -= 1  # 左

        clamped_x = int(round(min(INNER_MAX, max(INNER_MIN, proposed_x))))
        clamped_y = int(round(min(INNER_MAX, max(INNER_MIN, proposed_y))))

        is_target_center = (clamped_x == CENTER_X and clamped_y == CENTER_Y)
        if not is_target_center and not (clamped_x == current_x and clamped_y == current_y):
            agent_place_proposed_updates[agent_idx] = [clamped_x, clamped_y]

    final_agent_positions_this_step = np.array(agent_place_proposed_updates, dtype=float)
    # 防止同格衝突
    for i in range(n_chosen_agents):
        for j in range(i + 1, n_chosen_agents):
            if (final_agent_positions_this_step[i,0] == final_agent_positions_this_step[j,0] and
                final_agent_positions_this_step[i,1] == final_agent_positions_this_step[j,1]):
                final_agent_positions_this_step[i] = agent_place_current_step_start[i]
                final_agent_positions_this_step[j] = agent_place_current_step_start[j]
    return clamp_world(final_agent_positions_this_step)


# def move_logic_random(positions):
    """
    根據指定的邏輯規則計算三個 agent 的下一步位置。
    1. 隨機選定一個 agent 上下移動，另外兩個左右移動。
    2. 左右移動的 agent 隨機選擇方向 (左/右)。
    3. 如果左右移動方向相同，上下 agent 向下；若方向相反，則向上。

    Args:
        positions (np.array): 形狀為 (3, 2) 的陣列，代表 A, B, C 的當前 (x,y) 座標。

    Returns:
        np.array: 形狀為 (3, 2) 的陣列，代表 agent 的新位置。
    """
    if positions.shape != (3, 2):
        raise ValueError("此邏輯只適用於 3 個 agents。")

    new_positions = positions.copy()
    agent_indices = np.array([0, 1, 2]) # 代表 A, B, C

    # 1. 隨機選出上下和左右移動的 agent
    np.random.shuffle(agent_indices)
    up_down_idx = agent_indices[0]
    left_right_indices = agent_indices[1:]

    # 2. 決定左右移動的方向
    # -1 代表向左, +1 代表向右
    direction_choices = [-1, 1]
    dx1 = random.choice(direction_choices)
    dx2 = random.choice(direction_choices)

    # 更新左右移動的 agent 的 x 座標
    new_positions[left_right_indices[0], 0] += dx1
    new_positions[left_right_indices[1], 0] += dx2

    # 3. 根據左右移動的方向，決定上下 agent 的移動方向
    # 註：假設螢幕座標系，y 增加是向下，y 減少是向上
    if dx1 == dx2:  # 方向相同 (都左或都右)
        # 向下移動
        new_positions[up_down_idx, 1] += 1
    else:  # 方向相反 (一左一右)
        # 向上移動
        new_positions[up_down_idx, 1] -= 1

    # 邊界限制：確保所有 agent 都在 [7, 18] 的活動範圍內
    new_positions = np.clip(new_positions, 7, 18)

    return new_positions

def move_logic(positions, direc_case):
    if positions.shape != (3, 2):
        raise ValueError("此邏輯只適用於 3 個 agents。")

    new_positions = positions.copy()
    direc_rule = []
    if   (direc_case == 1 ):  direc_rule = ['L', 'L', 'D']
    elif (direc_case == 2 ):  direc_rule = ['D', 'L', 'L']
    elif (direc_case == 3 ):  direc_rule = ['L', 'D', 'L']
    elif (direc_case == 4 ):  direc_rule = ['R', 'R', 'D']
    elif (direc_case == 5 ):  direc_rule = ['D', 'R', 'R']
    elif (direc_case == 6 ):  direc_rule = ['R', 'D', 'R']
    elif (direc_case == 7 ):  direc_rule = ['L', 'R', 'U']
    elif (direc_case == 8 ):  direc_rule = ['U', 'L', 'R']
    elif (direc_case == 9 ):  direc_rule = ['U', 'R', 'L']
    elif (direc_case == 10): direc_rule = ['L', 'U', 'R']
    elif (direc_case == 11): direc_rule = ['R', 'U', 'L']
    elif (direc_case == 0 ): direc_rule = ['R', 'L', 'U']

    for i, dir in enumerate(direc_rule):
        if  (dir == 'L'):  new_positions [i, 0] -= 1
        elif(dir == 'R'):  new_positions [i, 0] += 1
        elif(dir == 'U'):  new_positions [i, 1] -= 1
        elif(dir == 'D'):  new_positions [i, 1] += 1

    # BC

    new_positions = np.clip(new_positions, INNER_MIN, INNER_MAX)
    return new_positions

def move_intermediate_step(pos, primary_idx, case_type, state):
    if pos.shape != (3,2):
        raise ValueError('move_intermediate_step 僅支援 3 個 agents')
    if len(primary_idx) != 2 or any(i not in (0,1,2) for i in primary_idx):
        raise ValueError('primary_idx must be two of (0,1,2)')

    sec_idx = 3 - sum(primary_idx)
    new_pos = pos.astype(float).copy()
    phase = state.get('phase', 'AB')

    if phase == 'AB':
        dA = np.array(random.choice(DIR8), dtype=float)
        dB = np.array(random.choice(DIR8), dtype=float)
        new_pos[primary_idx[0]] += dA
        new_pos[primary_idx[1]] += dB
        new_pos = clamp_world(new_pos)          # ★ 保險夾
        state['dA'], state['dB'], state['phase'] = dA, dB, 'C'

    elif phase == 'C':
        if case_type == 1:
            dA = state.get('dA', np.array([0,0], dtype=float))
            dB = state.get('dB', np.array([0,0], dtype=float))
            avg = (dA + dB) / 2.0                   # in [-1,1]
            dC  = (np.sign(avg) * np.floor(np.abs(avg) + 0.5)).astype(int)  # {-1,0,1}
            new_pos[sec_idx] += dC
        elif case_type == 2:
            target = np.round((new_pos[primary_idx[0]] + new_pos[primary_idx[1]]) / 2.0).astype(int)
            new_pos[sec_idx] = target
        else:
            raise ValueError('case_type 必須是 1 或 2')

        new_pos = clamp_world(new_pos)          # ★ 保險夾
        state['phase'] = 'AB'

    else:
        raise ValueError(f"未知 phase: {phase}")

    return new_pos, state


In [131]:
import numpy as np
import random
from collections import defaultdict
from scipy.stats import poisson
def simulation_data_dynamic(fname_csv='dS103_test',
                            RANDOM_NUM_GOALS=False,
                            VERSION_NAME='sim1',
                            STEP_TOTAL=20,
                            AGENT_NAME="dS_test",
                            N_AGENTS=3,
                            SET_UP_MAZE_TOTAL=1000,
                            PRINT_FINAL_MAZE=False,
                            move_type = 'normal'
                            ):
    move_type = move_type.lower()
    # - dir
    DIR_TXT_OUTPUT   = SIM_ROOT / move_type          # rulemap / random / intermediate
    DIR_TXT_OUTPUT.mkdir(parents=True, exist_ok=True)
    dir_txt   = SIM_ROOT / move_type
    dir_txt.mkdir(parents=True, exist_ok=True)
    file_summary = dir_txt / 'summary.csv'
    FILE_CSV_SUMMARY = DIR_TXT_OUTPUT / 'summary.csv'

    DIR_TXT_BASE = DIR_TXT_OUTPUT

    dir_csv    = SIM_ROOT / 'log' / 'social_score' 
    csv_file   = dir_csv / f'{fname_csv}.csv'
    if not csv_file.exists():
        raise FileNotFoundError(csv_file)
    DIR_CSV_OUTPUT = SIM_ROOT / 'log' / 'social_score'
    DIR_CSV_OUTPUT.mkdir(parents=True, exist_ok=True)
    
    agents_arr = pd.read_csv(csv_file, index_col=0).to_numpy(float)

    n_chosen_agents =  N_AGENTS
    n_move_agents = n_chosen_agents-1
    Chosen_agents_index = list(range(n_chosen_agents))
    Chosen_agents_label = (
        ['S'] + [chr(64 + i) for i in range(1, n_chosen_agents)]
    )
    #center_agent= 0  #index of the pinned-at-center agent


    #social reward
    social_reward = agents_arr.copy()
            
    ## social reward check
    # 如果某 agent 對所有人吸引力總和為 0，
    # 代表不會主動移動，
    # 因此將其固定在中心
    fix_central_index=[]
    for i in range (n_chosen_agents):
        check=sum(abs(social_reward[i]))
        if check == 0:
            fix_central_index.append(i)
    maze_total = 0
    df_collect_summary = pd.DataFrame()
    assert SET_UP_MAZE_TOTAL % 2 == 0, "SET_UP_MAZE_TOTAL 必須是偶數，才能各產生一半的 case1/case2"
    half_total = SET_UP_MAZE_TOTAL // 2
    while maze_total < SET_UP_MAZE_TOTAL:
        #new maze
        step=0
        true_step=999
        agent_place = np.zeros((n_chosen_agents,2))                      #recent agent place
        all_agent_place = np.zeros((1,2))                  #all coordinates of agent place
        flag_1=0
        flag_2=0
        
        center_agent=Center_Agent(fix_central_index, maze_total, n_chosen_agents)
        x_random,y_random,GAME_ART=game_art(center_agent, n_move_agents, 
                                            n_chosen_agents, Chosen_agents_label, Chosen_agents_index)
        agent_place = np.zeros((n_chosen_agents, 2))
        agent_place, _, _ = inputs(agent_place,
                              x_random, y_random,
                               social_reward, n_chosen_agents)
        all_agent_place = [agent_place.copy()]   
        same_counter = wobble_counter = 0
        prev_prev    = None
        step = 0

        # for itermediate:
        ## case_type 為 1or 2 random 選一個
        case_type_maze = 1 if maze_total < half_total else 2
        primary_idx_maze = tuple(sorted(random.sample(range(3), 2)))
        inter_state = {'phase': 'AB', 'dA': None, 'dB': None}  # AB 先動

        while step < STEP_TOTAL:
            agent_place, _, force = inputs(agent_place, x_random, y_random, social_reward, n_chosen_agents)

            if move_type == 'rulemap':
                agent_place_new = move(agent_place, force, n_chosen_agents)
            elif move_type == 'random':
                agent_place_new = move_random(agent_place, n_chosen_agents, n_chosen_agents)
            elif move_type == 'intermediate':
                agent_place_new, inter_state = move_intermediate_step(
                    agent_place, primary_idx_maze, case_type_maze, inter_state
                )
            else:
                raise ValueError(f"unknown move_type: {move_type}")

            agent_place_new = clamp_world(agent_place_new)   # 再保險夾
            all_agent_place.append(agent_place_new.copy())
            step += 1

            same_counter   = same_counter + 1 if np.array_equal(agent_place_new, agent_place) else 0
            wobble_counter = wobble_counter + 1 if prev_prev is not None and np.array_equal(agent_place_new, prev_prev) else 0
            if same_counter >= 2 or wobble_counter >= 2:
                break

            prev_prev   = agent_place.copy()
            agent_place = agent_place_new.copy()


        true_step = step
        all_agent_place = np.vstack(all_agent_place)   # 世界座標 (T+1, N, 2)
        assert_world(all_agent_place, 'before write')

        all_agent_place_disk = world_to_disk(all_agent_place)


        #print the last maze here
        GAME_FINAL = final_maze(all_agent_place, Chosen_agents_label,
                                    Chosen_agents_index, n_chosen_agents)
                
        if all_agent_place_disk.shape[0] <= 1:
            # 這個 maze/logic 沒有效步驟，跳過
            continue

        maze_final_stat=np.zeros((n_chosen_agents,n_chosen_agents))
        for i  in range(-n_chosen_agents,0):
            for j in range(-n_chosen_agents,0):
                if(int(all_agent_place[i][1])-int(all_agent_place[j][1]))**2+(int(all_agent_place[i][0])-int(all_agent_place[j][0]))**2 <= 2:
                    maze_final_stat[i][j]=1
                

        #take 10 units of the 5*5 to represent status
        pair_cnt = n_chosen_agents * (n_chosen_agents - 1) // 2
        maze_final_stat_bin = np.zeros(pair_cnt)
        l=0
        for i in range(n_chosen_agents):        # 這裡不用負索引
            for j in range(i):
                maze_final_stat_bin[l] = maze_final_stat[j][i]
                l += 1

        #change into string
        maze_final_stat_bin=maze_final_stat_bin.astype(int)
        maze_final_stat_bin=maze_final_stat_bin.astype(str)
        maze_final_stat_bin="".join(maze_final_stat_bin)

        #bin to dec(0~1023)
        maze_final_stat_dec=int(maze_final_stat_bin,base=2)
        
        maze_final_stat_bin = maze_final_stat_bin.split()
        df_summary = pd.DataFrame({
            "maze"      : [maze_total + 1],
            "steps"     : [true_step],
            "maze_final_binary" : [maze_final_stat_bin],
            "maze_final_decimal": [maze_final_stat_dec],
            "center_agent"      : [Chosen_agents_label[center_agent]],
            "flag_1"            : [flag_1],
            "flag_2"            : [flag_2],
            "case_type"         : [case_type_maze if move_type == "intermediate" else None],
            "move_dominant_agents": [primary_idx_maze if move_type == "intermediate" else None]
        })
        
        # df_collect_summary = df_collect_summary.append(df_summary)
        df_collect_summary = pd.concat([df_collect_summary, df_summary], ignore_index=True)
        try:
            # ★ 改這裡：intermediate 依 case 分子資料夾
            if move_type == 'intermediate':
                subfolder = f'case{case_type_maze}'               # 'case1' or 'case2'
                SAVE_DIR  = DIR_TXT_BASE / subfolder
            else:
                SAVE_DIR  = DIR_TXT_BASE

            SAVE_DIR.mkdir(parents=True, exist_ok=True)

            output_file = SAVE_DIR / f"{AGENT_NAME}_{maze_total+1}.txt"
            with open(output_file, "w") as text_file:
                text_file.write('Maze:\n')
                for i in range(len(GAME_ART[0])):
                    text_file.write(GAME_ART[0][i] + '\n')
                np.set_printoptions(threshold=np.inf, linewidth=200)
                agent_place_str = np.array2string(all_agent_place_disk, separator=' ')
                text_file.write(agent_place_str + '\n')

            if PRINT_FINAL_MAZE:
                output_file_final = SAVE_DIR / f"{AGENT_NAME}_{maze_total+1}FINAL.txt"
                with open(output_file_final, "w") as text_file:
                    text_file.write('Final_Maze:\n')
                    for i in range(len(GAME_ART[0])):
                        text_file.write(GAME_FINAL[0][i] + '\n')
                    text_file.write('\n')

        except Exception as e:
            print('error:', e)

        maze_total += 1
        if maze_total == SET_UP_MAZE_TOTAL:
            df_collect_summary.to_csv(FILE_CSV_SUMMARY)

def get_parity_map(case_type):
    """ 根據 case_type 返回 A, B, C 的奇偶性映射。"""
    # agent 索引: 0=A, 1=B, 2=C
    if (case_type - 1) // 12 + 1 == 1:
        return {0: 'odd', 1: 'odd', 2: 'even'}
    elif (case_type - 1) // 12 + 1 == 2:
        return {0: 'odd', 1: 'even', 2: 'odd'}
    elif (case_type - 1) // 12 + 1 == 3:
        return {0: 'even', 1: 'odd', 2: 'odd'}
    else:
        raise ValueError("case_type 必須是 1, 2, 或 3。")

# --- 統一 26x26 迷宮的底圖 ---
def build_base_maze_layout():
    return [
        '##########################',
        '#                        #',
        '#                        #',
        '#                        #',
        '#                        #',
        '#                        #',
        '#     ##############     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     #            #     #',
        '#     ##############     #',
        '#                        #',
        '#                        #',
        '#                        #',
        '#                        #',
        '#                        #',
        '##########################',
    ]

# --- 依據 agents 的初始位置把字母覆蓋到迷宮上（與 RuleMap 一樣） ---
def overlay_agents_on_maze(base_layout_lines, init_positions_xy, labels):
    """init_positions_xy: (N,2) 迷宮座標（7..18），labels: ['S','A','B',...]"""
    maze = [list(r) for r in base_layout_lines]
    for i, lab in enumerate(labels):
        x, y = map(int, map(round, init_positions_xy[i]))
        if 0 <= y < len(maze) and 0 <= x < len(maze[y]):
            maze[y][x] = lab
    return [''.join(r) for r in maze]

# --- 統一寫檔：與 RuleMap/Intermediate 相同的標頭＋座標區塊 ---
def write_maze_txt(output_path, maze_lines_26, trajectory_xy_tna, coord_offset=6):
    """
    maze_lines_26 : list[str]，26 行迷宮（已覆蓋代理字母）
    trajectory_xy_tna : (T+1, N, 2) 的軌跡，尚未減 offset
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    # 與你的其他模式一致：先寫 Maze: 與 26 行，再寫正規化後座標
    traj_norm = trajectory_xy_tna.astype(float) - coord_offset
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write('Maze:\n')
        for row in maze_lines_26:
            f.write(row + '\n')
        # 與現有代碼一致的 numpy 輸出風格
        np.set_printoptions(threshold=np.inf, linewidth=200)
        f.write(np.array2string(traj_norm, separator=' ') + '\n')

def run_logic_simulation(case_type, trial_id, total_steps=50, version_name=None, agent_name=None, fix_center=True):
    out_dir = SIM_ROOT / 'logic'
    out_dir.mkdir(parents=True, exist_ok=True)
    agent_name = agent_name or VERSION_NAME
    out_path = out_dir / f'{agent_name}_case{case_type}_{trial_id}.txt'

    # 起點：世界座標
    N = 3
    init = np.random.randint(INNER_MIN, INNER_MAX+1, size=(N,2)).astype(float)
    if fix_center: init[0] = np.array([CENTER_X, CENTER_Y], float)
    assert_world(init, 'logic init')

    # 產生軌跡（世界座標）
    trajectory = [init.copy()]
    cur = init.copy()
    prev_prev = None
    direc_case = case_type % 12

    for _ in range(total_steps):
        nxt = move_logic(cur, direc_case)
        nxt = clamp_world(nxt)
        if np.array_equal(nxt, cur) or (prev_prev is not None and np.array_equal(nxt, prev_prev)):
            break
        trajectory.append(nxt.copy())
        prev_prev = cur.copy()
        cur = nxt.copy()

    traj_world = np.stack(trajectory, axis=0)  # (T, 3, 2)
    assert_world(traj_world, 'logic traj')

    # 迷宮 ASCII：用世界座標覆字母（沒問題）
    base = build_base_maze_layout_38()
    maze = [list(r) for r in base]
    labels = ['S','A','B']
    for i, lab in enumerate(labels[:N]):
        x, y = map(int, map(round, init[i]))
        maze[y][x] = lab
    maze_lines = [''.join(r) for r in maze]

    # ★ 統一轉成磁碟座標後寫（1..24）
    traj_disk = world_to_disk(traj_world)

    if traj_disk.shape[0] <= 1:
        # 一開始就沒有效步驟 → 不生成檔案
        return

    with open(out_path, 'w', encoding='utf-8') as f:
        f.write('Maze:\n')
        for row in maze_lines: f.write(row + '\n')
        np.set_printoptions(threshold=np.inf, linewidth=200)
        f.write(np.array2string(traj_disk.reshape(-1,2), separator=' ') + '\n')


# 主程式區
### 參數調整
- 生成 Intermediate/ Random/ RuleMap 規則的圖的區塊
- （logic 規則的圖因為和其他規則比較不一樣，另外使用其他函數來實作）

In [129]:
# 用以生成 social score 的參數
n_chosen_agents = 3
move_type = 'random' # RuleMap, Random, Intermediate

FNAME = 'sim1'  # File name for the output CSV file
MIN = -10  # Minimum value for social score
MAX = 10  # Maximum value for social score
MEAN = 1  # Mean value for social score
SD = 2.5  # Standard deviation for social score

# Define constants for the simulation data generation
RANDOM_NUM_GOALS = False  # If True, the number of goals will vary across mazes
VERSION_NAME = FNAME  # Version name used for output directory naming
STEP_TOTAL = 15  # 總共走幾步
AGENT_NAME = FNAME
SET_UP_MAZE_TOTAL = 100  # 總共要畫幾個不同的 maze
PRINT_FINAL_MAZE = False  # Flag to indicate whether to print the final maze

# Define constants for the TOMNet model
LIST_SUBJECTS = FNAME  # List of subject names, should correspond to VERSION_NAMEs
SIM_ROOT     = PROJECT_ROOT / 'data' / FNAME 
SIM_ROOT.mkdir(parents=True, exist_ok=True)

### 生成 social score

In [130]:
generate_social_score(FNAME, n_chosen_agents, MIN, MAX, MEAN, SD)

✓ social score saved -> /Users/Jer_ry/Desktop/scripts/data/sim1/log/social_score/sim1.csv


Unnamed: 0,1,2,3
1,1.180187,-2.063187,-0.4415
2,-0.4415,1.180187,-1.252344
3,6.856094,2.801875,1.180187


In [153]:

simulation_data_dynamic(
    fname_csv=FNAME,
    RANDOM_NUM_GOALS=RANDOM_NUM_GOALS,
    VERSION_NAME=VERSION_NAME,
    STEP_TOTAL=STEP_TOTAL,
    AGENT_NAME=AGENT_NAME,
    N_AGENTS=n_chosen_agents,
    SET_UP_MAZE_TOTAL=4000,
    PRINT_FINAL_MAZE=PRINT_FINAL_MAZE,
    move_type = 'Intermediate'  # 'RuleMap', 'Random', 'Intermediate'
)
SET_UP_MAZE_TOTAL = 2000
for move_type in ['rulemap', 'random']:
    simulation_data_dynamic(
        fname_csv=FNAME,
        RANDOM_NUM_GOALS=RANDOM_NUM_GOALS,
        VERSION_NAME=VERSION_NAME,
        STEP_TOTAL=STEP_TOTAL,
        AGENT_NAME=AGENT_NAME,
        N_AGENTS=n_chosen_agents,
        SET_UP_MAZE_TOTAL=SET_UP_MAZE_TOTAL,
        PRINT_FINAL_MAZE=PRINT_FINAL_MAZE,
        move_type=move_type
    )

In [154]:
# 專門 for logic simulation 的函式
# 有 36 種 case_type
for CASE_TO_RUN in range(1, 2001): # 先取前 20 個 case_type
    SIMULATION_STEPS = 15        # 每個試驗模擬幾步
    run_logic_simulation(
        case_type = CASE_TO_RUN,
        trial_id  = 1, 
        total_steps = SIMULATION_STEPS,
        version_name = VERSION_NAME
    )

## 畫圖圖

In [23]:
import os, re, math, itertools, random, glob
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

In [24]:
Chosen_agents_label = ['S'] + [chr(64 + i) for i in range(1, n_chosen_agents)]
N_AGENTS = len(Chosen_agents_label)

DIR_ROOT = os.getcwd() 

DIR_TXT_OUTPUT = os.path.join(
        SIM_ROOT,    
        VERSION_NAME,
        move_type    
)
STEP_TOTAL = 10 
COORDINATE_OFFSET = 6 
def parse_coordinates(coord_lines_raw):
    """
    從原始座標行列表中解析座標。
    """
    coordinates = []
    for line in coord_lines_raw:
        cleaned_line = line.strip().replace('[', '').replace(']', '')
        parts = cleaned_line.split()
        if len(parts) == 2:
            try:
                coordinates.append([float(parts[0]), float(parts[1])])
            except ValueError:
                print(f"警告：無法解析座標行: {line.strip()}")
    return np.array(coordinates)

def draw_maze_state(base_maze_layout, agent_positions_txt, agent_labels, step_num, maze_filename):
    """
    繪製給定步驟的迷宮狀態。
    :param base_maze_layout: list of strings, 迷宮的基礎佈局
    :param agent_positions_txt: numpy array, shape (N_AGENTS, 2), 從 TXT 讀取的當前步驟 agent 座標
    :param agent_labels: list of strings, agent 的標籤
    :param step_num: int, 當前的步驟編號
    :param maze_filename: str, 正在處理的迷宮檔案名稱
    """
    print(f"\n--- 繪製迷宮: {maze_filename}, 步驟: {step_num} ---")
    
    current_maze = [list(row) for row in base_maze_layout]
    
    for i, label in enumerate(agent_labels):
        if i < len(agent_positions_txt):
            # 從 TXT 讀取的座標需要加上 OFFSET 才是原始的繪圖座標
            x_txt, y_txt = agent_positions_txt[i, 0], agent_positions_txt[i, 1]
            
            # 轉換回繪圖座標
            x_plot = int(round(x_txt + COORDINATE_OFFSET))
            y_plot = int(round(y_txt + COORDINATE_OFFSET))
            
            # 邊界檢查 (使用繪圖座標)
            if 0 <= y_plot < len(current_maze) and 0 <= x_plot < len(current_maze[y_plot]):
                current_maze[y_plot][x_plot] = label
            else:
                print(f"警告：Agent {label} 的繪圖座標 ({x_plot},{y_plot}) 超出迷宮邊界。原始TXT座標: ({x_txt},{y_txt})")
        else:
            print(f"警告：Agent {label} 的座標缺失。")

    for row in current_maze:
        print("".join(row))

def process_maze_file(filepath):
    """
    讀取單個迷宮檔案並繪製所有步驟。
    """
    print(f"now processing: {filepath}")
    base_maze_layout = []
    raw_coord_lines = []
    
    try:
        with open(filepath, 'r') as f:
            lines = f.readlines()

        if not lines[0].strip() == "Maze:":
            print(f"Errir：檔案 {filepath} 的格式不正確，開頭不是 'Maze:'。")
            return
        
        maze_layout_lines = 26 
        base_maze_layout = [line.strip('\n') for line in lines[1:1 + maze_layout_lines]]
        if len(base_maze_layout) != maze_layout_lines:
            print(f"Error：檔案 {filepath} 的迷宮佈局行數不足。")
            return
            
        coord_block_str = "".join(lines[1 + maze_layout_lines:]).strip()
        
        if not coord_block_str.startswith("[[") or not coord_block_str.endswith("]]"):
            print(f"Error: 檔案 {filepath} 的座標區塊格式可能不完全符合預期的 [[...]] 包裝。嘗試照行解析。")
            raw_coord_lines = [line.strip() for line in lines[1 + maze_layout_lines:] if line.strip() and '[' in line and ']' in line]
            if not raw_coord_lines:
                print(f"Error: 檔案 {filepath} 不能從行中解析出任何座標")
                return
        else: 
            coord_block_str = coord_block_str[2:-2]
            individual_coord_strs = re.split(r'\]\s*\[', coord_block_str)
            raw_coord_lines = [f"[{s.strip()}]" for s in individual_coord_strs if s.strip()]

        all_agent_coords_parsed = parse_coordinates(raw_coord_lines)
        
        expected_coords_len = (STEP_TOTAL + 1) * N_AGENTS
        if len(all_agent_coords_parsed) != expected_coords_len:
            print(f"!!：檔案 {filepath} 中解析得到的座標數量 ({len(all_agent_coords_parsed)}) 與預期 ({expected_coords_len}) 不符。")
            if len(all_agent_coords_parsed) < N_AGENTS:
                 print(f"Error：座標數量不足以繪製任何步驟。")
                 return

        num_frames_to_draw = STEP_TOTAL + 1
        
        for step in range(num_frames_to_draw):
            start_index = step * N_AGENTS
            end_index = (step + 1) * N_AGENTS
            
            if end_index <= len(all_agent_coords_parsed):
                current_step_positions_txt = all_agent_coords_parsed[start_index:end_index]
                if len(current_step_positions_txt) == N_AGENTS:
                     draw_maze_state(base_maze_layout, current_step_positions_txt, Chosen_agents_label, step, os.path.basename(filepath))
                else:
                    print(f"Alarm,：步驟 {step} 的座標數量不足 ({len(current_step_positions_txt)}/{N_AGENTS})。")
                    break 
            else:
                print(f"Alarm：步驟 {step} 的座標數據不足，無法繪製。")
                break 

    except FileNotFoundError:
        print(f"Error：找不到檔案 {filepath}")
    except Exception as e:
        print(f"處理檔案 {filepath} 時發生錯誤: {e}")

def main():
    # DIR_TXT_OUTPUT 會在 if __name__ == '__main__' 區塊中被正確設定 (可能為測試路徑)
    if not os.path.exists(DIR_TXT_OUTPUT):
        print(f"Error：找不到目錄 {DIR_TXT_OUTPUT}")
        return

    for filename in os.listdir(DIR_TXT_OUTPUT):
        if filename.startswith(AGENT_NAME_PATTERN) and filename.endswith(".txt") and "FINAL" not in filename:
            filepath = os.path.join(DIR_TXT_OUTPUT, filename)
            process_maze_file(filepath)

#### TXT 版本
用以簡單測試規則，可忽略

In [25]:
# PROJECT_ROOT   = Path('/Users/Jer_ry/Desktop/scripts')
# SIM_FOLDER     = FNAME                    # e.g. 'sim1'
# RUN_TYPE       = 'logic' 

# DIR_TXT        = PROJECT_ROOT / 'data' / SIM_FOLDER / RUN_TYPE
# AGENT_PREFIX   = 'sim1_'
# N_AGENTS       = 3
# FILES_TO_SHOW  = 1      
# STEPS_TO_SHOW  = 10     
# COORD_OFFSET   = 6
# AGENT_LABELS   = ['S'] + [chr(65+i) for i in range(N_AGENTS-1)]
# def parse_coords(raw_lines):
#     pts=[]
#     for ln in raw_lines:
#         ln=ln.strip().strip('[]')
#         if ln:
#             pts.append(tuple(map(float,ln.split())))
#     return np.asarray(pts,dtype=float)

# def draw_ascii(base_maze,xy,step,fname):
#     canvas=[list(r) for r in base_maze]
#     for idx,(x_txt,y_txt) in enumerate(xy):
#         x=int(round(x_txt+COORD_OFFSET))
#         y=int(round(y_txt+COORD_OFFSET))
#         if 0<=y<26 and 0<=x<26:
#             canvas[y][x]=AGENT_LABELS[idx]
#     print(f'\n【{fname} | step={step}】')
#     for row in canvas: print(''.join(row))

# def load_and_show(path):
#     with open(path,encoding='utf-8') as f:
#         lines=f.read().splitlines()

#     base_maze=[re.sub(r'[A-Z]',' ',r) for r in lines[1:27]]
#     coords=parse_coords([ln for ln in lines[27:] if '[' in ln])

#     if coords.size==0 or coords.shape[0]%N_AGENTS:
#         print(f' {os.path.basename(path)} 不匹配 (T,{N_AGENTS},2)，跳過')
#         return

#     traj=coords.reshape(-1,N_AGENTS,2)

#     if STEPS_TO_SHOW is None: steps=range(len(traj))
#     elif isinstance(STEPS_TO_SHOW,int): steps=range(min(len(traj),STEPS_TO_SHOW))
#     else: steps=[s for s in STEPS_TO_SHOW if 0<=s<len(traj)]

#     for s in steps: draw_ascii(base_maze,traj[s],s,os.path.basename(path))

# # ---------- main ----------
# def main():
#     if not DIR_TXT.is_dir():
#         print('找不到資料夾:',DIR_TXT); return

#     all_txt=sorted(
#         f for f in os.listdir(DIR_TXT)
#         if f.startswith(AGENT_PREFIX) and f.endswith('.txt') and 'FINAL' not in f
#     )

#     if FILES_TO_SHOW is None: targets=all_txt
#     elif isinstance(FILES_TO_SHOW,int): targets=all_txt[:FILES_TO_SHOW]
#     else: targets=[f for f in all_txt if f in FILES_TO_SHOW]

#     if not targets:
#         print('沒有符合條件的檔案可顯示'); return

#     for fname in targets:
#         load_and_show(DIR_TXT/fname)

# if __name__=='__main__':
#     main()

### 圖檔產生

In [28]:
ODD  = [1,3,5,7,9]
EVEN = [0,2,4,6,8]

def parse_case_type_from_name(path: Path) -> Optional[int]:
    m = re.search(r'case(\d+)', path.stem, re.I)
    return int(m.group(1)) if m else None

def parity_by_role(case_type: int) -> dict[str, str]:
    """依 case_type 分組：1..12, 13..24, 25..36。
    你的描述：case_1 時 S=偶，A=奇，B=奇；以此類推每組誰是偶交替。"""
    grp = (case_type - 1)//12 + 1
    if grp == 1: return {'S':'even', 'A':'odd',  'B':'odd'}
    if grp == 2: return {'S':'odd',  'A':'even', 'B':'odd'}
    if grp == 3: return {'S':'odd',  'A':'odd',  'B':'even'}
    raise ValueError('case_type 必須 1..36')

def find_role_positions(ascii_lines: list[str]) -> dict[str, tuple[int,int]]:
    """從 ASCII 迷宮抓 S/A/B 的世界座標（7..30）。"""
    roles = {}
    for y, row in enumerate(ascii_lines):
        for x, ch in enumerate(row):
            if ch in ('S','A','B'):
                roles[ch] = (x, y)
    return roles

def map_roles_to_indices(pos0_world: np.ndarray, role_pos: dict[str,tuple[int,int]]) -> dict[int, str]:
    """把軌跡第 0 步 (N,2) 的世界座標，對齊到 ASCII 裡的 S/A/B。
    先用整數座標精確比對；不行再用最近距離備援。"""
    idx_to_role: dict[int,str] = {}
    used_roles: set[str] = set()
    p0 = np.rint(pos0_world).astype(int)

    # 先嘗試精準配對
    for role, (rx, ry) in role_pos.items():
        for i, (x, y) in enumerate(p0):
            if i in idx_to_role: 
                continue
            if x == rx and y == ry:
                idx_to_role[i] = role
                used_roles.add(role)
                break

    # 沒配到的用最近距離補齊（避免少量 rounding 誤差）
    for i, (x, y) in enumerate(p0):
        if i in idx_to_role:
            continue
        # 找還沒用過的最近角色
        cand = [(abs(rx-x) + abs(ry-y), r) for r,(rx,ry) in role_pos.items() if r not in used_roles]
        if not cand:
            continue
        _, role = min(cand, key=lambda t: t[0])
        idx_to_role[i] = role
        used_roles.add(role)

    return idx_to_role

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import os, re, glob, math, itertools
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional

# ───── 你的設定值 ─────
PROJECT_ROOT   = Path('/Users/Jer_ry/Desktop/scripts')
SIM_FOLDER     = FNAME
STEP_LIMIT     = 15 
COORD_OFFSET   = 0            # ← 若 TXT 已是 7..30 絕對座標，請設 0
COLOR_PALETTE  = [
    '#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd',
    '#8c564b','#e377c2','#7f7f7f','#bcbd22','#17becf'
]
ALPHA_FILL     = .8
VIEW_X_MIN, VIEW_X_MAX = 7, 30
VIEW_Y_MIN, VIEW_Y_MAX = 7, 30

#（新增）統一給定 N；若想每種模式不同 N，就在 main() 各自傳
COMMON_N = 3


# ───── 目錄 ─────
DATA_ROOT = PROJECT_ROOT/'data'/SIM_FOLDER       
PLOT_ROOT = DATA_ROOT                            

# ───── 小工具 ─────
def to_world_for_plot(pos: np.ndarray) -> np.ndarray:
    """把讀到的 (T,N,2) 或 (N,2) 轉成世界座標：
       1..24 -> +6；7..30 -> 原樣；其他 -> 報錯以免畫錯"""
    arr = np.asarray(pos, dtype=float)
    mn, mx = arr.min(), arr.max()
    if 1 <= mn and mx <= 24:
        return arr + 6.0
    if 7 <= mn and mx <= 30:
        return arr
    raise ValueError(f"座標範圍不明（既非 1..24 也非 7..30）：min={mn}, max={mx}")

def min_factor_gt1(n:int)->int:
    if n < 2: return 1
    for d in range(2,int(math.sqrt(n))+2):
        if n % d == 0: return d
    return n

def parse_all_floats(lines_after_maze:list)->list[float]:
    return list(map(float, re.findall(r'-?\d+(?:\.\d+)?',
                                      ''.join(lines_after_maze))))

def draw_one_step(ax, pos, labels, colors):
    for (x,y),c,lbl in zip(pos, colors, labels):
        if VIEW_X_MIN<=x<=VIEW_X_MAX and VIEW_Y_MIN<=y<=VIEW_Y_MAX:
            ax.add_patch(plt.Rectangle((x-.5,y-.5),1,1,
                                       facecolor=c,edgecolor='black',
                                       alpha=ALPHA_FILL,lw=1.2))
            ax.text(x,y,lbl,ha='center',va='center',color='white',
                    weight='bold',fontsize=12)

def export_png(fig, out_path):
    fig.tight_layout()
    fig.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.close(fig)

# ===== 通用繪圖：rulemap / random / intermediate / logic 都走這支 =====
def _find_data_start(lines:list[str]) -> int:
    for i, ln in enumerate(lines):
        if '[' in ln:
            return i
    return len(lines)

def _infer_N(ascii_lines:list[str], floats:list[float], guess_max:int=16) -> int:
    """同時用字母數與數字長度推斷 N。字母數不相容就用可整除的 N。"""
    # 1) 先看迷宮區塊裡的字母（S 或 A~Z）
    letters = {ch for row in ascii_lines for ch in row if ch == 'S' or 'A' <= ch <= 'Z'}
    letter_cnt = len(letters)

    # 2) 找出所有能把長度整除的 N（len(floats) 必須是 2*N 的倍數）
    L = len(floats)
    candidates = [n for n in range(1, guess_max+1) if (L % (2*n) == 0)]

    # 3) 如果字母數看起來合理，就採用；否則用候選集合
    if letter_cnt in candidates and letter_cnt != 0:
        return letter_cnt
    if candidates:
        # 優先 3（最常見），否則就取最接近字母數的；再不然取最小可行值
        if 3 in candidates:
            return 3
        if letter_cnt > 0:
            return min(candidates, key=lambda n: abs(n - letter_cnt))
        return min(candidates)
    raise ValueError(f"無法推斷 N：len(floats)={L}, letters={letters}")

def batch_plot_general(
    move_type: str,
    src_dir: Path,
    dst_dir: Path,
    preserve_subdirs: bool = False,
    keep_all_steps: bool = False,
    n_override: Optional[int] = None,
    ):
    import random, hashlib, json

    # ------ logic 專用小工具（內嵌）------
    ODD  = ['1','3','5','7','9']
    EVEN = ['0','2','4','6','8']

    def parse_case_type_from_name(path: Path):
        m = re.search(r'case(\d+)', path.stem, re.I)
        return int(m.group(1)) if m else None

    def parity_by_role(case_type: int):
        """
        36 種 case 分 3 組 (1..12, 13..24, 25..36)：
        grp1: S=even, A=odd,  B=odd
        grp2: S=odd,  A=even, B=odd
        grp3: S=odd,  A=odd,  B=even
        """
        grp = (case_type - 1)//12 + 1
        if grp == 1: return {'S':'even', 'A':'odd',  'B':'odd'}
        if grp == 2: return {'S':'odd',  'A':'even', 'B':'odd'}
        if grp == 3: return {'S':'odd',  'A':'odd',  'B':'even'}
        raise ValueError('case_type 必須 1..36')

    def find_role_positions(ascii_lines):
        roles = {}
        for y, row in enumerate(ascii_lines):
            for x, ch in enumerate(row):
                if ch in ('S','A','B'):
                    roles[ch] = (x, y)   # 世界座標系 0-based，和你疊字母時一致
        return roles

    def map_roles_to_indices(pos0_world: np.ndarray, role_pos):
        """
        用第 0 步（世界座標 7..30）對齊 ASCII 的 S/A/B 位置。
        先精準比對，再用最近距離補齊。
        """
        idx_to_role = {}
        used_roles  = set()
        p0 = np.rint(pos0_world).astype(int)

        # 精準配對
        for role, (rx, ry) in role_pos.items():
            for i, (x, y) in enumerate(p0):
                if i in idx_to_role: 
                    continue
                if x == rx and y == ry:
                    idx_to_role[i] = role
                    used_roles.add(role)
                    break

        # 最近距離補齊
        for i, (x, y) in enumerate(p0):
            if i in idx_to_role:
                continue
            cand = [(abs(rx-x) + abs(ry-y), r) for r,(rx,ry) in role_pos.items() if r not in used_roles]
            if not cand:
                continue
            _, role = min(cand, key=lambda t: t[0])
            idx_to_role[i] = role
            used_roles.add(role)

        return idx_to_role

    def stable_labels_for_step(stem: str, step_idx: int, idx_to_role, parity_role, N: int):
        """以 <檔名>:<步數> 當 seed，依角色的奇偶性產生三個字（可重現）。"""
        seed = int(hashlib.sha1(f"{stem}:{step_idx}".encode()).hexdigest()[:8], 16)
        rng  = random.Random(seed)
        labels = []
        for i in range(N):
            role = idx_to_role.get(i, None)
            parity = parity_role.get(role, 'odd') if role else 'odd'
            labels.append(rng.choice(ODD if parity=='odd' else EVEN))
        return labels
    # -------------------------------------------------------------

    txt_list = sorted(
        t for t in glob.glob(str(src_dir/'**'/'*.txt'), recursive=True)
        if 'FINAL' not in os.path.basename(t)
    )
    if not txt_list:
        print(f'CANT FIND：{src_dir}')
        return

    for txt in txt_list:
        txt_path = Path(txt)
        with open(txt_path, encoding='utf-8') as f:
            lines = f.read().splitlines()

        if not lines or lines[0].strip() != 'Maze:':
            print(f'WARN: 檔案沒有 "Maze:" 表頭，跳過 → {txt_path}')
            continue

        maze_end    = 1 + _find_data_start(lines[1:])
        ascii_lines = lines[1:maze_end]
        data_lines  = lines[maze_end:]

        floats = list(map(float, re.findall(r'-?\d+(?:\.\d+)?', ''.join(data_lines))))
        if not floats:
            print(f'WARN: 找不到數字資料，跳過 → {txt_path}')
            continue

        # === 若有給 n_override 就用；否則用最簡單可整除的 N ===
        if n_override is not None:
            N = int(n_override)
            if len(floats) % (2*N) != 0:
                raise ValueError(
                    f"{txt_path.name}: len(floats)={len(floats)} 不能被 2*N={2*N} 整除，"
                    f"請確認 N 或檔案內容。"
                )
        else:
            L2 = len(floats) // 2
            for N in range(1, 33):
                if L2 % N == 0:
                    break

        colors = list(itertools.islice(itertools.cycle(COLOR_PALETTE), N))

        try:
            traj = np.asarray(floats, dtype=float).reshape(-1, N, 2)
        except Exception as e:
            raise ValueError(
                f"reshape 失敗：file={txt_path.name}, len(floats)={len(floats)}, N={N}"
            ) from e

        max_s = len(traj) if STEP_LIMIT is None else min(len(traj), STEP_LIMIT+1)
        trial_id = txt_path.stem

        # 是否保留來源子資料夾（intermediate 的 case1/case2 用得到）
        out_base  = (dst_dir / txt_path.parent.relative_to(src_dir)) if preserve_subdirs else dst_dir
        out_trial = out_base / f'step_plot_{trial_id}'
        out_trial.mkdir(parents=True, exist_ok=True)

        # --------- logic 專用：建立 index→角色 與 角色→奇偶 的對照，並準備 labels.json ---------
        is_logic  = (move_type.lower() == 'logic')
        case_type = parse_case_type_from_name(txt_path) if is_logic else None
        idx_to_role = None
        parity_role = None
        step_labels_map = {}   # 只在 logic 時填入

        if is_logic:
            role_pos   = find_role_positions(ascii_lines)     # {'S':(x,y), 'A':(x,y), 'B':(x,y)}
            pos0_world = to_world_for_plot(traj[0])           # (N,2) → 世界座標 7..30
            idx_to_role = map_roles_to_indices(pos0_world, role_pos)
            parity_role = parity_by_role(case_type or 1)      # 找不到 case 就當作第 1 組
        # -------------------------------------------------------------------

        prev = None
        for s in range(max_s):
            pos = traj[s]
            if (not keep_all_steps) and prev is not None and np.array_equal(pos, prev):
                continue
            prev = pos.copy()

            pos_plot = to_world_for_plot(pos)

            fig, ax = plt.subplots(figsize=(7,7))
            ax.set_facecolor('white')
            ax.set_xlim(VIEW_X_MIN-.5, VIEW_X_MAX+.5)
            ax.set_ylim(VIEW_Y_MAX+.5, VIEW_Y_MIN-.5)
            ax.set_aspect('equal')
            ax.set_xticks(np.arange(VIEW_X_MIN-.5, VIEW_X_MAX+1.5, 1))
            ax.set_yticks(np.arange(VIEW_Y_MIN-.5, VIEW_Y_MAX+1.5, 1))
            ax.grid(True, lw=.8, color='lightgrey')
            ax.tick_params(length=0, labelleft=False, labelbottom=False)

            # 只有 logic 用奇/偶數字（可重現）；其餘模式用 1..N
            if is_logic and idx_to_role and parity_role:
                labels_this_step = stable_labels_for_step(trial_id, s, idx_to_role, parity_role, N)
                step_labels_map[str(s)] = labels_this_step
                draw_one_step(ax, pos_plot, labels_this_step, colors)
            else:
                default_labels = [str(i+1) for i in range(N)]
                draw_one_step(ax, pos_plot, default_labels, colors)

            export_png(fig, out_trial / f'step_{s:03d}.png')

        # logic：寫出 labels.json，供後續訓練資料生成讀取最後一張的數字
        if is_logic:
            meta = {
                "case_type": case_type,
                "parity": parity_role,
                "idx_to_role": {str(k): v for k, v in (idx_to_role or {}).items()},
                "step_labels": step_labels_map
            }
            with open(out_trial / "labels.json", "w", encoding="utf-8") as f:
                json.dump(meta, f, ensure_ascii=False, indent=2)

        prefix = '' if not preserve_subdirs else ('' if str(txt_path.parent.relative_to(src_dir))=='.' else f'{txt_path.parent.relative_to(src_dir)}/')
        print(f'✓ {move_type:<12} {prefix}{trial_id} → {out_trial}')



if __name__ == '__main__':
    main()


✓ rulemap      sim1_1 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_1
✓ rulemap      sim1_10 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_10
✓ rulemap      sim1_11 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_11
✓ rulemap      sim1_12 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_12
✓ rulemap      sim1_13 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_13
✓ rulemap      sim1_14 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_14
✓ rulemap      sim1_15 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_15
✓ rulemap      sim1_16 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_16
✓ rulemap      sim1_17 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_17
✓ rulemap      sim1_18 → /Users/Jer_ry/Desktop/scripts/data/sim1/rulemap_plot/step_plot_sim1_18
✓ rulemap      sim1_19 → /Users/Jer_ry/Des

### 生成答案表 以及 答案圖檔
Root = "validation_data" # "validation_data" or "training_data"

In [155]:
# -*- coding: utf-8 -*-
import os, re, glob, json, hashlib, random
from pathlib import Path
from typing import Tuple, List, Optional, Dict
import numpy as np
import matplotlib.pyplot as plt

# ========= 參數（可調）=========
MAKE_PNG    = False          # 是否輸出 A.png~D.png
RANDOM_SEED = None           # 設成整數可重現（例如 1234）；None 則每次不同

Root = "training_data" # "validation_data" or "training_data" or "testing_data"
# ========= 路徑 =========
PROJECT_ROOT = Path("/Users/Jer_ry/Desktop/scripts")
DATA_ROOT    = PROJECT_ROOT / "data" / "sim1"
TRAIN_ROOT   = PROJECT_ROOT / Root / "sim1"

SRC_FOLDERS = [
    DATA_ROOT / "logic",
    DATA_ROOT / "intermediate" / "case1",
    DATA_ROOT / "intermediate" / "case2",
    DATA_ROOT / "random",
    DATA_ROOT / "rulemap",
]

# ========= 視覺 / 盤面 =========
# 與 cell 編碼一致：1..24
VIEW_MIN, VIEW_MAX = 1, 24
GRID_DPI = 150
COLORS = ['#1f77b4','#ff7f0e','#2ca02c']  # 三個點的顏色
DIR8 = [(dx,dy) for dx in (-1,0,1) for dy in (-1,0,1) if not (dx==0 and dy==0)]

ODD  = ['1','3','5','7','9']
EVEN = ['0','2','4','6','8']

# ========= 共用工具 =========
def read_floats_after_maze(txt_path: Path) -> np.ndarray:
    """讀掉 'Maze:' 與 ASCII 迷宮，返回 (K,2) 的數列（磁碟或世界座標皆可）。"""
    with open(txt_path, encoding='utf-8') as f:
        lines = f.read().splitlines()

    def _find_data_start(lines_after_header: List[str]) -> int:
        for i, ln in enumerate(lines_after_header):
            if '[' in ln:
                return i
        return len(lines_after_header)

    if lines and lines[0].strip() == 'Maze:':
        maze_end = 1 + _find_data_start(lines[1:])
        data_lines = lines[maze_end:]
    else:
        data_lines = lines

    nums = re.findall(r'-?\d+(?:\.\d+)?', ''.join(data_lines))
    if len(nums) % 2 != 0 or len(nums) == 0:
        raise ValueError(f"{txt_path}: 無法解析成偶數個數字（長度={len(nums)}）")
    return np.asarray(list(map(float, nums)), dtype=float).reshape(-1, 2)

def world_to_disk_if_needed(arr: np.ndarray) -> np.ndarray:
    """若是世界座標(7..30)，轉成磁碟座標(1..24)；否則原樣返回。"""
    mn, mx = float(arr.min()), float(arr.max())
    out = arr.astype(float).copy()
    if mn >= 7 - 1e-6 and mx <= 30 + 1e-6:
        out = out - 6.0  # 7..30 -> 1..24
    # 保險夾：1..24
    out[...,0] = np.clip(out[...,0], 1, 24)
    out[...,1] = np.clip(out[...,1], 1, 24)
    return out

def move_one_step(p: np.ndarray) -> np.ndarray:
    """九宮格移動一步；撞牆（超出 1..24）就換另一個方向。"""
    x0, y0 = map(int, map(round, p.tolist()))
    dirs = DIR8[:]
    random.shuffle(dirs)
    for dx, dy in dirs:
        nx, ny = x0 + dx, y0 + dy
        if VIEW_MIN <= nx <= VIEW_MAX and VIEW_MIN <= ny <= VIEW_MAX:
            return np.array([nx, ny], dtype=float)
    return np.array([x0, y0], dtype=float)

def write_numpy_block(arr2d: np.ndarray, out_path: Path):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    np.set_printoptions(threshold=np.inf, linewidth=200)
    with open(out_path, 'w', encoding='utf-8') as f:
        f.write(np.array2string(arr2d, separator=' ') + '\n')

def derive_out_dir(src_path: Path) -> Path:
    """輸出資料夾 = training_data/sim1/<mode>_plot/(case子資料夾)/step_plot_<stem>"""
    rel = src_path.relative_to(DATA_ROOT)
    parts = rel.parts  # e.g. ('logic','sim1_case1_1.txt')

    mode = parts[0].lower()
    if mode == 'intermediate' and len(parts) >= 2:
        case_dir = parts[1]  # 'case1' or 'case2'
        plot_root = TRAIN_ROOT / 'intermediate_plot' / case_dir
    elif mode == 'logic':
        plot_root = TRAIN_ROOT / 'logic_plot'
    elif mode == 'random':
        plot_root = TRAIN_ROOT / 'random_plot'
    elif mode == 'rulemap':
        plot_root = TRAIN_ROOT / 'rulemap_plot'
    else:
        plot_root = TRAIN_ROOT / 'misc_plot'

    stem = src_path.stem  # e.g. 'sim1_case1_1'
    step_dir = f"step_plot_{stem}"
    return plot_root / step_dir

# ========= logic 標籤（數字）處理 =========
def get_parity(case_type: int) -> Dict[int, str]:
    """
    支援任意正整數的 case_type。
    以 36 為週期循環：1..12 -> 組1, 13..24 -> 組2, 25..36 -> 組3，再往後重複。
    規則：
      組1: {0:'odd',  1:'odd',  2:'even'}   # S/A/B 對應 index 0/1/2
      組2: {0:'odd',  1:'even', 2:'odd' }
      組3: {0:'even', 1:'odd',  2:'odd' }
    """
    if case_type <= 0:
        # 非法值就當作 1
        case_type = 1
    # 先摺回 1..36
    norm = ((case_type - 1) % 36) + 1
    grp = (norm - 1) // 12 + 1  # 1,2,3

    if grp == 1:
        return {0: 'odd',  1: 'odd',  2: 'even'}
    if grp == 2:
        return {0: 'odd',  1: 'even', 2: 'odd' }
    # grp == 3
    return   {0: 'even', 1: 'odd',  2: 'odd' }


def stem_case_type(stem: str) -> Optional[int]:
    # e.g. "sim1_case1_1" -> 1
    m = re.search(r'_case(\d+)_', stem)
    return int(m.group(1)) if m else None

def choose_digits_for_step(stem: str, step_idx: int, parity: Dict[int,str]) -> List[str]:
    """決定性產生每步三個數字：seed = sha1(stem:step_idx)，依 parity 挑 ODD/EVEN。"""
    seed = int(hashlib.sha1(f"{stem}:{step_idx}".encode()).hexdigest()[:8], 16)
    rng = random.Random(seed)
    labels = []
    for i in range(3):
        labels.append(rng.choice(ODD if parity[i]=='odd' else EVEN))
    return labels

def logic_plot_dir_candidates(src_txt: Path) -> List[Path]:
    """先找 training_data，再找 data。"""
    stem = src_txt.stem
    return [
        TRAIN_ROOT / 'logic_plot' / f"step_plot_{stem}",
        DATA_ROOT  / 'logic_plot' / f"step_plot_{stem}",
    ]

def load_logic_labels(src_txt: Path, n_steps: int) -> List[str]:
    """
    讀取「原本最後一張 logic 圖」的三個數字。
    1) 若 training_data 或 data 的 step_plot 內 labels.json 存在 → 直接讀
    2) 不存在 → 依 case_type + 決定性亂數產生所有步的標籤，寫到 training_data 下，再回傳最後一步
    """
    stem = src_txt.stem
    last_idx = n_steps - 1

    # 先找現成 labels.json
    for d in logic_plot_dir_candidates(src_txt):
        labels_path = d / 'labels.json'
        if labels_path.exists():
            with open(labels_path, 'r', encoding='utf-8') as f:
                meta = json.load(f)
            step_labels = meta.get('step_labels', {})
            key = str(last_idx)
            if key in step_labels:
                return step_labels[key]
            # 若缺最後一步，仍 fallback 生成

    # 生成 → 寫到 training_data
    case = stem_case_type(stem)
    if case is None:
        return ['1','2','3']  # 非 logic 命名，給預設

    parity = get_parity(case)
    step_labels = {str(s): choose_digits_for_step(stem, s, parity) for s in range(n_steps)}
    out_dir = logic_plot_dir_candidates(src_txt)[0]  # TRAIN_ROOT 優先
    out_dir.mkdir(parents=True, exist_ok=True)
    with open(out_dir / 'labels.json', 'w', encoding='utf-8') as f:
        json.dump({'case_type': case, 'parity': parity, 'step_labels': step_labels},
                  f, ensure_ascii=False, indent=2)
    return step_labels[str(last_idx)]

# ========= 繪圖（支援外部 labels）=========
def draw_triplet_png(triplet_disk: np.ndarray, out_png: Path, labels: Optional[List[str]] = None):
    pos = np.asarray(triplet_disk, dtype=float)

    fig, ax = plt.subplots(figsize=(6,6))
    ax.set_facecolor('white')
    ax.set_xlim(VIEW_MIN-.5, VIEW_MAX+.5)
    ax.set_ylim(VIEW_MAX+.5, VIEW_MIN-.5)
    ax.set_aspect('equal')
    ax.set_xticks(np.arange(VIEW_MIN-.5, VIEW_MAX+1.5, 1))
    ax.set_yticks(np.arange(VIEW_MIN-.5, VIEW_MAX+1.5, 1))
    ax.grid(True, lw=.8, color='lightgrey')
    ax.tick_params(length=0, labelleft=False, labelbottom=False)

    for i, (x, y) in enumerate(pos):
        xi, yi = int(round(x)), int(round(y))
        if VIEW_MIN <= xi <= VIEW_MAX and VIEW_MIN <= yi <= VIEW_MAX:
            ax.add_patch(plt.Rectangle((xi-.5, yi-.5), 1, 1,
                                       facecolor=COLORS[i % len(COLORS)],
                                       edgecolor='black', lw=1.2, alpha=0.9))
            text_str = (labels[i] if labels and i < len(labels) else str(i+1))
            ax.text(xi, yi, text_str, ha='center', va='center',
                    color='white', fontsize=14, weight='bold')
    out_png.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(out_png, dpi=GRID_DPI, bbox_inches='tight')
    plt.close(fig)

# ========= 選項產生 =========
# ========= 選項產生（新規則：從「前一個三點」出發）=========
def random_step_triplet(prev3_disk: np.ndarray) -> np.ndarray:
    """
    給定前一個三點 prev3_disk(shape=(3,2), 1..24)，
    讓三個 agent 各自九宮格隨機走一步（撞牆就換方向）。
    """
    out = np.asarray(prev3_disk, dtype=float).copy()
    for i in range(3):
        out[i] = move_one_step(out[i])
    return out

def triplet_key(arr: np.ndarray) -> tuple:
    """用於檢查 '三個向量不能都重複'：把 3x2 量化成整數 tuple 做集合判斷。"""
    a = np.rint(arr).astype(int)
    return tuple(map(int, a.reshape(-1)))

def make_choices(prev3_disk: np.ndarray, gold3_disk: np.ndarray,
                 num_distractors: int = 3, max_retry: int = 100) -> list[np.ndarray]:
    """
    產生 1 個正解 + 3 個干擾項。
    正解 = gold3_disk；
    干擾項 = 從 prev3_disk 出發，三個 agent 各走一步（九宮格，避牆）。
    需滿足：
      * 每個干擾項 != 正解
      * 三個干擾項彼此不重複（'三個向量不能都重複'）
    """
    gold = np.asarray(gold3_disk, dtype=float)
    choices = [gold.copy()]
    seen = {triplet_key(gold)}

    tries = 0
    while len(choices) < (1 + num_distractors) and tries < max_retry:
        cand = random_step_triplet(prev3_disk)
        k = triplet_key(cand)
        if k not in seen:
            seen.add(k)
            choices.append(cand)
        tries += 1

    if len(choices) < (1 + num_distractors):
        # 極端邊角狀況很難碰到；保險起見從 gold 做微擾（單點一步）補足
        while len(choices) < (1 + num_distractors):
            tmp = gold.copy()
            i = random.randrange(3)
            tmp[i] = move_one_step(tmp[i])
            k = triplet_key(tmp)
            if k not in seen:
                seen.add(k)
                choices.append(tmp)

    return choices  # 回傳 [gold, d1, d2, d3]

# ========= 主流程 =========
# ========= 主流程（改為用「前一個三點」產生干擾）=========
def process_one_file(src_txt: Path):
    coords = read_floats_after_maze(src_txt)  # (K,2)
    if coords.shape[0] < 6:
        print(f"SKIP (少於兩步=6點): {src_txt}")
        return

    # 步數與最後/前一個三點（統一轉成 1..24）
    N = 3
    total_steps = coords.shape[0] // N
    gold3      = coords[-3:].copy()
    prev3      = coords[-6:-3].copy()
    gold3_disk = world_to_disk_if_needed(gold3)
    prev3_disk = world_to_disk_if_needed(prev3)

    # trail（去掉最後三點）
    trail = coords[:-3].copy()
    out_dir = derive_out_dir(src_txt)
    out_dir.mkdir(parents=True, exist_ok=True)
    write_numpy_block(trail, out_dir / "trail.txt")

    # logic：讀最後一步三個數字（若無 labels.json，決定性生成並寫回 training_data/...）
    is_logic = src_txt.parent.name.lower() == 'logic'
    last_labels = None
    if is_logic:
        try:
            last_labels = load_logic_labels(src_txt, n_steps=total_steps)
        except Exception as e:
            print(f"WARN: 讀 logic 標籤失敗，改用 1/2/3 ：{src_txt} :: {e}")
            last_labels = ['1','2','3']

    # 產生 1 正解 + 3 干擾（都來自 prev3 -> 移一步），再打亂成 A/B/C/D
    choice_arrays = make_choices(prev3_disk, gold3_disk)  # [gold, d1, d2, d3]
    idxs = list(range(4))
    random.shuffle(idxs)
    letter_map = ['A','B','C','D']
    shuffled = [(letter_map[i], choice_arrays[j]) for i, j in enumerate(idxs)]
    correct_idx = idxs.index(0)  # 原始 index 0 是 gold
    correct_letter = letter_map[correct_idx]

    # 寫選項文字檔（不暴露型別）
    choice_txt = out_dir / f"{src_txt.stem}_choice.txt"
    with open(choice_txt, 'w', encoding='utf-8') as f:
        for lbl, arr in shuffled:
            f.write(f"Choice {lbl}:\n")
            f.write(np.array2string(arr.astype(float), separator=' ') + "\n\n")

    # 寫正解（label.json）—也把正解三點存進去方便檢核
    meta = {
        "source_txt": str(src_txt),
        "answer_idx": int(correct_idx),
        "answer_letter": correct_letter,
        "gold_triplet_disk": gold3_disk.tolist(),
        "prev_triplet_disk": prev3_disk.tolist(),
        "is_logic": bool(is_logic),
    }
    if is_logic:
        meta["digits_last_step"] = last_labels  # e.g. ["3","6","1"]
    with open(out_dir / "label.json", 'w', encoding='utf-8') as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    # 需要 PNG 才畫圖；logic 一律用最後一步的數字標籤；其他模式用 "1/2/3"
    if MAKE_PNG:
        for lbl, arr in shuffled:
            png_path = out_dir / f"{lbl}.png"
            draw_triplet_png(arr, png_path,
                             labels=(last_labels if is_logic else ['1','2','3']))

    # 供你快速核對
    print(f"OK: {src_txt.name:>24} → {out_dir} | ans={correct_letter} ({correct_idx})")
    print("prev3 =", np.array2string(prev3_disk, separator=' '),
          "gold3 =", np.array2string(gold3_disk, separator=' '))


def main():
    if RANDOM_SEED is not None:
        random.seed(RANDOM_SEED)
        np.random.seed(RANDOM_SEED)
    else:
        random.seed()

    files = []
    for folder in SRC_FOLDERS:
        files += [
            Path(p) for p in glob.glob(str(folder/"**/*.txt"), recursive=True)
            if "FINAL" not in os.path.basename(p)
        ]
    if not files:
        print("找不到任何 .txt 檔")
        return

    for p in sorted(files):
        try:
            process_one_file(p)
        except Exception as e:
            print(f"ERR: {p} :: {e}")

if __name__ == "__main__":
    main()


OK:               sim1_1.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_1 | ans=A (0)
prev3 = [[11. 14.]
 [ 2.  3.]
 [ 1. 24.]] gold3 = [[12. 13.]
 [ 2.  3.]
 [ 1. 23.]]
OK:              sim1_10.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_10 | ans=C (2)
prev3 = [[11. 17.]
 [ 4.  7.]
 [ 3. 14.]] gold3 = [[10. 18.]
 [ 4.  7.]
 [ 4. 14.]]
OK:             sim1_100.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_100 | ans=B (1)
prev3 = [[15.  8.]
 [10. 20.]
 [20.  1.]] gold3 = [[15.  8.]
 [11. 20.]
 [19.  1.]]
OK:            sim1_1000.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_1000 | ans=D (3)
prev3 = [[ 3.  5.]
 [ 1.  8.]
 [ 5. 12.]] gold3 = [[ 2.  4.]
 [ 1.  7.]
 [ 5. 12.]]
OK:            sim1_1001.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_1001 | 

### 生成訓練用 CSV

In [156]:
# -*- coding: utf-8 -*-
import os, re, glob, json, csv
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import numpy as np

# ========== 基本路徑 ==========
PROJECT_ROOT = Path("/Users/Jer_ry/Desktop/scripts")
DATA_ROOT    = PROJECT_ROOT / "data" / "sim1"
TRAIN_ROOT   = PROJECT_ROOT / Root / "sim1"
CSV_ROOT     = TRAIN_ROOT / "csv"
CSV_ROOT.mkdir(parents=True, exist_ok=True)

# 掃描的模式（對應 training_data 下的資料夾）
MODES = [
    # name,                   plot_dir (under training_data),           original data dir (under data)
    ("logic",                TRAIN_ROOT / "logic_plot",                DATA_ROOT / "logic"),
    ("rulemap",              TRAIN_ROOT / "rulemap_plot",              DATA_ROOT / "rulemap"),
    ("intermediate_case1",   TRAIN_ROOT / "intermediate_plot" / "case1", DATA_ROOT / "intermediate" / "case1"),
    ("intermediate_case2",   TRAIN_ROOT / "intermediate_plot" / "case2", DATA_ROOT / "intermediate" / "case2"),
    ("random",               TRAIN_ROOT / "random_plot",               DATA_ROOT / "random"),
]

VIEW_MIN, VIEW_MAX = 1, 24

# ========== 小工具 ==========
num_pat = re.compile(r'-?\d+(?:\.\d+)?')

def read_float_array(txt_path: Path) -> np.ndarray:
    """讀取 trail.txt 或 data 檔案尾段的 [..]，輸出 (K,2)"""
    with open(txt_path, encoding='utf-8') as f:
        text = f.read()
    nums = list(map(float, num_pat.findall(text)))
    if len(nums) % 2 != 0 or len(nums) == 0:
        raise ValueError(f"{txt_path} 無法解析出偶數個數字")
    arr = np.asarray(nums, dtype=float).reshape(-1, 2)
    return arr

def world_to_disk_if_needed(arr: np.ndarray) -> np.ndarray:
    """若是世界座標(7..30) → 轉為磁碟(1..24)；若已是 1..24 則原樣。"""
    mn, mx = float(arr.min()), float(arr.max())
    if mn >= 7 - 1e-6 and mx <= 30 + 1e-6:
        return arr - 6.0
    return arr

def disk_to_cell_ids(arr_3n2: np.ndarray) -> List[int]:
    """把 (T*3,2) 的 1..24 座標轉成 token id（0..575），順序沿用檔案順序。"""
    out = []
    for x, y in arr_3n2:
        xi, yi = int(round(x)), int(round(y))
        # 允許 0..24（題目裡的「改其中一個到 [0,0]~[24,24]」），但 token 只支援 1..24
        if not (1 <= xi <= 24 and 1 <= yi <= 24):
            # 將 0 夾回 1；>24 夾回 24，避免崩
            xi = min(24, max(1, xi))
            yi = min(24, max(1, yi))
        cid = (yi - 1) * 24 + (xi - 1)
        out.append(int(cid))
    return out

def parse_choices(choice_txt_path: Path) -> Dict[str, np.ndarray]:
    """
    讀 *_choice.txt，回傳 {'A':(3,2), 'B':(3,2), 'C':(3,2), 'D':(3,2)}，座標保留原表示（通常 1..24）。
    格式相容：
        Choice A:
        [[x y]
         [x y]
         [x y]]
    """
    with open(choice_txt_path, encoding='utf-8') as f:
        content = f.read()
    chunks = {}
    for label in ['A', 'B', 'C', 'D']:
        m = re.search(rf'Choice\s+{label}\s*:\s*(\[\[.*?\]\])', content, flags=re.S)
        if not m:
            continue
        arr = np.asarray(list(map(float, num_pat.findall(m.group(1)))), dtype=float).reshape(-1, 2)
        chunks[label] = arr
    if len(chunks) != 4:
        raise ValueError(f"{choice_txt_path} 解析選項失敗，找到 {len(chunks)} 個")
    return chunks

def load_answer_json(step_dir: Path) -> Optional[str]:
    """若存在 training_data/.../step_plot_*/answer.json，讀取 'correct' 的 'A'..'D'。"""
    ans = step_dir / "answer.json"
    if ans.exists():
        try:
            with open(ans, "r", encoding="utf-8") as f:
                j = json.load(f)
            cor = j.get("correct", None)
            if isinstance(cor, str) and cor in ("A", "B", "C", "D"):
                return cor
        except Exception:
            pass
    return None

def infer_original_txt_path(mode_name: str, orig_dir: Path, step_dir: Path) -> Optional[Path]:
    """
    step_dir 名為 step_plot_<stem>，原始檔應在 data/.../<stem>.txt
    e.g. step_plot_sim1_case13_1 -> data/sim1/logic/sim1_case13_1.txt
         step_plot_sim1_1        -> data/sim1/intermediate/case1/sim1_1.txt
    """
    stem = step_dir.name.replace("step_plot_", "")
    cand = orig_dir / f"{stem}.txt"
    return cand if cand.exists() else None

def get_last3_from_original(orig_txt: Path) -> np.ndarray:
    """從 data/sim1/... 的原檔取得「最後三個點」(世界或磁碟都吃，輸出磁碟 1..24)"""
    arr = read_float_array(orig_txt)  # 這會把整份 [..] 都讀進來
    last3_world_or_disk = arr[-3:].copy()
    last3_disk = world_to_disk_if_needed(last3_world_or_disk)
    return last3_disk

def np_equal_rounded(a: np.ndarray, b: np.ndarray) -> bool:
    """以四捨五入到 int 後的逐元素相等判定 (3,2) vs (3,2)"""
    return np.array_equal(np.rint(a).astype(int), np.rint(b).astype(int))

def ensure_list_jsonable(x):
    return json.dumps(x, ensure_ascii=False)

# ========== 主流程：對每個模式各寫 2 份 CSV ==========
def process_mode(mode_name: str, plot_dir: Path, orig_dir: Path):
    seq_rows: List[Dict] = []
    ans_rows: List[Dict] = []

    # 找所有 step_plot_* 目錄
    step_dirs = sorted([Path(p) for p in glob.glob(str(plot_dir / "step_plot_*")) if os.path.isdir(p)])
    if not step_dirs:
        print(f"[{mode_name}] 找不到任何 step_plot_* 目錄：{plot_dir}")
        return

    for sd in step_dirs:
        trail = sd / "trail.txt"
        # 找 choices（容忍不同檔名）
        choice_txts = list(sd.glob("*choice*.txt"))
        choice_txt = choice_txts[0] if choice_txts else None
        if not trail.exists() or not choice_txt:
            # 沒 trail 或沒選項就跳過
            # print(f"skip: {sd} (trail/choice missing)")
            continue

        try:
            # --- 1) 讀 trail，轉磁碟座標 & token ---
            trail_arr = read_float_array(trail)       # (K,2)
            trail_disk = world_to_disk_if_needed(trail_arr)
            # 必須是 3 的倍數
            if (len(trail_disk) % 3) != 0 or len(trail_disk) == 0:
                # 不合規 trail 直接略過
                # print(f"skip malformed trail: {trail}")
                continue
            T = len(trail_disk) // 3
            cell_seq = disk_to_cell_ids(trail_disk)

            # 寫序列一筆（每個檔案一列）
            seq_rows.append({
                "trial_id": sd.name.replace("step_plot_", ""),  # 原來的 <stem>
                "mode": mode_name,
                "case_type": (
                    re.search(r'_case(\d+)_', sd.name) .group(1)
                    if mode_name == "logic" and re.search(r'_case(\d+)_', sd.name) else
                    ("1" if "intermediate_case1" in mode_name else ("2" if "intermediate_case2" in mode_name else ""))
                ),
                "seq_len_T": T,
                "xy_json": ensure_list_jsonable(np.rint(trail_disk).astype(int).tolist()),
                "cell_json": ensure_list_jsonable(cell_seq),
            })

            # --- 2) 讀 choice 與正解 ---
            choices = parse_choices(choice_txt)
            answer = load_answer_json(sd)  # 先看 answer.json
            if answer is None:
                # 沒有 answer.json → 從 data/sim1/... 找原始最後三點
                orig_txt = infer_original_txt_path(mode_name, orig_dir, sd)
                if orig_txt is None:
                    # 找不到原始檔，就跳過答案
                    # print(f"warn: no original txt for {sd}")
                    continue
                last3_disk = get_last3_from_original(orig_txt)
                # 比對四個選項，誰等於 last3 就是正解
                found = None
                for k in ['A', 'B', 'C', 'D']:
                    if k in choices and np_equal_rounded(choices[k], last3_disk):
                        found = k
                        break
                if found is None:
                    # 還是沒找到就跳過（避免寫錯標籤）
                    # print(f"warn: no matching choice for {sd}")
                    continue
                answer = found

            # 寫答案一筆
            row_ans = {
                "trial_id": sd.name.replace("step_plot_", ""),
                "mode": mode_name,
                "correct": answer
            }
            for k in ['A','B','C','D']:
                arr = world_to_disk_if_needed(choices[k])   # 容錯
                row_ans[f"choice{k}_xy"]   = ensure_list_jsonable(np.rint(arr).astype(int).tolist())
                row_ans[f"choice{k}_cell"] = ensure_list_jsonable(disk_to_cell_ids(arr))
            ans_rows.append(row_ans)

        except Exception as e:
            print(f"[{mode_name}] ERR {sd}: {e}")
            continue

    # --- 輸出 CSV ---
    if seq_rows:
        seq_csv = CSV_ROOT / f"{mode_name}_train.csv"
        with open(seq_csv, "w", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=list(seq_rows[0].keys()))
            w.writeheader(); w.writerows(seq_rows)
        print(f"✓ wrote {seq_csv} ({len(seq_rows)} rows)")
    else:
        print(f"✗ no sequence rows for {mode_name}")

    if ans_rows:
        ans_csv = CSV_ROOT / f"{mode_name}_answers.csv"
        with open(ans_csv, "w", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=list(ans_rows[0].keys()))
            w.writeheader(); w.writerows(ans_rows)
        print(f"✓ wrote {ans_csv} ({len(ans_rows)} rows)")
    else:
        print(f"✗ no answer rows for {mode_name}")

def main():
    for name, plot_dir, orig_dir in MODES:
        process_mode(name, plot_dir, orig_dir)

if __name__ == "__main__":
    main()


✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/logic_train.csv (2000 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/logic_answers.csv (2000 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/rulemap_train.csv (2000 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/rulemap_answers.csv (2000 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/intermediate_case1_train.csv (2000 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/intermediate_case1_answers.csv (2000 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/intermediate_case2_train.csv (2500 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/intermediate_case2_answers.csv (2500 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/random_train.csv (2000 rows)
✓ wrote /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv/random_answers.csv (2000 rows)


In [105]:
from pathlib import Path
import pandas as pd, json, numpy as np

PAD_ID = 576

def cid_to_xy(cid):
    y = cid // 24 + 1
    x = cid % 24 + 1
    return x, y

def xy_to_cid(x, y):
    return (y-1)*24 + (x-1)

def check_train_csv(csv_path: Path):
    df = pd.read_csv(csv_path)
    assert {"trial_id","mode","seq_len_T","xy_json","cell_json"}.issubset(df.columns)

    for i,row in df.iterrows():
        cells = json.loads(row["cell_json"])
        xys   = np.array(json.loads(row["xy_json"]), dtype=int)
        T     = int(row["seq_len_T"])
        assert len(cells) == 3*T, f"{csv_path}:{i} cell_json len != 3*T"
        assert xys.shape == (3*T, 2), f"{csv_path}:{i} xy_json shape error"
        assert all(0 <= c <= 575 for c in cells), f"{csv_path}:{i} cid out of range"
        assert (1 <= xys[:,0]).all() and (xys[:,0] <= 24).all(), f"{csv_path}:{i} x out of range"
        assert (1 <= xys[:,1]).all() and (xys[:,1] <= 24).all(), f"{csv_path}:{i} y out of range"
        # 一致性檢查
        cells2 = [xy_to_cid(int(x),int(y)) for x,y in xys]
        assert cells2 == cells, f"{csv_path}:{i} xy_json <-> cell_json mismatch"

def check_answers_csv(csv_path: Path):
    df = pd.read_csv(csv_path)
    assert {"trial_id","mode","correct"}.issubset(df.columns)

    for i,row in df.iterrows():
        assert row["correct"] in ["A","B","C","D"], f"{csv_path}:{i} correct not in ABCD"
        for ch in "ABCD":
            cc  = json.loads(row[f"choice{ch}_cell"])
            cxy = np.array(json.loads(row[f"choice{ch}_xy"]), dtype=int)
            assert len(cc) == 3, f"{csv_path}:{i} choice{ch}_cell len != 3"
            assert cxy.shape == (3,2), f"{csv_path}:{i} choice{ch}_xy shape error"
            assert all(0 <= c <= 575 for c in cc), f"{csv_path}:{i} choice{ch} cid out of range"
            assert (1 <= cxy[:,0]).all() and (cxy[:,0] <= 24).all(), f"{csv_path}:{i} choice{ch} x out of range"
            assert (1 <= cxy[:,1]).all() and (cxy[:,1] <= 24).all(), f"{csv_path}:{i} choice{ch} y out of range"
            cc2 = [xy_to_cid(int(x),int(y)) for x,y in cxy]
            assert cc2 == cc, f"{csv_path}:{i} choice{ch} xy <-> cell mismatch"


NAMES = ["rulemap","random","logic","intermediate_case1","intermediate_case2"]

def keys(root_dir: Path) -> set[str]:
    ks = set()
    for n in NAMES:
        for suffix in ["_train.csv", "_answers.csv"]:
            p = root_dir / f"{n}{suffix}"
            if not p.exists(): 
                continue
            df = pd.read_csv(p, usecols=["trial_id"])
            for t in df["trial_id"].astype(str):
                ks.add(f"{n}|{t}")  # ← 以 (mode|trial_id) 當唯一鍵
    return ks

def check_no_leakage(train_dir: Path, val_dir: Path, test_dir: Path):
    k_tr = keys(train_dir)
    k_va = keys(val_dir)
    k_te = keys(test_dir)
    assert k_tr.isdisjoint(k_va), "train/val trial overlap"
    assert k_tr.isdisjoint(k_te), "train/test trial overlap"
    assert k_va.isdisjoint(k_te), "val/test trial overlap"

def run_all():
    roots = {
        "train": Path("/Users/Jer_ry/Desktop/scripts/training_data/sim1/csv"),
        "val"  : Path("/Users/Jer_ry/Desktop/scripts/validation_data/sim1/csv"),
        "test" : Path("/Users/Jer_ry/Desktop/scripts/testing_data/sim1/csv"),
    }
    names = ["rulemap","random","logic","intermediate_case1","intermediate_case2"]

    for split, root in roots.items():
        for n in names:
            check_train_csv(root / f"{n}_train.csv")
            check_answers_csv(root / f"{n}_answers.csv")
        print(f"✓ {split} ok:", root)

    check_no_leakage(roots["train"], roots["val"], roots["test"])
    print("✓ no leakage across splits")

if __name__ == "__main__":
    run_all()


✓ train ok: /Users/Jer_ry/Desktop/scripts/training_data/sim1/csv
✓ val ok: /Users/Jer_ry/Desktop/scripts/validation_data/sim1/csv
✓ test ok: /Users/Jer_ry/Desktop/scripts/testing_data/sim1/csv


AssertionError: train/val trial overlap

In [None]:
# -*- coding: utf-8 -*-
import os, re, glob, json, hashlib, random
from pathlib import Path
from typing import Tuple, List, Optional, Dict
import numpy as np
import matplotlib.pyplot as plt

# ========= 參數（可調）=========
MAKE_PNG    = False          # 是否輸出 A.png~D.png
RANDOM_SEED = None           # 設成整數可重現（例如 1234）；None 則每次不同

Root = "training_data" # "validation_data" or "training_data" or "testing_data"
# ========= 路徑 =========
PROJECT_ROOT = Path("/Users/Jer_ry/Desktop/scripts")
DATA_ROOT    = PROJECT_ROOT / "data" / "sim1"
TRAIN_ROOT   = PROJECT_ROOT / Root / "sim1"

SRC_FOLDERS = [
    DATA_ROOT / "logic",
    DATA_ROOT / "intermediate" / "case1",
    DATA_ROOT / "intermediate" / "case2",
    DATA_ROOT / "random",
    DATA_ROOT / "rulemap",
]

VIEW_MIN, VIEW_MAX = 1, 24
GRID_DPI = 150
COLORS = ['#1f77b4','#ff7f0e','#2ca02c']  # 三個點的顏色
DIR8 = [(dx,dy) for dx in (-1,0,1) for dy in (-1,0,1) if not (dx==0 and dy==0)]

ODD  = ['1','3','5','7','9']
EVEN = ['0','2','4','6','8']

def read_floats_after_maze(txt_path: Path) -> np.ndarray:
    with open(txt_path, encoding='utf-8') as f:
        lines = f.read().splitlines()

    def _find_data_start(lines_after_header: List[str]) -> int:
        for i, ln in enumerate(lines_after_header):
            if '[' in ln:
                return i
        return len(lines_after_header)

    if lines and lines[0].strip() == 'Maze:':
        maze_end = 1 + _find_data_start(lines[1:])
        data_lines = lines[maze_end:]
    else:
        data_lines = lines

    nums = re.findall(r'-?\d+(?:\.\d+)?', ''.join(data_lines))
    if len(nums) % 2 != 0 or len(nums) == 0:
        raise ValueError(f"{txt_path}: 無法解析成偶數個數字（長度={len(nums)}）")
    return np.asarray(list(map(float, nums)), dtype=float).reshape(-1, 2)

def world_to_disk_if_needed(arr: np.ndarray) -> np.ndarray:
    mn, mx = float(arr.min()), float(arr.max())
    out = arr.astype(float).copy()
    if mn >= 7 - 1e-6 and mx <= 30 + 1e-6:
        out = out - 6.0  # 7..30 -> 1..24
    # 保險夾：1..24
    out[...,0] = np.clip(out[...,0], 1, 24)
    out[...,1] = np.clip(out[...,1], 1, 24)
    return out

def move_one_step(p: np.ndarray) -> np.ndarray:
    """九宮格移動一步；撞牆（超出 1..24）就換另一個方向。"""
    x0, y0 = map(int, map(round, p.tolist()))
    dirs = DIR8[:]
    random.shuffle(dirs)
    for dx, dy in dirs:
        nx, ny = x0 + dx, y0 + dy
        if VIEW_MIN <= nx <= VIEW_MAX and VIEW_MIN <= ny <= VIEW_MAX:
            return np.array([nx, ny], dtype=float)
    return np.array([x0, y0], dtype=float)

def write_numpy_block(arr2d: np.ndarray, out_path: Path):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    np.set_printoptions(threshold=np.inf, linewidth=200)
    with open(out_path, 'w', encoding='utf-8') as f:
        f.write(np.array2string(arr2d, separator=' ') + '\n')

def derive_out_dir(src_path: Path) -> Path:
    rel = src_path.relative_to(DATA_ROOT)
    parts = rel.parts  # e.g. ('logic','sim1_case1_1.txt')

    mode = parts[0].lower()
    if mode == 'intermediate' and len(parts) >= 2:
        case_dir = parts[1]  # 'case1' or 'case2'
        plot_root = TRAIN_ROOT / 'intermediate_plot' / case_dir
    elif mode == 'logic':
        plot_root = TRAIN_ROOT / 'logic_plot'
    elif mode == 'random':
        plot_root = TRAIN_ROOT / 'random_plot'
    elif mode == 'rulemap':
        plot_root = TRAIN_ROOT / 'rulemap_plot'
    else:
        plot_root = TRAIN_ROOT / 'misc_plot'

    stem = src_path.stem  # e.g. 'sim1_case1_1'
    step_dir = f"step_plot_{stem}"
    return plot_root / step_dir

# ========= logic 標籤（數字）處理 =========
def get_parity(case_type: int) -> Dict[int, str]:
    if case_type <= 0:
        # 非法值就當作 1
        case_type = 1
    # 先摺回 1..36
    norm = ((case_type - 1) % 36) + 1
    grp = (norm - 1) // 12 + 1  # 1,2,3

    if grp == 1:
        return {0: 'odd',  1: 'odd',  2: 'even'}
    if grp == 2:
        return {0: 'odd',  1: 'even', 2: 'odd' }
    # grp == 3
    return   {0: 'even', 1: 'odd',  2: 'odd' }


def stem_case_type(stem: str) -> Optional[int]:
    # e.g. "sim1_case1_1" -> 1
    m = re.search(r'_case(\d+)_', stem)
    return int(m.group(1)) if m else None

def choose_digits_for_step(stem: str, step_idx: int, parity: Dict[int,str]) -> List[str]:
    """決定性產生每步三個數字：seed = sha1(stem:step_idx)，依 parity 挑 ODD/EVEN。"""
    seed = int(hashlib.sha1(f"{stem}:{step_idx}".encode()).hexdigest()[:8], 16)
    rng = random.Random(seed)
    labels = []
    for i in range(3):
        labels.append(rng.choice(ODD if parity[i]=='odd' else EVEN))
    return labels

def logic_plot_dir_candidates(src_txt: Path) -> List[Path]:
    """先找 training_data，再找 data。"""
    stem = src_txt.stem
    return [
        TRAIN_ROOT / 'logic_plot' / f"step_plot_{stem}",
        DATA_ROOT  / 'logic_plot' / f"step_plot_{stem}",
    ]

def load_logic_labels(src_txt: Path, n_steps: int) -> List[str]:
    stem = src_txt.stem
    last_idx = n_steps - 1

    # 先找現成 labels.json
    for d in logic_plot_dir_candidates(src_txt):
        labels_path = d / 'labels.json'
        if labels_path.exists():
            with open(labels_path, 'r', encoding='utf-8') as f:
                meta = json.load(f)
            step_labels = meta.get('step_labels', {})
            key = str(last_idx)
            if key in step_labels:
                return step_labels[key]
            # 若缺最後一步，仍 fallback 生成

    # 生成 → 寫到 training_data
    case = stem_case_type(stem)
    if case is None:
        return ['1','2','3']  # 非 logic 命名，給預設

    parity = get_parity(case)
    step_labels = {str(s): choose_digits_for_step(stem, s, parity) for s in range(n_steps)}
    out_dir = logic_plot_dir_candidates(src_txt)[0]  # TRAIN_ROOT 優先
    out_dir.mkdir(parents=True, exist_ok=True)
    with open(out_dir / 'labels.json', 'w', encoding='utf-8') as f:
        json.dump({'case_type': case, 'parity': parity, 'step_labels': step_labels},
                  f, ensure_ascii=False, indent=2)
    return step_labels[str(last_idx)]

# ========= 繪圖（支援外部 labels）=========
def draw_triplet_png(triplet_disk: np.ndarray, out_png: Path, labels: Optional[List[str]] = None):
    pos = np.asarray(triplet_disk, dtype=float)

    fig, ax = plt.subplots(figsize=(6,6))
    ax.set_facecolor('white')
    ax.set_xlim(VIEW_MIN-.5, VIEW_MAX+.5)
    ax.set_ylim(VIEW_MAX+.5, VIEW_MIN-.5)
    ax.set_aspect('equal')
    ax.set_xticks(np.arange(VIEW_MIN-.5, VIEW_MAX+1.5, 1))
    ax.set_yticks(np.arange(VIEW_MIN-.5, VIEW_MAX+1.5, 1))
    ax.grid(True, lw=.8, color='lightgrey')
    ax.tick_params(length=0, labelleft=False, labelbottom=False)

    for i, (x, y) in enumerate(pos):
        xi, yi = int(round(x)), int(round(y))
        if VIEW_MIN <= xi <= VIEW_MAX and VIEW_MIN <= yi <= VIEW_MAX:
            ax.add_patch(plt.Rectangle((xi-.5, yi-.5), 1, 1,
                                       facecolor=COLORS[i % len(COLORS)],
                                       edgecolor='black', lw=1.2, alpha=0.9))
            text_str = (labels[i] if labels and i < len(labels) else str(i+1))
            ax.text(xi, yi, text_str, ha='center', va='center',
                    color='white', fontsize=14, weight='bold')
    out_png.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(out_png, dpi=GRID_DPI, bbox_inches='tight')
    plt.close(fig)

def random_step_triplet(prev3_disk: np.ndarray) -> np.ndarray:
    out = np.asarray(prev3_disk, dtype=float).copy()
    for i in range(3):
        out[i] = move_one_step(out[i])
    return out

def triplet_key(arr: np.ndarray) -> tuple:
    a = np.rint(arr).astype(int)
    return tuple(map(int, a.reshape(-1)))

def make_choices(prev3_disk: np.ndarray, gold3_disk: np.ndarray,
                 num_distractors: int = 3, max_retry: int = 100) -> list[np.ndarray]:
    gold = np.asarray(gold3_disk, dtype=float)
    choices = [gold.copy()]
    seen = {triplet_key(gold)}

    tries = 0
    while len(choices) < (1 + num_distractors) and tries < max_retry:
        cand = random_step_triplet(prev3_disk)
        k = triplet_key(cand)
        if k not in seen:
            seen.add(k)
            choices.append(cand)
        tries += 1

    if len(choices) < (1 + num_distractors):
        while len(choices) < (1 + num_distractors):
            tmp = gold.copy()
            i = random.randrange(3)
            tmp[i] = move_one_step(tmp[i])
            k = triplet_key(tmp)
            if k not in seen:
                seen.add(k)
                choices.append(tmp)

    return choices 

def process_one_file(src_txt: Path):
    coords = read_floats_after_maze(src_txt)  # (K,2)
    if coords.shape[0] < 6:
        print(f"SKIP (少於兩步=6點): {src_txt}")
        return

    N = 3
    total_steps = coords.shape[0] // N
    gold3      = coords[-3:].copy()
    prev3      = coords[-6:-3].copy()
    gold3_disk = world_to_disk_if_needed(gold3)
    prev3_disk = world_to_disk_if_needed(prev3)

    trail = coords[:-3].copy()
    out_dir = derive_out_dir(src_txt)
    out_dir.mkdir(parents=True, exist_ok=True)
    write_numpy_block(trail, out_dir / "trail.txt")

    is_logic = src_txt.parent.name.lower() == 'logic'
    last_labels = None
    if is_logic:
        try:
            last_labels = load_logic_labels(src_txt, n_steps=total_steps)
        except Exception as e:
            print(f"WARN: 讀 logic 標籤失敗，改用 1/2/3 ：{src_txt} :: {e}")
            last_labels = ['1','2','3']

    choice_arrays = make_choices(prev3_disk, gold3_disk)  
    idxs = list(range(4))
    random.shuffle(idxs)
    letter_map = ['A','B','C','D']
    shuffled = [(letter_map[i], choice_arrays[j]) for i, j in enumerate(idxs)]
    correct_idx = idxs.index(0)  # 原始 index 0 是 gold
    correct_letter = letter_map[correct_idx]

    choice_txt = out_dir / f"{src_txt.stem}_choice.txt"
    with open(choice_txt, 'w', encoding='utf-8') as f:
        for lbl, arr in shuffled:
            f.write(f"Choice {lbl}:\n")
            f.write(np.array2string(arr.astype(float), separator=' ') + "\n\n")

    meta = {
        "source_txt": str(src_txt),
        "answer_idx": int(correct_idx),
        "answer_letter": correct_letter,
        "gold_triplet_disk": gold3_disk.tolist(),
        "prev_triplet_disk": prev3_disk.tolist(),
        "is_logic": bool(is_logic),
    }
    if is_logic:
        meta["digits_last_step"] = last_labels  # e.g. ["3","6","1"]
    with open(out_dir / "label.json", 'w', encoding='utf-8') as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    if MAKE_PNG:
        for lbl, arr in shuffled:
            png_path = out_dir / f"{lbl}.png"
            draw_triplet_png(arr, png_path,
                             labels=(last_labels if is_logic else ['1','2','3']))

    print(f"OK: {src_txt.name:>24} → {out_dir} | ans={correct_letter} ({correct_idx})")
    print("prev3 =", np.array2string(prev3_disk, separator=' '),
          "gold3 =", np.array2string(gold3_disk, separator=' '))


def main():
    if RANDOM_SEED is not None:
        random.seed(RANDOM_SEED)
        np.random.seed(RANDOM_SEED)
    else:
        random.seed()

    files = []
    for folder in SRC_FOLDERS:
        files += [
            Path(p) for p in glob.glob(str(folder/"**/*.txt"), recursive=True)
            if "FINAL" not in os.path.basename(p)
        ]
    if not files:
        print("找不到任何 .txt 檔")
        return

    for p in sorted(files):
        try:
            process_one_file(p)
        except Exception as e:
            print(f"ERR: {p} :: {e}")

if __name__ == "__main__":
    main()


OK:               sim1_1.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_1 | ans=A (0)
prev3 = [[11. 14.]
 [ 2.  3.]
 [ 1. 24.]] gold3 = [[12. 13.]
 [ 2.  3.]
 [ 1. 23.]]
OK:              sim1_10.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_10 | ans=C (2)
prev3 = [[11. 17.]
 [ 4.  7.]
 [ 3. 14.]] gold3 = [[10. 18.]
 [ 4.  7.]
 [ 4. 14.]]
OK:             sim1_100.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_100 | ans=B (1)
prev3 = [[15.  8.]
 [10. 20.]
 [20.  1.]] gold3 = [[15.  8.]
 [11. 20.]
 [19.  1.]]
OK:            sim1_1000.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_1000 | ans=D (3)
prev3 = [[ 3.  5.]
 [ 1.  8.]
 [ 5. 12.]] gold3 = [[ 2.  4.]
 [ 1.  7.]
 [ 5. 12.]]
OK:            sim1_1001.txt → /Users/Jer_ry/Desktop/scripts/training_data/sim1/intermediate_plot/case1/step_plot_sim1_1001 | 

In [100]:
import json, torch, pandas as pd
PAD_ID = 576

def pad_batch_cells(cell_json_series):
    seqs = [torch.tensor(json.loads(s), dtype=torch.long) for s in cell_json_series]
    L = max(map(len, seqs))
    X = torch.full((len(seqs), L), PAD_ID, dtype=torch.long)
    M = torch.zeros((len(seqs), L), dtype=torch.bool)
    for i, s in enumerate(seqs):
        X[i, :len(s)] = s
        M[i, :len(s)] = True
    return X, M  # input_ids, attention_mask

# 例：讀 training/logic
root = "/Users/Jer_ry/Desktop/scripts/training_data/sim1/csv"
train_df = pd.read_csv(f"{root}/logic_train.csv")
answers  = pd.read_csv(f"{root}/logic_answers.csv")

X_ids, X_mask = pad_batch_cells(train_df["cell_json"])
# 把四選一做成 label（A/B/C/D -> 0/1/2/3）
letter2id = {"A":0,"B":1,"C":2,"D":3}
y = torch.tensor([letter2id[c] for c in answers["correct"]], dtype=torch.long)

print(X_ids.shape, X_mask.shape, y.shape)
# 你也可以把 choice?_cell 另外堆疊成 (batch, 4, 3) 的候選 tensor 給 scoring head 用


torch.Size([400, 45]) torch.Size([400, 45]) torch.Size([400])


In [109]:
import pandas as pd, json

def load_split(dir_path, mode):
    t = pd.read_csv(f"{dir_path}/{mode}_train.csv")
    a = pd.read_csv(f"{dir_path}/{mode}_answers.csv")
    return t, a

# 1) 對齊與基本格式
def basic_checks(t, a):
    assert set(t['trial_id']) == set(a['trial_id'])
    # cell_json 長度與 seq_len_T
    def ok_len(s, T): 
        cells = json.loads(s)
        return len(cells) == 3*int(T) and all(0 <= c <= 575 for c in cells)
    assert (t.apply(lambda r: ok_len(r['cell_json'], r['seq_len_T']), axis=1)).all()
    # choices 長度=3 & 範圍
    for ch in ['A','B','C','D']:
        col = f'choice{ch}_cell'
        assert (a[col].apply(lambda s: len(json.loads(s))==3 and all(0<=c<=575 for c in json.loads(s)))).all()
    # correct 合法
    assert set(a['correct']).issubset(set(list("ABCD")))

# 2) split 洩漏
def no_leakage(train_ids, val_ids, test_ids):
    assert train_ids.isdisjoint(val_ids) and train_ids.isdisjoint(test_ids) and val_ids.isdisjoint(test_ids)

# 3) 標籤分布
def label_balance(a):
    print(a['correct'].value_counts(normalize=True))
