In [None]:
import random
import nltk
import spacy
from nltk import word_tokenize, pos_tag, ne_chunk
from nltk.corpus import wordnet
from nltk.tree import Tree

In [None]:
# NLTK & spaCy 初始化
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
nltk.download('maxent_ne_chunker')
nltk.download('words')

nlp = spacy.load("en_core_web_sm")

ENTITY_POOL = {
    "PERSON": ["Alice", "Bob", "Charlie", "David", "Emma"],
    "ORGANIZATION": ["OpenAI", "Meta", "Stanford", "NASA"],
    "GPE": ["Germany", "Japan", "Kenya", "Brazil", "India"]
}

In [None]:
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

In [None]:
nltk.download('maxent_ne_chunker_tab')

In [None]:
def get_antonym(word):
    for syn in wordnet.synsets(word, pos=wordnet.ADJ):
        for lemma in syn.lemmas():
            if lemma.antonyms():
                return lemma.antonyms()[0].name()
    return None

def antonym_replacement(sentence):
    words = word_tokenize(sentence)
    tagged = pos_tag(words)
    return " ".join([get_antonym(w.lower()) if tag.startswith("JJ") and get_antonym(w.lower()) else w for w, tag in tagged])

nlp = spacy.load("de_core_news_sm")

def toggle_negation_de_spacy(sentence):
    doc = nlp(sentence)
    tokens = [token.text for token in doc]

    if "nicht" in tokens:
        return " ".join([t for t in tokens if t.lower() != "nicht"])

    for i, token in enumerate(doc):
        if token.pos_ == "VERB":
            return " ".join(tokens[:i+1] + ["nicht"] + tokens[i+1:])

    return " ".join(tokens + ["nicht"])

def strengthen_modality_de(sentence):
    mapping = {
        "kann": "muss",
        "könnte": "muss",
        "dürfte": "wird",
        "sollte": "wird",
        "mag": "wird"
    }

    words = word_tokenize(sentence, language="german")
    new_words = [mapping.get(w.lower(), w) for w in words]
    return " ".join(new_words)


def entity_replacement(sentence):
    words = word_tokenize(sentence)
    tagged = pos_tag(words)
    chunks = ne_chunk(tagged)
    new_words = []
    for chunk in chunks:
        if isinstance(chunk, Tree):
            label = chunk.label()
            if label in ENTITY_POOL:
                new_words.append(random.choice(ENTITY_POOL[label]))
            else:
                new_words.extend([leaf[0] for leaf in chunk])
        else:
            new_words.append(chunk[0])
    return " ".join(new_words)

def number_replacement(sentence):
    doc = nlp(sentence)
    return " ".join([str(random.randint(1, 100)) if token.like_num else token.text for token in doc])

def generate_variants(sentence, max_variants=3):
    funcs = [antonym_replacement, toggle_negation, strengthen_modality, entity_replacement, number_replacement]
    variants = set()
    queue = [sentence]
    while queue and len(variants) < max_variants:
        current = queue.pop(0)
        for func in funcs:
            changed = func(current)
            if changed != current and changed not in variants:
                variants.add(changed)
                queue.append(changed)
    return list(variants)

In [None]:
def process_german_augmentation_with_reversed_labels(
    input_file,
    augmented_output_file,
    combined_output_file,
    sample_size=800,
    variants_per_sentence=2
):
    with open(input_file, "r", encoding="utf-8") as fin:
        sentence_pairs = [line.strip().split("\t", 1) for line in fin if "\t" in line]

    print(f"Orignial: {len(sentence_pairs)}")

    sampled = random.sample(sentence_pairs, sample_size)

    augmented_pairs = []

    for src, tgt in sampled:
        variants = generate_variants(tgt, max_variants=variants_per_sentence)
        for v in variants:
            # 注意：增强数据标注为 label = 0
            augmented_pairs.append((src, v, 0))

    # 原始数据标注为 label = 1
    original_labeled = [(src, tgt, 1) for src, tgt in sentence_pairs]

    # 写增强句对文件
    with open(augmented_output_file, "w", encoding="utf-8") as fout:
        for src, tgt, label in augmented_pairs:
            fout.write(f"{src}\t{tgt}\t{label}\n")

    print(len(augmented_pairs))

    # 写原始+增强混合文件
    with open(combined_output_file, "w", encoding="utf-8") as fout:
        for src, tgt, label in original_labeled + augmented_pairs:
            fout.write(f"{src}\t{tgt}\t{label}\n")


if __name__ == "__main__":
    process_german_augmentation_with_reversed_labels(
        input_file="HSB-DE_train_sampled_10k.tsv",
        augmented_output_file="augmented_only.tsv",
        combined_output_file="combined_labeled.tsv",
        sample_size=800,
        variants_per_sentence=3
    )