In [2]:
import re

## 对于predict放到train中的情况

In [2]:
def extract_metrics(data: str):
    # 解析数据
    metrics = {
        "recall": {},
        "mrr": {},
        "ndcg": {}
    }
    
    for line in data.strip().split("\n"):
        match = re.search(r'wandb:\s+train/predict_(\w+)@(\d+)\s+([\d.]+)', line)
        if match:
            metric, k, value = match.groups()
            if metric in metrics:
                metrics[metric][int(k)] = round(float(value), 4) # 保留小数点后四位
    
    # 按照 recall 1-50, mrr 1-50, ndcg 1-50 顺序整理
    result = []
    for metric in ["recall", "mrr", "ndcg"]:
        for k in sorted(metrics[metric]):
            result.append(str(metrics[metric][k]))
    
    return ", ".join(result)

## predict独立的情况

In [1]:
import re

def extract_metrics(data: str):
    # 解析数据
    metrics = {
        "recall": {},
        "mrr": {},
        "ndcg": {}
    }
    
    for line in data.strip().split("\n"):
        match = re.search(r'wandb:\s+test/(\w+)@(\d+)\s+([\d.]+)', line)
        if match:
            metric, k, value = match.groups()
            if metric in metrics:
                metrics[metric][int(k)] = round(float(value), 4) # 保留小数点后四位
    
    # 按照 recall 1-50, mrr 1-50, ndcg 1-50 顺序整理
    result = []
    for metric in ["recall", "mrr", "ndcg"]:
        for k in sorted(metrics[metric]):
            result.append(str(metrics[metric][k]))
    
    return ", ".join(result)

## filter_user

In [5]:
# lr=5e-6, bs=2, num_negative_samples=32, gas=4

data = """
wandb:              train/learning_rate 1e-05
wandb:                       train/loss 2.9293
wandb:               train/predict_loss 6.83972
wandb:              train/predict_mrr@1 0.01495
wandb:             train/predict_mrr@10 0.0414
wandb:             train/predict_mrr@20 0.04666
wandb:              train/predict_mrr@5 0.03346
wandb:             train/predict_mrr@50 0.05017
wandb:             train/predict_ndcg@1 0.01495
wandb:            train/predict_ndcg@10 0.06204
wandb:            train/predict_ndcg@20 0.08114
wandb:             train/predict_ndcg@5 0.04302
wandb:            train/predict_ndcg@50 0.1034
wandb:           train/predict_recall@1 0.01495
wandb:          train/predict_recall@10 0.13079
wandb:          train/predict_recall@20 0.20628
wandb:           train/predict_recall@5 0.0725
wandb:          train/predict_recall@50 0.31913
wandb:            train/predict_runtime 186.5982
"""

extract_metrics(data)

'0.0149, 0.0725, 0.1308, 0.2063, 0.3191, 0.0149, 0.0335, 0.0414, 0.0467, 0.0502, 0.0149, 0.043, 0.062, 0.0811, 0.1034'

In [7]:
# lr=1e-6, bs=2, num_neg_samples=32, gas=1

data = """
wandb:                test/loss 7.66796                                                                                            
wandb:               test/mrr@1 0.00448                                                                                            
wandb:              test/mrr@10 0.01512                                                                                            
wandb:              test/mrr@20 0.01747                                                                                            
wandb:               test/mrr@5 0.01223                                                                                            
wandb:              test/mrr@50 0.01949
wandb:              test/ndcg@1 0.00448
wandb:             test/ndcg@10 0.02331
wandb:             test/ndcg@20 0.03196
wandb:              test/ndcg@5 0.01603
wandb:             test/ndcg@50 0.04488
wandb:            test/recall@1 0.00448
wandb:           test/recall@10 0.05082
wandb:           test/recall@20 0.0852
wandb:            test/recall@5 0.02765
wandb:           test/recall@50 0.15097
"""

extract_metrics(data)


'0.0045, 0.0277, 0.0508, 0.0852, 0.151, 0.0045, 0.0122, 0.0151, 0.0175, 0.0195, 0.0045, 0.016, 0.0233, 0.032, 0.0449'

In [5]:
# lr=1e-5, bs=2, num_neg_samples=64, gas=8

data = """
wandb:    eval/steps_per_second 1.812
wandb:                test/loss 6.92125
wandb:               test/mrr@1 0.01495
wandb:              test/mrr@10 0.04068
wandb:              test/mrr@20 0.0454
wandb:               test/mrr@5 0.03413
wandb:              test/mrr@50 0.04865
wandb:              test/ndcg@1 0.01495
wandb:             test/ndcg@10 0.06098
wandb:             test/ndcg@20 0.07806
wandb:              test/ndcg@5 0.04493
wandb:             test/ndcg@50 0.09889
wandb:            test/recall@1 0.01495
wandb:           test/recall@10 0.12855
wandb:           test/recall@20 0.19581
wandb:            test/recall@5 0.07848
wandb:           test/recall@50 0.30194
wandb:             test/runtime 181.0884
wandb:  test/samples_per_second 7.389
wandb:    test/steps_per_second 1.85
wandb:               total_flos 0
wandb:              train/epoch 0.99889
wandb:        train/global_step 336
wandb:          train/grad_norm 8.87292
wandb:      train/learning_rate 1e-05
wandb:               train/loss 3.9673
wandb:               train_loss 3.22661
wandb:            train_runtime 17893.3513
wandb: train_samples_per_second 0.602
wandb:   train_steps_per_second 0.019
"""

extract_metrics(data)

'0.0149, 0.0785, 0.1285, 0.1958, 0.3019, 0.0149, 0.0341, 0.0407, 0.0454, 0.0486, 0.0149, 0.0449, 0.061, 0.0781, 0.0989'

In [7]:
# lr=5e-6, bs=2, num_neg_samples=64, gas=4

data = """
wandb:                test/loss 6.8446
wandb:               test/mrr@1 0.01121
wandb:              test/mrr@10 0.04235
wandb:              test/mrr@20 0.04766
wandb:               test/mrr@5 0.03519
wandb:              test/mrr@50 0.05127
wandb:              test/ndcg@1 0.01121
wandb:             test/ndcg@10 0.06418
wandb:             test/ndcg@20 0.08384
wandb:              test/ndcg@5 0.04694
wandb:             test/ndcg@50 0.10669
wandb:            test/recall@1 0.01121
wandb:           test/recall@10 0.13602
wandb:           test/recall@20 0.2145
wandb:            test/recall@5 0.08296
wandb:           test/recall@50 0.33034
wandb:             test/runtime 202.1424
"""

extract_metrics(data)

'0.0112, 0.083, 0.136, 0.2145, 0.3303, 0.0112, 0.0352, 0.0423, 0.0477, 0.0513, 0.0112, 0.0469, 0.0642, 0.0838, 0.1067'

In [4]:
# lr=5e-6, bs=2, num_neg_samples=64, gas=8

data = """
wandb:                test/loss 6.82661
wandb:               test/mrr@1 0.01495
wandb:              test/mrr@10 0.04259
wandb:              test/mrr@20 0.04759
wandb:               test/mrr@5 0.03457
wandb:              test/mrr@50 0.05128
wandb:              test/ndcg@1 0.01495
wandb:             test/ndcg@10 0.06348
wandb:             test/ndcg@20 0.08202
wandb:              test/ndcg@5 0.0441
wandb:             test/ndcg@50 0.10458
wandb:            test/recall@1 0.01495
wandb:           test/recall@10 0.13303
wandb:           test/recall@20 0.20703
wandb:            test/recall@5 0.07324
wandb:           test/recall@50 0.31988
wandb:             test/runtime 185.6406
"""

extract_metrics(data)

'0.0149, 0.0732, 0.133, 0.207, 0.3199, 0.0149, 0.0346, 0.0426, 0.0476, 0.0513, 0.0149, 0.0441, 0.0635, 0.082, 0.1046'

In [6]:
# phi4, lr=5e-6, bs=2, num_neg_samples=32, gas=4

