In [23]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset("SoftAge-AI/prompt-eng_dataset")

df_train = pd.DataFrame(dataset["train"])

df_validation = pd.DataFrame(dataset["validation"]) if "validation" in dataset else None
df_test = pd.DataFrame(dataset["test"]) if "test" in dataset else None

categories_to_save = [
    "Writing",
    "Problem solving",
    "Text Summarization",
    "Logical Reasoning",
    "Language Translation",
    "Roleplaying",
]

filtered_df = df_train[df_train["Category"].isin(categories_to_save)]
Category_to_category_id = {
    "Writing": "CT",
    "Problem solving": "CR",
    "Text Summarization": "SUMM",
    "Logical Reasoning": "CR",
    "Language Translation": "MT",
    "Roleplaying": "IF",
}

filtered_df["category"] = filtered_df["Category"].map(Category_to_category_id)
filtered_df.category.value_counts()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df["category"] = filtered_df["Category"].map(Category_to_category_id)


category
CR      131
CT      106
SUMM     82
MT       32
IF       30
Name: count, dtype: int64

In [26]:
filtered_df["initial_prompt"] = filtered_df["Prompt"]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df["initial_prompt"]   = filtered_df["Prompt"]


In [36]:
import random
from nltk.corpus import wordnet


def synonym_replacement(sentence, n=1):
    words = sentence.split()
    new_words = words.copy()
    random_word_list = list(set([word for word in words if wordnet.synsets(word)]))
    random.shuffle(random_word_list)

    for random_word in random_word_list[:n]:
        synonyms = wordnet.synsets(random_word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if word == random_word else word for word in new_words]
    return " ".join(new_words)


def random_insertion(sentence, n=1):
    words = sentence.split()
    for _ in range(n):
        random_word = random.choice(words)
        synonyms = wordnet.synsets(random_word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            insert_position = random.randint(0, len(words))
            words.insert(insert_position, synonym)
    return " ".join(words)


def random_swap(sentence, n=1):
    words = sentence.split()
    if len(words) < 2:
        return sentence
    for _ in range(n):
        idx1, idx2 = random.sample(range(len(words)), 2)
        words[idx1], words[idx2] = words[idx2], words[idx1]
    return " ".join(words)


def random_deletion(sentence, p=0.2):
    words = sentence.split()
    if len(words) == 1:  # 如果只有一个单词，直接返回
        return sentence
    new_words = [word for word in words if random.uniform(0, 1) > p]
    return " ".join(new_words) if new_words else random.choice(words)


def spelling_error_injection(sentence, n=1):
    def typo(word):
        if len(word) <= 1:
            return word
        idx = random.randint(0, len(word) - 2)
        return word[:idx] + word[idx + 1] + word[idx] + word[idx + 2 :]

    words = sentence.split()
    for _ in range(n):
        idx = random.randint(0, len(words) - 1)
        words[idx] = typo(words[idx])
    return " ".join(words)


keyboard_map = {
    "a": ["q", "w", "s", "z"],
    "b": ["v", "g", "h", "n"],
}


def keyboard_error_injection(sentence, n=1):
    def keyboard_typo(word):
        chars = list(word)
        idx = random.randint(0, len(chars) - 1)
        if chars[idx] in keyboard_map:
            chars[idx] = random.choice(keyboard_map[chars[idx]])
        return "".join(chars)

    words = sentence.split()
    for _ in range(n):
        idx = random.randint(0, len(words) - 1)
        words[idx] = keyboard_typo(words[idx])
    return " ".join(words)


from collections import Counter

# 假设已有训练语料库的unigram分布
unigram_freq = Counter({"the": 5000, "a": 3000, "cat": 1000})


def unigram_noising(sentence, n=1):
    words = sentence.split()
    for _ in range(n):
        idx = random.randint(0, len(words) - 1)
        replacement_word = random.choices(
            list(unigram_freq.keys()), weights=list(unigram_freq.values()), k=1
        )[0]
        words[idx] = replacement_word
    return " ".join(words)

In [37]:
import pandas as pd
import nltk

nltk.download("wordnet")


# 定义增强函数集合|
def augment_text(text, times=5):
    augmented_texts = []

    for _ in range(times):
        # EDA 操作
        augmented_texts.append(synonym_replacement(text))
        augmented_texts.append(random_insertion(text))
        augmented_texts.append(random_swap(text, n=5))
        augmented_texts.append(random_deletion(text))

        # 噪声注入操作
        augmented_texts.append(spelling_error_injection(text))
        augmented_texts.append(keyboard_error_injection(text))
        augmented_texts.append(unigram_noising(text))

    return augmented_texts


augmented_data = []
for index, row in filtered_df.iterrows():
    original_prompt = row["initial_prompt"]
    augmented_prompts = augment_text(original_prompt, 2)
    for augmented_prompt in augmented_prompts:
        augmented_data.append(
            {
                "initial_prompt": original_prompt,
                "augmented_prompt": augmented_prompt,
                "ability": row["category"],
            }
        )

augmented_df = pd.DataFrame(augmented_data)

[nltk_data] Downloading package wordnet to /home/snt/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [39]:
df = pd.read_csv(
    "/home/snt/projects_lujun/temperature_eval_github/temperature_eval/data/Bert/all_data_for_bert_training_augmented.csv"
)

In [42]:
concated_df = pd.concat([df, augmented_df], axis=0)

In [None]:
concated_df.to_csv(
    "/home/snt/projects_lujun/temperature_eval_github/temperature_eval/data/Bert/all_data_for_bert_training_augmented_opensource.csv",
    index=False,
)