In [1]:
import re

## 对于predict放到train中的情况

In [4]:
def extract_metrics(data: str):
    # 解析数据
    metrics = {
        "recall": {},
        "mrr": {},
        "ndcg": {}
    }
    
    for line in data.strip().split("\n"):
        match = re.search(r'wandb:\s+train/predict_(\w+)@(\d+)\s+([\d.]+)', line)
        if match:
            metric, k, value = match.groups()
            if metric in metrics:
                metrics[metric][int(k)] = round(float(value), 4) # 保留小数点后四位
    
    # 按照 recall 1-50, mrr 1-50, ndcg 1-50 顺序整理
    result = []
    for metric in ["recall", "mrr", "ndcg"]:
        for k in sorted(metrics[metric]):
            result.append(str(metrics[metric][k]))
    
    return ", ".join(result)

## predict独立的情况

In [2]:
import re

def extract_metrics(data: str):
    # 解析数据
    metrics = {
        "recall": {},
        "mrr": {},
        "ndcg": {}
    }
    
    for line in data.strip().split("\n"):
        match = re.search(r'wandb:\s+test/(\w+)@(\d+)\s+([\d.]+)', line)
        if match:
            metric, k, value = match.groups()
            if metric in metrics:
                metrics[metric][int(k)] = round(float(value), 4) # 保留小数点后四位
    
    # 按照 recall 1-50, mrr 1-50, ndcg 1-50 顺序整理
    result = []
    for metric in ["recall", "mrr", "ndcg"]:
        for k in sorted(metrics[metric]):
            result.append(str(metrics[metric][k]))
    
    return ", ".join(result)

## filter_user

In [5]:
# lr=5e-6, bs=2, num_negative_samples=32, gas=4

data = """
wandb:              train/learning_rate 1e-05
wandb:                       train/loss 2.9293
wandb:               train/predict_loss 6.83972
wandb:              train/predict_mrr@1 0.01495
wandb:             train/predict_mrr@10 0.0414
wandb:             train/predict_mrr@20 0.04666
wandb:              train/predict_mrr@5 0.03346
wandb:             train/predict_mrr@50 0.05017
wandb:             train/predict_ndcg@1 0.01495
wandb:            train/predict_ndcg@10 0.06204
wandb:            train/predict_ndcg@20 0.08114
wandb:             train/predict_ndcg@5 0.04302
wandb:            train/predict_ndcg@50 0.1034
wandb:           train/predict_recall@1 0.01495
wandb:          train/predict_recall@10 0.13079
wandb:          train/predict_recall@20 0.20628
wandb:           train/predict_recall@5 0.0725
wandb:          train/predict_recall@50 0.31913
wandb:            train/predict_runtime 186.5982
"""

extract_metrics(data)

'0.0149, 0.0725, 0.1308, 0.2063, 0.3191, 0.0149, 0.0335, 0.0414, 0.0467, 0.0502, 0.0149, 0.043, 0.062, 0.0811, 0.1034'

In [7]:
# lr=1e-6, bs=2, num_neg_samples=32, gas=1

data = """
wandb:                test/loss 7.66796                                                                                            
wandb:               test/mrr@1 0.00448                                                                                            
wandb:              test/mrr@10 0.01512                                                                                            
wandb:              test/mrr@20 0.01747                                                                                            
wandb:               test/mrr@5 0.01223                                                                                            
wandb:              test/mrr@50 0.01949
wandb:              test/ndcg@1 0.00448
wandb:             test/ndcg@10 0.02331
wandb:             test/ndcg@20 0.03196
wandb:              test/ndcg@5 0.01603
wandb:             test/ndcg@50 0.04488
wandb:            test/recall@1 0.00448
wandb:           test/recall@10 0.05082
wandb:           test/recall@20 0.0852
wandb:            test/recall@5 0.02765
wandb:           test/recall@50 0.15097
"""

extract_metrics(data)


'0.0045, 0.0277, 0.0508, 0.0852, 0.151, 0.0045, 0.0122, 0.0151, 0.0175, 0.0195, 0.0045, 0.016, 0.0233, 0.032, 0.0449'

In [5]:
# lr=1e-5, bs=2, num_neg_samples=64, gas=8