data = """
wandb:                test/loss 6.99885                                                                                                      
wandb:               test/mrr@1 0.01462                                                                                                      
wandb:              test/mrr@10 0.04098                                                                                                      
wandb:              test/mrr@20 0.04536                                                                                                      
wandb:               test/mrr@5 0.03463                                                                                                      
wandb:              test/mrr@50 0.04861                                                                                                      
wandb:              test/ndcg@1 0.01462                                                                                                      
wandb:             test/ndcg@10 0.06003                                                                                                      
wandb:             test/ndcg@20 0.07661                                                                                                      
wandb:              test/ndcg@5 0.04448                                                                                                      
wandb:             test/ndcg@50 0.09635                                                                                                      
wandb:            test/recall@1 0.01462                                                                                                      
wandb:           test/recall@10 0.12308                                                                                                      
wandb:           test/recall@20 0.19                                                                                                         
wandb:            test/recall@5 0.07462                               
wandb:           test/recall@50 0.28846                               
wandb:             test/runtime 174.0181                              
wandb:  test/samples_per_second 7.47      
wandb:    test/steps_per_second 1.868
"""

extract_metrics(data)

'0.0146, 0.0746, 0.1231, 0.19, 0.2885, 0.0146, 0.0346, 0.041, 0.0454, 0.0486, 0.0146, 0.0445, 0.06, 0.0766, 0.0964'

In [7]:
# qwen 32b标注, lr=5e-6, bs=2, num_neg_samples=32, gas=4

data = """
wandb:                test/loss 6.85888                                                                                                      
wandb:               test/mrr@1 0.01692                                                                                                      
wandb:              test/mrr@10 0.04271                                                                                                      
wandb:              test/mrr@20 0.04886                                                                                                      
wandb:               test/mrr@5 0.03637                                                                                                      
wandb:              test/mrr@50 0.05271                                                                                                      
wandb:              test/ndcg@1 0.01692                                                                                                      
wandb:             test/ndcg@10 0.06148
wandb:             test/ndcg@20 0.08348
wandb:              test/ndcg@5 0.04594
wandb:             test/ndcg@50 0.10757
wandb:            test/recall@1 0.01692
wandb:           test/recall@10 0.12385
wandb:           test/recall@20 0.21
wandb:            test/recall@5 0.07538
wandb:           test/recall@50 0.33154
wandb:             test/runtime 250.2515
"""

extract_metrics(data)

'0.0169, 0.0754, 0.1239, 0.21, 0.3315, 0.0169, 0.0364, 0.0427, 0.0489, 0.0527, 0.0169, 0.0459, 0.0615, 0.0835, 0.1076'

In [11]:
# gemma标注, lr=5e-6, bs=2, num_neg_samples=32, gas=4

data = """
wandb:                test/loss 6.88364                                                                                                      
wandb:               test/mrr@1 0.01231                                                                                                      
wandb:              test/mrr@10 0.04258                                                                                                      
wandb:              test/mrr@20 0.04772                                                                                                      
wandb:               test/mrr@5 0.0355                                                                                                       
wandb:              test/mrr@50 0.05117
wandb:              test/ndcg@1 0.01231
wandb:             test/ndcg@10 0.06422
wandb:             test/ndcg@20 0.08303
wandb:              test/ndcg@5 0.04663
wandb:             test/ndcg@50 0.10532
wandb:            test/recall@1 0.01231
wandb:           test/recall@10 0.13615
wandb:           test/recall@20 0.21077
wandb:            test/recall@5 0.08077
wandb:           test/recall@50 0.32462
wandb:             test/runtime 170.0957
"""

extract_metrics(data)

'0.0123, 0.0808, 0.1361, 0.2108, 0.3246, 0.0123, 0.0355, 0.0426, 0.0477, 0.0512, 0.0123, 0.0466, 0.0642, 0.083, 0.1053'

In [12]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4

data = """
wandb:                test/loss 6.87585                                                                                                      
wandb:               test/mrr@1 0.01692                                                                                                      
wandb:              test/mrr@10 0.04237                                                                                                      
wandb:              test/mrr@20 0.04733                                                                                                      
wandb:               test/mrr@5 0.03524                                                                                                      
wandb:              test/mrr@50 0.05123
wandb:              test/ndcg@1 0.01692
wandb:             test/ndcg@10 0.06233
wandb:             test/ndcg@20 0.08042
wandb:              test/ndcg@5 0.04468
wandb:             test/ndcg@50 0.10435
wandb:            test/recall@1 0.01692
wandb:           test/recall@10 0.12923
wandb:           test/recall@20 0.20077
wandb:            test/recall@5 0.07385
wandb:           test/recall@50 0.32077
wandb:             test/runtime 193.2848
"""

extract_metrics(data)

'0.0169, 0.0738, 0.1292, 0.2008, 0.3208, 0.0169, 0.0352, 0.0424, 0.0473, 0.0512, 0.0169, 0.0447, 0.0623, 0.0804, 0.1043'

## gte-1.5B

In [3]:
# lr=5e-6, num_negative_samples=32, bs=16, gas=4, lr_scheduler_type inverse_sqrt, warmup_ratio=0.2, weight_decay=0.001

data = """
wandb:                test/loss 6.84031                                                                                                 
wandb:               test/mrr@1 0.01154                                                                                                 
wandb:              test/mrr@10 0.04271                                                                                                 
wandb:              test/mrr@20 0.04731                                                                                                 
wandb:               test/mrr@5 0.03462
wandb:              test/mrr@50 0.05125
wandb:              test/ndcg@1 0.01154
wandb:             test/ndcg@10 0.06599
wandb:             test/ndcg@20 0.08299
wandb:              test/ndcg@5 0.0465
wandb:             test/ndcg@50 0.10744
wandb:            test/recall@1 0.01154
wandb:           test/recall@10 0.14308
wandb:           test/recall@20 0.21077
wandb:            test/recall@5 0.08308
wandb:           test/recall@50 0.33385
wandb:             test/runtime 93.9027
"""

extract_metrics(data)

'0.0115, 0.0831, 0.1431, 0.2108, 0.3338, 0.0115, 0.0346, 0.0427, 0.0473, 0.0512, 0.0115, 0.0465, 0.066, 0.083, 0.1074'

In [4]:
# lr=1e-6, constant

data = """
wandb:                test/loss 7.16428                                                                                                 
wandb:               test/mrr@1 0.01308                                                                                                 
wandb:              test/mrr@10 0.03429                                                                                                 
wandb:              test/mrr@20 0.03816                                                                                                 
wandb:               test/mrr@5 0.02844
wandb:              test/mrr@50 0.041
wandb:              test/ndcg@1 0.01308
wandb:             test/ndcg@10 0.05051
wandb:             test/ndcg@20 0.06481
wandb:              test/ndcg@5 0.03618
wandb:             test/ndcg@50 0.0826
wandb:            test/recall@1 0.01308
wandb:           test/recall@10 0.10462
wandb:           test/recall@20 0.16154
wandb:            test/recall@5 0.06
wandb:           test/recall@50 0.25154
wandb:             test/runtime 97.4869
wandb:  test/samples_per_second 13.335
wandb:    test/steps_per_second 0.841
wandb:               total_flos 0
wandb:              train/epoch 1
wandb:        train/global_step 668
wandb:          train/grad_norm 22.88761
wandb:      train/learning_rate 0.0
wandb:               train/loss 2.1781
wandb:               train_loss 2.94727
wandb:            train_runtime 5248.9053
wandb: train_samples_per_second 2.036
wandb:   train_steps_per_second 0.127
"""

extract_metrics(data)

'0.0131, 0.06, 0.1046, 0.1615, 0.2515, 0.0131, 0.0284, 0.0343, 0.0382, 0.041, 0.0131, 0.0362, 0.0505, 0.0648, 0.0826'

In [13]:
# lr=1e-6, constant, neg_samples=64

