In [15]:
# 📘 analyze_merged_data.ipynb

# ✅ 1. 导入依赖
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ✅ 2. 读取数据
data_path = '/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/data/1_CityFace/filtered_data.csv'
df = pd.read_csv(data_path)
df.head()


Unnamed: 0,PostID,ImageName,Gender,Age
0,1,001A.jpg,Male,16
1,3,003A.jpg,Male,48
2,5,005A.jpg,Female,29
3,6,006A.jpg,Female,25
4,7,007A.jpg,Male,28


In [16]:
# ✅ 3. 检查是否有重复的 ImageId
duplicate_ids = df[df.duplicated('PostID', keep=False)]
print(f"🔍 重复的 ImageId 总数: {len(duplicate_ids)}")
duplicate_ids.head()


🔍 重复的 ImageId 总数: 0


Unnamed: 0,PostID,ImageName,Gender,Age


In [17]:
gender_counts = df['Gender'].value_counts()
print("📊 男女数量统计：")
print(gender_counts)

male_ratio = gender_counts.get('Male', 0) / len(df) * 100
female_ratio = gender_counts.get('Female', 0) / len(df) * 100
print(f"\n男性比例: {male_ratio:.2f}%")
print(f"女性比例: {female_ratio:.2f}%")


📊 男女数量统计：
Male      7754
Female    1279
Name: Gender, dtype: int64

男性比例: 85.84%
女性比例: 14.16%


In [18]:
def get_age_group_fixed5_chinese(age):
    """将年龄映射为5个固定年龄段（中文标签）"""
    if age < 19:
        return '儿童组 (0-18)'
    elif age < 30:
        return '青年组 (19-29)'
    elif age < 40:
        return '青壮组 (30-39)'
    elif age < 60:
        return '中年组 (40-59)'
    else:
        return '老年组 (60+)'


def fixed_age_group5_analysis(df, return_df=False):
    """对数据集按固定年龄段分组并输出统计结果"""
    # 清洗年龄列
    df['Age'] = pd.to_numeric(df['Age'], errors='coerce')
    df = df.dropna(subset=['Age'])
    df['Age'] = df['Age'].astype(int)

    # 添加年龄组列（中文标签）
    df['Age Group Fixed5'] = df['Age'].apply(get_age_group_fixed5_chinese)

    # 确保输出顺序固定
    age_group_order = [
        '儿童组 (0-18)',
        '青年组 (19-29)',
        '青壮组 (30-39)',
        '中年组 (40-59)',
        '老年组 (60+)',
    ]
    group_counts = df['Age Group Fixed5'].value_counts().reindex(age_group_order).fillna(0).astype(int)

    # 输出统计
    print("📊 不同年龄段（固定5组）数量：")
    print(group_counts)

    print("\n📈 各年龄段所占比例：")
    print((group_counts / len(df)).apply(lambda x: f"{x:.2%}"))

    if return_df:
        return df

# 仅输出统计
fixed_age_group5_analysis(df)



📊 不同年龄段（固定5组）数量：
儿童组 (0-18)     1092
青年组 (19-29)    4929
青壮组 (30-39)    1887
中年组 (40-59)    1055
老年组 (60+)        70
Name: Age Group Fixed5, dtype: int64

📈 各年龄段所占比例：
儿童组 (0-18)     12.09%
青年组 (19-29)    54.57%
青壮组 (30-39)    20.89%
中年组 (40-59)    11.68%
老年组 (60+)       0.77%
Name: Age Group Fixed5, dtype: object


In [20]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error
from collections import defaultdict
import numpy as np


In [21]:

# 准备所有模型的文件路径和字段信息
model_files = {
    "Model 1": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/test_predictions.csv", "GT_Age", "Pred_Age"),
    "Model 2": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_nogender/test_predictions.csv", "GroundTruth", "Prediction"),
    "Model 3": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_gender/test_predictions.csv", "GT_Age", "Pred_Age"),
    "Model 4": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_nogender/test_predictions.csv", "GroundTruthAge", "PredictedAge"),
    "Model 5": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_cascade_gender_age/test_predictions.csv", "GT_Age", "Pred_Age"),
}


results = []

for model_name, (file_path, gt_col, pred_col) in model_files.items():
    df = pd.read_csv(file_path)
    mae = mean_absolute_error(df[gt_col], df[pred_col])
    results.append({"Model": model_name, "MAE": mae})

# 打印结果
mae_df = pd.DataFrame(results)
print(mae_df)

     Model       MAE
0  Model 1  2.852017
1  Model 2  2.684717
2  Model 3  4.675885
3  Model 4  7.256637
4  Model 5  2.935456


In [24]:

# 准备所有模型的文件路径和字段信息
model_files = {
    "Model 1": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_gender/test_predictions.csv", "GT_Age", "Pred_Age"),
    "Model 2": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_nogender/test_predictions.csv", "GroundTruth", "Prediction"),
    "Model 3": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_gender/test_predictions.csv", "GT_Age", "Pred_Age"),
    "Model 4": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_nogender/test_predictions.csv", "GroundTruthAge", "PredictedAge"),
    "Model 5": ("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_cascade_gender_age/test_predictions.csv", "GT_Age", "Pred_Age"),
}


results = []

