## filter

In [1]:
import re

In [2]:
def extract_metrics(data: str):
    # 解析数据
    metrics = {
        "recall": {},
        "mrr": {},
        "ndcg": {}
    }
    
    for line in data.strip().split("\n"):
        match = re.search(r'wandb:\s+train/predict_(\w+)@(\d+)\s+([\d.]+)', line)
        if match:
            metric, k, value = match.groups()
            if metric in metrics:
                metrics[metric][int(k)] = round(float(value), 4) # 保留小数点后四位
    
    # 按照 recall 1-50, mrr 1-50, ndcg 1-50 顺序整理
    result = []
    for metric in ["recall", "mrr", "ndcg"]:
        for k in sorted(metrics[metric]):
            result.append(str(metrics[metric][k]))
    
    return ", ".join(result)

In [4]:
# lr=5e-6, gas=16, bs=4

data = """
wandb:                       train/loss 6.834
wandb:               train/predict_loss 7.63146
wandb:              train/predict_mrr@1 0.01813
wandb:             train/predict_mrr@10 0.03638
wandb:             train/predict_mrr@20 0.04018
wandb:              train/predict_mrr@5 0.03177
wandb:             train/predict_mrr@50 0.04271
wandb:             train/predict_ndcg@1 0.01813
wandb:            train/predict_ndcg@10 0.04971
wandb:            train/predict_ndcg@20 0.06351
wandb:             train/predict_ndcg@5 0.03843
wandb:            train/predict_ndcg@50 0.07969
wandb:           train/predict_recall@1 0.01813
wandb:          train/predict_recall@10 0.09396
wandb:          train/predict_recall@20 0.14835
wandb:           train/predict_recall@5 0.05879
wandb:          train/predict_recall@50 0.23077
wandb:            train/predict_runtime 123.2537
wandb: train/predict_samples_per_second 14.766
wandb:   train/predict_steps_per_second 1.85
wandb:                       train_loss 7.83732
wandb:                    train_runtime 6564.7283
wandb:         train_samples_per_second 2.799
wandb:           train_steps_per_second 0.022
"""

extract_metrics(data)

'0.0181, 0.0588, 0.094, 0.1484, 0.2308, 0.0181, 0.0318, 0.0364, 0.0402, 0.0427, 0.0181, 0.0384, 0.0497, 0.0635, 0.0797'

In [8]:
# lr=5e-6, gas=4, bs=4, epoch=1

data = """
wandb:               train/predict_loss 7.71655
wandb:              train/predict_mrr@1 0.01703
wandb:             train/predict_mrr@10 0.03893
wandb:             train/predict_mrr@20 0.04172
wandb:              train/predict_mrr@5 0.0341
wandb:             train/predict_mrr@50 0.04468
wandb:             train/predict_ndcg@1 0.01703
wandb:            train/predict_ndcg@10 0.05392
wandb:            train/predict_ndcg@20 0.06407
wandb:             train/predict_ndcg@5 0.04221
wandb:            train/predict_ndcg@50 0.08243
wandb:           train/predict_recall@1 0.01703
wandb:          train/predict_recall@10 0.1033
wandb:          train/predict_recall@20 0.14341
wandb:           train/predict_recall@5 0.06703
wandb:          train/predict_recall@50 0.23571
wandb:            train/predict_runtime 120.9738
wandb: train/predict_samples_per_second 15.045
"""

extract_metrics(data)

'0.017, 0.067, 0.1033, 0.1434, 0.2357, 0.017, 0.0341, 0.0389, 0.0417, 0.0447, 0.017, 0.0422, 0.0539, 0.0641, 0.0824'

In [9]:
# lr=1e-6, gas=4, bs=4, epoch=1

