In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import srsly
import random

from rich.pretty import pprint
from datasets import load_from_disk

from aymurai.data_augmentation import DataAugmenter

from aymurai.data_augmentation.anonymizer_entities import (
    augmentation_functions,
    faker,
)

In [None]:
# faker.seed_instance(42)
faker.seed_instance(None)
for i in range(10):
    print(faker.name())

## Dataset

In [None]:
from datasets import load_from_disk
import srsly

DATASET_NAME = "/resources/data/restricted/anonymization/datasets/anonymization-dataset-pruned-2023-09-06"
dataset = load_from_disk(DATASET_NAME)

with open(f"{DATASET_NAME}/label_mapping.json") as file:
    label2code = srsly.json_loads(file.read())
    code2label = {v: k for k, v in label2code.items()}

print(dataset)
print("nlabels:", len(code2label))

In [None]:
train = dataset["train"]
train

In [None]:
train_labeled = train.filter(lambda x: x["n_labels"][0] > 0)

In [None]:
train_labeled

In [None]:
data_augmenter = DataAugmenter(code2label, augmentation_functions, random_state=42)

In [None]:
# sample = random.choice(train)
sample = train[2747]
print(sample["tokens"])

In [None]:
# if you want to reset the seed uncomment following
# faker.seed_instance(42)

augmented_sample = data_augmenter.augment_sample(sample)
print(augmented_sample["tokens"])

## Augmentation

In [None]:
data_augmenter = DataAugmenter(
    code2label=code2label,
    augmentation_functions=augmentation_functions,
    random_state=42,
)

In [None]:
import numpy as np
import pandas as pd

In [None]:
len(dataset)

In [None]:
from datasets import DatasetDict, concatenate_datasets

data_augmenter.random_state = 42

datadict = DatasetDict()

datadict["rebalanced-7k"] = data_augmenter.augment_dataset(
    train_labeled, frac=1, weighted=True
)
datadict["rebalanced-3k"] = data_augmenter.augment_dataset(
    train_labeled, frac=0.5, weighted=True
)

for copies in range(1, 5):
    tag = len(train_labeled) * copies // 1000
    name = f"inbalanced-{tag}k"
    print(f"generating {name}")
    datadict[name] = concatenate_datasets(
        [
            train_labeled.map(
                lambda x: data_augmenter.augment_sample(x), load_from_cache_file=True
            )
            for _ in range(copies)
        ]
    )

datadict

# Check examples

In [None]:
import random
from aymurai.utils.display.pandas import pandas_context

options = {}
options["display.max_columns"] = 0
options["display.max_colwidth"] = 0

revision = datadict["rebalanced-7k"]

item = random.choice(revision)

original = train_labeled.filter(lambda x: x["hash"] == item["original_hash"])[0]
with pandas_context(**options):
    aux = pd.DataFrame(
        {
            "labels": [code2label[code] for code in original["tags"]],
            "tokens": original["tokens"],
        }
    )
    display(aux.T)
original

# augmented
samples = revision.filter(lambda x: x["original_hash"] == original["hash"])

print("total augmented:", len(samples))
for i, sample in enumerate(samples):
    aux = pd.DataFrame(
        {
            "labels": [code2label[code] for code in sample["tags"]],
            "tokens": sample["tokens"],
        }
    )
    with pandas_context(**options):
        display(aux.T)

In [None]:
from aymurai.data_augmentation.utils import compute_label_counts

In [None]:
total_counts = pd.DataFrame()

# original
counts = compute_label_counts(dataset=train, code2label=code2label)
counts = pd.DataFrame(counts, index=pd.Index(["full original"]))
counts = counts.sort_values(axis=1, by="full original", ascending=False)
total_counts = pd.concat([total_counts, counts])

# separator
sep = counts.map(lambda x: "-")
sep.index = pd.Index(["-"])
total_counts = pd.concat([total_counts, sep])

# augmented
for name, dataset in datadict.items():
    counts = compute_label_counts(dataset=dataset, code2label=code2label)
    counts = pd.DataFrame(counts, index=pd.Index([name]))
    total_counts = pd.concat([total_counts, counts])

with pandas_context(**options):
    display(total_counts)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

fig, ax = plt.subplots(1, 1, figsize=(10, 2))

df = pd.DataFrame()
for name, dataset in datadict.items():
    hash_count = dataset.to_pandas()["original_hash"].value_counts()
    hash_count = pd.DataFrame(hash_count).reset_index()
    hash_count["dataset"] = name
    df = pd.concat([df, hash_count], ignore_index=True)

sns.boxplot(data=df, x="count", y="dataset", ax=ax)

ax.tick_params()
ax.grid(visible=True, which="both")
ax.set_xscale("symlog")
ax.xaxis.set_minor_locator(MultipleLocator(1))
ax.set_xlim(xmin=0)

# Save dataset

In [None]:
AUGMENTED__DATASET_NAME = "/resources/data/restricted/anonymization/datasets/anonymization-dataset-augmented-2023-09-06"

datadict.save_to_disk(AUGMENTED__DATASET_NAME)
with open(f"{AUGMENTED__DATASET_NAME}/label_mapping.json", "w") as file:
    json = srsly.json_dumps(label2code)
    file.write(json)