In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import srsly
import random

from rich.pretty import pprint
from datasets import load_from_disk

from aymurai.data_augmentation import DataAugmenter

from aymurai.data_augmentation.anonymizer_entities import (
    augmentation_functions,
    faker,
    Faker,
)

In [None]:
# faker.seed_instance(42)
faker.seed_instance(None)
for i in range(10):
    print(faker.name())

## Dataset

In [None]:
from datasets import load_from_disk
import srsly

DATASET_NAME = "/resources/data/restricted/anonymization/datasets/anonymization-dataset-pruned-2023-09-06"
dataset = load_from_disk(DATASET_NAME)

with open(f"{DATASET_NAME}/label_mapping.json") as file:
    label2code = srsly.json_loads(file.read())
    code2label = {v: k for k, v in label2code.items()}

print(dataset)
print("nlabels:", len(code2label))

In [None]:
train = dataset["train"]
train

In [None]:
train_labeled = train.filter(lambda x: x["n_labels"][0] > 0)

In [None]:
train_labeled

In [None]:
data_augmenter = DataAugmenter(code2label, random_state=42)

In [None]:
# sample = random.choice(train)
sample = train[2747]
print(sample["tokens"])

In [None]:
data_augmenter = DataAugmenter(code2label, random_state=42)
# faker.seed_instance(42)
augmented_sample = data_augmenter.augment(sample)
print(augmented_sample["tokens"])

## Augmentation

In [None]:
data_augmenter = DataAugmenter(augmentation_functions, code2label, random_state=42)

In [None]:
import numpy as np
import pandas as pd

In [None]:
def anonymize(example):
    return data_augmenter.augment(example)


anonymized_train = train.map(anonymize)
anonymized_train

In [None]:
import re

from tqdm.auto import tqdm


def compute_label_weights(
    dataset,
    ignore_labels: list[str] = ["O", "PER", "FECHA"],
) -> dict[str, float]:
    counts = []
    for example in tqdm(dataset, total=len(dataset)):
        labels = [code2label[code] for code in example["tags"]]
        labels = [re.sub(r"[BI]-", "", label) for label in labels]
        labels, count = np.unique(labels, return_counts=True)

        counts.append({l: c for l, c in zip(labels, count)})
    counts = pd.DataFrame(counts)
    counts = counts.drop(columns=["O", "PER", "FECHA"])
    counts = counts.sum()

    label_weights = counts.sum() / counts
    label_weights /= label_weights.min()
    label_weights = label_weights.to_dict()
    return label_weights

In [None]:
label_weights = compute_label_weights(train)

In [None]:
label_weights

In [None]:
def get_weight(example):
    labels = [code2label[code] for code in example["tags"]]
    labels = [re.sub(r"[BI]-", "", label) for label in labels]
    weights = [label_weights.get(label, 0) for label in labels]

    example["weight"] = max(weights)
    return example


wtrain = train.map(get_weight)
wtrain

In [None]:
from datasets import Dataset
from enum import Enum, auto


def augment_dataset(
    dataset,
    frac: float = 1,
    weights: str = "labels_max",
    random_state: int | None = None,
):
    if weights == "labels_max":
        w = dataset.map(get_weight)
        w = w["weight"]
    else:
        w = None

    resampled = dataset.to_pandas().sample(
        frac=frac,
        weights=w,
        replace=True,
        random_state=random_state,
    )
    resampled = Dataset.from_pandas(resampled)
    resampled = resampled.remove_columns(["__index_level_0__"])

    data_augmenter = DataAugmenter(
        # augmentation_functions,
        code2label,
        random_state=None,
    )

    def anonymize(example):
        return data_augmenter.augment(example)

    resampled = resampled.map(anonymize)

    return resampled

In [None]:
aug_train = augment_dataset(
    train_labeled, frac=1, random_state=None, weights="labels_max"
)
aug_train

In [None]:
import random
from aymurai.utils.display.pandas import pandas_context

options = {}
options["display.max_columns"] = 0
# options["display.max_cols"] = 0

train_df = train_labeled.to_pandas()
aug_train_df = aug_train.to_pandas()

item = random.choice(aug_train)

index = train_df[train_df["hash"] == item["original_hash"]].index[0]
aux = train_labeled[int(index)]
original_hash = aux.pop("hash")
print("nlabels:", aux.pop("n_labels"))
print("original:", original_hash)
with pandas_context(**options):
    display(pd.DataFrame(aux).T)


# augmenteds
samples = aug_train.filter(lambda x: x["original_hash"] == original_hash)

print("total augmented:", len(samples))
for i, sample in enumerate(samples):
    sample.pop("original_hash")
    sample.pop("n_labels")
    print(f"augmented {i:03.0f}:", sample.pop("hash"))
    with pandas_context(**options):
        display(pd.DataFrame(sample).T)

In [None]:
import re

from tqdm import tqdm

df = pd.DataFrame()
for i, example in tqdm(aug_train.to_pandas().iterrows()):
    labels = [code2label[code] for code in example["tags"]]
    labels = [re.sub(r"[BI]-", "", label) for label in labels]
    labels, count = np.unique(labels, return_counts=True)

    counts_df = pd.DataFrame({l: c for l, c in zip(labels, count)}, index=pd.Index([0]))
    df = pd.concat([df, counts_df], ignore_index=True)

In [None]:
_df = df.copy()
_df = _df.drop(columns=["O"])

counts = df.sum()
counts = counts.drop("O")
counts

In [None]:
len(df)

In [None]:
from datasets import DatasetDict

datadict = DatasetDict(
    {
        "anonymized_train": anonymized_train,
        "augmented_train": aug_train,
    }
)
datadict

In [None]:
datadict.save_to_disk(
    "/resources/data/restricted/anonymization/anonymization-dataset-augmented-2023-09-06/"
)