for model_name, (file_path, gt_col, pred_col) in model_files.items():
    df = pd.read_csv(file_path)
    mae = mean_absolute_error(df[gt_col], df[pred_col])
    results.append({"Model": model_name, "MAE": mae})

# 打印结果
mae_df = pd.DataFrame(results)
print(mae_df)

     Model       MAE
0  Model 1  3.610982
1  Model 2  3.313554
2  Model 3  5.555039
3  Model 4  9.607752
4  Model 5  3.731688


In [22]:
# 重新加载 balance前四个模型 + 模型5 的预测文件
m1_pre = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/test_predictions.csv")
m2_pre = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_nogender/test_predictions.csv")
m3_pre = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_gender/test_predictions.csv")
m4_pre = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_nogender/test_predictions.csv")
m5 = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_cascade_gender_age/test_predictions.csv")



# 定义分组函数
def get_group(age):
    if age < 19:
        return "Child"
    elif age < 30:
        return "Young"
    elif age < 40:
        return "Young Adult"
    elif age < 60:
        return "Middle Age"
    else:
        return "Elderly"

# 计算 group-wise MAE
def groupwise_mae(df, gt_col, pred_col):
    groups = defaultdict(list)
    for gt, pred in zip(df[gt_col], df[pred_col]):
        group = get_group(gt)
        groups[group].append(abs(gt - pred))
    return {k: np.mean(v) for k, v in groups.items()}


# 计算每个模型的 group-wise MAE
m1_mae = groupwise_mae(m1_pre, "GT_Age", "Pred_Age")
m2_mae = groupwise_mae(m2_pre, "GroundTruth", "Prediction")
m3_mae = groupwise_mae(m3_pre, "GT_Age", "Pred_Age")
m4_mae = groupwise_mae(m4_pre, "GroundTruthAge", "PredictedAge")
m5_mae = groupwise_mae(m5, "GT_Age", "Pred_Age")

# 组织为 DataFrame
group_order = ["Child", "Young", "Young Adult", "Middle Age", "Elderly"]
mae_df = pd.DataFrame({
    "Group": group_order,
    "Model 1": [m1_mae[g] for g in group_order],
    "Model 2": [m2_mae[g] for g in group_order],
    "Model 3": [m3_mae[g] for g in group_order],
    "Model 4": [m4_mae[g] for g in group_order],
    "Model 5": [m5_mae[g] for g in group_order],
})

# 添加平均 MAE 行
mae_df.loc[len(mae_df.index)] = {
    "Group": "Average",
    "Model 1": mae_df["Model 1"].mean(),
    "Model 2": mae_df["Model 2"].mean(),
    "Model 3": mae_df["Model 3"].mean(),
    "Model 4": mae_df["Model 4"].mean(),
    "Model 5": mae_df["Model 5"].mean()
}

# 打印输出结果
print(mae_df.round(3))


         Group  Model 1  Model 2  Model 3  Model 4  Model 5
0        Child    2.238    2.043    7.138    7.101    3.262
1        Young    1.983    2.049    3.366    2.715    2.233
2  Young Adult    3.544    3.368    4.762   10.026    3.374
3   Middle Age    5.934    4.834    7.648   21.686    4.802
4      Elderly    8.839    6.858   11.857   39.000    7.620
5      Average    4.508    3.830    6.954   16.106    4.258


In [23]:
# 重新加载 balance前四个模型 + 模型5 的预测文件
m1_balance = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_gender/test_predictions.csv")
m2_balance = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_nogender/test_predictions.csv")
m3_balance = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_gender/test_predictions.csv")
m4_balance = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_nogender/test_predictions.csv")
m5_balance = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_cascade_gender_age/test_predictions.csv")

import pandas as pd
import numpy as np
from collections import defaultdict
from sklearn.metrics import mean_absolute_error

# 定义分组函数
def get_group(age):
    if age < 19:
        return "Child"
    elif age < 30:
        return "Young"
    elif age < 40:
        return "Young Adult"
    elif age < 60:
        return "Middle Age"
    else:
        return "Elderly"

# 计算 group-wise MAE
def groupwise_mae(df, gt_col, pred_col):
    groups = defaultdict(list)
    for gt, pred in zip(df[gt_col], df[pred_col]):
        group = get_group(gt)
        groups[group].append(abs(gt - pred))
    return {k: np.mean(v) for k, v in groups.items()}


# 计算每个模型的 group-wise MAE
m1_mae = groupwise_mae(m1_balance, "GT_Age", "Pred_Age")
m2_mae = groupwise_mae(m2_balance, "GroundTruth", "Prediction")
m3_mae = groupwise_mae(m3_balance, "GT_Age", "Pred_Age")
m4_mae = groupwise_mae(m4_balance, "GroundTruthAge", "PredictedAge")
m5_mae = groupwise_mae(m5_balance, "GT_Age", "Pred_Age")

# 组织为 DataFrame
group_order = ["Child", "Young", "Young Adult", "Middle Age", "Elderly"]
mae_df = pd.DataFrame({
    "Group": group_order,
    "Model 1": [m1_mae[g] for g in group_order],
    "Model 2": [m2_mae[g] for g in group_order],
    "Model 3": [m3_mae[g] for g in group_order],
    "Model 4": [m4_mae[g] for g in group_order],
    "Model 5": [m5_mae[g] for g in group_order],
})

