In [2]:
import torch
import transformers
from transformers import AutoTokenizer
from tqdm import tqdm

import json
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("../../allmodels/Meta-Llama-3-8B-base-original")
tokenizer('hello world')

  from .autonotebook import tqdm as notebook_tqdm


{'input_ids': [128000, 15339, 1917], 'attention_mask': [1, 1, 1]}

In [3]:
def jload(path):
    with open(path, 'r') as f:
        data = json.load(f)
        f.close()
    return data


def jdump(data, path):
    with open(path, 'w') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
        f.close()


def log(out_path):
    with open('../LLaMA-Factory/data/dataset_info.json', 'r') as f:
        dataset_info = json.load(f)
    name = out_path.split('/')[-1].split('.json')[0]
    dataset_info[name] = {
        "file_name": f"{name}.json",
        "formatting": "sharegpt"
    }
    with open('../LLaMA-Factory/data/dataset_info.json', 'w') as f:
        json.dump(dataset_info, f, indent=2)
        

def get_valid_token_len_and_conv_turn(conv, max_length=2048):
    """Get the effective train token length and conversation turns
    params
        conv: complete conversation
        max_length: max_length set in trainer
    """
    valid_token = 0
    valid_train_token = 0
    valid_turn = 0
    for turn in conv:
        cur_sent_token = len(tokenizer(turn['value'])['input_ids'])
        if valid_train_token + cur_sent_token > max_length:
            break
        if turn['from'] == 'gpt':
            valid_turn += 1
            valid_train_token += cur_sent_token
    return valid_train_token, valid_turn

In [4]:
def train_token_budget_adjust(conversations, max_length=2048, token_budget=3000000, turn_budget=7000):
    """Adjust the train data to fit token budget
    params
        conversations: raw conversation data
        max_length: max_length set in trainer
        token_budget: the training token budget
    """
    valid_conversations = []
    total_valid_turns = 0
    valid_train_token = 0
    for conv in conversations:
        total_token_len = 0
        tmp_valid_turns = []
        for turn in conv['conversations']:
            cur_sent_token = len(tokenizer(turn['value'])['input_ids'])
            if total_token_len + cur_sent_token > max_length:
                break
            total_token_len += cur_sent_token
            if turn['from'] == 'gpt':
                valid_train_token += cur_sent_token
                if valid_train_token + cur_sent_token > token_budget:
                    break
                total_valid_turns += 1
            tmp_valid_turns.append(turn)

        while len(tmp_valid_turns) > 2 and tmp_valid_turns[-1]['from'] != 'gpt':
            tmp_valid_turns = tmp_valid_turns[:-1]
        valid_conversations.append({"conversations": tmp_valid_turns})

        if valid_train_token >= token_budget or total_valid_turns>=turn_budget:
            break
    return valid_conversations, total_valid_turns, valid_train_token


In [43]:
names = (
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_self_instruct",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_alpaca",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_dolly",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_sharegpt",
)

stats = {
    'name': [],
    'valid_train_token': [],
    'valid_conversations_turns': []
}

for name in tqdm(names):
    path=f'../LLaMA-Factory/data/{name}.json'
    data = jload(path)

    # QC. Preliminary: token_budget=2041300, turn_budget=7000
    valid_conversations, valid_turn, valid_train_token = train_token_budget_adjust(data)
    jdump(valid_conversations, f"../LLaMA-Factory/data/budget_adjusted_v2_{name}.json")
    log(f"../LLaMA-Factory/data/budget_adjusted_v2_{name}.json")

    print(valid_turn)
    print(valid_train_token)
    print("--"*20)

 25%|██▌       | 1/4 [00:02<00:07,  2.57s/it]

6000
445347
----------------------------------------


 50%|█████     | 2/4 [00:09<00:10,  5.13s/it]

6000
1733819
----------------------------------------


 75%|███████▌  | 3/4 [00:15<00:05,  5.68s/it]

5999
1755852
----------------------------------------


100%|██████████| 4/4 [00:26<00:00,  6.59s/it]

7003
2570873
----------------------------------------





In [42]:
names = (
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_self_instruct",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_alpaca",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_dolly",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_sharegpt",
)

stats = {
    'name': [],
    'valid_train_token': [],
    'valid_conversations_turns': []
}