data = """
wandb:                test/loss 7.16127                                                                                                                            
wandb:               test/mrr@1 0.01462                                                                                                                            
wandb:              test/mrr@10 0.03476                                                                                                                            
wandb:              test/mrr@20 0.03896                                          
wandb:               test/mrr@5 0.0291                                           
wandb:              test/mrr@50 0.04175                                          
wandb:              test/ndcg@1 0.01462                                          
wandb:             test/ndcg@10 0.05033                                          
wandb:             test/ndcg@20 0.06579                                          
wandb:              test/ndcg@5 0.03649                                          
wandb:             test/ndcg@50 0.08319                                          
wandb:            test/recall@1 0.01462                                          
wandb:           test/recall@10 0.10231                                          
wandb:           test/recall@20 0.16385                                          
wandb:            test/recall@5 0.05923                                          
wandb:           test/recall@50 0.25154                                          
wandb:             test/runtime 61.1096
"""

extract_metrics(data)

'0.0146, 0.0592, 0.1023, 0.1638, 0.2515, 0.0146, 0.0291, 0.0348, 0.039, 0.0418, 0.0146, 0.0365, 0.0503, 0.0658, 0.0832'

In [14]:
# lr=1e-6, constant, neg_samples=96

data = """
wandb:    eval/steps_per_second 1.608
wandb:                test/loss 7.14913
wandb:               test/mrr@1 0.01538
wandb:              test/mrr@10 0.03686
wandb:              test/mrr@20 0.04046
wandb:               test/mrr@5 0.03035
wandb:              test/mrr@50 0.04345
wandb:              test/ndcg@1 0.01538
wandb:             test/ndcg@10 0.0538
wandb:             test/ndcg@20 0.067
wandb:              test/ndcg@5 0.03765
wandb:             test/ndcg@50 0.0856
wandb:            test/recall@1 0.01538
wandb:           test/recall@10 0.11077
wandb:           test/recall@20 0.16308
wandb:            test/recall@5 0.06
wandb:           test/recall@50 0.25692
wandb:             test/runtime 50.6674
wandb:  test/samples_per_second 25.658
"""

extract_metrics(data)

'0.0154, 0.06, 0.1108, 0.1631, 0.2569, 0.0154, 0.0303, 0.0369, 0.0405, 0.0435, 0.0154, 0.0377, 0.0538, 0.067, 0.0856'

In [5]:
# lr=5e-6, num_negative_samples=128, bs=8, gas=4, lr_scheduler_type inverse_sqrt, warmup_ratio=0.2, weight_decay=0.001

data = """
wandb:                test/loss 6.84322
wandb:               test/mrr@1 0.01615
wandb:              test/mrr@10 0.04557
wandb:              test/mrr@20 0.05019
wandb:               test/mrr@5 0.03821
wandb:              test/mrr@50 0.05389
wandb:              test/ndcg@1 0.01615
wandb:             test/ndcg@10 0.0673
wandb:             test/ndcg@20 0.08432
wandb:              test/ndcg@5 0.04957
wandb:             test/ndcg@50 0.10692
wandb:            test/recall@1 0.01615
wandb:           test/recall@10 0.13923
wandb:           test/recall@20 0.20692
wandb:            test/recall@5 0.08462
wandb:           test/recall@50 0.32
wandb:             test/runtime 106.6008
wandb:  test/samples_per_second 12.195
wandb:    test/steps_per_second 0.769
wandb:               total_flos 0
wandb:              train/epoch 1
wandb:        train/global_step 167
wandb:          train/grad_norm 15.56251
wandb:      train/learning_rate 0.0
wandb:               train/loss 3.6105
wandb:               train_loss 3.93329
wandb:            train_runtime 8192.4939
wandb: train_samples_per_second 1.304
wandb:   train_steps_per_second 0.02
"""

extract_metrics(data)

'0.0162, 0.0846, 0.1392, 0.2069, 0.32, 0.0162, 0.0382, 0.0456, 0.0502, 0.0539, 0.0162, 0.0496, 0.0673, 0.0843, 0.1069'

In [8]:
# lr=5e-6, num_negative_samples=96, bs=8, gas=8, lr_scheduler_type cosine_with_min_lr, warmup_ratio=0.3, weight_decay=0.001, min_lr=1e-7, max_grad_norm=0.5

data = """
wandb:                test/loss 6.87735                                                                                                      
wandb:               test/mrr@1 0.01308                                                                                                      
wandb:              test/mrr@10 0.0427                                                                                                       
wandb:              test/mrr@20 0.04713                                                                                                      
wandb:               test/mrr@5 0.03609                                                                                                      
wandb:              test/mrr@50 0.05105
wandb:              test/ndcg@1 0.01308
wandb:             test/ndcg@10 0.06321
wandb:             test/ndcg@20 0.07961
wandb:              test/ndcg@5 0.04725
wandb:             test/ndcg@50 0.10402
wandb:            test/recall@1 0.01308
wandb:           test/recall@10 0.13077
wandb:           test/recall@20 0.19615
wandb:            test/recall@5 0.08154
wandb:           test/recall@50 0.31923
wandb:             test/runtime 138.3797
"""

extract_metrics(data)

'0.0131, 0.0815, 0.1308, 0.1961, 0.3192, 0.0131, 0.0361, 0.0427, 0.0471, 0.051, 0.0131, 0.0473, 0.0632, 0.0796, 0.104'

In [10]:
# lr=5e-6, num_negative_samples=96, bs=8, gas=8, lr_scheduler_type cosine_with_restarts（两次）, warmup_ratio=0.1, weight_decay=0.01, max_grad_norm=0.5

data = """
wandb:                test/loss 6.87788                                                                                                      
wandb:               test/mrr@1 0.01385                                                                                                      
wandb:              test/mrr@10 0.04439                                                                                                      
wandb:              test/mrr@20 0.04922
wandb:               test/mrr@5 0.03785
wandb:              test/mrr@50 0.05327
wandb:              test/ndcg@1 0.01385
wandb:             test/ndcg@10 0.06521
wandb:             test/ndcg@20 0.08274
wandb:              test/ndcg@5 0.04902
wandb:             test/ndcg@50 0.10827
wandb:            test/recall@1 0.01385
wandb:           test/recall@10 0.13385
wandb:           test/recall@20 0.20308
wandb:            test/recall@5 0.08308
wandb:           test/recall@50 0.33231
"""

extract_metrics(data)

'0.0138, 0.0831, 0.1338, 0.2031, 0.3323, 0.0138, 0.0379, 0.0444, 0.0492, 0.0533, 0.0138, 0.049, 0.0652, 0.0827, 0.1083'

## like和dislike放在一起

In [3]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，文本内有like/dislike划分

data = """
wandb:                test/loss 6.80229                                                                                                 
wandb:               test/mrr@1 0.01345                                                                                                 
wandb:              test/mrr@10 0.0448                                                                                                  
wandb:              test/mrr@20 0.04907                                                                                                 
wandb:               test/mrr@5 0.03637                                                                                                 
wandb:              test/mrr@50 0.05244                                                                                                 
wandb:              test/ndcg@1 0.01345
wandb:             test/ndcg@10 0.06944
wandb:             test/ndcg@20 0.08558
wandb:              test/ndcg@5 0.04866
wandb:             test/ndcg@50 0.10683
wandb:            test/recall@1 0.01345
wandb:           test/recall@10 0.15172
wandb:           test/recall@20 0.21674
wandb:            test/recall@5 0.0867
wandb:           test/recall@50 0.32436
"""

extract_metrics(data)

'0.0135, 0.0867, 0.1517, 0.2167, 0.3244, 0.0135, 0.0364, 0.0448, 0.0491, 0.0524, 0.0135, 0.0487, 0.0694, 0.0856, 0.1068'

In [3]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，文本内有like/dislike划分，第二次测试

