In [None]:
import sys
import os

import numpy as np
from sympy.codegen.ast import continue_

sys.path.append(os.getcwd())

from utils.global_const import *

# Set the cache directory for huggingface hub. [Important!] It should be before the import of transformers
os.environ["HF_HOME"] = HF_HOME

from datasets import load_dataset
import re
from utils import save_json
from interaction.player import get_invalid_words, get_all_punctuations
from tqdm import tqdm

dataset_name = "rajpurkar/squad"
train_set_raw = load_dataset(dataset_name, split='train') 
val_set_raw = load_dataset(dataset_name, split='validation')
# ('id', 'title', 'context', 'question', 'answers')

all_punctuations = get_all_punctuations()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def deduplicate_list(lst): # deduplicate while keeping the order
    return list(dict.fromkeys(lst))

train_context = deduplicate_list(train_set_raw['context'])
val_context = deduplicate_list(val_set_raw['context'])
print(len(train_context))
print(len(val_context))

# randomly shuffle the context list
seed = 0
np.random.RandomState(seed).shuffle(train_context)
np.random.RandomState(seed).shuffle(val_context)

18891
2067


**20250126 custom-squad-v1** -> **v2 (bug fix)**

1. v1用的是set进行deduplication，v2用的是dict，set有可能会打乱顺序且无法复现。

2. 另外v2还用了shuffle，打乱了context的顺序。

3. v2要求predicted word第一个字符不能是标点符号。

4. 把punctuation的范围从ascii_punctuation扩展到unicode_punctuation。

5. v2要求predicted word前一个词不能是invalid word，不然这个invalid word会一直作为background存在，有可能导致 v(空) 过大 以及 v(N) - v(空) 过小。 一个典型例子是 at least 这个词组，在at这个词存在时，least的预测概率本身就会大大增加。

In [3]:
def strip_punctuation(word, punc_list=all_punctuations):
    """
    This function is used to strip the punctuation from the beginning and end of a word.
    """
    return re.sub('^[{0}]+|[{0}]+$'.format(punc_list), '', word)

def detect_punctuation(word, punc_list=all_punctuations):
    """
    This function is used to detect whether there are at least one punctuation at the beginning or end of a word.
    """
    if word[0] in punc_list or word[-1] in punc_list:
        return True
    return False


word_test_ = "yes,"
# word_test_ = "./;'[]//"
print(strip_punctuation(word_test_))
print(detect_punctuation(word_test_, punc_list=['.', ';', ',']))

yes,
True


In [4]:


NUM_PLAYERS_MAX = 12
POS_NO_PUNC = 3
SEP_PUNCS = ['.', ';', ',']

def check_precede_punc(context_words, curr_index):
    # try:
    for j in range(1, POS_NO_PUNC + 1):
        if detect_punctuation(context_words[curr_index - j], punc_list=SEP_PUNCS):
            # print("precede punc detected")
            return True
    return False
    # except IndexError:
    #     print("context_words: ", context_words)
    #     return False

def get_sentences_v1_20250126(context_list, invalid_words):
    sentence_list = []
    next_word_list = []
    player_words_list = []
    for idx, context in tqdm(enumerate(context_list)):
        # if idx not in [97,98,99,100,101]:
        #     continue
        # print(f"idx: {idx}")

        # 换行符替换为空格
        context = context.replace('\n', ' ')
        # 多个空格替换为一个空格
        context = re.sub(r'\s+', ' ', context)
        context_words = context.split(' ')

        num_valid_words = 0
        player_words = []
        # print("context_words:", context_words)
        for i in range(len(context_words)): # start from 0

            word = context_words[i]
            word_strip_punc = strip_punctuation(word.lower())
            if word_strip_punc not in invalid_words and word_strip_punc != '':
                num_valid_words += 1
                player_words.append(word)

            if num_valid_words == NUM_PLAYERS_MAX + 1 or i == len(context_words) - 1:
                # print("no valid sentence when num_valid_words == NUM_PLAYERS_MAX + 1, or when i == len(context_words) - 1")
                break

            if num_valid_words == NUM_PLAYERS_MAX:
                next_word = context_words[i + 1]
                # print(f"next_word: {next_word}")
                # print(f"next_word[0]: {next_word[0]}")

                next_word_strip_punc = strip_punctuation(next_word.lower())
                if check_precede_punc(context_words, i + 1) or next_word_strip_punc in invalid_words or next_word_strip_punc == '' or next_word[0] in all_punctuations: # v1这里是i, 但是应该是i+1， v2已经修复
                    # print("precede punc detected or next word is invalid")
                    # continue
                    break # v2从continue改为break，要求predicted word前面一个词不能是invalid word
                else:
                    # print("i=", i)
                    # print("context_words[:i]", context_words[:i])
                    # print("next_word: ", next_word)
                    sentence = ' '.join(context_words[:i + 1]) # include the current word
                    sentence_list.append(sentence)
                    next_word_list.append(next_word)
                    player_words_list.append(player_words)
                    break

    return sentence_list, next_word_list, player_words_list