data = """
wandb:    eval/steps_per_second 1.812
wandb:                test/loss 6.92125
wandb:               test/mrr@1 0.01495
wandb:              test/mrr@10 0.04068
wandb:              test/mrr@20 0.0454
wandb:               test/mrr@5 0.03413
wandb:              test/mrr@50 0.04865
wandb:              test/ndcg@1 0.01495
wandb:             test/ndcg@10 0.06098
wandb:             test/ndcg@20 0.07806
wandb:              test/ndcg@5 0.04493
wandb:             test/ndcg@50 0.09889
wandb:            test/recall@1 0.01495
wandb:           test/recall@10 0.12855
wandb:           test/recall@20 0.19581
wandb:            test/recall@5 0.07848
wandb:           test/recall@50 0.30194
wandb:             test/runtime 181.0884
wandb:  test/samples_per_second 7.389
wandb:    test/steps_per_second 1.85
wandb:               total_flos 0
wandb:              train/epoch 0.99889
wandb:        train/global_step 336
wandb:          train/grad_norm 8.87292
wandb:      train/learning_rate 1e-05
wandb:               train/loss 3.9673
wandb:               train_loss 3.22661
wandb:            train_runtime 17893.3513
wandb: train_samples_per_second 0.602
wandb:   train_steps_per_second 0.019
"""

extract_metrics(data)

'0.0149, 0.0785, 0.1285, 0.1958, 0.3019, 0.0149, 0.0341, 0.0407, 0.0454, 0.0486, 0.0149, 0.0449, 0.061, 0.0781, 0.0989'

In [7]:
# lr=5e-6, bs=2, num_neg_samples=64, gas=4

data = """
wandb:                test/loss 6.8446
wandb:               test/mrr@1 0.01121
wandb:              test/mrr@10 0.04235
wandb:              test/mrr@20 0.04766
wandb:               test/mrr@5 0.03519
wandb:              test/mrr@50 0.05127
wandb:              test/ndcg@1 0.01121
wandb:             test/ndcg@10 0.06418
wandb:             test/ndcg@20 0.08384
wandb:              test/ndcg@5 0.04694
wandb:             test/ndcg@50 0.10669
wandb:            test/recall@1 0.01121
wandb:           test/recall@10 0.13602
wandb:           test/recall@20 0.2145
wandb:            test/recall@5 0.08296
wandb:           test/recall@50 0.33034
wandb:             test/runtime 202.1424
"""

extract_metrics(data)

'0.0112, 0.083, 0.136, 0.2145, 0.3303, 0.0112, 0.0352, 0.0423, 0.0477, 0.0513, 0.0112, 0.0469, 0.0642, 0.0838, 0.1067'

In [4]:
# lr=5e-6, bs=2, num_neg_samples=64, gas=8

data = """
wandb:                test/loss 6.82661
wandb:               test/mrr@1 0.01495
wandb:              test/mrr@10 0.04259
wandb:              test/mrr@20 0.04759
wandb:               test/mrr@5 0.03457
wandb:              test/mrr@50 0.05128
wandb:              test/ndcg@1 0.01495
wandb:             test/ndcg@10 0.06348
wandb:             test/ndcg@20 0.08202
wandb:              test/ndcg@5 0.0441
wandb:             test/ndcg@50 0.10458
wandb:            test/recall@1 0.01495
wandb:           test/recall@10 0.13303
wandb:           test/recall@20 0.20703
wandb:            test/recall@5 0.07324
wandb:           test/recall@50 0.31988
wandb:             test/runtime 185.6406
"""

extract_metrics(data)

'0.0149, 0.0732, 0.133, 0.207, 0.3199, 0.0149, 0.0346, 0.0426, 0.0476, 0.0513, 0.0149, 0.0441, 0.0635, 0.082, 0.1046'

In [6]:
# phi4, lr=5e-6, bs=2, num_neg_samples=32, gas=4

data = """
wandb:                test/loss 6.99885                                                                                                      
wandb:               test/mrr@1 0.01462                                                                                                      
wandb:              test/mrr@10 0.04098                                                                                                      
wandb:              test/mrr@20 0.04536                                                                                                      
wandb:               test/mrr@5 0.03463                                                                                                      
wandb:              test/mrr@50 0.04861                                                                                                      
wandb:              test/ndcg@1 0.01462                                                                                                      
wandb:             test/ndcg@10 0.06003                                                                                                      
wandb:             test/ndcg@20 0.07661                                                                                                      
wandb:              test/ndcg@5 0.04448                                                                                                      
wandb:             test/ndcg@50 0.09635                                                                                                      
wandb:            test/recall@1 0.01462                                                                                                      
wandb:           test/recall@10 0.12308                                                                                                      
wandb:           test/recall@20 0.19                                                                                                         
wandb:            test/recall@5 0.07462                               
wandb:           test/recall@50 0.28846                               
wandb:             test/runtime 174.0181                              
wandb:  test/samples_per_second 7.47      
wandb:    test/steps_per_second 1.868
"""

extract_metrics(data)

'0.0146, 0.0746, 0.1231, 0.19, 0.2885, 0.0146, 0.0346, 0.041, 0.0454, 0.0486, 0.0146, 0.0445, 0.06, 0.0766, 0.0964'

In [7]:
# qwen 32b标注, lr=5e-6, bs=2, num_neg_samples=32, gas=4