for name in tqdm(names):
    path=f'../LLaMA-Factory/data/{name}.json'
    data = jload(path)

    total_train_token = 0
    total_valid_turn = 0
    for conv in data:
        valid_train_token, valid_turn = get_valid_token_len_and_conv_turn(conv['conversations'])
        total_train_token += valid_train_token
        total_valid_turn += valid_turn
    
    stats['name'].append(name)
    stats['valid_train_token'].append(total_train_token)
    stats['valid_conversations_turns'].append(total_valid_turn)

    print(total_train_token)
    print(total_valid_turn)
    print("--"*20)

# df = pd.DataFrame(stats)
# df.to_csv('../analysis/train_data_stats.csv', index=None)

 25%|██▌       | 1/4 [00:02<00:06,  2.22s/it]

445347
6000
----------------------------------------


 50%|█████     | 2/4 [00:08<00:08,  4.37s/it]

1733819
6000
----------------------------------------


 75%|███████▌  | 3/4 [00:14<00:05,  5.11s/it]

1755852
5999
----------------------------------------


 75%|███████▌  | 3/4 [00:25<00:08,  8.54s/it]


KeyboardInterrupt: 

In [39]:
names = (
    "first_2k_1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
    "second_2k_1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
    "third_2k_1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm"
)

stats = {
    'name': [],
    'valid_train_token': [],
    'valid_conversations_turns': []
}

for name in tqdm(names):
    path=f'../LLaMA-Factory/data/{name}.json'
    data = jload(path)

    valid_conversations, valid_turn, valid_train_token = train_token_budget_adjust(data, 2048, 900000, 2700)
    jdump(valid_conversations, f"../LLaMA-Factory/data/budget_adjusted_v2_{name}.json")
    log(f"../LLaMA-Factory/data/budget_adjusted_v2_{name}.json")

    print(valid_turn)
    print(valid_train_token)
    print("--"*20)

 33%|███▎      | 1/3 [00:05<00:10,  5.05s/it]

2313
900236
----------------------------------------


 67%|██████▋   | 2/3 [00:09<00:04,  4.59s/it]

2301
900089
----------------------------------------


100%|██████████| 3/3 [00:13<00:00,  4.48s/it]

2291
900101
----------------------------------------





In [40]:
names = (
    # "budget_adjusted_cleaned_5sets_knn_multiply_gamma_1",
    # "budget_adjusted_cleaned_5sets_knn_multiply_gamma_2",
    # "budget_adjusted_cleaned_5sets_kcenter_multiply_gamma_1",
    # "budget_adjusted_cleaned_5sets_kcenter_multiply_gamma_2",
    # "budget_adjusted_cleaned_no_complexity_deita_6k",
    # "budget_adjusted_1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    # "budget_adjusted_1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
    # "budget_adjusted_1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    # "budget_adjusted_1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",

    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_self_instruct",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_alpaca",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_dolly",
    "1223_cleaned_5sets_euclidean_nonlinear_v4_0.95_alpha_0.3_lambda_0.9_gamma_2.0_6k_sharegpt",
)

stats = {
    'name': [],
    'valid_train_token': [],
    'valid_conversations_turns': []
}

for name in tqdm(names):
    path=f'../LLaMA-Factory/data/{name}.json'
    data = jload(path)

    total_train_token = 0
    total_valid_turn = 0
    for conv in data:
        valid_train_token, valid_turn = get_valid_token_len_and_conv_turn(conv['conversations'])
        total_train_token += valid_train_token
        total_valid_turn += valid_turn
    
    stats['name'].append(name)
    stats['valid_train_token'].append(total_train_token)
    stats['valid_conversations_turns'].append(total_valid_turn)

    print(total_train_token)
    print(total_valid_turn)
    print("--"*20)

df = pd.DataFrame(stats)
df.to_csv('../analysis/train_data_stats.csv', index=None)

 25%|██▌       | 1/4 [00:02<00:07,  2.34s/it]

445347
6000
----------------------------------------


 50%|█████     | 2/4 [00:08<00:09,  4.54s/it]

1733819
6000
----------------------------------------


 75%|███████▌  | 3/4 [00:14<00:05,  5.27s/it]

1755852
5999
----------------------------------------


100%|██████████| 4/4 [00:28<00:00,  7.25s/it]

3670416
10201
----------------------------------------





In [24]:
names = (
    # "budget_adjusted_v2_cleaned_5sets_knn_multiply_gamma_1",
    # "budget_adjusted_v2_cleaned_5sets_knn_multiply_gamma_2",
    # "budget_adjusted_v2_cleaned_5sets_kcenter_multiply_gamma_1",
    # "budget_adjusted_v2_cleaned_5sets_kcenter_multiply_gamma_2",
    # "budget_adjusted_v2_cleaned_no_complexity_deita_6k",
    # "budget_adjusted_v2_1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    # "budget_adjusted_v2_1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
    # "budget_adjusted_v2_1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    # "budget_adjusted_v2_1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
)

