# Augmenting NER data

by Benjamin Kissinger & Andreas Sünder

In [1]:
print("hello")

hello


In [2]:
DATA_DIR = 'data'
SOURCE_FILE = 'dataset.jsonl'
TARGET_FILE = 'ner_data_augmented.jsonl'

In [3]:
import os

import gensim.downloader as api
from datasets import load_dataset

model = api.load('word2vec-google-news-300')
dataset = load_dataset('json', data_files=os.path.join(DATA_DIR, SOURCE_FILE), split='train')

In [4]:
import nltk
from nltk.corpus import stopwords
from nltk.tag import pos_tag
from nltk.tokenize import word_tokenize

nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('stopwords', quiet=True)
swords = stopwords.words('english')

In [12]:
df = dataset.to_pandas()

In [13]:
def generate_weighted_number(start_year: int = 1700, end_year: int = 2023):
  import numpy as np

  years = np.arange(start_year, end_year + 1)
  #weights = np.linspace(1, 10, len(years))
  #weights /= weights.sum()

  return years

In [14]:
import itertools

import pandas as pd
from langdetect import detect as lang_detect
from tqdm import tqdm


def base_augmentation(df):
  special_words = ['da']

  new_df = pd.DataFrame(columns=['prompt', 'response'])
  vocab = model.key_to_index

  for i in tqdm(range(len(df))):
    row = df.iloc[i]

    # TODO: add german language support
    if lang_detect(row['prompt']) != 'en':
      continue

    prompt_tokenized = word_tokenize(row['prompt'])
    prompt_tagged = pos_tag(prompt_tokenized)
    prompt_cleaned = [word for word in prompt_tagged if word[1] in ('NN', 'NNS') and word[0] not in special_words]

    words_to_replace = []
    replace_list = []

    for word in prompt_cleaned:
      if word[0] not in vocab:
        continue

      ms = [word[0] for word in model.most_similar(word[0], topn=10)]
      ms_tagged = pos_tag(ms)
      ms_new = [
        replacement[0] for replacement in ms_tagged 
        if replacement[1] == word[1] and
        '_' not in replacement[0] and
        word[0].lower() != replacement[0].lower() and
        model.distance(word[0], replacement[0]) < 0.5
      ]

      words_to_replace.append(word[0])
      ms_new.append(word[0])
  
      if len(ms_new) > 0:
        replace_list.append(ms_new)

    prompt_removed = row['prompt']
    for word in words_to_replace:
      prompt_removed = prompt_removed.replace(word, '{}')

    replace_combinations = list(itertools.product(*replace_list))

    for combination in replace_combinations:
      new_df = pd.concat([new_df, pd.DataFrame(
          [[prompt_removed.format(*combination), row['response']]],
          columns=['prompt', 'response']
      )])

 
  return new_df

new_df = base_augmentation(df)

100%|██████████| 292/292 [00:53<00:00,  5.45it/s]


In [15]:
print(len(new_df))

20243


In [16]:
import json
import spacy
from faker import Faker
from first import first

MAX_NAME_COUNTER = 40

def name_date_augmentation(df):
  new_df = pd.DataFrame(columns=['prompt', 'response'])
  author_replacements = {}
  fake = Faker()

  nlp = spacy.load('en_core_web_md')

  for i in tqdm(range(len(df))):
    row = df.iloc[i]
    prompt = row['prompt']
    response = json.loads(row['response'])

    doc = nlp(prompt)
    author = first([ent.text for ent in doc.ents if ent.label_ == 'PERSON'])

    if author:
      prompt = prompt.replace(author, '{author}')

      if author not in author_replacements:
        author_replacements[author] = [fake.name(), 0]

      counter = author_replacements[author][1]
      if counter >= MAX_NAME_COUNTER:
        counter = 0
        author_replacements[author] = [fake.name(), 0]

      author_new = author_replacements[author][0]
      # print(author_replacements)
      author_replacements[author][1] = counter + 1

      # TODO: add name augmentation
      prompt = prompt.format(author=author_new)
      response['author'] = author_new

    new_df = new_df.append({ 'prompt': prompt, 'response': json.dumps(response, default=str) }, ignore_index=True )

  return new_df

final_df = name_date_augmentation(new_df)

100%|██████████| 20243/20243 [01:53<00:00, 178.31it/s]


In [17]:
print(len(final_df))

0


In [9]:
# final_df.drop_duplicates().sample(frac=1).reset_index(drop=True, inplace=True)

In [10]:
with open(os.path.join(DATA_DIR, TARGET_FILE), 'w+') as f:
  final_df.to_json(f, orient='records', lines=True, force_ascii=False)

In [11]:
print(len(final_df))

0