data = """
wandb:                test/loss 6.85888                                                                                                      
wandb:               test/mrr@1 0.01692                                                                                                      
wandb:              test/mrr@10 0.04271                                                                                                      
wandb:              test/mrr@20 0.04886                                                                                                      
wandb:               test/mrr@5 0.03637                                                                                                      
wandb:              test/mrr@50 0.05271                                                                                                      
wandb:              test/ndcg@1 0.01692                                                                                                      
wandb:             test/ndcg@10 0.06148
wandb:             test/ndcg@20 0.08348
wandb:              test/ndcg@5 0.04594
wandb:             test/ndcg@50 0.10757
wandb:            test/recall@1 0.01692
wandb:           test/recall@10 0.12385
wandb:           test/recall@20 0.21
wandb:            test/recall@5 0.07538
wandb:           test/recall@50 0.33154
wandb:             test/runtime 250.2515
"""

extract_metrics(data)

'0.0169, 0.0754, 0.1239, 0.21, 0.3315, 0.0169, 0.0364, 0.0427, 0.0489, 0.0527, 0.0169, 0.0459, 0.0615, 0.0835, 0.1076'

In [11]:
# gemma标注, lr=5e-6, bs=2, num_neg_samples=32, gas=4

data = """
wandb:                test/loss 6.88364                                                                                                      
wandb:               test/mrr@1 0.01231                                                                                                      
wandb:              test/mrr@10 0.04258                                                                                                      
wandb:              test/mrr@20 0.04772                                                                                                      
wandb:               test/mrr@5 0.0355                                                                                                       
wandb:              test/mrr@50 0.05117
wandb:              test/ndcg@1 0.01231
wandb:             test/ndcg@10 0.06422
wandb:             test/ndcg@20 0.08303
wandb:              test/ndcg@5 0.04663
wandb:             test/ndcg@50 0.10532
wandb:            test/recall@1 0.01231
wandb:           test/recall@10 0.13615
wandb:           test/recall@20 0.21077
wandb:            test/recall@5 0.08077
wandb:           test/recall@50 0.32462
wandb:             test/runtime 170.0957
"""

extract_metrics(data)

'0.0123, 0.0808, 0.1361, 0.2108, 0.3246, 0.0123, 0.0355, 0.0426, 0.0477, 0.0512, 0.0123, 0.0466, 0.0642, 0.083, 0.1053'

## gte-1.5B

In [3]:
# lr=5e-6, num_negative_samples=32, bs=16, gas=4, lr_scheduler_type inverse_sqrt, warmup_ratio=0.2, weight_decay=0.001

data = """
wandb:                test/loss 6.84031                                                                                                 
wandb:               test/mrr@1 0.01154                                                                                                 
wandb:              test/mrr@10 0.04271                                                                                                 
wandb:              test/mrr@20 0.04731                                                                                                 
wandb:               test/mrr@5 0.03462
wandb:              test/mrr@50 0.05125
wandb:              test/ndcg@1 0.01154
wandb:             test/ndcg@10 0.06599
wandb:             test/ndcg@20 0.08299
wandb:              test/ndcg@5 0.0465
wandb:             test/ndcg@50 0.10744
wandb:            test/recall@1 0.01154
wandb:           test/recall@10 0.14308
wandb:           test/recall@20 0.21077
wandb:            test/recall@5 0.08308
wandb:           test/recall@50 0.33385
wandb:             test/runtime 93.9027
"""

extract_metrics(data)

'0.0115, 0.0831, 0.1431, 0.2108, 0.3338, 0.0115, 0.0346, 0.0427, 0.0473, 0.0512, 0.0115, 0.0465, 0.066, 0.083, 0.1074'

In [4]:
# lr=1e-6, constant

data = """
wandb:                test/loss 7.16428                                                                                                 
wandb:               test/mrr@1 0.01308                                                                                                 
wandb:              test/mrr@10 0.03429                                                                                                 
wandb:              test/mrr@20 0.03816                                                                                                 
wandb:               test/mrr@5 0.02844
wandb:              test/mrr@50 0.041
wandb:              test/ndcg@1 0.01308
wandb:             test/ndcg@10 0.05051
wandb:             test/ndcg@20 0.06481
wandb:              test/ndcg@5 0.03618
wandb:             test/ndcg@50 0.0826
wandb:            test/recall@1 0.01308
wandb:           test/recall@10 0.10462
wandb:           test/recall@20 0.16154
wandb:            test/recall@5 0.06
wandb:           test/recall@50 0.25154
wandb:             test/runtime 97.4869
wandb:  test/samples_per_second 13.335
wandb:    test/steps_per_second 0.841
wandb:               total_flos 0
wandb:              train/epoch 1
wandb:        train/global_step 668
wandb:          train/grad_norm 22.88761
wandb:      train/learning_rate 0.0
wandb:               train/loss 2.1781
wandb:               train_loss 2.94727
wandb:            train_runtime 5248.9053
wandb: train_samples_per_second 2.036
wandb:   train_steps_per_second 0.127
"""

