In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import os
import json
import time
from datetime import datetime, timedelta
import re
from typing import Tuple, Dict, Any


gaussian_filter1d_sigma = 5

training_time_left_estimation = ''

def add_time_delta_from_now(delta_time_str):
    hh, mm, ss = map(int, delta_time_str.split(':'))
    delta = timedelta(hours=hh, minutes=mm, seconds=ss)
    return (datetime.now() + delta).strftime("%Y-%m-%d %H:%M:%S")

def extract_training_progress(log_line: str) -> Tuple[int, Dict[str, Any]]:
    """
    从日志行中提取step和指标字典

    参数:
        log_line: 包含step信息的日志行字符串

    返回:
        一个元组，包含step整数和指标字典
    样例：
        step:79 - global_seqlen/min:21988 - global_seqlen/max:26988 - global_seqlen/minmax_diff:5000 - global_seqlen/balanced_min:23979 - global_seqlen/balanced_max:23979 - global_seqlen/mean:23979.0 - actor/entropy:0.6258355975151062 - actor/kl_loss:0.005031203974795062 - actor/kl_coef:0.001 - actor/pg_loss:0.0067213474694654 - actor/pg_clipfrac:0.0010934590645774733 - actor/ppo_kl:-9.40958524324742e-05 - actor/pg_clipfrac_lower:0.0 - actor/grad_norm:0.9099477827548981 - perf/mfu/actor:0.0474406283411776 - perf/max_memory_allocated_gb:54.153775215148926 - perf/max_memory_reserved_gb:60.97265625 - perf/cpu_memory_used_gb:128.8563690185547 - actor/lr:4e-07 - training/global_step:79 - training/epoch:0 - critic/score/mean:1.360457181930542 - critic/score/max:8.245745658874512 - critic/score/min:-6.4666666984558105 - critic/rewards/mean:1.360457181930542 - critic/rewards/max:8.245745658874512 - critic/rewards/min:-6.4666666984558105 - critic/advantages/mean:-0.012269635684788227 - critic/advantages/max:3.659734010696411 - critic/advantages/min:-3.624459743499756 - critic/returns/mean:-0.012269635684788227 - critic/returns/max:3.659734010696411 - critic/returns/min:-3.624459743499756 - response_length/mean:570.46875 - response_length/max:954.0 - response_length/min:235.0 - response_length/clip_ratio:0.0 - prompt_length/mean:178.875 - prompt_length/max:203.0 - prompt_length/min:160.0 - prompt_length/clip_ratio:0.0 - timing_s/start_profile:7.44890421628952e-05 - timing_s/generate_sequences:22.991798400878906 - timing_s/reshard:5.679366588592529 - timing_s/generation_timing/max:26.468095779418945 - timing_s/generation_timing/min:19.58159828186035 - timing_s/generation_timing/topk_ratio:0.125 - timing_s/gen:33.77305157203227 - timing_s/reward:401.85691031068563 - timing_s/old_log_prob:29.903172932565212 - timing_s/ref:33.689625184983015 - timing_s/adv:0.0065030502155423164 - timing_s/update_actor:117.17569359019399 - timing_s/step:616.4165181787685 - timing_s/stop_profile:8.272938430309296e-05 - timing_per_token_ms/adv:3.3899715456974415e-05 - timing_per_token_ms/ref:0.17562046574598095 - timing_per_token_ms/update_actor:0.610824542256735 - timing_per_token_ms/gen:0.23125891243517033 - perf/total_num_tokens:191832 - perf/time_per_step:616.4165181787685 - perf/throughput:38.900644763458125
    """
    # 移除ANSI颜色代码（如[36m和[0m）
    ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
    cleaned_line = ansi_escape.sub('', log_line)
    
    # 移除可能的TaskRunner前缀 (TaskRunner pid=xxxx)
    cleaned_line = re.sub(r'\(TaskRunner pid=\d+\)\s*', '', cleaned_line)

    # 正则表达式匹配step值
    step_match = re.match(r'step:(\d+)', cleaned_line)
    if not step_match:
        raise ValueError("日志行不包含有效的step信息")

    step = int(step_match.group(1))

    # 提取所有的键值对
    # 匹配形式如 "key:value" 或 " - key:value" 的模式
    pattern = r'(- )?([\w/]+):([\d\.\-e]+)'
    matches = re.findall(pattern, cleaned_line)

    # 构建指标字典
    metric_dict = {}
    for _, key, value in matches:
        # 尝试将值转换为适当的类型
        try:
            # 先尝试转换为整数
            metric_dict[key] = int(value)
        except ValueError:
            try:
                # 再尝试转换为浮点数
                metric_dict[key] = float(value)
            except ValueError:
                # 如果都失败，保留字符串形式
                metric_dict[key] = value

    return step, metric_dict

score_dict = dict()
train_time_dict = dict()

for filename in os.listdir('.'):
    if not filename.startswith('verl_grpo_rationanomaly_') or not filename.endswith('.log'):
        continue

    progress_data = {}
    with open(filename, 'r', encoding='utf-8') as f:
        # print('visiting', filename)
        for line in f:
            try:
                # 打印得分
                step, metrics = extract_training_progress(line.strip())
                print(f"{step},{metrics}")
                progress_data[step] = metrics
                score_dict[int(step)] = metrics["critic/rewards/mean"]

            except Exception as e:
                continue

x = [i for i in range(1, len(score_dict) + 1)]
nums = [score_dict[idx] for idx in x]

print('total step metric(s) gathered:', len(score_dict))
print('training time left estimation:', training_time_left_estimation)

x = np.asarray(x)
nums = np.asarray(nums)
curve = gaussian_filter1d(nums, sigma=gaussian_filter1d_sigma)  # sigma 控制平滑程度
plt.figure(figsize=(20, 6))
plt.scatter(x, nums, color='gray', label='Raw Datapoint')
plt.plot(x, curve, 'b-', label=rf'Gaussian Filtered Curve ($\sigma={gaussian_filter1d_sigma}$)', linewidth=0.5)
plt.title('RL critic/rewards/mean', fontsize=14)
plt.xlabel('Step', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend(fontsize=12)
plt.xticks(x, rotation=90)
plt.xlim(left=1)
plt.tight_layout()
plt.show()