# 添加平均 MAE 行
mae_df.loc[len(mae_df.index)] = {
    "Group": "Average",
    "Model 1": mae_df["Model 1"].mean(),
    "Model 2": mae_df["Model 2"].mean(),
    "Model 3": mae_df["Model 3"].mean(),
    "Model 4": mae_df["Model 4"].mean(),
    "Model 5": mae_df["Model 5"].mean()
}

# 打印输出结果
print(mae_df.round(3))


         Group  Model 1  Model 2  Model 3  Model 4  Model 5
0        Child    3.050    1.806    7.577    7.306    2.179
1        Young    3.209    2.863    4.341    2.692    2.831
2  Young Adult    3.369    3.559    4.867    9.946    4.057
3   Middle Age    5.038    4.913    6.708   22.345    5.974
4      Elderly    8.614    7.872   11.429   39.143    9.867
5      Average    4.656    4.202    6.984   16.286    4.982


In [27]:
# Updated final MAE data based on user-provided values (unbalanced and balanced)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Age groups and models
groups = ["Child", "Young", "Young Adult", "Middle Age", "Elderly"]
models = ["M1", "M2", "M3", "M4", "M5"]

# Unbalanced results
unbalanced_df = pd.DataFrame({
    "Group": groups,
    "M1": [2.238, 1.983, 3.544, 5.934, 8.839],
    "M2": [2.043, 2.049, 3.368, 4.834, 6.858],
    "M3": [7.138, 3.366, 4.762, 7.648, 11.857],
    "M4": [7.101, 2.715, 10.026, 21.686, 39.000],
    "M5": [3.262, 2.233, 3.374, 4.802, 7.620]
}).set_index("Group")

# Balanced results
balanced_df = pd.DataFrame({
    "Group": groups,
    "M1": [3.050, 3.209, 3.369, 5.038, 8.614],
    "M2": [1.806, 2.863, 3.559, 4.913, 7.872],
    "M3": [7.577, 4.341, 4.867, 6.708, 11.429],
    "M4": [7.306, 2.692, 9.946, 22.345, 39.143],
    "M5": [2.179, 2.831, 4.057, 5.974, 9.867]
}).set_index("Group")

# Calculate delta
delta_mae = balanced_df - unbalanced_df
delta_mae.index.name = "Age Group"
delta_mae.columns.name = "Model"

# Plot heatmap with color style matching user's example
plt.figure(figsize=(6, 4.5))
sns.heatmap(
    delta_mae.T,
    annot=True,
    fmt=".2f",
    cmap=sns.diverging_palette(240, 10, n=200, as_cmap=True),  # blue-white-red
    center=0,
    linewidths=0.5,
    cbar_kws={"label": "Δ MAE"}
)
plt.title("Δ MAE (Balanced - Unbalanced) per Model and Age Group")
plt.tight_layout()
final_heatmap_path = "delta_mae_heatmap.png"
plt.savefig(final_heatmap_path, dpi=300)
plt.close()

final_heatmap_path


'delta_mae_heatmap.png'

In [51]:
# Updated final MAE data based on user-provided values (unbalanced and balanced)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Age groups and models
groups = ["Child", "Young", "Young Adult", "Middle Age", "Elderly"]
models = ["M1", "M2", "M3", "M4", "M5"]

# Unbalanced results
unbalanced_df = pd.DataFrame({
    "Group": groups,
    "M1": [2.238, 1.983, 3.544, 5.934, 8.839],
    "M2": [2.043, 2.049, 3.368, 4.834, 6.858],
    "M3": [7.138, 3.366, 4.762, 7.648, 11.857],
    "M4": [7.101, 2.715, 10.026, 21.686, 39.000],
    "M5": [3.262, 2.233, 3.374, 4.802, 7.620]
}).set_index("Group")

# Balanced results
balanced_df = pd.DataFrame({
    "Group": groups,
    "M1": [3.050, 3.209, 3.369, 5.038, 8.614],
    "M2": [1.806, 2.863, 3.559, 4.913, 7.872],
    "M3": [7.577, 4.341, 4.867, 6.708, 11.429],
    "M4": [7.306, 2.692, 9.946, 22.345, 39.143],
    "M5": [2.179, 2.831, 4.057, 5.974, 9.867]
}).set_index("Group")

# Calculate delta
delta_mae = balanced_df - unbalanced_df
delta_mae.index.name = "Age Group"
delta_mae.columns.name = "Model"

# Plot heatmap with color style matching user's example
plt.figure(figsize=(6, 4.5))
sns.heatmap(
    delta_mae.T,
    annot=True,
    fmt=".2f",
    cmap=sns.diverging_palette(240, 10, n=200, as_cmap=True),  # blue-white-red
    center=0,
    linewidths=0.5,
    cbar_kws={"label": "Δ MAE"}
)
plt.title("Δ MAE (Balanced - Unbalanced) per Model and Age Group")
plt.tight_layout()
final_heatmap_path = "delta_mae_heatmap_tr.png"
plt.savefig(final_heatmap_path, dpi=300, transparent=True)
plt.close()


In [53]:
import pandas as pd
import matplotlib.pyplot as plt