data = """
wandb:                test/loss 6.84522
wandb:               test/mrr@1 0.01121
wandb:              test/mrr@10 0.04025
wandb:              test/mrr@20 0.04564
wandb:               test/mrr@5 0.03438
wandb:              test/mrr@50 0.04941
wandb:              test/ndcg@1 0.01121
wandb:             test/ndcg@10 0.06116
wandb:             test/ndcg@20 0.08068
wandb:              test/ndcg@5 0.04663
wandb:             test/ndcg@50 0.10445
wandb:            test/recall@1 0.01121
wandb:           test/recall@10 0.13004
wandb:           test/recall@20 0.20703
wandb:            test/recall@5 0.08445
wandb:           test/recall@50 0.32735
"""

extract_metrics(data)


'0.0112, 0.0844, 0.13, 0.207, 0.3273, 0.0112, 0.0344, 0.0403, 0.0456, 0.0494, 0.0112, 0.0466, 0.0612, 0.0807, 0.1045'

In [6]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，文本内有like/dislike划分，第三次测试

data = """
wandb:               test/mrr@1 0.01046
wandb:              test/mrr@10 0.04002
wandb:              test/mrr@20 0.04486
wandb:               test/mrr@5 0.03137
wandb:              test/mrr@50 0.04851
wandb:              test/ndcg@1 0.01046
wandb:             test/ndcg@10 0.06304
wandb:             test/ndcg@20 0.08063
wandb:              test/ndcg@5 0.04218
wandb:             test/ndcg@50 0.10367
wandb:            test/recall@1 0.01046
wandb:           test/recall@10 0.13976
wandb:           test/recall@20 0.20927
wandb:            test/recall@5 0.07549
wandb:           test/recall@50 0.32586
"""

extract_metrics(data)

'0.0105, 0.0755, 0.1398, 0.2093, 0.3259, 0.0105, 0.0314, 0.04, 0.0449, 0.0485, 0.0105, 0.0422, 0.063, 0.0806, 0.1037'

In [5]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，文本内无like/dislike，只有movies和attributes（分别包含like和dislike的数据）

data = """
wandb:                test/loss 6.94579
wandb:               test/mrr@1 0.00897
wandb:              test/mrr@10 0.03635
wandb:              test/mrr@20 0.04176
wandb:               test/mrr@5 0.02835
wandb:              test/mrr@50 0.04521
wandb:              test/ndcg@1 0.00897
wandb:             test/ndcg@10 0.05659
wandb:             test/ndcg@20 0.07663
wandb:              test/ndcg@5 0.03724
wandb:             test/ndcg@50 0.09782
wandb:            test/recall@1 0.00897
wandb:           test/recall@10 0.12407
wandb:           test/recall@20 0.20404
wandb:            test/recall@5 0.06428
wandb:           test/recall@50 0.31016
wandb:             test/runtime 111.0837
wandb:  test/samples_per_second 12.045
"""

extract_metrics(data)

'0.009, 0.0643, 0.1241, 0.204, 0.3102, 0.009, 0.0284, 0.0364, 0.0418, 0.0452, 0.009, 0.0372, 0.0566, 0.0766, 0.0978'

## 7b生成数据测试

In [4]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，原版qwen 7b产生

data = """
wandb:                test/loss 6.98243                                                                                                 
wandb:               test/mrr@1 0.01585                                                                                                 
wandb:              test/mrr@10 0.04021                                                                                                 
wandb:              test/mrr@20 0.04491                                                                                                 
wandb:               test/mrr@5 0.03361                                                                                                 
wandb:              test/mrr@50 0.04795                                                                                                 
wandb:              test/ndcg@1 0.01585                                                                                                 
wandb:             test/ndcg@10 0.05879                                                                                                 
wandb:             test/ndcg@20 0.0762                                                                                                  
wandb:              test/ndcg@5 0.04273                                                                                                 
wandb:             test/ndcg@50 0.09517                                                                                                 
wandb:            test/recall@1 0.01585                                                                                                 
wandb:           test/recall@10 0.12075                                                                                                 
wandb:           test/recall@20 0.19019                                                                                                 
wandb:            test/recall@5 0.07094                                                                                                 
wandb:           test/recall@50 0.28604                                                                                                 
wandb:             test/runtime 288.4573
"""

extract_metrics(data)

'0.0158, 0.0709, 0.1207, 0.1902, 0.286, 0.0158, 0.0336, 0.0402, 0.0449, 0.0479, 0.0158, 0.0427, 0.0588, 0.0762, 0.0952'

In [3]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，使用7b生成的valid/test

data = """
wandb:    eval/steps_per_second 1.641
wandb:                test/loss 7.13416
wandb:               test/mrr@1 0.01434
wandb:              test/mrr@10 0.03241
wandb:              test/mrr@20 0.03643
wandb:               test/mrr@5 0.02721
wandb:              test/mrr@50 0.03949
wandb:              test/ndcg@1 0.01434
wandb:             test/ndcg@10 0.04679
wandb:             test/ndcg@20 0.06122
wandb:              test/ndcg@5 0.03385
wandb:             test/ndcg@50 0.08036
wandb:            test/recall@1 0.01434
wandb:           test/recall@10 0.09509
wandb:           test/recall@20 0.1517
wandb:            test/recall@5 0.05434
wandb:           test/recall@50 0.2483
wandb:             test/runtime 210.0132
"""

extract_metrics(data)

'0.0143, 0.0543, 0.0951, 0.1517, 0.2483, 0.0143, 0.0272, 0.0324, 0.0364, 0.0395, 0.0143, 0.0338, 0.0468, 0.0612, 0.0804'

In [7]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，使用SFT后7b生成的valid/test

data = """
wandb:               test/mrr@1 0.0124                                                                                                     
wandb:              test/mrr@10 0.03844                                                                                                    
wandb:              test/mrr@20 0.04339                                                                                                    
wandb:               test/mrr@5 0.03112                                                                                                    
wandb:              test/mrr@50 0.04751                                                                                                    
wandb:              test/ndcg@1 0.0124
wandb:             test/ndcg@10 0.05753
wandb:             test/ndcg@20 0.07583
wandb:              test/ndcg@5 0.03992
wandb:             test/ndcg@50 0.10174
wandb:            test/recall@1 0.0124
wandb:           test/recall@10 0.12093
wandb:           test/recall@20 0.1938
wandb:            test/recall@5 0.06667
wandb:           test/recall@50 0.32481
wandb:             test/runtime 215.5864
"""

extract_metrics(data)

'0.0124, 0.0667, 0.1209, 0.1938, 0.3248, 0.0124, 0.0311, 0.0384, 0.0434, 0.0475, 0.0124, 0.0399, 0.0575, 0.0758, 0.1017'

In [10]:
# lr=5e-6, bs=2, num_neg_samples=32, gas=4，使用SFT后7b生成的valid/test，movie_info使用qwen-32b得到的

data = """
wandb:                test/loss 6.91795                                                                                                    
wandb:               test/mrr@1 0.0093                                                                                                     
wandb:              test/mrr@10 0.0382                                                                                                     
wandb:              test/mrr@20 0.04322                                                                                                    
wandb:               test/mrr@5 0.03189                                                                                                    
wandb:              test/mrr@50 0.04702                                                                                                    
wandb:              test/ndcg@1 0.0093                                                                                                     
wandb:             test/ndcg@10 0.05765                                                                                                    
wandb:             test/ndcg@20 0.07576
wandb:              test/ndcg@5 0.04221
wandb:             test/ndcg@50 0.09967
wandb:            test/recall@1 0.0093
wandb:           test/recall@10 0.12171
wandb:           test/recall@20 0.19302
wandb:            test/recall@5 0.07364
wandb:           test/recall@50 0.31395
"""

extract_metrics(data)

'0.0093, 0.0736, 0.1217, 0.193, 0.314, 0.0093, 0.0319, 0.0382, 0.0432, 0.047, 0.0093, 0.0422, 0.0576, 0.0758, 0.0997'

In [3]:
# 7b sample-3 SFT

