## Utilities

In [2]:
%run ../common_functions.py

read_dataset()
read_all_datasets()
read_processed_data()
combine_text_columns()
split_train_test()
evaluate_model()
detect_missing_values()


## Imports

In [2]:
import yaml
import pandas as pd
from tqdm.auto import tqdm
from random import choice

In [1]:
from data_preprocessing.text_cleaning import (
    remove_html,
    remove_urls,
    remove_digits,
    remove_punctuation,
    remove_special_characters,
    normalize_whitespace,
    to_lowercase
)
from data_preprocessing.add_domain import add_domain
from data_preprocessing.find_and_remove_duplicates import find_and_remove_dups

In [4]:
from data_augmentation.synonym_replacement import replace_synonyms_database, replace_synonyms_llm
from data_augmentation.paraphrasing import paraphrase_llm
from data_augmentation.style_transfer import transfer_style_llm

In [4]:
from constants import DATASETS

## Load config

In [6]:
with open("config.yml", "r") as f:
    config = yaml.safe_load(f)

## Load Data

In [4]:
data = read_all_datasets()
data = data.sample(n=10_000)

# -- Orchestration --

## Data Preprocessing

### Find and Remove duplicate Data

In [8]:
datasets = {dataset: read_dataset(dataset) for dataset in DATASETS}
data = find_and_remove_dups(data, datasets)

Total rows                 : 40000
Unique 'text' entries      : 38624
Duplicate 'text' entries   : 1376 (3.44% of total)

==DUPLICATES WITHIN DATASET==
Found 0 duplicated rows for Dataset 'Climate-FEVER'
Row count: 1535 → 1535 (-0.00%)
Found 7335 duplicated rows for Dataset 'Fake News Corpus'
Row count: 70846 → 63511 (-10.35%)
Found 29 duplicated rows for Dataset 'Fake News Prediction'
Row count: 6335 → 6306 (-0.46%)
Found 40 duplicated rows for Dataset 'Fakeddit'
Row count: 7000 → 6960 (-0.57%)
Found 7518 duplicated rows for Dataset 'FEVER'
Row count: 109810 → 102292 (-6.85%)
Found 1408 duplicated rows for Dataset 'FineFake'
Row count: 16909 → 15501 (-8.33%)
Found 5793 duplicated rows for Dataset 'ISOT Fake News'
Row count: 44898 → 39105 (-12.90%)
Found 34 duplicated rows for Dataset 'LIAR2'
Row count: 22962 → 22928 (-0.15%)
Found 3034 duplicated rows for Dataset 'llm-misinformation'
Row count: 7057 → 4023 (-42.99%)
Found 40 duplicated rows for Dataset 'Source based FN'
Row count: 205

### Domain Tagging

In [5]:
data = add_domain(data)

Computing sentence embeddings...


Batches:   0%|          | 0/20 [00:00<?, ?it/s]

Initializing KMeans with k=12 Clusters...
Predicting domains with KMeans...
Mapping domain IDs to topic names...


## Data Augmentation

### Synonym Replacement

In [7]:
if config["augmentation"]["synonym_replacement"]["enabled"]:
    n_samples = int(max(config["augmentation"]["synonym_replacement"]["n_samples"] * len(data), 1))
    n_replacements = 2

    sampled = data.sample(n=n_samples, random_state=42).copy()
    augmented_texts = []

    if config["augmentation_methods"]["local_database"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(replace_synonyms_database(text, n_replacements))
    elif config["augmentation_methods"]["llm"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(replace_synonyms_llm(text, n_replacements))
    
    if len(augmented_texts):
        sampled['text'] = augmented_texts

        data = pd.concat([data, sampled], ignore_index=True)

### Paraphrasing

In [8]:
if config["augmentation"]["paraphrasing"]["enabled"]:
    n_samples = int(max(config["augmentation"]["paraphrasing"]["n_samples"] * len(data), 1))

    sampled = data.sample(n=n_samples, random_state=42).copy()
    augmented_texts = []

    if config["augmentation_methods"]["llm"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(paraphrase_llm(text))
    
        sampled['text'] = augmented_texts

        data = pd.concat([data, sampled], ignore_index=True)

### Style Transfer

In [9]:
if config["augmentation"]["style_transfer"]["enabled"]:
    n_samples = int(max(config["augmentation"]["paraphrasing"]["n_samples"] * len(data), 1))
    target_styles = [
        "Neutral journalistic",
        "Slightly formal",
        "Slightly informal",
        "Objective report",
        "Simplified for general audience",
        "Summarized headline-style",
        "Explanatory tone",
        "Bullet-point format", #TODO change?
        "Skeptical tone",
        "Curious/inquisitive tone"]

    sampled = data.sample(n=n_samples, random_state=42).copy()
    augmented_texts = []

    if config["augmentation_methods"]["llm"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(transfer_style_llm(text, choice(target_styles)))
    
        sampled['text'] = augmented_texts

        data = pd.concat([data, sampled], ignore_index=True)

## Data Preprocessing

### Text Cleaning

In [9]:
def clean_text(text, config):
    if config.get("strip_html"):
        text = remove_html(text)
    if config.get("remove_urls"):
        text = remove_urls(text)
    if config.get("remove_digits"):
        text = remove_digits(text)
    if config.get("remove_punctuation"):
        text = remove_punctuation(text)
    if config.get("remove_special_chars"):
        text = remove_special_characters(text)
    if config.get("lowercase"):
        text = to_lowercase(text)
    if config.get("normalize_whitespace"):
        text = normalize_whitespace(text)
    
    return text

In [10]:
data["text"] = data["text"].apply(lambda x: clean_text(x, config["preprocessing"]))

Stripping out HTML tags…
Erasing any URLs…
Removing all digits…
Dropping punctuation marks…
Filtering out special characters…
Converting text to lowercase…
Normalizing whitespace…
Stripping out HTML tags…
Erasing any URLs…
Removing all digits…
Dropping punctuation marks…
Filtering out special characters…
Converting text to lowercase…
Normalizing whitespace…
Stripping out HTML tags…
Erasing any URLs…
Removing all digits…
Dropping punctuation marks…
Filtering out special characters…
Converting text to lowercase…
Normalizing whitespace…
Stripping out HTML tags…
Erasing any URLs…
Removing all digits…
Dropping punctuation marks…
Filtering out special characters…
Converting text to lowercase…
Normalizing whitespace…
Stripping out HTML tags…
Erasing any URLs…
Removing all digits…
Dropping punctuation marks…
Filtering out special characters…
Converting text to lowercase…
Normalizing whitespace…
Stripping out HTML tags…
Erasing any URLs…
Removing all digits…
Dropping punctuation marks…
Filterin

## Save Output

In [11]:
augmented = any([config["augmentation"]["synonym_replacement"]["enabled"],
                config["augmentation"]["paraphrasing"]["enabled"],
                config["augmentation"]["style_transfer"]["enabled"]])

data.to_parquet(f"{config['output']}/preprocessed_data{'_augmented' if augmented else ''}.parquet", index=False)