In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from attlike.utils import (
    load_transformer,
    tokenize,
    perturb,
    dataset_format,
    data_info,
    load_data,
    LikelihoodMode,
    sample_encode,
)
import functools
from typing import Sequence
from tqdm import tqdm
from pathlib import Path
from datasets import DatasetDict, Dataset
from pathlib import Path
import spacy
from spacy import Language
import spacy.cli as spacy_down
from collections import Counter

In [None]:
DEVICE = "cuda"

In [None]:
class SpacyManager:
    _lang2model = {
        "en": "en_core_web_sm",
    }

    @classmethod
    def instantiate(cls, language: str) -> Language:
        model_name = SpacyManager._lang2model[language]

        try:
            pipeline: Language = spacy.load(model_name)
        except Exception as e:  # noqa
            spacy_down.download(model_name)
            pipeline: Language = spacy.load(model_name)

        return pipeline

In [None]:
from nltk.corpus import wordnet as wn
from nltk.corpus.reader import Lemma, Synset

In [None]:
_LEMMA_SEP: str = ","

lang2model = {
    lang: SpacyManager.instantiate(language=lang)
    for lang in [
        "en",
    ]
}
spacy2wn_lang = {"en": "eng"}

In [None]:
from dataclasses import dataclass
import dataclasses


@dataclass(frozen=True)
class Sample:
    synset_id: str
    lemma: str
    sentence: str
    pos: str

In [None]:
synset2samples = {}

for synset in tqdm(wn.all_synsets(), desc="Iterating synsets"):
    """For each synset, we iterate over all the examples and lemmas.
    If the lemma is in the example, and the example is long enough, we add it to the list of samples."""
    synset: Synset

    synset_samples = []
    for lang, spacy_model in lang2model.items():
        wn_lang: str = spacy2wn_lang[lang]
        examples = synset.examples(lang=wn_lang)
        lemmas = synset.lemma_names(lang=wn_lang)

        tokenized_examples = [[token.text for token in spacy_model.tokenizer(example)] for example in examples]
        for lemma in lemmas:
            if "_" in lemma:
                continue
            for example_index in range(len(tokenized_examples)):
                tokenized_example = tokenized_examples[example_index]
                if len(tokenized_example) < 5:
                    continue
                if Counter(tokenized_example).get(lemma, 0) == 1:
                    synset_samples.append(
                        Sample(
                            synset_id=synset.name(),
                            lemma=lemma,
                            sentence=examples[example_index],
                            pos=synset.pos(),
                        )
                    )
    if len(synset_samples) > 0:
        synset2samples[synset.name()] = synset_samples

In [None]:
sum([len(samples) for samples in synset2samples.values()])

In [None]:
real_dataset = Dataset.from_list(
    [dataclasses.asdict(sample) for samples in synset2samples.values() for sample in samples]
)
real_dataset = real_dataset.map(function=lambda _, index: {"index": index}, batched=True, with_indices=True)
real_dataset

In [None]:
encoders: Sequence[str] = [
    "bert-base-uncased",
    "roberta-base",
    "xlm-roberta-base",
    "roberta-large",
    "distilbert-base-uncased",
]

In [None]:
encoder2data = {}
for encoder_name in encoders:
    """For each encoder, we tokenize the dataset."""
    _, tokenizer = load_transformer(transformer_name=encoder_name)

    encoder_data = real_dataset.map(
        functools.partial(tokenize, tokenizer=tokenizer, encoder_name=encoder_name),
        num_proc=1,
        batched=True,
        batch_size=1000,
        desc=f"{encoder_name} tokenization",
    )

    dataset_format(encoder_data)
    encoder2data[encoder_name] = encoder_data
encoder2data

In [None]:
"""We filter out the samples that are not tokenized correctly. """
encoder2data = {
    encoder_name: encoder_data.filter(
        function=lambda start_lemma_index, end_lemma_index: start_lemma_index != -1 and end_lemma_index != -1,
        input_columns=["start_lemma_index", "end_lemma_index"],
    )
    for encoder_name, encoder_data in encoder2data.items()
}

In [None]:
kept_indices = set.intersection(*[set(encoder_data["index"]) for encoder_data in encoder2data.values()])
len(kept_indices)

In [None]:
encoder2data = {
    encoder_name: encoder_data.filter(function=lambda index: index.item() in kept_indices, input_columns=["index"])
    for encoder_name, encoder_data in encoder2data.items()
}
for encoder_name, encoder_data in encoder2data.items():
    dataset_format(encoder_data)

encoder2data

In [None]:
""" We encode each sample in the dataset with each encoder. """
for encoder_name, encoder_data in encoder2data.items():
    encoder, tokenizer = load_transformer(transformer_name=encoder_name)
    encoder = encoder.to(DEVICE)
    encoder_data = encoder_data.map(
        functools.partial(sample_encode, encoder=encoder, tokenizer=tokenizer, encoder_name=encoder_name),
        num_proc=1,
        batched=False,
        with_indices=True,
        desc=f"{encoder_name} sample encoding",
    )
    encoder2data[encoder_name] = encoder_data
    encoder.cpu()

