In [1]:
from omegaconf import OmegaConf

In [11]:
data_cfg = OmegaConf.load("/root/similaritybench/nlp/config/dataset/sst2.yaml")
model_cfg = OmegaConf.load("/root/similaritybench/nlp/config/model/multibert.yaml")

print(model_cfg)
print(data_cfg)

{'name': 'google/multiberts-seed_${.seed}', 'name_human': 'multibert-${.seed}', 'seed': 0, 'remove_sos_token': False, 'token_pos': 0, 'kwargs': {'tokenizer_name': 'google/multiberts-seed_${..seed}', 'model_type': None}}
{'path': 'sst2', 'name': None, 'split': 'test', 'prompt_template': None, 'feature_column': ['sentence'], 'target_column': 'label', 'finetuning': {'num_labels': 2, 'trainer': {'_target_': 'transformers.Trainer', 'args': {'_target_': 'transformers.TrainingArguments', 'output_dir': '${hydra:runtime.output_dir}', 'overwrite_output_dir': True, 'warmup_ratio': 0.1, 'evaluation_strategy': 'steps', 'eval_steps': 1000, 'save_steps': 1000, 'per_device_train_batch_size': 64, 'per_device_eval_batch_size': 64, 'seed': 123456789, 'num_train_epochs': 10, 'save_total_limit': 2, 'load_best_model_at_end': True}}, 'eval_dataset': ['validation']}}


In [10]:
from repsim.nlp import get_dataset, get_tokenizer
import transformers

In [None]:
ds = get_dataset(data_cfg.path, data_cfg.name, )

In [12]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_cfg.kwargs.tokenizer_name, additional_special_tokens=[f"[CLASS{i}]" for i in range(data_cfg.finetuning.num_labels)])

In [13]:
tokenizer

BertTokenizerFast(name_or_path='google/multiberts-seed_0', vocab_size=30522, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]', 'additional_special_tokens': ['[CLASS0]', '[CLASS1]']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	30522: AddedToken("[CLASS0]", rstrip=F

In [36]:
from bert_finetune import tokenize_function
from functools import partial
from typing import Any
import numpy as np

In [54]:
from typing import Any


dataset_name = data_cfg.path + "__" + data_cfg.name if data_cfg.name is not None else data_cfg.path
feature_column = data_cfg.feature_column[0]


class ShortcutAdder():
    def __init__(self,  num_labels: int, p:float, seed:int=123457890) -> None:
        self.num_labels= num_labels
        self.labels = np.arange(num_labels)
        self.p = p
        self.seed = seed
        self.rng = np.random.default_rng(seed)

    def __call__(self, example: dict[str, Any]) -> dict[str, str]:
        label = example['label']
        if self.rng.random() < self.p:
            added_tok = f"[CLASS{label}] "
        else:
            added_tok = f"[CLASS{self.rng.choice(self.labels[self.labels != label])}] "
        return {"sentence_w_shortcut": added_tok + example["sentence"]}


ds_w_shortcut = ds.map(ShortcutAdder(data_cfg.finetuning.num_labels, 0.75))
tokenized_dataset = ds_w_shortcut.map(
        partial(tokenize_function, tokenizer=tokenizer, dataset_name=dataset_name, feature_column="sentence_w_shortcut"),
        batched=True,
    )

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [55]:
print(tokenizer.additional_special_tokens)
print(tokenizer.additional_special_tokens_ids)

additional_tokids_to_toks = {
    idx: tok for tok, idx in zip(tokenizer.additional_special_tokens, tokenizer.additional_special_tokens_ids)
}


def shortcut_eq_label(example: dict[str, Any]) -> dict[str, str]:
    label = example["label"]
    added_tok_id = example["input_ids"][1]
    shortcut_label = int(additional_tokids_to_toks[added_tok_id][6:-1])
    print(label, shortcut_label)
    return {"shortcut_eq_label": label == shortcut_label}

new_ds = tokenized_dataset["validation"].map(shortcut_eq_label)
sum(new_ds["shortcut_eq_label"])/len(new_ds["shortcut_eq_label"])


['[CLASS0]', '[CLASS1]']
[30522, 30523]


Map:   0%|          | 0/872 [00:00<?, ? examples/s]

1 1
0 0
1 1
1 1
0 1
1 0
0 0
0 0
1 1
0 0
1 1
0 0
0 0
1 1
0 0
1 1
1 0
1 1
0 0
0 0
0 0
0 0
0 0
1 1
1 1
0 0
0 0
1 0
0 0
0 0
1 1
0 1
1 0
0 1
0 1
0 1
1 1
0 0
1 1
1 1
1 1
1 1
1 1
1 1
0 1
0 0
0 1
1 1
1 1
0 0
0 0
1 0
1 1
1 1
0 0
1 1
0 0
0 0
0 0
0 0
1 1
0 1
1 1
1 1
0 1
0 0
1 1
1 1
1 1
0 0
0 0
1 1
1 1
1 1
0 0
1 1
0 0
1 0
1 0
0 0
1 1
0 0
0 0
1 1
1 1
1 1
0 0
1 1
1 1
1 1
1 1
1 1
0 0
1 1
1 0
0 1
0 0
1 1
0 0
0 1
1 1
0 0
1 1
1 1
1 1
0 0
1 1
0 1
0 0
1 1
0 0
0 0
1 1
0 0
1 1
0 1
1 1
1 1
0 1
0 1
1 1
0 1
1 1
1 0
1 1
1 1
0 0
0 0
1 1
1 1
0 0
0 0
1 1
0 0
0 0
0 1
0 0
0 0
1 0
0 0
0 1
0 1
1 1
1 1
0 0
0 1
1 0
0 0
0 0
1 0
1 1
1 1
1 1
0 0
1 0
0 0
1 1
0 1
0 0
1 1
0 1
0 0
0 0
0 0
1 0
0 0
0 0
0 0
1 0
1 0
1 1
0 0
1 1
0 0
1 1
1 0
0 0
0 0
0 0
1 1
0 0
0 0
0 0
0 0
0 0
1 1
1 1
1 1
0 1
0 0
0 0
1 1
1 1
1 0
1 1
1 1
1 1
0 0
1 1
0 0
1 1
1 1
0 0
0 1
1 1
1 0
1 1
1 1
0 0
0 0
1 0
0 0
0 1
0 0
0 0
1 0
1 1
0 0
0 0
1 1
1 1
0 1
1 0
1 1
1 0
1 1
0 1
1 1
1 0
0 0
1 1
1 0
0 0
0 1
1 1
1 1
1 0
0 0
0 1
1 1
1 1
0 0
0 0
1 0
1 1
1 1
1 0
0 0
1 1
1 1


0.7522935779816514

In [35]:
tokenized_dataset["validation"][1]

{'idx': 1,
 'sentence': 'unflinchingly bleak and desperate ',
 'label': 0,
 'sentence_w_shortcut': '[CLASS0] unflinchingly bleak and desperate ',
 'input_ids': [101,
  30522,
  4895,
  10258,
  2378,
  8450,
  2135,
  21657,
  1998,
  7143,
  102,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,