extract_metrics(data)

'0.0131, 0.06, 0.1046, 0.1615, 0.2515, 0.0131, 0.0284, 0.0343, 0.0382, 0.041, 0.0131, 0.0362, 0.0505, 0.0648, 0.0826'

In [5]:
# lr=5e-6, num_negative_samples=128, bs=8, gas=4, lr_scheduler_type inverse_sqrt, warmup_ratio=0.2, weight_decay=0.001

data = """
wandb:                test/loss 6.84322
wandb:               test/mrr@1 0.01615
wandb:              test/mrr@10 0.04557
wandb:              test/mrr@20 0.05019
wandb:               test/mrr@5 0.03821
wandb:              test/mrr@50 0.05389
wandb:              test/ndcg@1 0.01615
wandb:             test/ndcg@10 0.0673
wandb:             test/ndcg@20 0.08432
wandb:              test/ndcg@5 0.04957
wandb:             test/ndcg@50 0.10692
wandb:            test/recall@1 0.01615
wandb:           test/recall@10 0.13923
wandb:           test/recall@20 0.20692
wandb:            test/recall@5 0.08462
wandb:           test/recall@50 0.32
wandb:             test/runtime 106.6008
wandb:  test/samples_per_second 12.195
wandb:    test/steps_per_second 0.769
wandb:               total_flos 0
wandb:              train/epoch 1
wandb:        train/global_step 167
wandb:          train/grad_norm 15.56251
wandb:      train/learning_rate 0.0
wandb:               train/loss 3.6105
wandb:               train_loss 3.93329
wandb:            train_runtime 8192.4939
wandb: train_samples_per_second 1.304
wandb:   train_steps_per_second 0.02
"""

extract_metrics(data)

'0.0162, 0.0846, 0.1392, 0.2069, 0.32, 0.0162, 0.0382, 0.0456, 0.0502, 0.0539, 0.0162, 0.0496, 0.0673, 0.0843, 0.1069'

In [8]:
# lr=5e-6, num_negative_samples=96, bs=8, gas=8, lr_scheduler_type cosine_with_min_lr, warmup_ratio=0.3, weight_decay=0.001, min_lr=1e-7, max_grad_norm=0.5

data = """
wandb:                test/loss 6.87735                                                                                                      
wandb:               test/mrr@1 0.01308                                                                                                      
wandb:              test/mrr@10 0.0427                                                                                                       
wandb:              test/mrr@20 0.04713                                                                                                      
wandb:               test/mrr@5 0.03609                                                                                                      
wandb:              test/mrr@50 0.05105
wandb:              test/ndcg@1 0.01308
wandb:             test/ndcg@10 0.06321
wandb:             test/ndcg@20 0.07961
wandb:              test/ndcg@5 0.04725
wandb:             test/ndcg@50 0.10402
wandb:            test/recall@1 0.01308
wandb:           test/recall@10 0.13077
wandb:           test/recall@20 0.19615
wandb:            test/recall@5 0.08154
wandb:           test/recall@50 0.31923
wandb:             test/runtime 138.3797
"""

extract_metrics(data)

'0.0131, 0.0815, 0.1308, 0.1961, 0.3192, 0.0131, 0.0361, 0.0427, 0.0471, 0.051, 0.0131, 0.0473, 0.0632, 0.0796, 0.104'

In [10]:
# lr=5e-6, num_negative_samples=96, bs=8, gas=8, lr_scheduler_type cosine_with_restarts（两次）, warmup_ratio=0.1, weight_decay=0.01, max_grad_norm=0.5

data = """
wandb:                test/loss 6.87788                                                                                                      
wandb:               test/mrr@1 0.01385                                                                                                      
wandb:              test/mrr@10 0.04439                                                                                                      
wandb:              test/mrr@20 0.04922
wandb:               test/mrr@5 0.03785
wandb:              test/mrr@50 0.05327
wandb:              test/ndcg@1 0.01385
wandb:             test/ndcg@10 0.06521
wandb:             test/ndcg@20 0.08274
wandb:              test/ndcg@5 0.04902
wandb:             test/ndcg@50 0.10827
wandb:            test/recall@1 0.01385
wandb:           test/recall@10 0.13385
wandb:           test/recall@20 0.20308
wandb:            test/recall@5 0.08308
wandb:           test/recall@50 0.33231
"""

extract_metrics(data)

'0.0138, 0.0831, 0.1338, 0.2031, 0.3323, 0.0138, 0.0379, 0.0444, 0.0492, 0.0533, 0.0138, 0.049, 0.0652, 0.0827, 0.1083'