In [1]:
import gensim.downloader as api
from datasets import load_dataset
import nltk
from nltk.tag import pos_tag
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

import pandas as pd
from langdetect import detect
from tqdm import tqdm
import itertools
import json

In [2]:
model = api.load('word2vec-google-news-300')
#dataset = load_dataset('textminr/ner', split='train')
dataset = load_dataset('json', data_files='data_new.jsonl', split='train')

In [3]:
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('stopwords', quiet=True)
swords = stopwords.words('english')

In [4]:
df = dataset.to_pandas()

In [5]:
def generate_weighted_number():
  import numpy as np
  start_year, end_year = 1700, 2023

  years = np.arange(start_year, end_year + 1)
  
  # lineare funktion; startet bei 1, endet bei 10
  weights = np.linspace(1, 10, len(years))

  # Werte werden normalisiert
  weights /= weights.sum()

  return np.random.choice(years, p=weights)

In [6]:
def name_date_augmentation(df):
  new_df = pd.DataFrame(columns=['prompt', 'response'])
  author_replacements = {}

  from faker import Faker
  fake = Faker()

  import re
  import spacy
  nlp = spacy.load('en_core_web_md')

  for i in tqdm(range(len(df))):
    row = df.iloc[i]
    prompt = row['prompt']
    response = json.loads(row['response'])

    doc = nlp(prompt)
    for ent in doc.ents:
      if ent.label_ == 'PERSON':
        prompt = prompt.replace(ent.text, '{author}')

        if ent.text not in author_replacements:
          author_replacements[ent.text] = fake.name()
        author = author_replacements[ent.text]

    # author = fake.name()
    # date = generate_weighted_number()

    # new_prompt = re.sub(r'\d{4}', '{date}', prompt)
    # new_prompt = new_prompt.format(author = author, date = date)
    new_prompt = prompt.format(author = author)

    if response['author'] != 'N/A':
      response['author'] = author
    # if response['date'] != 'N/A':
    #   response['date'] = date

    new_df = pd.concat([new_df, pd.DataFrame(
        [[new_prompt, json.dumps(response, default=str)]],
        columns=['prompt', 'response']
    )])

  return new_df

new_df = name_date_augmentation(df)

100%|██████████| 292/292 [00:02<00:00, 129.99it/s]


In [9]:
special_words = ['da']

def base_augmentation(df):
  new_df = pd.DataFrame(columns=['prompt', 'response'])
  vocab = model.key_to_index

  for i in tqdm(range(len(df))):
    row = df.iloc[i]

    if detect(row['prompt']) != 'en':
      continue

    prompt_tokenized = word_tokenize(row['prompt'])
    prompt_tagged = pos_tag(prompt_tokenized)
    prompt_cleaned = [word for word in prompt_tagged if word[1] in ('NN', 'NNS') and word[0] not in special_words]

    words_to_replace = []
    replace_list = []

    for word in prompt_cleaned:
      if word[0] not in vocab:
        continue

      ms = [word[0] for word in model.most_similar(word[0], topn=10)]
      ms_tagged = pos_tag(ms)
      ms_new = [
        replacement[0] for replacement in ms_tagged 
        if replacement[1] == word[1] and
        '_' not in replacement[0] and
        word[0].lower() != replacement[0].lower() and
        model.distance(word[0], replacement[0]) < 0.5
      ]

      words_to_replace.append(word[0])
      ms_new.append(word[0])
  
      if len(ms_new) > 0:
        replace_list.append(ms_new)

    prompt_removed = row['prompt']
    for word in words_to_replace:
      prompt_removed = prompt_removed.replace(word, '{}')

    replace_combinations = list(itertools.product(*replace_list))

    for combination in replace_combinations:
      new_df = pd.concat([new_df, pd.DataFrame(
          [[prompt_removed.format(*combination), row['response']]],
          columns=['prompt', 'response']
      )])

 
  return new_df

new_df = base_augmentation(df)

100%|██████████| 292/292 [01:59<00:00,  2.44it/s]


In [10]:
new_df = new_df.drop_duplicates().sample(frac=1).reset_index(drop=True)

In [11]:
with open('data.jsonl', 'w+') as f:
  new_df.to_json(f, orient='records', lines=True, force_ascii=False)