In [6]:
import os
import json
from prettytable import PrettyTable


evaluateResultFolder = os.path.join('../result', 'RIST_240Hz')

MODEL_LIST = (
    'ESTMD', 'DSTMD', # backbone model
    'STMDPlus',  # with contrast patheway
    'FeedbackSTMD', # with Feedback pathway 1
    'ApgSTMD',  # with attention and prediction mechanism
    'FSTMD',  # with Feedback pathway 2
    'FracSTMD',   # backbone model
    'STMDNet',
    'STMDNetF',  # proposal + feedback in lobula
    'FSTMDv2', # proposal + feedback from lobula to lamina
    'RPFC', 'ZBS',
    )

datasetInfo = {
    'GX010071-1': list(range(1300)),
    'GX010220-1': list(range(1300)),
    'GX010228-1': list(range(1300)),
    'GX010230-1': list(range(2400)),
    'GX010231-1': list(range(2400)),
    'GX010241-1': list(range(3600)),
    'GX010250-1': list(range(2000)),
    'GX010266-1': list(range(2400)),
    'GX010290-1': list(range(1300)),
    'GX010291-1': list(range(1300)),
    'GX010303-1': list(range(2400)),
    'GX010307-1': list(range(1000)),
    'GX010315-1': list(range(1000)),
    'GX010321-1': list(range(1000)),
    'GX010322-1': list(range(1300)),
    'GX010327-1': list(range(900)),
    'GX010335-1': list(range(1300)),
    'GX010336-1': list(range(1000)),
    'GX010337-1': list(range(700)),
}


def get_result_table():

    aucDict = {}
    arDict = {}
    f1Dict = {}
    apDict = {}
    timeDict = {}

    for datasetName in datasetInfo.keys():
        aucDict[datasetName] = {}
        arDict[datasetName] = {}
        f1Dict[datasetName] = {}
        apDict[datasetName] = {}
        timeDict[datasetName] = {}

        for modelName in MODEL_LIST:
            try:
                with open(os.path.join(evaluateResultFolder, datasetName, modelName+'evaluate.json'), 'r') as f:
                    _data = json.load(f)
                    aucDict[datasetName][modelName] = _data['AUC']
                    arDict[datasetName][modelName] = _data['AR']
                    f1Dict[datasetName][modelName] = _data['F1']
                    apDict[datasetName][modelName] = _data['AP']
                    timeDict[datasetName][modelName] = _data['timePerImage']
            except (FileNotFoundError, json.JSONDecodeError) as e:
                aucDict[datasetName][modelName] = '-'
                arDict[datasetName][modelName] = '-'
                f1Dict[datasetName][modelName] = '-'
                apDict[datasetName][modelName] = '-'
                timeDict[datasetName][modelName] = '-'

    
    with open('result_240Hz.json', 'w') as json_file: 
        json.dump({'AUC': aucDict, 'AR': arDict, 'AP': apDict, 'F1':f1Dict, 'time': timeDict}, json_file, indent=4)


def show_AUC_table():
    # table IV
    _MODEL_LIST = (
    'STMDPlus',  # with contrast patheway
    'FeedbackSTMD', # with Feedback pathway 1
    'ApgSTMD',  # with attention and prediction mechanism
    'FSTMD',  # with Feedback pathway 2
    'FracSTMD',   # backbone model
    'STMDNet',
    'STMDNetF',  # proposal + feedback in lobula
    )

    with open('result_240Hz.json', 'r') as f:
        _data = json.load(f)
        aucDict = _data['AUC']

    # Prepare the tables for AUC, mR, and time per image
    auc_table = PrettyTable()
    # Initialize the tables with the model names as rows
    auc_table.field_names = ["Model"] + list(_MODEL_LIST)[:-2] + ['STMDNet', 'STMDNet-F']

    meanAUC = {}

    for modelName in _MODEL_LIST:
        meanAUC[modelName] = 0

    for datasetName in datasetInfo.keys():
        auc_row = [datasetName]
        for modelName in _MODEL_LIST:
            # Add data for the current dataset to the row
            auc_row.append(f"{aucDict[datasetName][modelName]*100:.2f}" if isinstance(aucDict[datasetName][modelName], float) else '-')
            meanAUC[modelName] += aucDict[datasetName][modelName] if isinstance(aucDict[datasetName][modelName], float) else 0

        # Add the row to the corresponding table
        auc_table.add_row(auc_row)


    meanAucRow = ['mean',]
    for modelName in _MODEL_LIST:
        meanAucRow.append(f"{meanAUC[modelName]/len(datasetInfo)*100:.2f}" if isinstance(meanAUC[modelName], float) else '-')
        

    # Add the row to the corresponding table
    auc_table.add_row(meanAucRow)

    # Print the tables
    print("Table IV -- AUC (%):")
    print(auc_table)