def load_pred_gt(filepath):
    df = pd.read_csv(filepath)
    if "GT_Age" in df.columns and "Pred_Age" in df.columns:
        return df["GT_Age"], df["Pred_Age"]
    elif "GroundTruth" in df.columns and "Prediction" in df.columns:
        return df["GroundTruth"], df["Prediction"]
    elif "GroundTruthAge" in df.columns and "PredictedAge" in df.columns:
        return df["GroundTruthAge"], df["PredictedAge"]
    else:
        raise ValueError("Unsupported column format.")

def plot_predicted_vs_gt(before_csv, after_csv, model_name="Model", save_path=None):
    # Load data
    gt1, pred1 = load_pred_gt(before_csv)
    gt2, pred2 = load_pred_gt(after_csv)

    # Define plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 4), sharex=True, sharey=True)

    # Plot before balancing
    axs[0].scatter(gt1, pred1, alpha=0.6, color="#2f73b8", s=18)
    axs[0].plot([10, 80], [10, 80], 'k--', linewidth=1)
    axs[0].set_title("Before Balancing Attempts")
    axs[0].set_xlabel("Ground Truth Age")
    axs[0].set_ylabel("Predicted Age")
    axs[0].set_xlim(0, 85)
    axs[0].set_ylim(0, 85)

    # Plot after balancing
    axs[1].scatter(gt2, pred2, alpha=0.6, color="#2f73b8", s=18)
    axs[1].plot([10, 80], [10, 80], 'k--', linewidth=1)
    axs[1].set_title("After Balancing Attempts")
    axs[1].set_xlabel("Ground Truth Age")
    axs[1].set_xlim(0, 85)
    axs[1].set_ylim(0, 85)

    # Main title
    fig.suptitle(f"{model_name}: Predicted vs. Ground Truth Age", fontsize=14)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)

    if save_path:
        plt.savefig(save_path, dpi=300, transparent=True)
    plt.show()

# 示例用法（替换为你自己的路径）
plot_predicted_vs_gt(
    before_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/test_predictions.csv",
    after_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_gender/test_predictions.csv",
    model_name="Model 1",
    save_path="scatter_model1_comparison_tr.png"
)


In [30]:
import pandas as pd
import matplotlib.pyplot as plt

# ===== 通用函数，支持自动识别列名 =====
def load_pred_gt(filepath):
    df = pd.read_csv(filepath)
    if "GT_Age" in df.columns and "Pred_Age" in df.columns:
        return df["GT_Age"], df["Pred_Age"]
    elif "GroundTruth" in df.columns and "Prediction" in df.columns:
        return df["GroundTruth"], df["Prediction"]
    elif "GroundTruthAge" in df.columns and "PredictedAge" in df.columns:
        return df["GroundTruthAge"], df["PredictedAge"]
    else:
        raise ValueError("❌ 无法识别文件列名格式")

# ===== 散点图绘制函数 =====
def plot_predicted_vs_gt(before_csv, after_csv, model_name="Model", save_path=None):
    gt1, pred1 = load_pred_gt(before_csv)
    gt2, pred2 = load_pred_gt(after_csv)

    fig, axs = plt.subplots(1, 2, figsize=(12, 4), sharex=True, sharey=True)

    # Before balancing
    axs[0].scatter(gt1, pred1, alpha=0.6, color="#2f73b8", s=18)
    axs[0].plot([10, 80], [10, 80], 'k--', linewidth=1)
    axs[0].set_title("Before Balancing Attempts")
    axs[0].set_xlabel("Ground Truth Age")
    axs[0].set_ylabel("Predicted Age")
    axs[0].set_xlim(0, 85)
    axs[0].set_ylim(0, 85)

    # After balancing
    axs[1].scatter(gt2, pred2, alpha=0.6, color="#2f73b8", s=18)
    axs[1].plot([10, 80], [10, 80], 'k--', linewidth=1)
    axs[1].set_title("After Balancing Attempts")
    axs[1].set_xlabel("Ground Truth Age")
    axs[1].set_xlim(0, 85)
    axs[1].set_ylim(0, 85)

    fig.suptitle(f"{model_name}: Predicted vs. Ground Truth Age", fontsize=14)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)

    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()

# 示例用法（替换为你自己的路径）
plot_predicted_vs_gt(
    before_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_nogender/test_predictions.csv",
    after_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_nogender/test_predictions.csv",
    model_name="Model 2",
    save_path="scatter_model2_comparison.png"
)


In [54]:
import pandas as pd
import matplotlib.pyplot as plt

# ===== 通用函数，支持自动识别列名 =====
def load_pred_gt(filepath):
    df = pd.read_csv(filepath)
    if "GT_Age" in df.columns and "Pred_Age" in df.columns:
        return df["GT_Age"], df["Pred_Age"]
    elif "GroundTruth" in df.columns and "Prediction" in df.columns:
        return df["GroundTruth"], df["Prediction"]
    elif "GroundTruthAge" in df.columns and "PredictedAge" in df.columns:
        return df["GroundTruthAge"], df["PredictedAge"]
    else:
        raise ValueError("❌ 无法识别文件列名格式")

