In [None]:
!pip install nltk googletrans==4.0.0-rc1

In [None]:
# Data extension
import pandas as pd
import random
import nltk
from nltk.corpus import wordnet
from googletrans import Translator
from google.colab import drive
import time

# 下载 NLTK 数据
nltk.download('wordnet')
nltk.download('omw-1.4')

# 配置参数
drive.mount('/content/drive')
file_path = '/content/drive/MyDrive/'
INPUT_CSV = file_path + 'label_tokenized.csv'  # 输入的 CSV 文件路径
OUTPUT_CSV = file_path + 'augmented_data.csv'  # 增强后的 CSV 文件路径
TEXT_COLUMN = 'text'  # 包含文本的列名
LABEL_COLUMN = 'label'  # 包含标签的列名
URL_COLUMN = 'url'  # 主键列名
NUM_AUGMENT = 2  # 每条文本生成的增强样本数量

# 同义词替换
def synonym_replacement(text, n=1):
    words = text.split()
    new_words = words[:]
    random_word_list = list(set(words))
    random.shuffle(random_word_list)
    num_replaced = 0

    for random_word in random_word_list:
        synonyms = wordnet.synsets(random_word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            if synonym != random_word:  # 避免替换成相同的单词
                new_words = [synonym if word == random_word else word for word in new_words]
                num_replaced += 1
            if num_replaced >= n:
                break

    return ' '.join(new_words)

# 回译（Back-Translation）
def back_translation(text, src_lang='en', dest_lang='fr'):
    translator = Translator()
    try:
        translated = translator.translate(text, src=src_lang, dest=dest_lang).text
        time.sleep(1)
        if not translated:
            raise ValueError("Translation failed; received an empty string.")
        # print('Trying again.')
        # time.sleep(1)
        back_translated = translator.translate(translated, src=dest_lang, dest=src_lang).text
        if not back_translated:
            raise ValueError("Back-translation failed; received an empty string.")
        return back_translated
    except Exception as e:
        print(f"Back-translation failed: {e}")
        return text  # Fallback to original text

# 数据增强函数
def augment_data(dataframe, text_column, label_column, url_column, num_augment=NUM_AUGMENT):
    augmented_rows = []

    for _, row in dataframe.iterrows():
        text = row[text_column]
        label = row[label_column]
        url = row[url_column]

        # 添加原始数据
        augmented_rows.append({text_column: text, label_column: label, url_column: url})

        # 同义词替换增强
        for _ in range(num_augment // 2):
            augmented_text = synonym_replacement(text)
            augmented_rows.append({text_column: augmented_text, label_column: label, url_column: url})

        # 回译增强
        for _ in range(num_augment // 2):
            augmented_text = back_translation(text)
            augmented_rows.append({text_column: augmented_text, label_column: label, url_column: url})

    return pd.DataFrame(augmented_rows)

# 主程序
def main():
    # 读取输入数据
    df = pd.read_csv(INPUT_CSV)
    df = df[df[LABEL_COLUMN] == 1]

    # 确保文本、标签和主键列存在
    if TEXT_COLUMN not in df.columns or LABEL_COLUMN not in df.columns or URL_COLUMN not in df.columns:
        raise ValueError(f"CSV file must contain '{TEXT_COLUMN}', '{LABEL_COLUMN}', and '{URL_COLUMN}' columns.")

    # 数据增强
    print("Performing data augmentation...")
    augmented_df = augment_data(df, TEXT_COLUMN, LABEL_COLUMN, URL_COLUMN, NUM_AUGMENT)

    # 保存增强后的数据
    augmented_df.to_csv(OUTPUT_CSV, index=False)
    print(f"Data augmentation complete. Augmented data saved to '{OUTPUT_CSV}'.")

# 运行主程序
if __name__ == "__main__":
    main()


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Performing data augmentation...
Back-translation failed: the JSON object must be str, bytes or bytearray, not NoneType
Back-translation failed: the JSON object must be str, bytes or bytearray, not NoneType
Back-translation failed: the JSON object must be str, bytes or bytearray, not NoneType
Data augmentation complete. Augmented data saved to '/content/drive/MyDrive/augmented_data.csv'.


In [None]:
df.dropna(inplace=True)

In [None]:
df.shape

(1749, 10)

In [None]:
import pandas as pd
import spacy
import torch
from google.colab import drive

drive.mount('/content/drive')
file_path = '/content/drive/MyDrive/'
INPUT_CSV = file_path + 'augmented_data.csv'  # 输入的 CSV 文件路径
df = pd.read_csv(INPUT_CSV)
df = df[~df['text'].isna()]

# 加载 spaCy 的英语模型
nlp = spacy.load("en_core_web_sm", disable=["ner", "parser"])  # 只需要词性标注，加速处理

# 定义国家及其语法形式
country_forms = {
    "United States of America": {"noun": "United States", "adj": "American", "capital_noun": "Washington"},
    "China": {"noun": "China", "adj": "Chinese", "capital_noun": "Beijing"},
    "Japan": {"noun": "Japan", "adj": "Japanese", "capital_noun": "Tokyo"},
    "Germany": {"noun": "Germany", "adj": "German", "capital_noun": "Berlin"},
    "India": {"noun": "India", "adj": "Indian", "capital_noun": "New Delhi"},
    "United Kingdom": {"noun": "United Kingdom", "adj": "British", "capital_noun": "London"},
    "France": {"noun": "France", "adj": "French", "capital_noun": "Paris"},
    "Canada": {"noun": "Canada", "adj": "Canadian", "capital_noun": "Ottawa"},
    "Russia": {"noun": "Russia", "adj": "Russian", "capital_noun": "Moscow"},
    "Italy": {"noun": "Italy", "adj": "Italian", "capital_noun": "Rome"},
    "South Korea": {"noun": "South Korea", "adj": "Korean", "capital_noun": "Seoul"},
    "Saudi Arabia": {"noun": "Saudi Arabia", "adj": "Saudi", "capital_noun": "Riyadh"},
    "Spain": {"noun": "Spain", "adj": "Spanish", "capital_noun": "Madrid"},
    "Turkey": {"noun": "Turkey", "adj": "Turkish", "capital_noun": "Ankara"},
}

# 确保包含 'text' 列
assert "text" in df.columns, "'text' 列不存在，请检查文件格式！"

Mounted at /content/drive


In [None]:
df = df[df.origin_country == 'German']

In [None]:
# 确保 DataFrame 索引连续
df = df.reset_index(drop=True)

In [None]:
# 使用 spaCy 提前解析文本
def preprocess_texts(texts):
    docs = list(nlp.pipe(texts, batch_size=64))  # 批量处理文本，提升速度
    tokenized_texts = []
    pos_tags = []
    for doc in docs:
        tokenized_texts.append([token.text for token in doc])
        pos_tags.append([token.pos_ for token in doc])  # 词性标注
    return tokenized_texts, pos_tags

# 提前处理文本
print("开始解析文本...")
tokenized_texts, pos_tags = preprocess_texts(df['text'].tolist())
print("文本解析完成！")

开始解析文本...
文本解析完成！


In [None]:
'Germany' in tokenized_texts[0]

True

In [None]:
# 替换函数（利用 GPU）
def replace_country_with_gpu(tokenized_texts, pos_tags, original_country, target_country):
    original_forms = country_forms[original_country]
    target_forms = country_forms[target_country]

    # 准备替换映射
    noun_map = {
        original_forms["noun"].lower(): target_forms["noun"],
    }
    adj_map = {
        original_forms["adj"].lower(): target_forms["adj"]
    }
    capital_map = {
        original_forms["capital_noun"].lower(): target_forms["capital_noun"]
}
    # 转换为 Tensor 格式
    tokenized_tensors = [
        torch.tensor([ord(ch) for word in sentence for ch in word], dtype=torch.int32) for sentence in tokenized_texts
    ]

    # 替换逻辑在 GPU 上处理
    results = []
    for i, (tokens, tags) in enumerate(zip(tokenized_texts, pos_tags)):
        replaced_sentence = []
        for token, pos in zip(tokens, tags):
            token_lower = token.lower()
            if (pos == "NOUN" or pos == "PROPN") and token_lower in noun_map:
                replaced_sentence.append(noun_map[token_lower])
                # print(f"Replacing noun '{token}' with '{noun_map[token_lower]}'")
            elif pos == "ADJ" and token_lower in adj_map:
                replaced_sentence.append(adj_map[token_lower])
                # print(f"Replacing adj '{token}' with '{adj_map[token_lower]}'")
            elif (pos == "NOUN" or pos == "PROPN") and token_lower in capital_map:
                replaced_sentence.append(capital_map[token_lower])
                # print(f"Replacing capital '{token}' with '{capital_map[token_lower]}'")
            else:
                replaced_sentence.append(token)
        results.append(" ".join(replaced_sentence))
        # print(f"Replaced Sentence {i}: {' '.join(replaced_sentence)}")
    return results

# 目标国家列表
countries_list = [
    # "United States of America",
    # "China",
    # "Japan",
    # "Germany",
    # "India",
    # "United Kingdom",
    "France",
    # "Canada",
    # "Russia",
    # "Italy",
    # "South Korea",
    # "Saudi Arabia",
    # "Spain",
    # "Turkey"
]

# 批量替换生成增强数据集
print("开始生成增强数据...")
augmented_data = []
original_country = "Germany"

for target_country in countries_list:
    augmented_texts = replace_country_with_gpu(tokenized_texts, pos_tags, original_country, target_country)
    for i, text in enumerate(augmented_texts):
        augmented_data.append({"text": text, "label": df['label'][i]})

开始生成增强数据...


In [None]:
original_df = pd.read_csv(INPUT_CSV)
original_df.loc[original_df.origin_country.isna(), 'origin_country'] = 'France'
original_df = original_df[original_df.origin_country == 'France']

In [None]:
final_df = pd.concat([original_df, join_df])

In [None]:
final_df.to_csv(file_path + 'augmented_data_total_change_to_France.csv', index=False)

In [None]:
# 转换为 DataFrame
augmented_df = pd.DataFrame(augmented_data)
augmented_df

Unnamed: 0,text,label
0,"The leaders of France , France and Poland have...",0
1,The age of European countries outsourcing thei...,0
2,Moscow will not repeat its past mistakes andag...,2
3,A group of EU countries have blasted Hungarian...,2
4,The International Monetary Fund ( IMF ) has ra...,0
...,...,...
313,By fueling the Ukraine conflict and waging a p...,0
314,"A total of 4,405 civilians have been killed on...",0
315,Paris has no plans to send modern Western - ma...,0
316,The Mayor of Kiev Vitaly Klitschko has alleged...,0


In [None]:
df.drop(columns='text', inplace=True)

In [None]:
join_df['is_augmented'] = 1

In [None]:
join_df = df.join(augmented_df['text'])

In [None]:
df

Unnamed: 0,url,label,target_country,origin_country,publish_date,title,language,processed_body,is_augmented
0,https://www.rt.com/news/607275-georgia-eu-memb...,0,Russia,German,2024-11-07,EU leaders threaten neighbor of Russia,en,"['The', 'leaders', 'France', 'Germany', 'Polan...",0
1,https://www.rt.com/news/606933-trump-harris-eu...,0,Russia,German,2024-11-03,‘Outsourcing’ of EU security to America is ove...,en,"['The', 'age', 'European', 'countries', 'outso...",0
2,https://www.rt.com/russia/606878-no-ukraine-ce...,2,Russia,German,2024-11-01,No repeat of Minsk agreements – Moscow,en,"['Moscow', 'repeat', 'past', 'mistakes', 'anda...",0
3,https://www.rt.com/russia/606694-orban-georgia...,2,Russia,German,2024-10-29,EU countries blast Orban over Georgia visit,en,"['A', 'group', 'EU', 'countries', 'blasted', '...",0
4,https://www.rt.com/business/606344-russia-four...,0,Russia,German,2024-10-24,IMF upgrades Russia to world’s fourth-largest ...,en,"['The', 'International', 'Monetary', 'Fund', '...",0
...,...,...,...,...,...,...,...,...,...
313,https://www.rt.com/news/569369-de-gaulle-ukrai...,0,Russia,German,2023-01-04,US making Europeans suffer – de Gaulle’s grandson,en,"['By', 'fueling', 'Ukraine', 'conflict', 'wagi...",0
314,https://www.rt.com/russia/569348-number-civili...,0,Russia,German,2023-01-03,Number of civilians killed in Donbass revealed,en,"['A', 'total', '4,405', 'civilians', 'killed',...",0
315,https://www.rt.com/news/569304-germany-nato-ru...,0,Russia,German,2023-01-02,German MP warns against ‘unimaginable escalation’,en,"['Berlin', 'plans', 'send', 'modern', 'Western...",0
316,https://www.rt.com/russia/569230-kiev-mayor-cr...,0,Russia,German,2022-12-31,Kiev mayor criticizes Ukrainian authorities,en,"['The', 'Mayor', 'Kiev', 'Vitaly', 'Klitschko'...",0


In [None]:
# 转换为 DataFrame
augmented_df = pd.DataFrame(augmented_data)

# 保存增强后的数据集
augmented_df.to_csv("training_data_German_to_France.csv", index=False)
print("增强数据集已保存！")

In [None]:
# 配置路径
file_path = '/content/drive/MyDrive/'
original_csv = file_path + 'label_tokenized.csv'  # 原始 CSV 文件路径
merged_csv = file_path + 'augmented_data.csv'  # 合并后 CSV 文件路径

# 读取原始和增强数据
original_df = pd.read_csv(original_csv)
augmented_df_positive = pd.read_csv(file_path + 'augmented_data_positive.csv')
augmented_df_negative = pd.read_csv(file_path + 'augmented_data_negative_filtered.csv')
augmented_dfs = [augmented_df_positive, augmented_df_negative]
# 确保增强数据中只有相关列，并补全其他列
for augmented_df in augmented_dfs:
  if 'is_augmented' not in augmented_df.columns:
      augmented_df['is_augmented'] = 1  # 标记为增强数据

merged_augmented_dfs = []
# 使用主键列（如 'url'）将增强数据与原始数据匹配，补全缺失的列
# 确保 augmented_df 至少有原始列中的主键列
for augmented_df in augmented_dfs:
  key_columns = ['url']  # 确保与增强数据的主键列名称一致
  merged_augmented_df = pd.merge(
      augmented_df,
      original_df.drop(columns=['text', 'label']),  # 删除增强数据已有的列，避免重复
      on=key_columns,
      how='left'
  )
  merged_augmented_dfs.append(augmented_df)

# 标记原始数据
original_df['is_augmented'] = 0  # 标记为原始数据

# 合并数据
merged_df = pd.concat([original_df, merged_augmented_dfs[0], merged_augmented_dfs[1]], ignore_index=True)

# 保存合并后的数据
merged_df.to_csv(merged_csv, index=False)

print(f"Merged data saved to {merged_csv}")

Merged data saved to /content/drive/MyDrive/augmented_data.csv


In [None]:
merged_csv = file_path + 'augmented_data.csv'  # 合并后 CSV 文件路径
merged_df = pd.read_csv(merged_csv)

In [None]:
merged_df.shape

(1750, 10)

In [None]:
merged_df[merged_df['label'] == 1].shape

(314, 10)

In [None]:
for i in range(3):
  a = (merged_df['label'] == i).mean()
  print(f'{i}: {a}')

0: 0.3251428571428571
1: 0.17942857142857144
2: 0.49542857142857144


In [None]:
augmented_df_positive = pd.read_csv(file_path + 'augmented_data_positive.csv')
augmented_df_negative = pd.read_csv(file_path + 'augmented_data_negative.csv')

In [None]:
augmented_df_positive

Unnamed: 0,text,label,url
0,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
1,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
2,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
3,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
4,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
...,...,...,...
263,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
264,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
265,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
266,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...


In [None]:
augmented_df_negative = pd.read_csv(file_path + 'augmented_data_negative.csv')

In [None]:
augmented_df_negative

Unnamed: 0,text,label,url
0,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
1,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
2,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
3,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
4,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
...,...,...,...
2031,Washington kicked off its expanded training po...,2,https://www.rt.com/news/569928-ukraine-trainin...
2032,Washington kicked off its expanded training pr...,2,https://www.rt.com/news/569928-ukraine-trainin...
2033,Washington kicked off its expanded training pr...,2,https://www.rt.com/news/569928-ukraine-trainin...
2034,Washington kicked off its expanded training pr...,2,https://www.rt.com/news/569928-ukraine-trainin...


In [None]:
filtered_df = augmented_df_negative.groupby('url').apply(lambda x: x.tail(2)).reset_index(drop=True)

  filtered_df = augmented_df_negative.groupby('url').apply(lambda x: x.tail(2)).reset_index(drop=True)


In [None]:
filtered_df

Unnamed: 0,text,label,url
0,Last weeks Russia-Africa summit in St. Petersb...,2,https://www.rt.com/africa/580584-russias-bigge...
1,Last weeks Russia-Africa summit in St. Petersb...,2,https://www.rt.com/africa/580584-russias-bigge...
2,The Economic Community of West African States ...,2,https://www.rt.com/africa/580733-west-african-...
3,"Wednesday, the Economic Community of West Afri...",2,https://www.rt.com/africa/580733-west-african-...
4,The European Unions military partnership missi...,2,https://www.rt.com/africa/598334-eu-ending-nig...
...,...,...,...
559,A group of EU countries have blasted Hungarian...,2,https://www.rt.com/russia/606694-orban-georgia...
560,Moscow will not repeat its past mistake andagr...,2,https://www.rt.com/russia/606878-no-ukraine-ce...
561,Moscow will not repeat its past mistakes andag...,2,https://www.rt.com/russia/606878-no-ukraine-ce...
562,The more understanding with Russia and other c...,2,https://www.rt.com/russia/606938-ukraine-lavro...


In [None]:
filtered_df.to_csv(file_path + 'augmented_data_negative_filtered.csv', index=False)

In [None]:
filtered_df[filtered_df['text'].str.contains('The leaders of France,')]

Unnamed: 0,text,label,url
804,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
805,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
806,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...
807,"The leaders of France, Germany and Poland have...",2,https://www.rt.com/news/607275-georgia-eu-memb...


In [None]:
augmented_df_negative

In [None]:
(a['n'] <= 7).mean()

0.9148936170212766

In [None]:
a = augmented_df_negative.groupby(['url']).sum('label').reset_index()
a['n'] = a['label'] / 2

In [None]:
df = pd.read_csv(INPUT_CSV)
df = df.merge(augmented_df_positive, on='url', how='left')
df = df.merge(augmented_df_negative, on='url', how='left')
df.shape

Unnamed: 0,text,label,url
0,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
1,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
2,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
3,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
4,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
...,...,...,...
497,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
498,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
499,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
500,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...


In [None]:
df.to_csv('label_enhanced.csv', index=False)

In [None]:
augmented_df

Unnamed: 0,text,label,url
0,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
1,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
2,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
3,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
4,Russian energy giant Rosatom is extending its ...,1,https://www.rt.com/news/605483-iter-fusion-ene...
...,...,...,...
263,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
264,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
265,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...
266,Warsaw is pressuring Berlin to allow it supply...,1,https://www.rt.com/news/570130-poland-germany-...


In [None]:
for index, row in augmented_df.iloc[:2].items():
  print(row)

0    Russian energy giant Rosatom is extending its ...
1    Russian energy giant Rosatom is extending its ...
Name: text, dtype: object
0    1
1    1
Name: label, dtype: int64


In [None]:
from googletrans import Translator
translator = Translator()
print(translator.translate("Hello", src="en", dest="fr").text)

Bonjour