def show_AP_and_F1_table():
    # table V
    _MODEL_LIST = (
    'RPFC', 'ZBS',
    'STMDNet',
    'STMDNetF',  # proposal + feedback in lobula
    )

    with open('result_240Hz.json', 'r') as f:
        _data = json.load(f)
        apDict = _data['AP']
        f1Dict = _data['F1']

    # Prepare the tables for AR, AP 
    _table = PrettyTable()
    # Initialize the tables with the model names as rows
    _table.field_names = ["Model",
                          'RPFC_mAP', 'ZBS_mAP', 'STMDNet_mAP', 'STMDNet-F_mAP',
                          'RPFC_mF1', 'ZBS_mF1', 'STMDNet_mF1',  'STMDNet-F_mF1', 
                          ]

    mAP = {}
    mF1 = {}
    for modelName in _MODEL_LIST:
        mAP[modelName] = 0
        mF1[modelName] = 0

    for datasetName in datasetInfo.keys():
        _row = [datasetName, ]
        _rowF1 = []
        for modelName in _MODEL_LIST:
            # Add data for the current dataset to the row
            _F1 = f1Dict[datasetName][modelName]
            _AP = apDict[datasetName][modelName]
            _row.append(f"{_AP*100:.2f}" if isinstance(_AP, float) else '-')
            _rowF1.append(f"{_F1*100:.2f}" if isinstance(_F1, float) else '-') 
            mF1[modelName] += _F1 if isinstance(_F1, float) else 0
            mAP[modelName] += _AP if isinstance(_AP, float) else 0
        _row.extend(_rowF1)    

        # Add the row to the corresponding table
        _table.add_row(_row)


    meanRow = ['mean',]
    for modelName in _MODEL_LIST:
        meanRow.append(f"{mAP[modelName]/len(datasetInfo)*100:.2f}" if isinstance(mAP[modelName], float) else '-')
    for modelName in _MODEL_LIST:    
        meanRow.append(f"{mF1[modelName]/len(datasetInfo)*100:.2f}" if isinstance(mF1[modelName], float) else '-')
            
    

    # Add the row to the corresponding table
    _table.add_row(meanRow)


    # Print the tables
    print("Table V: -- AP (%) and F1(%):")
    print(_table)

    
if __name__ == '__main__':

    get_result_table()

    # table IV
    show_AUC_table()

    # table V
    show_AP_and_F1_table()
    
    


Table IV -- AUC (%):
+------------+----------+--------------+---------+-------+----------+---------+-----------+
|   Model    | STMDPlus | FeedbackSTMD | ApgSTMD | FSTMD | FracSTMD | STMDNet | STMDNet-F |
+------------+----------+--------------+---------+-------+----------+---------+-----------+
| GX010071-1 |  54.47   |    68.16     |  46.08  | 62.29 |  42.16   |  33.32  |   51.66   |
| GX010220-1 |  25.50   |    27.92     |  17.33  | 24.62 |  31.30   |  49.58  |   64.82   |
| GX010228-1 |  19.02   |    33.07     |  11.31  | 23.89 |  33.59   |  38.31  |   54.06   |
| GX010230-1 |  27.57   |    29.03     |  23.68  | 30.74 |  32.96   |  58.27  |   72.27   |
| GX010231-1 |  34.13   |    34.40     |  23.74  | 37.61 |  37.55   |  44.81  |   55.82   |
| GX010241-1 |  29.62   |    34.60     |  21.32  | 31.62 |  37.65   |  54.10  |   72.51   |
| GX010250-1 |  32.11   |    37.95     |  21.64  | 35.15 |  35.56   |  49.17  |   73.73   |
| GX010266-1 |  21.20   |    28.55     |  17.39  | 23.41 | 