# ===== 散点图绘制函数 =====
def plot_predicted_vs_gt(before_csv, after_csv, model_name="Model", save_path=None):
    gt1, pred1 = load_pred_gt(before_csv)
    gt2, pred2 = load_pred_gt(after_csv)

    # 自动对齐对角线范围
    all_gt = pd.concat([gt1, gt2])
    min_age = int(all_gt.min()) - 2
    max_age = int(all_gt.max()) + 2

    fig, axs = plt.subplots(1, 2, figsize=(12, 4), sharex=True, sharey=True)

    # Before balancing
    axs[0].scatter(gt1, pred1, alpha=0.6, color="#2f73b8", s=18)
    axs[0].plot([min_age, max_age], [min_age, max_age], 'k--', linewidth=1)
    axs[0].set_title("Before Balancing Attempts")
    axs[0].set_xlabel("Ground Truth Age")
    axs[0].set_ylabel("Predicted Age")
    axs[0].set_xlim(min_age, max_age)
    axs[0].set_ylim(min_age, max_age)

    # After balancing
    axs[1].scatter(gt2, pred2, alpha=0.6, color="#2f73b8", s=18)
    axs[1].plot([min_age, max_age], [min_age, max_age], 'k--', linewidth=1)
    axs[1].set_title("After Balancing Attempts")
    axs[1].set_xlabel("Ground Truth Age")
    axs[1].set_xlim(min_age, max_age)
    axs[1].set_ylim(min_age, max_age)

    fig.suptitle(f"{model_name}: Predicted vs. Ground Truth Age", fontsize=14)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)

    if save_path:
        plt.savefig(save_path, dpi=300, transparent=True)
    plt.show()

# ========= ✅ 用法 =========
plot_predicted_vs_gt(
    before_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_nogender/test_predictions.csv",
    after_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_nogender/test_predictions.csv",
    model_name="Model 2",
    save_path="scatter_model2_comparison_tr.png"
)


In [55]:
import pandas as pd
import matplotlib.pyplot as plt

def load_pred_gt(filepath):
    df = pd.read_csv(filepath)
    if "GT_Age" in df.columns and "Pred_Age" in df.columns:
        return df["GT_Age"], df["Pred_Age"]
    elif "GroundTruth" in df.columns and "Prediction" in df.columns:
        return df["GroundTruth"], df["Prediction"]
    elif "GroundTruthAge" in df.columns and "PredictedAge" in df.columns:
        return df["GroundTruthAge"], df["PredictedAge"]
    else:
        raise ValueError("⚠️ Unsupported column format.")

def plot_predicted_vs_gt(before_csv, after_csv, model_name="Model", save_path=None):
    # Load data
    gt1, pred1 = load_pred_gt(before_csv)
    gt2, pred2 = load_pred_gt(after_csv)

    # Define plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 4), sharex=True, sharey=True)

    # Plot before balancing
    axs[0].scatter(gt1, pred1, alpha=0.6, color="#2f73b8", s=18)
    axs[0].plot([10, 80], [10, 80], 'k--', linewidth=1)
    axs[0].set_title("Before Balancing Attempts")
    axs[0].set_xlabel("Ground Truth Age")
    axs[0].set_ylabel("Predicted Age")
    axs[0].set_xlim(0, 85)
    axs[0].set_ylim(0, 85)

    # Plot after balancing
    axs[1].scatter(gt2, pred2, alpha=0.6, color="#2f73b8", s=18)
    axs[1].plot([10, 80], [10, 80], 'k--', linewidth=1)
    axs[1].set_title("After Balancing Attempts")
    axs[1].set_xlabel("Ground Truth Age")
    axs[1].set_xlim(0, 85)
    axs[1].set_ylim(0, 85)

    # Main title
    fig.suptitle(f"{model_name}: Predicted vs. Ground Truth Age", fontsize=14)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)

    if save_path:
        plt.savefig(save_path, dpi=300, transparent=True)
    plt.show()

# 示例用法（替换为你自己的路径）
plot_predicted_vs_gt(
    before_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_cascade_gender_age/test_predictions.csv",
    after_csv="/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_cascade_gender_age/test_predictions.csv",
    model_name="Model 5",
    save_path="scatter_model5_comparison_tr.png"
)


In [33]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import re

# Standard output settings
plt.rcParams.update({
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "font.size": 11
})

STANDARD_SIZE = (6, 4.5)  # ✅ 统一图像尺寸

df = pd.read_csv("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/test_predictions.csv")

# --- Figure 3.4: Scatter Plot ---
plt.figure(figsize=STANDARD_SIZE)
plt.scatter(df["GT_Age"], df["Pred_Age"], alpha=0.5, color='#1f77b4', edgecolor='none', s=20)
plt.plot([df["GT_Age"].min(), df["GT_Age"].max()],
         [df["GT_Age"].min(), df["GT_Age"].max()],
         linestyle='--', color='black', linewidth=1)
plt.xlabel("Ground Truth Age")
plt.ylabel("Predicted Age")
plt.title("Predicted vs Ground Truth Age (Model 1)")
plt.grid(False)
plt.tight_layout()
plt.savefig("fig3_4_age_scatter_m1.png")
plt.close()

# --- Figure 3.5: Confusion Matrix (Gender) ---
y_true = df["GT_Gender"]
y_pred = df["Pred_Gender"]
fig, ax = plt.subplots(figsize=STANDARD_SIZE)
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Female", "Male"])
disp.plot(cmap="Blues", ax=ax, values_format='d')
plt.title("Gender Classification Confusion Matrix (Model 1)")
plt.tight_layout()
plt.savefig("fig3_5_gender_confusion_m1.png")
plt.close()