data = """
wandb:               test/mrr@1 0.01279                                                                                                        
wandb:              test/mrr@10 0.0399                                                                                                         
wandb:              test/mrr@20 0.04373                                                                                                        
wandb:               test/mrr@5 0.0318                                                                                                         
wandb:              test/mrr@50 0.04734                                                                                                        
wandb:              test/ndcg@1 0.01279
wandb:             test/ndcg@10 0.06162
wandb:             test/ndcg@20 0.07554
wandb:              test/ndcg@5 0.04153
wandb:             test/ndcg@50 0.09828
wandb:            test/recall@1 0.01279
wandb:           test/recall@10 0.13469
wandb:           test/recall@20 0.18962
wandb:            test/recall@5 0.07148
wandb:           test/recall@50 0.30474
"""

extract_metrics(data)

'0.0128, 0.0715, 0.1347, 0.1896, 0.3047, 0.0128, 0.0318, 0.0399, 0.0437, 0.0473, 0.0128, 0.0415, 0.0616, 0.0755, 0.0983'

In [3]:
# 7b sample-5 0.2 top1 train/valid/test

data = """
wandb:               test/mrr@1 0.01809                                                                                                   
wandb:              test/mrr@10 0.04193                                                                                                   
wandb:              test/mrr@20 0.04557                                                                                                   
wandb:               test/mrr@5 0.03505                                                                                                   
wandb:              test/mrr@50 0.04927                                                                                                   
wandb:              test/ndcg@1 0.01809                                                                                                   
wandb:             test/ndcg@10 0.06096                                                                                                   
wandb:             test/ndcg@20 0.07415                                                                                                   
wandb:              test/ndcg@5 0.04449                                                                                                   
wandb:             test/ndcg@50 0.09771                                                                                                   
wandb:            test/recall@1 0.01809
wandb:           test/recall@10 0.12434
wandb:           test/recall@20 0.17634
wandb:            test/recall@5 0.07385
wandb:           test/recall@50 0.29616
"""

extract_metrics(data)

'0.0181, 0.0738, 0.1243, 0.1763, 0.2962, 0.0181, 0.035, 0.0419, 0.0456, 0.0493, 0.0181, 0.0445, 0.061, 0.0741, 0.0977'

In [2]:
# 7b sample-5 0.2 all train/valid/test

data = """
wandb:                test/loss 6.97289                                                                                                              
wandb:               test/mrr@1 0.01353                                                                                                              
wandb:              test/mrr@10 0.03687                                                                                                              
wandb:              test/mrr@20 0.04133                                                                                                              
wandb:               test/mrr@5 0.03018                                                                                                              
wandb:              test/mrr@50 0.04469                                                                                                              
wandb:              test/ndcg@1 0.01353                                                                                                              
wandb:             test/ndcg@10 0.05512                                                                                                              
wandb:             test/ndcg@20 0.07143                                                                                                              
wandb:              test/ndcg@5 0.03829                                                                                                              
wandb:             test/ndcg@50 0.09276                                                                                                              
wandb:            test/recall@1 0.01353                                                                                                              
wandb:           test/recall@10 0.11654                                                                                                              
wandb:           test/recall@20 0.1812                                                                                                               
wandb:            test/recall@5 0.06316                                                                                                              
wandb:           test/recall@50 0.28947
wandb:             test/runtime 184.0131
"""

extract_metrics(data)

'0.0135, 0.0632, 0.1165, 0.1812, 0.2895, 0.0135, 0.0302, 0.0369, 0.0413, 0.0447, 0.0135, 0.0383, 0.0551, 0.0714, 0.0928'

In [5]:
# 7b sample-3 train/valid/test preference

data = """
wandb:                test/loss 6.8618                                                                                                        
wandb:               test/mrr@1 0.01345                                                                                                       
wandb:              test/mrr@10 0.03988                                                                                                       
wandb:              test/mrr@20 0.0445                                                                                                        
wandb:               test/mrr@5 0.0313                                                                                                        
wandb:              test/mrr@50 0.04815                                                                                                       
wandb:              test/ndcg@1 0.01345                                                                                                       
wandb:             test/ndcg@10 0.0619                                                                                                        
wandb:             test/ndcg@20 0.07885                                                                                                       
wandb:              test/ndcg@5 0.04084                                                                                                       
wandb:             test/ndcg@50 0.10188                                                                                                       
wandb:            test/recall@1 0.01345                                                                                                       
wandb:           test/recall@10 0.13602                                                                                                       
wandb:           test/recall@20 0.20329                                                                                                       
wandb:            test/recall@5 0.07025                                                                                                       
wandb:           test/recall@50 0.31988
"""

extract_metrics(data)

'0.0135, 0.0703, 0.136, 0.2033, 0.3199, 0.0135, 0.0313, 0.0399, 0.0445, 0.0481, 0.0135, 0.0408, 0.0619, 0.0789, 0.1019'

In [7]:
# 7b all 5 top1 train/valid/test preference

data = """
wandb:                test/loss 6.91557                                                                                                       
wandb:               test/mrr@1 0.0142                                                                                                        
wandb:              test/mrr@10 0.03902                                                                                                       
wandb:              test/mrr@20 0.04434                                                                                                       
wandb:               test/mrr@5 0.03271                                                                                                       
wandb:              test/mrr@50 0.048                                                                                                         
wandb:              test/ndcg@1 0.0142                                                                                                        
wandb:             test/ndcg@10 0.05758
wandb:             test/ndcg@20 0.07738
wandb:              test/ndcg@5 0.04175
wandb:             test/ndcg@50 0.10083
wandb:            test/recall@1 0.0142
wandb:           test/recall@10 0.11958
wandb:           test/recall@20 0.1988
wandb:            test/recall@5 0.06951
wandb:           test/recall@50 0.31839
"""

extract_metrics(data)

'0.0142, 0.0695, 0.1196, 0.1988, 0.3184, 0.0142, 0.0327, 0.039, 0.0443, 0.048, 0.0142, 0.0418, 0.0576, 0.0774, 0.1008'

In [8]:
# 7b all 5 top1 train/valid/test preference

data = """
wandb:                test/loss 6.86643                                                                                                       
wandb:               test/mrr@1 0.01046                                                                                                       
wandb:              test/mrr@10 0.03849                                                                                                       
wandb:              test/mrr@20 0.04376                                                                                                       
wandb:               test/mrr@5 0.0312                                                                                                        
wandb:              test/mrr@50 0.04733                                                                                                       
wandb:              test/ndcg@1 0.01046                                                                                                       
wandb:             test/ndcg@10 0.05967                                                                                                       
wandb:             test/ndcg@20 0.07868                                                                                                       
wandb:              test/ndcg@5 0.04189
wandb:             test/ndcg@50 0.10069
wandb:            test/recall@1 0.01046
wandb:           test/recall@10 0.13004
wandb:           test/recall@20 0.20478
wandb:            test/recall@5 0.07474
wandb:           test/recall@50 0.3154
wandb:             test/runtime 94.3323
"""

extract_metrics(data)

'0.0105, 0.0747, 0.13, 0.2048, 0.3154, 0.0105, 0.0312, 0.0385, 0.0438, 0.0473, 0.0105, 0.0419, 0.0597, 0.0787, 0.1007'

In [4]:
# 7b all 5 top1 0.3 train/valid/test

data = """
wandb:                test/loss 7.01601                                                                                                    
wandb:               test/mrr@1 0.01206                                                                                                    
wandb:              test/mrr@10 0.03766                                                                                                    
wandb:              test/mrr@20 0.04237                                                                                                    
wandb:               test/mrr@5 0.03161                                                                                                    
wandb:              test/mrr@50 0.04615                                                                                                    
wandb:              test/ndcg@1 0.01206
wandb:             test/ndcg@10 0.05586
wandb:             test/ndcg@20 0.07302
wandb:              test/ndcg@5 0.04108
wandb:             test/ndcg@50 0.09619
wandb:            test/recall@1 0.01206
wandb:           test/recall@10 0.11605
wandb:           test/recall@20 0.18387
wandb:            test/recall@5 0.07008
wandb:           test/recall@50 0.29992
wandb:             test/runtime 178.2775
"""

extract_metrics(data)