stats = {
    'name': [],
    'valid_train_token': [],
    'valid_conversations_turns': []
}

for name in tqdm(names):
    path=f'../LLaMA-Factory/data/{name}.json'
    data = jload(path)

    total_train_token = 0
    total_valid_turn = 0
    for conv in data:
        valid_train_token, valid_turn = get_valid_token_len_and_conv_turn(conv['conversations'])
        total_train_token += valid_train_token
        total_valid_turn += valid_turn
    
    stats['name'].append(name)
    stats['valid_train_token'].append(total_train_token)
    stats['valid_conversations_turns'].append(total_valid_turn)

    print(total_train_token)
    print(total_valid_turn)
    print("--"*20)

df = pd.DataFrame(stats)
df.to_csv('../analysis/train_data_stats.csv', index=None)

 11%|█         | 1/9 [00:06<00:52,  6.51s/it]

2284267
7005
----------------------------------------


 22%|██▏       | 2/9 [00:13<00:49,  7.05s/it]

2519609
7000
----------------------------------------


 33%|███▎      | 3/9 [00:21<00:44,  7.44s/it]

2291435
6125
----------------------------------------


 44%|████▍     | 4/9 [00:31<00:41,  8.21s/it]

2116656
7000
----------------------------------------


 56%|█████▌    | 5/9 [00:39<00:32,  8.17s/it]

2568780
7000
----------------------------------------


 67%|██████▋   | 6/9 [00:47<00:24,  8.21s/it]

2479524
7000
----------------------------------------


 78%|███████▊  | 7/9 [00:56<00:16,  8.35s/it]

2707970
7000
----------------------------------------


 89%|████████▉ | 8/9 [01:04<00:08,  8.37s/it]

2587597
7000
----------------------------------------


100%|██████████| 9/9 [01:13<00:00,  8.15s/it]

2786629
7000
----------------------------------------





In [4]:
names = (
    "cleaned_5sets_knn_multiply_gamma_1",
    "cleaned_5sets_knn_multiply_gamma_2",
    "cleaned_5sets_kcenter_multiply_gamma_1",
    "cleaned_5sets_kcenter_multiply_gamma_2",
    "cleaned_no_complexity_deita_6k",
    "1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    "1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
    "1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    "1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
)

stats = {
    'name': [],
    'valid_train_token': [],
    'valid_conversations_turns': []
}

for name in tqdm(names):
    path=f'../LLaMA-Factory/data/{name}.json'
    data = jload(path)

    total_train_token = 0
    total_valid_turn = 0
    for conv in data:
        valid_train_token, valid_turn = get_valid_token_len_and_conv_turn(conv['conversations'])
        total_train_token += valid_train_token
        total_valid_turn += valid_turn
    
    stats['name'].append(name)
    stats['valid_train_token'].append(total_train_token)
    stats['valid_conversations_turns'].append(total_valid_turn)

    print(total_train_token)
    print(total_valid_turn)
    print("--"*20)

df = pd.DataFrame(stats)
df.to_csv('../analysis/train_data_stats.csv', index=None)

  0%|          | 0/9 [00:00<?, ?it/s]

 11%|█         | 1/9 [00:12<01:41, 12.64s/it]

4689489
14489
----------------------------------------


 22%|██▏       | 2/9 [00:24<01:26, 12.31s/it]

3975417
11161
----------------------------------------


 33%|███▎      | 3/9 [00:32<01:00, 10.06s/it]

2343499
6162
----------------------------------------


 44%|████▍     | 4/9 [00:43<00:52, 10.51s/it]

2412402
7720
----------------------------------------


 56%|█████▌    | 5/9 [00:52<00:40, 10.20s/it]

3112491
8561
----------------------------------------


 67%|██████▋   | 6/9 [01:03<00:30, 10.20s/it]

3147849
8751
----------------------------------------


 78%|███████▊  | 7/9 [01:12<00:19,  9.81s/it]

3106774
7988
----------------------------------------


 89%|████████▉ | 8/9 [01:57<00:21, 21.06s/it]

3114606
8299
----------------------------------------


100%|██████████| 9/9 [02:08<00:00, 14.28s/it]

3094305
7751
----------------------------------------





