In [None]:
import pandas as pd

In [None]:
DATASET_PATH = "/workspace/resources/data/restricted/anonymization/data-splits-2.0"

In [None]:
train = pd.read_csv(
    f"{DATASET_PATH}/train-ready.txt",
    sep=" ",
    names=["token", "label"],
    skip_blank_lines=False,
)

val = pd.read_csv(
    f"{DATASET_PATH}/dev-ready.txt",
    sep=" ",
    names=["token", "label"],
    skip_blank_lines=False,
)

# test = pd.read_csv(
#     f"{OUTPUT_DIR}/test.txt",
#     sep=" ",
#     names=["token", "label"],
#     skip_blank_lines=False,
# )

cat_count = dict(train["label"].value_counts())
categories = {v: k for k, v in enumerate(train["label"].unique()) if not pd.isna(v)}
label2code = {k: i for i, (k, v) in enumerate(categories.items())}
code2label = {v: k for k, v in label2code.items()}

print("train:", len(train))
print("val:", len(val))
# print("test:", len(test))
print("nlabels:", len(code2label))

In [None]:
sorted(label2code, key=lambda x: x[2:])

In [None]:
import numpy as np
from more_itertools import unzip
from tqdm.auto import tqdm
from more_itertools import split_at
from functools import cache


def get_hg_format(df):
    df = df.copy()

    df["tags"] = df["label"].map(lambda x: label2code.get(x, label2code["O"]))
    items = []
    indices, rows = unzip(df.iterrows())
    tuples = map(lambda row: (row["token"], row["tags"]), rows)
    tuples = split_at(tuples, lambda x: pd.isna(x[0]))

    for i, paragraph in enumerate(tuples):
        df = pd.DataFrame(paragraph)
        if not len(df):
            continue
        tokens = df[0].values
        labels = df[1].values.astype(int)
        nlabels = (len(labels[labels > 0]),)

        if any(np.isnan(labels)):
            continue

        if len(tokens) != len(labels):
            print("mismatch size")
            continue

        items.append(
            {
                "n_labels": nlabels,
                "tokens": list(tokens),
                "tags": labels,
            }
        )

    return items

In [None]:
from tqdm import tqdm
from datasets import DatasetDict, Dataset

dataset = DatasetDict(
    {
        "train": Dataset.from_list(get_hg_format(train)),
        "validation": Dataset.from_list(get_hg_format(val)),
        # "test": Dataset.from_list(get_hg_format(test)),
    }
)
dataset

### drop duplicates

In [None]:
df_train = dataset["train"].to_pandas()
df_dev = dataset["validation"].to_pandas()
# df_test = dataset["test"].to_pandas()

# apply hash to fast compare dupplicated
df_train["hash"] = df_train["tokens"].str.join(" ").apply(hash)
df_dev["hash"] = df_dev["tokens"].str.join(" ").apply(hash)
# df_test["hash"] = df_test["tokens"].str.join(" ").apply(hash)

# drop duplicates
df_train.drop_duplicates(subset="hash", inplace=True)
df_dev.drop_duplicates(subset="hash", inplace=True)
# df_test.drop_duplicates(subset="hash", inplace=True)

# get train hashes
train_hash = set(df_train["hash"])
dev_hash = set(df_dev["hash"])

In [None]:
from aymurai.utils.display.pandas import pandas_context

options = {}
options["display.max_colwidth"] = 0

with pandas_context(**options):
    aux = df_train.query("hash in @train_hash and hash in @dev_hash")
    aux["ntags"] = aux["tags"].apply(lambda x: np.sum(x))
    display(aux.query("ntags > 0"))
    # display(aux)

In [None]:
# drop paragraphs shared between datasets
df_dev.query("hash not in @train_hash", inplace=True)
# df_test.query("hash not in @train_hash and hash not in @dev_hash", inplace=True)

In [None]:
dataset["train"] = Dataset.from_pandas(df_train)
dataset["validation"] = Dataset.from_pandas(df_dev)
# dataset["test"] = Dataset.from_pandas(df_test)

In [None]:
set(dataset["train"]["hash"]).intersection(set(dataset["validation"]["hash"]))

In [None]:
dataset

## save dataset

In [None]:
import srsly

DATASET_NAME = (
    "/resources/data/restricted/anonymization/annonimization-dataset-pruned-2023-09-04"
)

dataset.save_to_disk(DATASET_NAME)
with open(f"{DATASET_NAME}/label_mapping.json", "w") as file:
    json = srsly.json_dumps(label2code)
    file.write(json)