'0.0121, 0.0701, 0.1161, 0.1839, 0.2999, 0.0121, 0.0316, 0.0377, 0.0424, 0.0461, 0.0121, 0.0411, 0.0559, 0.073, 0.0962'

In [5]:
# 7b all 5 top1 0.3 train/valid/test preference

data = """
wandb:                test/loss 6.83485                                                                                                    
wandb:               test/mrr@1 0.01644                                                                                                    
wandb:              test/mrr@10 0.04657                                                                                                    
wandb:              test/mrr@20 0.05129                                                                                                    
wandb:               test/mrr@5 0.03954                                                                                                    
wandb:              test/mrr@50 0.0549                                                                                                     
wandb:              test/ndcg@1 0.01644                                                                                                    
wandb:             test/ndcg@10 0.0679                                                                                                     
wandb:             test/ndcg@20 0.0851                                                                                                     
wandb:              test/ndcg@5 0.05079                                                                                                    
wandb:             test/ndcg@50 0.10784                                                                                                    
wandb:            test/recall@1 0.01644                                                                                                    
wandb:           test/recall@10 0.13827                                                                                                    
wandb:           test/recall@20 0.20628                                                                                                    
wandb:            test/recall@5 0.0852
wandb:           test/recall@50 0.32138
"""

extract_metrics(data)

'0.0164, 0.0852, 0.1383, 0.2063, 0.3214, 0.0164, 0.0395, 0.0466, 0.0513, 0.0549, 0.0164, 0.0508, 0.0679, 0.0851, 0.1078'

In [6]:
# sample 3 top1 preference

data = """
wandb:                test/loss 6.90632                                                                                                     
wandb:               test/mrr@1 0.00972                                                                                                     
wandb:              test/mrr@10 0.03665                                                                                                     
wandb:              test/mrr@20 0.04151                                                                                                     
wandb:               test/mrr@5 0.02992                                                                                                     
wandb:              test/mrr@50 0.04542                                                                                                     
wandb:              test/ndcg@1 0.00972                                                                                                     
wandb:             test/ndcg@10 0.05674                                                                                                     
wandb:             test/ndcg@20 0.07437                                                                                                     
wandb:              test/ndcg@5 0.04037                                                                                                     
wandb:             test/ndcg@50 0.0984                                                                                                      
wandb:            test/recall@1 0.00972                                                                                                     
wandb:           test/recall@10 0.12332                                                                                                     
wandb:           test/recall@20 0.19283                                                                                                     
wandb:            test/recall@5 0.0725                                                                                                      
wandb:           test/recall@50 0.31315                                                                                                     
wandb:             test/runtime 90.1477
"""

extract_metrics(data)

'0.0097, 0.0725, 0.1233, 0.1928, 0.3131, 0.0097, 0.0299, 0.0367, 0.0415, 0.0454, 0.0097, 0.0404, 0.0567, 0.0744, 0.0984'

In [2]:
# sample 3 top1 preference 2 epoch

data = """
wandb:                test/loss 6.94539                                                                                                        
wandb:               test/mrr@1 0.00897                                                                                                        
wandb:              test/mrr@10 0.037                                                                                                          
wandb:              test/mrr@20 0.042                                                                                                          
wandb:               test/mrr@5 0.02942
wandb:              test/mrr@50 0.04542
wandb:              test/ndcg@1 0.00897
wandb:             test/ndcg@10 0.058
wandb:             test/ndcg@20 0.07643
wandb:              test/ndcg@5 0.03949
wandb:             test/ndcg@50 0.09786
wandb:            test/recall@1 0.00897
wandb:           test/recall@10 0.1278
wandb:           test/recall@20 0.20105
wandb:            test/recall@5 0.07025
wandb:           test/recall@50 0.30942
"""

extract_metrics(data)

'0.009, 0.0703, 0.1278, 0.2011, 0.3094, 0.009, 0.0294, 0.037, 0.042, 0.0454, 0.009, 0.0395, 0.058, 0.0764, 0.0979'

In [6]:
# matching / generate preference

data = """
wandb:                test/loss 6.93941                                                                                                     
wandb:               test/mrr@1 0.0142                                                                                                      
wandb:              test/mrr@10 0.04375                                                                                                     
wandb:              test/mrr@20 0.04793                                                                                                     
wandb:               test/mrr@5 0.03666                                                                                                     
wandb:              test/mrr@50 0.05141                                                                                                     
wandb:              test/ndcg@1 0.0142                                                                                                      
wandb:             test/ndcg@10 0.06478                                                                                                     
wandb:             test/ndcg@20 0.08019                                                                                                     
wandb:              test/ndcg@5 0.04733                                                                                                     
wandb:             test/ndcg@50 0.10238                                                                                                     
wandb:            test/recall@1 0.0142                                                                                                      
wandb:           test/recall@10 0.13453                                                                                                     
wandb:           test/recall@20 0.19581                                                                                                     
wandb:            test/recall@5 0.07997                                                                                                     
wandb:           test/recall@50 0.30867
"""

extract_metrics(data)

'0.0142, 0.08, 0.1345, 0.1958, 0.3087, 0.0142, 0.0367, 0.0437, 0.0479, 0.0514, 0.0142, 0.0473, 0.0648, 0.0802, 0.1024'

In [2]:
# match/generate like-dislike

data = """
wandb:                test/loss 6.95996                                                                                                     
wandb:               test/mrr@1 0.00905                                                                                                     
wandb:              test/mrr@10 0.03542                                                                                                     
wandb:              test/mrr@20 0.04038                                                                                                     
wandb:               test/mrr@5 0.02925                                                                                                     
wandb:              test/mrr@50 0.04427                                                                                                     
wandb:              test/ndcg@1 0.00905                                                                                                     
wandb:             test/ndcg@10 0.05372                                                                                                     
wandb:             test/ndcg@20 0.07171
wandb:              test/ndcg@5 0.03896
wandb:             test/ndcg@50 0.09593
wandb:            test/recall@1 0.00905
wandb:           test/recall@10 0.11388
wandb:           test/recall@20 0.18477
wandb:            test/recall@5 0.06863
wandb:           test/recall@50 0.30694
"""

extract_metrics(data)

'0.0091, 0.0686, 0.1139, 0.1848, 0.3069, 0.0091, 0.0293, 0.0354, 0.0404, 0.0443, 0.0091, 0.039, 0.0537, 0.0717, 0.0959'

In [4]:
# 7b origin preference

data = """
wandb:                test/loss 6.89574
wandb:               test/mrr@1 0.01121
wandb:              test/mrr@10 0.04002
wandb:              test/mrr@20 0.04463
wandb:               test/mrr@5 0.03225
wandb:              test/mrr@50 0.04844
wandb:              test/ndcg@1 0.01121
wandb:             test/ndcg@10 0.06191
wandb:             test/ndcg@20 0.07896
wandb:              test/ndcg@5 0.04321
wandb:             test/ndcg@50 0.1027
wandb:            test/recall@1 0.01121
wandb:           test/recall@10 0.13453
wandb:           test/recall@20 0.20254
wandb:            test/recall@5 0.07698
wandb:           test/recall@50 0.32212
"""

extract_metrics(data)

'0.0112, 0.077, 0.1345, 0.2025, 0.3221, 0.0112, 0.0323, 0.04, 0.0446, 0.0484, 0.0112, 0.0432, 0.0619, 0.079, 0.1027'

## matching -> matching/generate

In [3]:
# matching -> matching/generate preference

data = """
wandb:                test/loss 6.93493                                                                                                             
wandb:               test/mrr@1 0.00822                                                                                                             
wandb:              test/mrr@10 0.03725                                                                                                             
wandb:              test/mrr@20 0.04222                                                                                                             
wandb:               test/mrr@5 0.03027                                                                                                             
wandb:              test/mrr@50 0.04591                                                                                                             
wandb:              test/ndcg@1 0.00822                                                                                                             
wandb:             test/ndcg@10 0.05761                                   
wandb:             test/ndcg@20 0.07587                                   
wandb:              test/ndcg@5 0.04069                                   
wandb:             test/ndcg@50 0.09886                                   
wandb:            test/recall@1 0.00822                                   
wandb:           test/recall@10 0.12481                                   
wandb:           test/recall@20 0.19731                                   
wandb:            test/recall@5 0.0725                                    
wandb:           test/recall@50 0.31315
"""