In [5]:
invalid_words = get_invalid_words()
sentence_list_train, next_word_list_train, player_words_list_train = get_sentences_v1_20250126(train_context, invalid_words)
sentence_list_val, next_word_list_val, player_words_list_val = get_sentences_v1_20250126(val_context, invalid_words)

for idx in range(10):
    print("sentence:", sentence_list_train[idx])
    print("next_word:", next_word_list_train[idx])
    print("player_words:", player_words_list_train[idx])
print("=====================================")
for idx in range(10):
    print("sentence:", sentence_list_val[idx])
    print("next_word:", next_word_list_val[idx])
    print("player_words:", player_words_list_val[idx])

18891it [00:49, 381.39it/s]
2067it [00:05, 386.64it/s]

sentence: Humanistic psychology is a psychological perspective which rose to prominence in the mid-20th century in response to Sigmund Freud's psychoanalytic
next_word: theory
player_words: ['Humanistic', 'psychology', 'psychological', 'perspective', 'rose', 'prominence', 'mid-20th', 'century', 'response', 'Sigmund', "Freud's", 'psychoanalytic']
sentence: Federalism has a long tradition in German history. The Holy Roman Empire comprised many petty states
next_word: numbering
player_words: ['Federalism', 'long', 'tradition', 'German', 'history.', 'Holy', 'Roman', 'Empire', 'comprised', 'many', 'petty', 'states']
sentence: 122nd Street is mentioned in the movie Taxi Driver by main character Travis Bickle as the location where a fellow
next_word: cab
player_words: ['122nd', 'Street', 'mentioned', 'movie', 'Taxi', 'Driver', 'main', 'character', 'Travis', 'Bickle', 'location', 'fellow']
sentence: In response to the pressure on Hot AC, a new kind of AC format cropped up among American radio





In [6]:
# save_dir_name_train = "custom-squad-v1-20250126-train"
# save_dir_name_val = "custom-squad-v1-20250126-val"
save_dir_name_train = "custom-squad-v2-20250202-train"
save_dir_name_val = "custom-squad-v2-20250202-val"

player_dir_name_list = [
    "players-pythia",
    "players-qwen",
]

os.makedirs(f"../datasets/{save_dir_name_train}", exist_ok=True)
os.makedirs(f"../datasets/{save_dir_name_val}", exist_ok=True)

# write sentences into a txt file
with open(f"../datasets/{save_dir_name_train}/sentences.txt", "w") as f:
    for sentence in sentence_list_train:
        f.write(sentence + "\n")
with open(f"../datasets/{save_dir_name_val}/sentences.txt", "w") as f:
    for sentence in sentence_list_val:
        f.write(sentence + "\n")

# write next words into a txt file
with open(f"../datasets/{save_dir_name_train}/next_words.txt", "w") as f:
    for next_word in next_word_list_train:
        f.write(next_word + "\n")
with open(f"../datasets/{save_dir_name_val}/next_words.txt", "w") as f:
    for next_word in next_word_list_val:
        f.write(next_word + "\n")
# save_json(sentence_list_train, save_dir=f"../datasets/{save_dir_name_train}", file_name="sentences.txt")
# save_json(sentence_list_val, save_dir=f"../datasets/{save_dir_name_val}", file_name="sentences.txt")
# save_json(next_word_list_train, save_dir=f"../datasets/{save_dir_name_train}", file_name="next_words.txt")
# save_json(next_word_list_val, save_dir=f"../datasets/{save_dir_name_val}", file_name="next_words.txt")

player_words_list_train_with_idx = {f"{idx}": player_words for idx, player_words in enumerate(player_words_list_train)}
player_words_list_val_with_idx = {f"{idx}": player_words for idx, player_words in enumerate(player_words_list_val)}

for player_dir_name in player_dir_name_list:
    save_json(player_words_list_train_with_idx, save_dir=f"../players/{save_dir_name_train}/{player_dir_name}", file_name='player_words.json')
    save_json(player_words_list_val_with_idx, save_dir=f"../players/{save_dir_name_val}/{player_dir_name}", file_name='player_words.json')

In [10]:
'−' == '-'

False