In [6]:
names = (
    # "cleaned_5sets_knn_multiply_gamma_1",
    # "cleaned_5sets_knn_multiply_gamma_2",
    # "cleaned_5sets_kcenter_multiply_gamma_1",
    # "cleaned_5sets_kcenter_multiply_gamma_2",
    # "cleaned_no_complexity_deita_6k",
    # "1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    # "1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
    # "1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_1.0_6k_wizardlm",
    # "1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_2.0_6k_wizardlm",
    "cleaned_5sets_kcenter_addition_gamma_1",
    "cleaned_5sets_kcenter_addition_gamma_2",
    "cleaned_5sets_knn_addition_gamma_1",
    "cleaned_5sets_knn_addition_gamma_2",
)

stats = {
    'name': [],
    'valid_train_token': [],
    'valid_conversations_turns': []
}

for name in tqdm(names):
    path=f'../LLaMA-Factory/data/{name}.json'
    data = jload(path)

    valid_conversations, valid_turn, valid_train_token = train_token_budget_adjust(data)
    jdump(valid_conversations, f"../LLaMA-Factory/data/budget_adjusted_v2_{name}.json")
    log(f"../LLaMA-Factory/data/budget_adjusted_v2_{name}.json")

    print(valid_turn)
    print(valid_train_token)
    print("--"*20)

 25%|██▌       | 1/4 [00:08<00:24,  8.14s/it]

6109
2286630
----------------------------------------


 50%|█████     | 2/4 [00:18<00:18,  9.17s/it]

7000
1793077
----------------------------------------


 75%|███████▌  | 3/4 [00:26<00:08,  8.78s/it]

7000
2296054
----------------------------------------


100%|██████████| 4/4 [00:35<00:00,  8.75s/it]

7001
2553184
----------------------------------------





In [18]:
def train_token_budget_adjust_pth(path, max_length=2048, token_budget=3000000):
    """Adjust the train data to fit token budget
    params
        conversations: raw conversation data
        max_length: max_length set in trainer
        token_budget: the training token budget
    """
    data = torch.load(path)['data']

    valid_conv_quality = []
    total_valid_turns = 0
    valid_train_token = 0
    for i in range(len(data)):
        conv = data[i]
        total_token_len = 0
        tmp_valid_turns = []
        for turn in conv['conversations']:
            cur_sent_token = len(tokenizer(turn['value'])['input_ids'])
            if total_token_len + cur_sent_token > max_length:
                break
            total_token_len += cur_sent_token
            if turn['from'] not in ['human', 'user', 'system']:
                valid_train_token += cur_sent_token
                if valid_train_token + cur_sent_token > token_budget:
                    break
                total_valid_turns += 1
            tmp_valid_turns.append(turn)

        while len(tmp_valid_turns) > 2 and tmp_valid_turns[-1]['from'] in ['human', 'user', 'system']:
            tmp_valid_turns = tmp_valid_turns[:-1]
        valid_conv_quality.append(data[i]['quality'])

        if valid_train_token >= token_budget or total_valid_turns>=7000:
            break
    avg_quality = sum(valid_conv_quality) / len(valid_conv_quality)
    return avg_quality, total_valid_turns, valid_train_token


In [19]:
paths = (
    "../baselines/cleaned_5sets_knn_addition_gamma_1.pth",
    "../baselines/cleaned_5sets_knn_addition_gamma_2.pth",
    "../baselines/cleaned_5sets_knn_multiply_gamma_1.pth",
    "../baselines/cleaned_5sets_knn_multiply_gamma_2.pth",
    "../pool_evolve/ap_outputs/1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_1.0_6k/WizardLM_alpaca.pth",
    "../pool_evolve/ap_outputs/1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_2.0_6k/WizardLM_alpaca.pth",
    "../pool_evolve/ap_outputs/1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_1.0_6k/WizardLM_alpaca.pth",
    "../pool_evolve/ap_outputs/1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_2.0_6k/WizardLM_alpaca.pth",
)

for path in tqdm(paths):
    avg_quality, total_valid_turns, valid_train_token = train_token_budget_adjust_pth(path)
    print(avg_quality)
    print(total_valid_turns)
    print(valid_train_token)
    print("--"*20)


 12%|█▎        | 1/8 [00:07<00:52,  7.57s/it]

4.860175119071972
7000
2296054
----------------------------------------


 25%|██▌       | 2/8 [00:16<00:51,  8.59s/it]

5.160868338225023
7000
2553184
----------------------------------------


 38%|███▊      | 3/8 [00:24<00:40,  8.10s/it]

