## Utilities

In [37]:
%run ../common_functions.py

read_dataset()
read_all_datasets()
read_processed_data()
combine_text_columns()
split_train_test()
evaluate_model()
detect_missing_values()


## Imports

In [38]:
import yaml
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from random import choice

In [39]:
from data_preprocessing.text_cleaning import (
    remove_html,
    remove_urls,
    remove_digits,
    remove_punctuation,
    remove_special_characters,
    normalize_whitespace,
    to_lowercase
)
from data_preprocessing.tokenizer import tokenize_text
from data_preprocessing.feature_extraction import extract_features

In [40]:
from data_augmentation.synonym_replacement import replace_synonyms_database, replace_synonyms_llm
from data_augmentation.paraphrasing import paraphrase_llm
from data_augmentation.style_transfer import transfer_style_llm

## Load config

In [41]:
with open("config.yml", "r") as f:
    config = yaml.safe_load(f)

## Load Data

In [42]:
data = read_all_datasets()

data = data.head(10)

# Orchestration

## Data Augmentation

### Synonym Replacement

In [43]:
if config["augmentation"]["synonym_replacement"]["enabled"]:
    n_samples = int(max(config["augmentation"]["synonym_replacement"]["n_samples"] * len(data), 1))
    n_replacements = 2

    sampled = data.sample(n=n_samples, random_state=42).copy()
    augmented_texts = []

    if config["augmentation_methods"]["local_database"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(replace_synonyms_database(text, n_replacements))
    elif config["augmentation_methods"]["llm"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(replace_synonyms_llm(text, n_replacements))
    
    if len(augmented_texts):
        sampled['text'] = augmented_texts

        data = pd.concat([data, sampled], ignore_index=True)

  0%|          | 0/1 [00:00<?, ?it/s]

### Paraphrasing

In [44]:
if config["augmentation"]["paraphrasing"]["enabled"]:
    n_samples = int(max(config["augmentation"]["paraphrasing"]["n_samples"] * len(data), 1))

    sampled = data.sample(n=n_samples, random_state=42).copy()
    augmented_texts = []

    if config["augmentation_methods"]["llm"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(paraphrase_llm(text))
    
        sampled['text'] = augmented_texts

        data = pd.concat([data, sampled], ignore_index=True)

  0%|          | 0/1 [00:00<?, ?it/s]

### Style Transfer

In [45]:
if config["augmentation"]["style_transfer"]["enabled"]:
    n_samples = int(max(config["augmentation"]["paraphrasing"]["n_samples"] * len(data), 1))
    target_styles = [
        "Neutral journalistic",
        "Slightly formal",
        "Slightly informal",
        "Objective report",
        "Simplified for general audience",
        "Summarized headline-style",
        "Explanatory tone",
        "Bullet-point format",
        "Skeptical tone",
        "Curious/inquisitive tone"]

    sampled = data.sample(n=n_samples, random_state=42).copy()
    augmented_texts = []

    if config["augmentation_methods"]["llm"]:
        for text in tqdm(sampled['text'], total=n_samples):
            augmented_texts.append(transfer_style_llm(text, choice(target_styles)))
    
        sampled['text'] = augmented_texts

        data = pd.concat([data, sampled], ignore_index=True)

  0%|          | 0/1 [00:00<?, ?it/s]

## Data Preprocessing

### Text Cleaning

In [46]:
def clean_text(text, config):
    if config.get("strip_html"):
        text = remove_html(text)
    if config.get("remove_urls"):
        text = remove_urls(text)
    if config.get("remove_digits"):
        text = remove_digits(text)
    if config.get("remove_punctuation"):
        text = remove_punctuation(text)
    if config.get("remove_special_chars"):
        text = remove_special_characters(text)
    if config.get("lowercase"):
        text = to_lowercase(text)
    if config.get("normalize_whitespace"):
        text = normalize_whitespace(text)
    
    return text

In [47]:
data["text"] = data["text"].apply(lambda x: clean_text(x, config["preprocessing"]))

### Tokenization

In [48]:
data_tokenized = tokenize_text(data, tokenizer_name=config["tokenization"]["tokenizer_model"], max_length=config["tokenization"]["sequence_max_length"])

Tokenizing: 100%|██████████| 13/13 [00:00<00:00, 7518.75it/s]


### Feature Extraction

In [49]:
X = extract_features(data_tokenized, model_name=config["tokenization"]["tokenizer_model"], batch_size=32)

Extracting features: 100%|██████████| 1/1 [00:01<00:00,  1.11s/it]


### Match labels

In [50]:
y = data["label"].to_numpy(dtype=np.int32)

## Save Output

In [51]:
augmented = any([config["augmentation"]["synonym_replacement"]["enabled"],
                config["augmentation"]["paraphrasing"]["enabled"],
                config["augmentation"]["style_transfer"]["enabled"]])

np.savez_compressed(f"output/preprocessed_data{'_augmented' if augmented else ''}.npz", X=X, y=y)