In [None]:
# 以讀檔獲得 Run, Point, Obj
import csv
import os
import math
from collections import defaultdict
import matplotlib.pyplot as plt

def print_dict(dict):
    print('{')
    for k, v in dict.items():
        print(f'\t{k}: {v}')
    print('}')


def printModel(model, map_start, size, T, R, A, printstep=False):
    print("running time = ", model.Runtime)
    print("optimal value = ", model.objVal)

    # 以下程式是用來把單一測資的移動過程輸出到output區，方便閱覽
    if printstep:
        ans_map = []
        for row in map_start:
            ans_map.append(row.copy())
        for t in [0]+T:
            print('-' * 40)
            print(f"time period {t}")
            for row in ans_map:
                print('|', end = ' ')
                for col in row:
                    if col == -1: col = ' X'
                    print(f'{col:2}', end = ' ')
                print('|')
    
            ans_map = []
            for i in range(size[0]):
                row = [' X'] * size[1]
                ans_map.append(row)
            print()
            for r in R:
                for (i,j) in A:
                    var_name = f"f({t},{r},{i},{j})"
                    var = model.getVarByName(var_name)
                    EPS = 1.e-6
                    if var.X > EPS:
                        if i != j: 
                            print(var_name, var.X)
                        if r == 0:
                            ans_map[(j-1)//size[1]][(j-1)%size[1]] = 0
                        else:
                            ans_map[(j-1)//size[1]][(j-1)%size[1]] = r

    print('-' * 40)
    

def writeModel(model, map_start, file_path, size, T, R, A):
    """
    Write the model output to a file instead of printing it.
    """
    with open(file_path, 'w') as file:
        file.write(f"running time = {model.Runtime}\n")
        file.write(f"optimal value = {model.objVal}\n")
        if model.objVal < float('inf'):
            ans_map = []
            for row in map_start:
                ans_map.append(row.copy())
    
            for t in [0] + T:
                file.write('-' * 40 + '\n')
                file.write(f"time period {t}\n")
                for row in ans_map:
                    file.write('| ')
                    for col in row:
                        if col == -1: col = ' X'
                        file.write(f'{col:2} ')
                    file.write('|\n')
    
                ans_map = []
                for i in range(size[0]):
                    row = [' X'] * size[1]
                    ans_map.append(row)
    
                file.write('\n')
                for r in R:
                    for (i, j) in A:
                        var_name = f"f({t},{r},{i},{j})"
                        var = model.getVarByName(var_name)
                        EPS = 1.e-6
                        if var.X > EPS:
                            if i != j:
                                file.write(f"{var_name} {var.X}\n")
                            if r == 0:
                                ans_map[(j-1)//size[1]][(j-1)%size[1]] = 0
                            else:
                                ans_map[(j-1)//size[1]][(j-1)%size[1]] = r
        else:
            file.write('-' * 40 + '\n')
            file.write('Cannot get a feasible solution in time.')



def checkModel(model, T_ub, map_end, size, n_tar, n_sp, id):
    result = []
    with open(f"result_graphform\\model_spa\\{size[0]}x{size[1]}\\tar{n_tar}\\graph_result_{size[0]}x{size[1]}_tar{n_tar}_sp{n_sp}_{id}.txt", newline='', encoding='utf-8') as file_result:
        line = file_result.readline()
        runtime = float(line.split('=')[1].strip())
        if runtime >=  model.Params.timeLimit:
            return True
        found = False
        while True:
            line = file_result.readline()
            if not line:
                break
            if f"time period {T_ub}" in line:
                found = True
                continue
            if found:
                line = line.strip().strip('|').split()
                result.append(list(line))
    # print(result)
    check = True
    for i in range(size[0]):
        for j in range(size[1]):
            if map_end[i][j] > 0:
                if str(map_end[i][j]) != result[i][j]:
                    check = False
    return check   



def pairDataDict(size):
    s = size
    datadict = {}
    # [1,2,3,4,5] + [Q1, Q2, Q3] + [n*n-1, n*n//2]
    tar_list = [i+1 for i in range(5) if i+1<= s[0]*s[1]-1]\
                 + [(i+1)*(s[0]*s[1]//4) for i in range(3)]\
                 + [s[0]*s[1]-1, math.floor((s[0]*s[1])/2)]
    tar_list = sorted(list(set(tar_list)-{s[0]*s[1]}))
    for n_tar in tar_list:
        # 空格不同數量：1,2,3,4,5,Quatile, size-target
        spacelist = [i+1 for i in range(5) if i+1 <= s[0]*s[1]-n_tar] \
                  + [(i+1)*((s[0]*s[1])//4) for i in range(3) if (i+1)*((s[0]*s[1])//4) <= s[0]*s[1]-n_tar] \
                  + [s[0]*s[1]-n_tar]
        # spacelist = sorted(list(set(spacelist)-{0,s[0]*s[1]}))
        spacelist = sorted(list(set(spacelist)-{0, 1, s[0]*s[1]}))  # 先把一個空格的也去掉，避免無解
    
        if len(spacelist) > 0: datadict[n_tar] = spacelist

    return datadict, tar_list


def drawingData(size, datadict, tar_list):
    Run = {i:{j:[] for j in datadict[i] if len(datadict[i]) > 0} for i in tar_list[:-1]}
    # Point = {i:[] for i in tar_list[1:]}   # 每組 dataRun 的平均
    Obj = {i:{j:[] for j in datadict[i] if len(datadict[i]) > 0} for i in tar_list[:-1]}

    runtime_dict = defaultdict(lambda: {'total_runtime': 0, 'count': 0})
    
    with open(f'result/model_spa/spa_{size[0]}x{size[1]}.csv', 'r') as f:
        r = csv.DictReader(f)
        for row in r:
            n_tar = int(row['n_tar'])
            n_sp = int(row['n_sp'])
            id = int(row['id'])
            runT = float(row['Runtime'])
            runObj = float(row['Obj'])
            Run[n_tar][n_sp].append(runT)
            Obj[n_tar][n_sp].append(runObj)
            key = (n_tar, n_sp)
            runtime_dict[key]['total_runtime'] += runT
            runtime_dict[key]['count'] += 1

    new_dict = {i:{j:0 for j in datadict[i] if len(datadict[i]) > 0} for i in tar_list[:-1]}
    for n_tar in new_dict.keys():
        for n_sp in new_dict[n_tar].keys():
            value = runtime_dict[(n_tar, n_sp)]
            new_dict[n_tar][n_sp] = value['total_runtime'] / value['count']
        
    print('-----new_dict----')
    print_dict(new_dict)

    return new_dict, Run, Obj 


def resultGraph(size, datadict, tar_list):
    new_dict, datarun, dataobj = drawingData(size, datadict, tar_list)
    # print(dataobj)
    # print(datarun)
    nums = len(datadict)
    nums = (nums-1)//3+1
    fig, axes = plt.subplots(nums, 3, figsize=(15, nums*5))
    
    x_values = []
    x_tick = []
    y_values = []
    subplot_count = 1
    for n_tar in new_dict.keys():
        row = (subplot_count-1) // 3  # 獲取當前的行索引
        col = (subplot_count-1) % 3   # 獲取當前的列索引
    
        x_values = range(len(new_dict[n_tar].values()))
        x_tick = [(n_tar, n_sp) for n_sp in new_dict[n_tar].keys()]
        y_values = new_dict[n_tar].values()
    
        axes[row,col].plot(x_values, y_values, marker='*', label=f'targets: {n_tar}')
        # 為每個子圖設置自己的 xticks 和標籤
        axes[row,col].set_xticks(range(len(x_tick)))  # 設置 x 軸的位置
        axes[row,col].set_xticklabels(x_tick)         # 設置 x 軸的標籤
        axes[row,col].set_ylim(-10,510)
        axes[row,col].set_title(f'Subplot for {n_tar} targets')
        axes[row,col].set_xlabel('Number of (target, space)')
        axes[row,col].set_ylabel('Running time')

        subplot_count += 1

    os.makedirs('./result/model_spa_chart_summary/', exist_ok=True)
    plt.savefig(f'./result/model_spa_chart_summary/spa_{size[0]}x{size[1]}_turn.png')


# size = (4,4)
# datadict, tar_list = pairDataDict(size)
# resultGraph(size, datadict, tar_list)