extract_metrics(data)


'0.0082, 0.0725, 0.1248, 0.1973, 0.3131, 0.0082, 0.0303, 0.0372, 0.0422, 0.0459, 0.0082, 0.0407, 0.0576, 0.0759, 0.0989'

In [4]:
# matching -> matching/generate

data = """
wandb:                test/loss 6.95575                                                                                                             
wandb:               test/mrr@1 0.00828                                                                                                             
wandb:              test/mrr@10 0.03705                                                                                                             
wandb:              test/mrr@20 0.04166                                                                                                             
wandb:               test/mrr@5 0.03001                                                                                                             
wandb:              test/mrr@50 0.04536                                                                                                             
wandb:              test/ndcg@1 0.00828                                                                                                             
wandb:             test/ndcg@10 0.05729                                                                                                             
wandb:             test/ndcg@20 0.07368                                   
wandb:              test/ndcg@5 0.0401                                    
wandb:             test/ndcg@50 0.09682                                   
wandb:            test/recall@1 0.00828                                   
wandb:           test/recall@10 0.12425                                   
wandb:           test/recall@20 0.18825                                   
wandb:            test/recall@5 0.07078                                   
wandb:           test/recall@50 0.30497
"""

extract_metrics(data)


'0.0083, 0.0708, 0.1242, 0.1883, 0.305, 0.0083, 0.03, 0.037, 0.0417, 0.0454, 0.0083, 0.0401, 0.0573, 0.0737, 0.0968'

In [2]:
# dpo（模型打分，生成的内容形式较为自由）

data = """
wandb:               test/mrr@1 0.01196
wandb:              test/mrr@10 0.03792
wandb:              test/mrr@20 0.04345
wandb:               test/mrr@5 0.02991
wandb:              test/mrr@50 0.04733
wandb:              test/ndcg@1 0.01196
wandb:             test/ndcg@10 0.05859
wandb:             test/ndcg@20 0.07839
wandb:              test/ndcg@5 0.03907
wandb:             test/ndcg@50 0.103
wandb:            test/recall@1 0.01196
wandb:           test/recall@10 0.1278
wandb:           test/recall@20 0.20553
wandb:            test/recall@5 0.06726
wandb:           test/recall@50 0.33034
"""

extract_metrics(data)

'0.012, 0.0673, 0.1278, 0.2055, 0.3303, 0.012, 0.0299, 0.0379, 0.0435, 0.0473, 0.012, 0.0391, 0.0586, 0.0784, 0.103'

In [3]:
## dpo matching/generate top1

data = """
wandb:                test/loss 6.99101
wandb:               test/mrr@1 0.01196
wandb:              test/mrr@10 0.03954
wandb:              test/mrr@20 0.04439
wandb:               test/mrr@5 0.03303
wandb:              test/mrr@50 0.04753
wandb:              test/ndcg@1 0.01196
wandb:             test/ndcg@10 0.05836
wandb:             test/ndcg@20 0.07635
wandb:              test/ndcg@5 0.04277
wandb:             test/ndcg@50 0.09677
wandb:            test/recall@1 0.01196
wandb:           test/recall@10 0.12033
wandb:           test/recall@20 0.19208
wandb:            test/recall@5 0.0725
wandb:           test/recall@50 0.29671
"""

extract_metrics(data)

'0.012, 0.0725, 0.1203, 0.1921, 0.2967, 0.012, 0.033, 0.0395, 0.0444, 0.0475, 0.012, 0.0428, 0.0584, 0.0764, 0.0968'

In [2]:
# dpo click preference

data = """
wandb:                test/loss 6.94789
wandb:               test/mrr@1 0.0142
wandb:              test/mrr@10 0.03664
wandb:              test/mrr@20 0.04138
wandb:               test/mrr@5 0.03
wandb:              test/mrr@50 0.04487
wandb:              test/ndcg@1 0.0142
wandb:             test/ndcg@10 0.05465
wandb:             test/ndcg@20 0.07275
wandb:              test/ndcg@5 0.03822
wandb:             test/ndcg@50 0.09475
wandb:            test/recall@1 0.0142
wandb:           test/recall@10 0.1151
wandb:           test/recall@20 0.18834
wandb:            test/recall@5 0.06353
wandb:           test/recall@50 0.2997
wandb:             test/runtime 92.7841
wandb:  test/samples_per_second 14.421
wandb:    test/steps_per_second 3.611
wandb:               total_flos 0
wandb:              train/epoch 0.99889
wandb:        train/global_step 672
wandb:          train/grad_norm 9.8872
wandb:      train/learning_rate 1e-05
wandb:               train/loss 2.8257
wandb:               train_loss 2.65461
wandb:            train_runtime 9496.2436
"""

extract_metrics(data)

'0.0142, 0.0635, 0.1151, 0.1883, 0.2997, 0.0142, 0.03, 0.0366, 0.0414, 0.0449, 0.0142, 0.0382, 0.0546, 0.0727, 0.0948'

In [2]:
# dpo click left preference

data = """
wandb:                test/loss 6.96831
wandb:               test/mrr@1 0.00897
wandb:              test/mrr@10 0.03667
wandb:              test/mrr@20 0.04166
wandb:               test/mrr@5 0.02957
wandb:              test/mrr@50 0.04531
wandb:              test/ndcg@1 0.00897
wandb:             test/ndcg@10 0.05656
wandb:             test/ndcg@20 0.07534
wandb:              test/ndcg@5 0.03925
wandb:             test/ndcg@50 0.09761
wandb:            test/recall@1 0.00897
wandb:           test/recall@10 0.12257
wandb:           test/recall@20 0.19806
wandb:            test/recall@5 0.06876
wandb:           test/recall@50 0.30942
"""

extract_metrics(data)

'0.009, 0.0688, 0.1226, 0.1981, 0.3094, 0.009, 0.0296, 0.0367, 0.0417, 0.0453, 0.009, 0.0393, 0.0566, 0.0753, 0.0976'

In [5]:
# 覆盖率dpo

data = """
wandb:                test/loss 6.81946                                                                                             
wandb:               test/mrr@1 0.00747                                                                                             
wandb:              test/mrr@10 0.03862                                                                                             
wandb:              test/mrr@20 0.04332                                                                                             
wandb:               test/mrr@5 0.03019                                                                                             
wandb:              test/mrr@50 0.0473                                                                                              
wandb:              test/ndcg@1 0.00747                                                                                             
wandb:             test/ndcg@10 0.06283                                                                                             
wandb:             test/ndcg@20 0.08012                                                                                             
wandb:              test/ndcg@5 0.0422                                                                                              
wandb:             test/ndcg@50 0.105
wandb:            test/recall@1 0.00747
wandb:           test/recall@10 0.1435
wandb:           test/recall@20 0.21226
wandb:            test/recall@5 0.07922
wandb:           test/recall@50 0.33782
"""

extract_metrics(data)

'0.0075, 0.0792, 0.1435, 0.2123, 0.3378, 0.0075, 0.0302, 0.0386, 0.0433, 0.0473, 0.0075, 0.0422, 0.0628, 0.0801, 0.105'

In [3]:
# 覆盖率DPO 2

data = """
wandb:                test/loss 6.75584                                                                                                                 
wandb:               test/mrr@1 0.00747                                                                                                                 
wandb:              test/mrr@10 0.03848                                                                                                                 
wandb:              test/mrr@20 0.04368                                                                                                                 
wandb:               test/mrr@5 0.02887                                                                                                                 
wandb:              test/mrr@50 0.04759                                                                                                                 
wandb:              test/ndcg@1 0.00747                                                                                                                 
wandb:             test/ndcg@10 0.06295                                                                                                                 
wandb:             test/ndcg@20 0.08212                                                                                                                 
wandb:              test/ndcg@5 0.03957                                                                                                                 
wandb:             test/ndcg@50 0.10655
wandb:            test/recall@1 0.00747
wandb:           test/recall@10 0.14499
wandb:           test/recall@20 0.22123
wandb:            test/recall@5 0.0725 
wandb:           test/recall@50 0.34454
"""