# --- Figure 3.6: Residual Histogram ---
errors = df['Pred_Age'] - df['GT_Age']
plt.figure(figsize=STANDARD_SIZE)
plt.hist(errors, bins=30, color='steelblue', edgecolor='none')
plt.title("Residual Error Histogram")
plt.xlabel("Prediction Error (Predicted - Ground Truth)")
plt.ylabel("Frequency")
plt.grid(False)
plt.tight_layout()
plt.savefig("fig3_6_residual_histogram_m1.png")
plt.close()

# --- Figure 3.7: MAE by Age Group ---
bins = [0, 10, 20, 30, 40, 50, 60, 70, 100]
labels = ['[0,10)', '[10,20)', '[20,30)', '[30,40)', '[40,50)', '[50,60)', '[60,70)', '[70,100)']
df['AgeBin'] = pd.cut(df['GT_Age'], bins=bins, labels=labels, right=False)
mae_per_bin = df.groupby('AgeBin').apply(lambda x: np.mean(np.abs(x['GT_Age'] - x['Pred_Age'])))
plt.figure(figsize=STANDARD_SIZE)
mae_per_bin.plot(kind='bar', color='steelblue')
plt.ylabel('Mean Absolute Error')
plt.xlabel('Age Bin')
plt.title('MAE by Age Bin (Model 1)')
plt.tight_layout()
plt.savefig("fig3_7_mae_by_age_group_m1.png")
plt.close()

# --- Figure 3.8: Validation MAE from Training Log ---
with open("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/train.log", "r") as file:
    log_lines = file.readlines()

pattern = r"Epoch (\d+)/\d+, .*?Val MAE: ([\d\.]+)"
results = [(int(m.group(1)), float(m.group(2))) for line in log_lines if (m := re.search(pattern, line))]
results.sort(key=lambda x: x[0])
epochs, val_mae = zip(*results)