4.847259938568849
7000
2284267
----------------------------------------


 50%|█████     | 4/8 [00:33<00:33,  8.50s/it]

5.136837000070238
7000
2519609
----------------------------------------


 62%|██████▎   | 5/8 [00:44<00:28,  9.34s/it]

5.1235864869229735
7000
2479524
----------------------------------------


 75%|███████▌  | 6/8 [00:53<00:18,  9.37s/it]

5.307052951961695
7000
2707970
----------------------------------------


 88%|████████▊ | 7/8 [01:03<00:09,  9.51s/it]

5.2235377209058065
7000
2587597
----------------------------------------


100%|██████████| 8/8 [01:14<00:00,  9.31s/it]

5.34572109318431
7000
2786629
----------------------------------------





In [22]:
def train_token_budget_adjust_pth(path, max_length=2048, token_budget=3000000):
    """Adjust the train data to fit token budget
    params
        conversations: raw conversation data
        max_length: max_length set in trainer
        token_budget: the training token budget
    """
    data = torch.load(path)['data']

    valid_conv_quality = []
    total_valid_turns = 0
    valid_train_token = 0
    for i in range(len(data)):
        conv = data[i]
        total_token_len = 0
        tmp_valid_turns = []
        for turn in conv['conversations']:
            cur_sent_token = len(tokenizer(turn['value'])['input_ids'])
            if total_token_len + cur_sent_token > max_length:
                break
            total_token_len += cur_sent_token
            if turn['from'] not in ['human', 'user', 'system']:
                valid_train_token += cur_sent_token
                if valid_train_token + cur_sent_token > token_budget:
                    break
                total_valid_turns += 1
            tmp_valid_turns.append(turn)

        while len(tmp_valid_turns) > 2 and tmp_valid_turns[-1]['from'] in ['human', 'user', 'system']:
            tmp_valid_turns = tmp_valid_turns[:-1]
        valid_conv_quality.append(data[i]['quality'])

        if valid_train_token >= token_budget:
            break
    avg_quality = sum(valid_conv_quality) / len(valid_conv_quality)
    return avg_quality, total_valid_turns, valid_train_token


In [23]:
paths = (
    "../baselines/cleaned_5sets_knn_addition_gamma_1.pth",
    "../baselines/cleaned_5sets_knn_addition_gamma_2.pth",
    "../baselines/cleaned_5sets_knn_multiply_gamma_1.pth",
    "../baselines/cleaned_5sets_knn_multiply_gamma_2.pth",
    "../pool_evolve/ap_outputs/1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_1.0_6k/WizardLM_alpaca.pth",
    "../pool_evolve/ap_outputs/1205_cleaned_euclidean_multiply_alpha_0.3_lambda_0.9_gamma_2.0_6k/WizardLM_alpaca.pth",
    "../pool_evolve/ap_outputs/1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_1.0_6k/WizardLM_alpaca.pth",
    "../pool_evolve/ap_outputs/1209_cleaned_euclidean_addition_alpha_0.3_lambda_0.9_gamma_2.0_6k/WizardLM_alpaca.pth",
)

for path in tqdm(paths):
    avg_quality, total_valid_turns, valid_train_token = train_token_budget_adjust_pth(path)
    print(avg_quality)
    print(total_valid_turns)
    print(valid_train_token)
    print("--"*20)


 12%|█▎        | 1/8 [00:09<01:06,  9.53s/it]

4.854001088280239
9211
3000009
----------------------------------------


 25%|██▌       | 2/8 [00:18<00:56,  9.42s/it]

5.148246010369022
8328
3000258
----------------------------------------


 38%|███▊      | 3/8 [00:28<00:46,  9.39s/it]

4.839619930337753
9236
3000344
----------------------------------------


 50%|█████     | 4/8 [00:37<00:37,  9.43s/it]

5.121643718742637
8406
3000493
----------------------------------------


 62%|██████▎   | 5/8 [00:48<00:29,  9.86s/it]

5.131396633689667
8375
3000475
----------------------------------------


 75%|███████▌  | 6/8 [00:58<00:19,  9.91s/it]

5.3044659973302695
7743
3000153
----------------------------------------


 88%|████████▊ | 7/8 [01:09<00:10, 10.25s/it]

5.225794443713172
8023
3000046
----------------------------------------


100%|██████████| 8/8 [01:19<00:00,  9.88s/it]

5.341632059033741
7533
3000563
----------------------------------------