extract_metrics(data)

'0.0075, 0.0725, 0.145, 0.2212, 0.3445, 0.0075, 0.0289, 0.0385, 0.0437, 0.0476, 0.0075, 0.0396, 0.063, 0.0821, 0.1066'

In [3]:
# prompt修改后的7b

data = """
wandb:               test/mrr@1 0.01495                                                                                                                     
wandb:              test/mrr@10 0.04244                                                                                                                     
wandb:              test/mrr@20 0.04809                                                                                                                     
wandb:               test/mrr@5 0.03587                                                                                                                     
wandb:              test/mrr@50 0.05122                                                                                                                     
wandb:              test/ndcg@1 0.01495                                                                                                                     
wandb:             test/ndcg@10 0.06167                                                                                                                     
wandb:             test/ndcg@20 0.08215                                                                                                                     
wandb:              test/ndcg@5 0.04602                                                                                                                     
wandb:             test/ndcg@50 0.10214                                                                                                                     
wandb:            test/recall@1 0.01495                                                                                                                     
wandb:           test/recall@10 0.12481                                                                                                                     
wandb:           test/recall@20 0.20553                                                                                                                     
wandb:            test/recall@5 0.07698                                                                                                                     
wandb:           test/recall@50 0.30717
"""

extract_metrics(data)

'0.0149, 0.077, 0.1248, 0.2055, 0.3072, 0.0149, 0.0359, 0.0424, 0.0481, 0.0512, 0.0149, 0.046, 0.0617, 0.0822, 0.1021'

In [3]:
# prompt修改后的7b

data = """
wandb:                test/loss 6.82058                                                                                                                     
wandb:               test/mrr@1 0.01644                                                                                                                     
wandb:              test/mrr@10 0.04022                                                                                                                     
wandb:              test/mrr@20 0.04521                                                                                                                     
wandb:               test/mrr@5 0.03175                                                                                                                     
wandb:              test/mrr@50 0.04923                                                                                                                     
wandb:              test/ndcg@1 0.01644                                                                                                                     
wandb:             test/ndcg@10 0.05988                                                                                                                     
wandb:             test/ndcg@20 0.07802                                                                                                                     
wandb:              test/ndcg@5 0.0392                                                                                                                      
wandb:             test/ndcg@50 0.10297                                                                                                                     
wandb:            test/recall@1 0.01644 
wandb:           test/recall@10 0.12631 
wandb:           test/recall@20 0.19806 
wandb:            test/recall@5 0.06203 
wandb:           test/recall@50 0.32362
"""

extract_metrics(data)

'0.0164, 0.062, 0.1263, 0.1981, 0.3236, 0.0164, 0.0318, 0.0402, 0.0452, 0.0492, 0.0164, 0.0392, 0.0599, 0.078, 0.103'

In [4]:
# 提取用户历史记录中的全部电影+属性

data = """
wandb:                test/loss 6.99816                                                                                                                   
wandb:               test/mrr@1 0.00523                                                                                                                   
wandb:              test/mrr@10 0.0291                                                                                                                    
wandb:              test/mrr@20 0.03459                                                                                                                   
wandb:               test/mrr@5 0.02265                                                                                                                   
wandb:              test/mrr@50 0.0383                                                                                                                    
wandb:              test/ndcg@1 0.00523                                                                                                                   
wandb:             test/ndcg@10 0.04822                                                                                                                   
wandb:             test/ndcg@20 0.06862                                                                                                                   
wandb:              test/ndcg@5 0.03225                                                                                                                   
wandb:             test/ndcg@50 0.09196                                                                                                                   
wandb:            test/recall@1 0.00523                                                                                                                   
wandb:           test/recall@10 0.11211
wandb:           test/recall@20 0.19357
wandb:            test/recall@5 0.06203
wandb:           test/recall@50 0.31166
"""

extract_metrics(data)

'0.0052, 0.062, 0.1121, 0.1936, 0.3117, 0.0052, 0.0226, 0.0291, 0.0346, 0.0383, 0.0052, 0.0323, 0.0482, 0.0686, 0.092'

In [2]:
# rpo方法

data = """
wandb:                test/loss 6.78508                                                                                                               
wandb:               test/mrr@1 0.01196                                                                                                               
wandb:              test/mrr@10 0.03952                                                                                                               
wandb:              test/mrr@20 0.0451                                                                                                                
wandb:               test/mrr@5 0.03188                                                                                                               
wandb:              test/mrr@50 0.04928                                                                                                               
wandb:              test/ndcg@1 0.01196                                                                                                               
wandb:             test/ndcg@10 0.0603                                                                                                                
wandb:             test/ndcg@20 0.08055                                                                                                               
wandb:              test/ndcg@5 0.04187                                                                                                               
wandb:             test/ndcg@50 0.10642                                    
wandb:            test/recall@1 0.01196                                    
wandb:           test/recall@10 0.1293                                     
wandb:           test/recall@20 0.20927                                    
wandb:            test/recall@5 0.0725                                     
wandb:           test/recall@50 0.33931
"""

extract_metrics(data)

'0.012, 0.0725, 0.1293, 0.2093, 0.3393, 0.012, 0.0319, 0.0395, 0.0451, 0.0493, 0.012, 0.0419, 0.0603, 0.0805, 0.1064'

In [3]:
## 覆盖率 + label综合打分相加筛选出的数据进行DPO

data = """
wandb:                test/loss 6.73459                                                                                                         
wandb:               test/mrr@1 0.01196                                                                                                         
wandb:              test/mrr@10 0.04224                                                                                                         
wandb:              test/mrr@20 0.04663                                                                                                         
wandb:               test/mrr@5 0.03311                                                                                                         
wandb:              test/mrr@50 0.05053                                                                                                         
wandb:              test/ndcg@1 0.01196                                                                                                         
wandb:             test/ndcg@10 0.06624                                                                                                         
wandb:             test/ndcg@20 0.08215                                                                                                         
wandb:              test/ndcg@5 0.04405                                 
wandb:             test/ndcg@50 0.1069                                  
wandb:            test/recall@1 0.01196                                 
wandb:           test/recall@10 0.14649                                 
wandb:           test/recall@20 0.20927                                 
wandb:            test/recall@5 0.07773                                 
wandb:           test/recall@50 0.33483

"""

extract_metrics(data)

'0.012, 0.0777, 0.1465, 0.2093, 0.3348, 0.012, 0.0331, 0.0422, 0.0466, 0.0505, 0.012, 0.044, 0.0662, 0.0822, 0.1069'

In [2]:
data = """
wandb:                test/loss 6.71424                                                                                                                                    
wandb:               test/mrr@1 0.01046                                                                                                                                    
wandb:              test/mrr@10 0.04011                                                                                                                                    
wandb:              test/mrr@20 0.04514                                                                                                                                    
wandb:               test/mrr@5 0.03082                                                                                                                                    
wandb:              test/mrr@50 0.04885                                                                                                                                    
wandb:              test/ndcg@1 0.01046                                                                                                                                    
wandb:             test/ndcg@10 0.06448
wandb:             test/ndcg@20 0.0828
wandb:              test/ndcg@5 0.04156
wandb:             test/ndcg@50 0.1067
wandb:            test/recall@1 0.01046
wandb:           test/recall@10 0.14649
wandb:           test/recall@20 0.21898
wandb:            test/recall@5 0.07474
wandb:           test/recall@50 0.34081
"""

extract_metrics(data)

'0.0105, 0.0747, 0.1465, 0.219, 0.3408, 0.0105, 0.0308, 0.0401, 0.0451, 0.0488, 0.0105, 0.0416, 0.0645, 0.0828, 0.1067'