In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# 定義模型名稱列表
models = ['svm', 'bert-chinese-base', 'chinese-roberta-wwm-ext', 'chinese-roberta-wwm-ext(with-adversarial-training)', 'chinese-roberta-wwm-ext-large(with-adversarial-training)', 'robert-distilled-model']

# 載入所有 JSON 數據
data = {}
for model in models:
    filename = f"./{model}/evaluation_results.json"
    with open(filename, 'r') as f:
        data[model] = json.load(f)

# 提取情緒分類指標
sentiment_df = pd.DataFrame({
    'Model': models,
    'Accuracy': [data[model]['sentiment']['accuracy'] for model in models],
    'Macro F1': [data[model]['sentiment']['macro_f1'] for model in models]
})

# 繪製情緒分類準確率柱狀圖
plt.figure(figsize=(10, 6))
ax = sns.barplot(x='Model', y='Accuracy', data=sentiment_df)
plt.title('情緒分類準確率比較')
plt.xticks(rotation=45)
for p in ax.patches:
    ax.annotate(f'{p.get_height():.3f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='center', xytext=(0, 5), textcoords='offset points')
plt.tight_layout()
plt.show()

# 繪製情緒分類 Macro F1 柱狀圖
plt.figure(figsize=(10, 6))
ax = sns.barplot(x='Model', y='Macro F1', data=sentiment_df)
plt.title('情緒分類 Macro F1 分數比較')
plt.xticks(rotation=45)
for p in ax.patches:
    ax.annotate(f'{p.get_height():.3f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='center', xytext=(0, 5), textcoords='offset points')
plt.tight_layout()
plt.show()

# 定義回歸任務目標
regression_targets = ['rating', 'delight', 'anger', 'sorrow', 'happiness']

# 為每個回歸目標繪製 R2 分數柱狀圖
for target in regression_targets:
    df = pd.DataFrame({
        'Model': models,
        'R2': [data[model]['regression'][target]['r2'] for model in models]
    })
    plt.figure(figsize=(10, 6))
    ax = sns.barplot(x='Model', y='R2', data=df)
    plt.title(f'{target.capitalize()} 回歸任務 R2 分數比較')
    plt.xticks(rotation=45)
    for p in ax.patches:
        ax.annotate(f'{p.get_height():.3f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='center', xytext=(0, 5), textcoords='offset points')
    plt.tight_layout()
    plt.show()