In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
from matplotlib.colors import TABLEAU_COLORS

def read_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    
    result = {}
    if 'mmlu' in data['results']:
        result['mmlu'] = data['results']['mmlu']['acc,none']
    if 'gpqa_main_zeroshot' in data['results']:
        result['gpqa'] = data['results']['gpqa_main_zeroshot']['acc,none']
    if 'gsm8k_cot' in data['results']:
        result['gsm8k'] = data['results']['gsm8k_cot']['exact_match,strict-match']
    
    return result

def extract_model_name(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('-')
    name = '-'.join(parts[1:])  # llama3-lora-r4 등의 형태로 추출
    return name.replace('personal1000-','').replace('pretrain-','')

def plot_radar_chart(data_dict, labels, model_names):
    angles = np.linspace(0, 2*np.pi, len(labels), endpoint=False)
    angles = np.concatenate((angles, [angles[0]]))
    
    fig, ax = plt.subplots(figsize=(10, 7), subplot_kw=dict(projection='polar'))
    
    colors = list(TABLEAU_COLORS.values())
    
    # Calculate min and max for each metric
    metric_ranges = {label: (min(data.get(label, 1) for data in data_dict.values()),
                             max(data.get(label, 0) for data in data_dict.values()))
                     for label in labels}
    
    # Adjust the ranges to create a buffer
    for label, (min_val, max_val) in metric_ranges.items():
        range_size = max_val - min_val
        buffer = range_size * 0.4  # 10% buffer
        metric_ranges[label] = (max(0, min_val - buffer), min(1, max_val + buffer))
        print(label, range_size, max_val, min_val, buffer, metric_ranges[label])
    
    for i, name in enumerate(model_names):
        data = data_dict.get(name, {})
        values = []
        for label in labels:
            value = data.get(label, 0)
            min_val, max_val = metric_ranges[label]
            try:
                normalized_value = (value - min_val) / (max_val - min_val)
            except:
                normalized_value = value
            values.append(normalized_value)
        values = np.concatenate((values, [values[0]]))
        
        ax.plot(angles, values, linewidth=2, linestyle='solid', label=name, color=colors[i % len(colors)])
        ax.fill(angles, values, alpha=0.25, color=colors[i % len(colors)])
        
        # Add text annotations with original values
        for angle, value, label in zip(angles[:-1], values[:-1], labels):
            original_value = data.get(label, 0)
            ax.text(angle, value, f'{original_value:.4f}', ha='center', va='center')
    
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    plt.xticks(angles[:-1], labels)
    ax.set_rlabel_position(0)
    ax.set_yticks([])  # Remove radial ticks
    
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
    plt.title("Model Comparison", fontsize=16)
    
    plt.tight_layout()
    plt.show()

def print_value_ranges(data_dict):
    metrics = ['gpqa', 'gsm8k', 'mmlu']
    for metric in metrics:
        values = [data.get(metric, 0) for data in data_dict.values()]
        min_value = min(values)
        max_value = max(values)
        print(f"{metric.upper()} values range from {min_value:.4f} to {max_value:.4f}")

# json_files = glob.glob('*.json')
json_files = [
    'gpqa0-llama3.json',
    # 'gpqa0-llama3-personal1000.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-50step.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-100step.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-150step.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-200step.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-250step.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-300step.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-350step.json',
    'o/gpqa0-llama3-lora-r32s2-personal1000-400step.json',
    'gsm8k8cot-llama3.json',
    # 'gsm8k8cot-llama3-personal1000.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-50step.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-100step.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-150step.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-200step.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-250step.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-300step.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-350step.json',
    'o/gsm8k8cot-llama3-lora-r32s2-personal1000-400step.json',
    'mmlu5-llama3.json',
    # 'mmlu5-llama3-personal1000.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-50step.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-100step.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-150step.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-200step.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-250step.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-300step.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-350step.json',
    'o/mmlu5-llama3-lora-r32s2-personal1000-400step.json',
]

# 데이터를 모델별로 정리
data_dict = {}
for file in json_files:
    model_name = extract_model_name(file)
    data = read_json(file)
    if model_name not in data_dict:
        data_dict[model_name] = {}
    data_dict[model_name].update(data)

# 모델 이름 추출
model_names = list(data_dict.keys())

# 모든 라벨 추출 (중복 제거)
all_labels = sorted(set(key for data in data_dict.values() for key in data.keys()))

# 값 범위 출력
print_value_ranges(data_dict)

# 레이더 차트 그리기
plot_radar_chart(data_dict, all_labels, model_names)


In [None]:
import pandas as pd
import glob

file_list = glob.glob('../privacy_instruction/v/Generated_300step_LoRA_r32s2_Personal_Instruction_llama3_selected1000.csv')

def replace_sentence(sentence):
    sentence = sentence.replace('user\n\n', '').replace('assistant\n\n', '').replace('\n\n', '').replace('\n', '')
    return sentence

for file in file_list:
    df = pd.read_csv(file)

    df['generated_full_sentence'] = df.apply(
        lambda row: replace_sentence(row['generated_full_sentence']), 
        axis=1
    )
    df['full_sentence'] = df.apply(
        lambda row: replace_sentence(row['full_sentence']), 
        axis=1
    )

    df.to_csv('test.csv', index=False)
    print(df)