data = """
wandb:                       train/loss 8.1445
wandb:               train/predict_loss 7.72033
wandb:              train/predict_mrr@1 0.01868
wandb:             train/predict_mrr@10 0.03261
wandb:             train/predict_mrr@20 0.03581
wandb:              train/predict_mrr@5 0.02973
wandb:             train/predict_mrr@50 0.03827
wandb:             train/predict_ndcg@1 0.01868
wandb:            train/predict_ndcg@10 0.0423
wandb:            train/predict_ndcg@20 0.05387
wandb:             train/predict_ndcg@5 0.03515
wandb:            train/predict_ndcg@50 0.06922
wandb:           train/predict_recall@1 0.01868
wandb:          train/predict_recall@10 0.07418
wandb:          train/predict_recall@20 0.11978
wandb:           train/predict_recall@5 0.05165
wandb:          train/predict_recall@50 0.19725
wandb:            train/predict_runtime 121.9311
wandb: train/predict_samples_per_second 14.926
wandb:   train/predict_steps_per_second 1.87
wandb:                       train_loss 8.20834
wandb:                    train_runtime 6747.3813
wandb:         train_samples_per_second 2.723
"""

extract_metrics(data)

'0.0187, 0.0517, 0.0742, 0.1198, 0.1973, 0.0187, 0.0297, 0.0326, 0.0358, 0.0383, 0.0187, 0.0352, 0.0423, 0.0539, 0.0692'

## filter_user

In [3]:
# lr=5e-6, gas=4, bs=4

data = """
wandb:                       train/loss 6.6016
wandb:               train/predict_loss 7.48954
wandb:              train/predict_mrr@1 0.01046
wandb:             train/predict_mrr@10 0.03071
wandb:             train/predict_mrr@20 0.03596
wandb:              train/predict_mrr@5 0.02667
wandb:             train/predict_mrr@50 0.03873
wandb:             train/predict_ndcg@1 0.01046
wandb:            train/predict_ndcg@10 0.04495
wandb:            train/predict_ndcg@20 0.06418
wandb:             train/predict_ndcg@5 0.0348
wandb:            train/predict_ndcg@50 0.08138
wandb:           train/predict_recall@1 0.01046
wandb:          train/predict_recall@10 0.09193
wandb:          train/predict_recall@20 0.16816
wandb:           train/predict_recall@5 0.05979
wandb:          train/predict_recall@50 0.25486
wandb:            train/predict_runtime 114.9731
wandb: train/predict_samples_per_second 11.638
wandb:   train/predict_steps_per_second 1.461
wandb:                       train_loss 7.64995
wandb:                    train_runtime 5138.1786
wandb:         train_samples_per_second 2.095
wandb:           train_steps_per_second 0.065
"""

extract_metrics(data)

'0.0105, 0.0598, 0.0919, 0.1682, 0.2549, 0.0105, 0.0267, 0.0307, 0.036, 0.0387, 0.0105, 0.0348, 0.0449, 0.0642, 0.0814'

In [4]:
# lr=5e-6, gas=4, bs=4
data = """
wandb:                      train/epoch 1
wandb:                train/global_step 1346
wandb:                  train/grad_norm 46.14608
wandb:              train/learning_rate 0.0
wandb:                       train/loss 7.0938
wandb:               train/predict_loss 7.62948
wandb:              train/predict_mrr@1 0.01868
wandb:             train/predict_mrr@10 0.03301
wandb:             train/predict_mrr@20 0.03593
wandb:              train/predict_mrr@5 0.02849
wandb:             train/predict_mrr@50 0.03823
wandb:             train/predict_ndcg@1 0.01868
wandb:            train/predict_ndcg@10 0.04397
wandb:            train/predict_ndcg@20 0.05481
wandb:             train/predict_ndcg@5 0.03291
wandb:            train/predict_ndcg@50 0.0697
wandb:           train/predict_recall@1 0.01868
wandb:          train/predict_recall@10 0.08072
wandb:          train/predict_recall@20 0.12407
wandb:           train/predict_recall@5 0.04634
wandb:          train/predict_recall@50 0.2003
"""

extract_metrics(data)

'0.0187, 0.0463, 0.0807, 0.1241, 0.2003, 0.0187, 0.0285, 0.033, 0.0359, 0.0382, 0.0187, 0.0329, 0.044, 0.0548, 0.0697'