plt.figure(figsize=STANDARD_SIZE)
plt.plot(epochs, val_mae, color="#357EC7", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Validation MAE")
plt.title("Validation MAE across Epochs (Model 1)")
plt.grid(False)
plt.tight_layout()
plt.savefig("fig3_8_val_mae_curve_m1.png")
plt.close()


In [34]:
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

# 设定目标统一尺寸（例如 600x450）
target_width, target_height = 600, 450

# 函数：按比例缩放并填充
def resize_and_fill(img, target_size, fill_color=(255, 255, 255)):
    target_w, target_h = target_size
    ratio = min(target_w / img.width, target_h / img.height)
    new_size = (int(img.width * ratio), int(img.height * ratio))
    img_resized = img.resize(new_size, Image.Resampling.LANCZOS)

    # 创建目标画布，居中填充
    new_img = Image.new("RGB", (target_w, target_h), fill_color)
    paste_x = (target_w - new_size[0]) // 2
    paste_y = (target_h - new_size[1]) // 2
    new_img.paste(img_resized, (paste_x, paste_y))
    return new_img



# Load the images
img1 = resize_and_fill(Image.open("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/loss_curve.png"), (target_width, target_height))  # Top-left
img2 = resize_and_fill(Image.open("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/val_mae_curve.png"), (target_width, target_height)) # Top-right
img3 = resize_and_fill(Image.open("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/fig3_6_residual_histogram_m1.png"), (target_width, target_height))  # Bottom-left
img4 = resize_and_fill(Image.open("/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/fig3_5_gender_confusion_m1.png"), (target_width, target_height))  # Bottom-right

# 拼接
panel = Image.new("RGB", (target_width * 2, target_height * 2), (255, 255, 255))
panel.paste(img1, (0, 0))  # Top-left
panel.paste(img2, (target_width, 0))  # Top-right
panel.paste(img3, (0, target_height))  # Bottom-left
panel.paste(img4, (target_width, target_height))  # Bottom-right

# Save the final combined image
output_path = "m1_eval_panel.png"
panel.save(output_path)

output_path


'm1_eval_panel.png'

In [38]:
from PIL import Image, ImageDraw, ImageFont
import os

# Define the image paths for 10 plots
img_paths = [
    # Model 1
    ("Model 1 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/loss_curve.png"),
    ("Model 1 - Balanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_gender/loss_curve.png"),
    # Model 2
    ("Model 2 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_nogender/loss_curve.png"),
    ("Model 2 - Balanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_nogender/loss_curve.png"),
    # Model 3
    ("Model 3 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_gender/loss_curve.png"),
    ("Model 3 - Balanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_gender/loss_curve.png"),
    # Model 4
    ("Model 4 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_nogender/loss_curve.png"),
    ("Model 4 - Balanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_nogender/loss_curve.png"),
    # Model 5
    ("Model 5 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_cascade_gender_age/loss_curve.png"),
    ("Model 5 - Balanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_cascade_gender_age/loss_curve.png"),
]
# 设置大小和字体
img_size = (300, 250)
title_height = 30
font = ImageFont.load_default()

resized_imgs_with_titles = []
for title, path in img_paths:
    img = Image.open(path).resize(img_size)
    canvas = Image.new("RGB", (img_size[0], img_size[1] + title_height), (255, 255, 255))
    draw = ImageDraw.Draw(canvas)
    draw.text((10, 5), title, fill=(0, 0, 0), font=font)
    canvas.paste(img, (0, title_height))
    resized_imgs_with_titles.append(canvas)

# 拼成5行2列
rows = []
for i in range(0, 10, 2):
    row = Image.new("RGB", (img_size[0] * 2, img_size[1] + title_height))
    row.paste(resized_imgs_with_titles[i], (0, 0))
    row.paste(resized_imgs_with_titles[i + 1], (img_size[0], 0))
    rows.append(row)

# 最终大图
final_img = Image.new("RGB", (img_size[0] * 2, (img_size[1] + title_height) * 5))
for j, row_img in enumerate(rows):
    final_img.paste(row_img, (0, j * (img_size[1] + title_height)))

# Save final panel
final_output_path = "loss_comparison_all.png"
final_img.save(final_output_path)



In [52]:
from PIL import Image, ImageDraw, ImageFont
import os

# 图像路径列表
img_paths = [
    ("Model 1 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/loss_curve.png"),
    ("Model 1 - Balanced",   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_gender/loss_curve.png"),
    ("Model 2 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_nogender/loss_curve.png"),
    ("Model 2 - Balanced",   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_nogender/loss_curve.png"),
    ("Model 3 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_gender/loss_curve.png"),
    ("Model 3 - Balanced",   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_gender/loss_curve.png"),
    ("Model 4 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_group5_nogender/loss_curve.png"),
    ("Model 4 - Balanced",   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_group5_nogender/loss_curve.png"),
    ("Model 5 - Unbalanced", "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_cascade_gender_age/loss_curve.png"),
    ("Model 5 - Balanced",   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_cascade_gender_age/loss_curve.png"),
]

# 设置大小和字体
img_size = (300, 250)
title_height = 30
font = ImageFont.load_default()

resized_imgs_with_titles = []
for title, path in img_paths:
    img = Image.open(path).convert("RGBA").resize(img_size)
    canvas = Image.new("RGBA", (img_size[0], img_size[1] + title_height), (255, 255, 255, 0))  # 透明背景
    draw = ImageDraw.Draw(canvas)
    draw.text((10, 5), title, fill=(0, 0, 0, 255), font=font)
    canvas.paste(img, (0, title_height), mask=img)  # 保留透明通道
    resized_imgs_with_titles.append(canvas)

# 拼成5行2列
rows = []
for i in range(0, 10, 2):
    row = Image.new("RGBA", (img_size[0] * 2, img_size[1] + title_height), (255, 255, 255, 0))
    row.paste(resized_imgs_with_titles[i], (0, 0), mask=resized_imgs_with_titles[i])
    row.paste(resized_imgs_with_titles[i + 1], (img_size[0], 0), mask=resized_imgs_with_titles[i + 1])
    rows.append(row)

# 最终大图
final_img = Image.new("RGBA", (img_size[0] * 2, (img_size[1] + title_height) * 5), (255, 255, 255, 0))
for j, row_img in enumerate(rows):
    final_img.paste(row_img, (0, j * (img_size[1] + title_height)), mask=row_img)

# 保存为透明 PNG
final_output_path = "loss_comparison_all_tr.png"
final_img.save(final_output_path)


In [56]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# 创建输出目录
output_dir = "c5_final_residual_figures"
os.makedirs(output_dir, exist_ok=True)

# 模型文件路径映射
model_files = {
    "Model 1 (Unbalanced)": "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_gender/test_predictions.csv",
    "Model 1 (Balanced)":   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_gender/test_predictions.csv",
    "Model 2 (Unbalanced)": "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_age76_nogender/test_predictions.csv",
    "Model 2 (Balanced)":   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_age76_nogender/test_predictions.csv",
    "Model 5 (Unbalanced)": "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_mtcnn/logs_cascade_gender_age/test_predictions.csv",
    "Model 5 (Balanced)":   "/mimer/NOBACKUP/groups/naiss2025-22-39/lili_project/history_age_predict_ifmerge/logs_cascade_gender_age/test_predictions.csv"
}

# 自定义配色
custom_palette = {
    "Model 1 (Unbalanced)": "#1f77b4",
    "Model 1 (Balanced)":   "#aec7e8",
    "Model 2 (Unbalanced)": "#2ca02c",
    "Model 2 (Balanced)":   "#98df8a",
    "Model 5 (Unbalanced)": "#d62728",
    "Model 5 (Balanced)":   "#ff9896"
}

# 加载并处理数据
all_data = []
for model_name, file in model_files.items():
    try:
        df = pd.read_csv(file)

        # 自动识别列名
        if "Prediction" in df.columns and "GroundTruth" in df.columns:
            df["Pred_Age"] = df["Prediction"]
            df["GT_Age"] = df["GroundTruth"]

        if "Pred_Age" not in df.columns or "GT_Age" not in df.columns:
            print(f"[WARNING] Skipping '{model_name}' — Missing 'Pred_Age' or 'GT_Age'.")
            continue

        df["Residual"] = df["Pred_Age"] - df["GT_Age"]
        df["Model"] = model_name
        all_data.append(df)

    except Exception as e:
        print(f"[ERROR] Skipping '{model_name}' due to error: {e}")
        continue

# 合并数据
combined_df = pd.concat(all_data, ignore_index=True)

# 年龄分组
def age_group(age):
    if age < 19:
        return "Child"
    elif age < 30:
        return "Young"
    elif age < 40:
        return "Young Adult"
    elif age < 60:
        return "Middle Age"
    else:
        return "Elderly"

combined_df["Age Group"] = combined_df["GT_Age"].apply(age_group)
age_order = ["Child", "Young", "Young Adult", "Middle Age", "Elderly"]
combined_df["Age Group"] = pd.Categorical(combined_df["Age Group"], categories=age_order, ordered=True)

# === 图 1：Residual vs. GT Age (Scatter) ===
plt.figure(figsize=(10, 5))
sns.scatterplot(
    data=combined_df,
    x="GT_Age", y="Residual",
    hue="Model", palette=custom_palette,
    alpha=0.5, edgecolor=None
)
plt.axhline(0, color="black", linestyle="--", linewidth=1)
plt.title("Residual vs. Ground Truth Age")
plt.xlabel("Ground Truth Age")
plt.ylabel("Prediction Error")
plt.legend(title="Model", bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "residual_vs_gt_consistent_tr.png"), bbox_inches="tight", dpi=300, transparent=True)
plt.close()

# === 图 2：Distribution of Prediction Residuals (Transparent Line) ===
plt.figure(figsize=(10, 5))
for model in combined_df["Model"].unique():
    subset = combined_df[combined_df["Model"] == model]
    sns.histplot(
        subset["Residual"],
        label=model,
        kde=True,
        stat="density",
        bins=40,
        element="step",
        fill=False,
        linewidth=1.5,
        color=custom_palette[model]
    )

plt.axvline(0, color="black", linestyle="--", linewidth=1)
plt.title("Distribution of Prediction Residuals")
plt.xlabel("Prediction Error (Predicted Age - Ground Truth Age)")
plt.ylabel("Density")
plt.legend(title="Model", bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "residual_histogram_consistent_tr.png"), bbox_inches="tight", dpi=300, transparent=True)
plt.close()

# === 图 3：Residuals by Age Group (Boxplot with Ordered Age Groups) ===
plt.figure(figsize=(12, 5))
sns.boxplot(data=combined_df, x="Age Group", y="Residual", hue="Model", palette=custom_palette)
plt.axhline(0, color="black", linestyle="--", linewidth=1)
plt.title("Prediction Residuals by Age Group")
plt.xlabel("Age Group")
plt.ylabel("Prediction Error")
plt.legend(title="Model", bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "residual_boxplot_consistent_tr.png"), bbox_inches="tight", dpi=300, transparent=True)
plt.close()


In [57]:
# === 新图 2：拆分三个模型残差分布并拼接展示 ===
fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)

model_groups = {
    "Model 1": ["Model 1 (Unbalanced)", "Model 1 (Balanced)"],
    "Model 2": ["Model 2 (Unbalanced)", "Model 2 (Balanced)"],
    "Model 5": ["Model 5 (Unbalanced)", "Model 5 (Balanced)"]
}

for ax, (model_base, variants) in zip(axes, model_groups.items()):
    for variant in variants:
        subset = combined_df[combined_df["Model"] == variant]
        sns.histplot(
            subset["Residual"],
            label=variant,
            kde=True,
            stat="density",
            bins=40,
            element="step",
            fill=False,
            linewidth=1.5,
            color=custom_palette[variant],
            ax=ax
        )
    ax.axvline(0, color="black", linestyle="--", linewidth=1)
    ax.set_title(f"{model_base} Residuals")
    ax.set_xlabel("Prediction Error")
    if ax == axes[0]:
        ax.set_ylabel("Density")
    else:
        ax.set_ylabel("")
    ax.legend(title="Model", loc="upper right")

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "residual_distribution_split_compare_tr.png"), bbox_inches="tight", dpi=300, transparent=True)
plt.close()


In [58]:
# === 新图 2（竖排拼接）：按模型拆分残差分布图 ===
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(8, 12), sharex=True)

model_groups = {
    "Model 1": ["Model 1 (Unbalanced)", "Model 1 (Balanced)"],
    "Model 2": ["Model 2 (Unbalanced)", "Model 2 (Balanced)"],
    "Model 5": ["Model 5 (Unbalanced)", "Model 5 (Balanced)"]
}

for ax, (model_base, variants) in zip(axes, model_groups.items()):
    for variant in variants:
        subset = combined_df[combined_df["Model"] == variant]
        sns.histplot(
            subset["Residual"],
            label=variant,
            kde=True,
            stat="density",
            bins=40,
            element="step",
            fill=False,
            linewidth=1.5,
            color=custom_palette[variant],
            ax=ax
        )
    ax.axvline(0, color="black", linestyle="--", linewidth=1)
    ax.set_title(f"{model_base} Residuals")
    ax.set_ylabel("Density")
    ax.legend(title="Model", loc="upper right")

axes[-1].set_xlabel("Prediction Error (Predicted Age - Ground Truth Age)")

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "residual_distribution_split_vertical_tr.png"), bbox_inches="tight", dpi=300, transparent=True)
plt.close()