encoder2data

In [None]:
from nn_core.common import PROJECT_ROOT

DATA_DIR: Path = PROJECT_ROOT / "data"

In [None]:
real_data = DatasetDict(encoder2data)
real_data.save_to_disk(DATA_DIR / "real")
real_data

In [None]:
real_data = load_data(DATA_DIR / "real")
real_data

In [None]:
for encoder_name, encoder_data in real_data.items():
    encoder_data = encoder_data.map(
        function=data_info,
        num_proc=8,
        batched=True,
        batch_size=1000,
        with_indices=True,
        desc=f"{encoder_name} data info",
        input_columns=["attention", "start_lemma_index", "end_lemma_index"],
    )
    dataset_format(encoder_data)
    real_data[encoder_name] = encoder_data

real_data.save_to_disk(PROJECT_ROOT / "real")
real_data

In [None]:
real_data = load_data("data/real")
real_data

In [None]:
likelihood_modes = list(LikelihoodMode)

In [None]:
""" We change the likelihood of each sample in the dataset for each encoder."""
encoder2likelihood = {}
for encoder_name, encoder_data in real_data.items():
    all_columns = encoder_data.column_names
    encoder, tokenizer = load_transformer(transformer_name=encoder_name)
    encoder = encoder.to(DEVICE)

    encoder2likelihood[encoder_name] = encoder_data.map(
        function=functools.partial(
            perturb,
            # tokenizer=tokenizer,
            encoder=encoder,
            likelihood_modes=likelihood_modes,
        ),
        num_proc=1,
        batched=False,
        desc=f"{encoder_name} changing likelihood",
        input_columns=[
            "sentence_ids",
            "attention_mask",
            "start_lemma_index",
            "end_lemma_index",
            "lemma_ids",
        ],
        remove_columns=[
            x
            for x in all_columns
            if x
            not in {
                "synset_id",
                "lemma",
                "sentence",
                "pos",
                "index",
                "sentence_ids",
                "sentence_special_mask",
                "attention_mask",
                "lemma_ids",
                "start_lemma_index",
                "end_lemma_index",
            }
            and x in all_columns
        ],
    )
    encoder.cpu()

    dataset_format(encoder2likelihood[encoder_name])
encoder2likelihood

In [None]:
""" We re-encode each perturbed sample in the dataset with each encoder, for each likelihood."""
likelihood2encoder2data = {}
for likelihood_mode in likelihood_modes:
    for encoder_name, encoder_data in encoder2likelihood.items():
        all_columns = encoder_data.column_names
        encoder, tokenizer = load_transformer(transformer_name=encoder_name)
        encoder = encoder.to(DEVICE)

        encoder_data = encoder_data.map(
            functools.partial(
                sample_encode,
                tokenizer=tokenizer,
                encoder=encoder,
                encoder_name=encoder_name,
                likelihood_mode=likelihood_mode,
            ),
            num_proc=1,
            batched=False,
            with_indices=True,
            desc=f"{encoder_name} sample encoding",
            remove_columns=[x for x in LikelihoodMode if x != likelihood_mode and x in all_columns],
        )
        encoder.cpu()

        dataset_format(encoder_data)

        likelihood2encoder2data.setdefault(likelihood_mode, DatasetDict())[encoder_name] = encoder_data
likelihood2encoder2data

In [None]:
likelihood2encoder2data["real"] = real_data
likelihood2encoder2data

In [None]:
likelihood2encoder2data = DatasetDict({k: DatasetDict(v) for k, v in likelihood2encoder2data.items()})
likelihood2encoder2data.save_to_disk(DATA_DIR / "likelihood2encoder2data")

In [None]:
likelihood2encoder2data = DatasetDict(
    {
        likelihood_path.name: DatasetDict.load_from_disk(likelihood_path)
        for likelihood_path in (DATA_DIR / "likelihood2encoder2data").iterdir()
        if not likelihood_path.name.endswith(".json")
    }
)
likelihood2encoder2data

In [None]:
for likelihood_mode, encoder2data in likelihood2encoder2data.items():
    for encoder_name, encoder_data in encoder2data.items():
        encoder_data = encoder_data.map(
            function=data_info,
            num_proc=1,
            batched=True,
            batch_size=1000,
            with_indices=True,
            desc=f"{encoder_name} data info",
            input_columns=["attention", "start_lemma_index", "end_lemma_index"],
        )
        dataset_format(encoder_data)

        likelihood2encoder2data[likelihood_mode][encoder_name] = encoder_data

likelihood2encoder2data.save_to_disk(DATA_DIR / "likelihoods")
likelihood2encoder2data