In [1]:
import os
import transformers
import datasets
import torch
import inspect
import numpy as np
import torch.nn as nn
from scipy.stats import pearsonr, spearmanr
from datasets import load_dataset, concatenate_datasets, DatasetDict, Dataset
from dataclasses import dataclass, field, asdict
from typing import *

from torch.utils.data import DataLoader
from transformers import TrainingArguments as TrainingArgumentsBase
from transformers import IntervalStrategy, SchedulerType
from transformers.training_args import OptimizerNames
from transformers.utils.generic import PaddingStrategy
from transformers.tokenization_utils_base import BatchEncoding
from transformers.debug_utils import DebugOption
from transformers.trainer_utils import get_last_checkpoint
from transformers.activations import ACT2FN
from transformers import (
    AutoConfig,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    PreTrainedTokenizerBase,
    PreTrainedModel,
    AutoModel,
    DataCollatorWithPadding,
    default_data_collator,
)

In [2]:
# parser args
@dataclass
class TrainingArguments(TrainingArgumentsBase):
    # general
    output_dir: str = field(default='output/0')
    overwrite_output_dir: bool = field(default=True)
    do_train: bool = field(default=True)
    do_eval: bool = field(default=True)
    do_predict: bool = field(default=True)
    seed: int = field(
        default=42,
        metadata={
            "help": "Random seed that will be set at the beginning of training."
        }
    )
    load_best_model_at_end: bool = field(default=True)
    fp16: bool = field(default=True)  # precision
    metric_for_best_model: str = field(default='eval_spearmanr')
    include_inputs_for_metrics: bool = field(
        default=False,
        metadata={
            "help": "Whether or not the inputs will be passed to the `compute_metrics` function."
        }
    )  # for extra info
    # greater_is_better: Optional[bool] = field(default=None,
    #     metadata={
    #         "help": (
    #             "Whether the `metric_for_best_model` should be maximized or not."
    #             "- `True` if doesn't end in `'loss'`."
    #             "- `False` if ends in `'loss'`."
    #         )
    #     })
    # debug: Union[str, List[DebugOption]] = field(default="underflow_overflow",
    #     metadata={
    #         "help": (
    #             "Whether or not to enable debug mode. Current options: "
    #             "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
    #             "`tpu_metrics_debug` (print debug metrics on TPU)."
    #         )
    #     })

    # dataset
    num_train_epochs: float = field(default=5)
    per_device_train_batch_size: int = field(default=8)
    dataloader_prefetch_factor: Optional[int] = field(
        default=1,
        metadata={
           "help": "Number of batches loaded in advance by each worker."
        }
    )
    dataloader_num_workers: int = field(
        default=2,
        metadata={
            "help": (
                "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded"
                " in the main process."
            )
        },
    )

    # learning rate
    optim: Union[OptimizerNames, str] = field(default='adamw_torch', metadata={"help": "The optimizer to use."}, )
    learning_rate: float = field(default=3e-5, metadata={"help": "The initial learning rate for AdamW."})
    weight_decay: float = field(default=0.1, metadata={"help": "Weight decay for AdamW if we apply some."})
    lr_scheduler_type: Union[SchedulerType, str] = field(default="linear")
    warmup_ratio: float = field(
        default=0.1,
        metadata={
           "help": "Ratio of total training steps used for a linear warmup from 0 to `learning_rate`."
        }
    )

    # gradient
    max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
    gradient_accumulation_steps: int = field(default=1)  # if>1, real save/eval_steps = it * ori_steps

    # eval and save
    eval_strategy: Union[IntervalStrategy, str] = field(default='epoch')
    save_strategy: Union[IntervalStrategy, str] = field(default="epoch")
    save_total_limit: Optional[int] = field(default=1)

    # log/progress bar
    log_level: Optional[str] = field(default="info")
    disable_tqdm: Optional[bool] = field(default=False)

    # callback
    restore_callback_states_from_checkpoint: bool = field(
        default=True,
        metadata={
            "help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
        }
    )

@dataclass
class DataTrainingArguments:
    max_seq_length: int = field(default=512)
    padding: str = field(default='longest', metadata={"help": "The padding strategy to use.(longest/max_length)"})
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
        },
    )
    train_file: Optional[str] = field(
        default="data/csts_train.csv",
        metadata={"help": "A csv or a json file containing the training data."},
    )
    validation_file: Optional[str] = field(
        default="data/csts_validation.csv",
        metadata={"help": "A csv or a json file containing the validation data."},
    )
    test_file: Optional[str] = field(
        default="data/csts_test.csv",
        metadata={"help": "A csv or a json file containing the test data."},
    )

@dataclass
class TokenizerAndModelArguments:
    model_name_or_path: str = field(
        default="princeton-nlp/sup-simcse-roberta-base",
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    cl_temp: float = field(default=1.5, metadata={"help": "Temperature for contrastive loss."})
    freeze_encoder: Optional[bool] = field(
        default=True, metadata={"help": "Freeze encoder weights."}
    )
    transform: Optional[bool] = field(
        default=False,
        metadata={"help": "Use a linear transformation on the encoder output"},
    )

In [27]:
datasets.disable_progress_bars()
raw_datasets = load_dataset("csv", data_files={'train': 'data/csts_train.csv', 'test': 'data/csts_test.csv', 'validation': 'data/csts_validation.csv'})

In [24]:
def unbatch(examples):
    out = {}
    for k,v in examples.items():
        tv = []
        for iv in v:
            tv = tv + iv
        out[k] = tv
    return out

def scale_to_range(labels:List, scale:tuple):
    min_, max_ = scale
    return list(map(lambda x: (x - min_) / (max_ - min_), labels))

def preprocess_func(examples, tokenizer: PreTrainedTokenizerBase,
                    sentence1_key: str, sentence2_key: str, condition_key: str,
                    similarity_key: str, scale: tuple):
    sent1_args = (examples[sentence1_key], examples[condition_key])
    sent2_args = (examples[sentence2_key], examples[condition_key])
    sent1_res = tokenizer(*sent1_args, truncation=True)
    sent2_res = tokenizer(*sent2_args, truncation=True)
    for idx in [2, ]:
        for key in sent2_res.keys():
            sent1_res[key + '_' + str(idx)] = sent2_res[key]
    sent1_res['labels'] = scale_to_range(examples[similarity_key], scale)
    return sent1_res

# dataset
def str_if_contain_in_str_list(one_str:str, str_list:Iterable[str], mode:str='contain'):
    """
    check if one_str in str_list
    :param one_str:
    :param str_list:
    :param mode:
    :return:
    """
    for one_str_in_list in str_list:
        if mode=='contain':
            if one_str_in_list in one_str:
                return True
        elif mode=='contained':
            if one_str in one_str_in_list:
                return True
        else:
            raise ValueError(f'mode {mode} not recognized')
    return False

def listdict_map_dictlist(listdict:Optional[List[Dict[str, Any]]]=None,
                          dictlist:Optional[Dict[str, List[Any]]]=None):
    """
    listdict: [{"a":1, "b":2}, {"a":3, "b":4}]
    dictlist: {"a":[1,3], "b":[2,4]}
    :param listdict:
    :param dictlist:
    :return:
    """
    result = None
    if listdict is not None:
        result = {}
        for one_dict in listdict:
            for key, value in one_dict.items():
                if key not in result.keys():
                    result[key] = []
                result[key].append(value)
    elif dictlist is not None:
        result = []
        for key, value_list in dictlist.items():
            for idx, value in enumerate(value_list):
                if idx >= len(result):
                    result.append({})
                result[idx][key] = value
    return result

@dataclass
class DataCollatorWithPaddingForMultiRenameInputs:
    group_feature_names: List[List[str]]
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"
    model_input_names: tuple[str] = ("input_ids", "token_type_ids", "attention_mask")

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        all_input_keys = features[0].keys()
        features: Dict[str, List[Any]] = listdict_map_dictlist(features,None)
        other_keys_not_in_group_feature_names = []
        for input_key in all_input_keys:
            is_in_group = False
            for one_group_keys in self.group_feature_names:
                if input_key in one_group_keys:
                    is_in_group = True
                    break
            if not is_in_group:
                if str_if_contain_in_str_list(input_key, self.model_input_names, mode='contain'):
                    pass # although not in group_feature_names, we really need to pad them
                else:
                    other_keys_not_in_group_feature_names.append(input_key)

        batch = {}
        for other_key in other_keys_not_in_group_feature_names:
            batch[other_key] = features[other_key]
        for gid, one_group_keys in enumerate(self.group_feature_names):
            #  first, find group_key contain input_ids
            one_group_input_ids = []
            for group_key in one_group_keys:
                if self.model_input_names[0] in group_key:
                    one_group_input_ids.append(group_key)
            count_input_ids = len(one_group_input_ids)

            # second, expand complete group keys
            complete_one_group_keys = one_group_input_ids.copy()
            for input_ids_var in one_group_input_ids:
                for stand_name in self.model_input_names[1:]:
                    complete_one_group_keys.append(input_ids_var.replace(self.model_input_names[0], stand_name))

            # third, intersection really input
            complete_one_group_keys = set(complete_one_group_keys).intersection(all_input_keys)
            complete_one_group_keys = sorted(list(complete_one_group_keys))

            # forth, construct stand_name_list map original_name_list
            original_name_list = complete_one_group_keys.copy()
            stand_name_list = []
            for original_name in original_name_list:
                for stand_name in self.model_input_names:
                    if stand_name in original_name:
                        stand_name_list.append(stand_name)
                        break
            original_to_stand_map = dict(zip(original_name_list, stand_name_list))

            # fifth, concat multiInput
            original_name_features: Dict[str, List[Any]] = {}
            for original_name in original_name_list:
                original_name_features[original_name] = features[original_name]
            group_features: Dict[str, List[Any]] = {key:[] for key in set(stand_name_list)}
            for original_name in original_name_list:
                group_features[original_to_stand_map[original_name]].extend(original_name_features[original_name])

            # sixth, pad
            group_batch = self.tokenizer.pad(group_features,
                                             padding=self.padding,
                                             max_length=self.max_length,
                                             pad_to_multiple_of=self.pad_to_multiple_of,
                                             return_tensors=self.return_tensors)

            # chunk
            #if count_input_ids > 1:
            per_chunk_size = group_batch[self.model_input_names[0]].shape[0] // count_input_ids
            for stand_name in group_batch.keys():
                corr_original_name_list = []
                for original_name in original_name_list:
                    if stand_name in original_name:
                        corr_original_name_list.append(original_name)
                corr_original_name_features_list = []
                for idx in range(count_input_ids): #2  0/1
                    corr_original_name_features_list.append(group_batch[stand_name][idx*per_chunk_size:(idx+1)*per_chunk_size])
                batch.update(dict(zip(corr_original_name_list, corr_original_name_features_list)))
        return BatchEncoding(batch, tensor_type=self.return_tensors)


tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/sup-simcse-roberta-base')
label_unique = raw_datasets.unique('label')
all_labels = set(label_unique['train']+label_unique['validation'])
scale = (min(all_labels),max(all_labels))
shuffle_datasets = raw_datasets.sort(['sentence1','label'], reverse=[False, True])['train'].batch(2).shuffle(42)
trans_datasets = shuffle_datasets.map(unbatch, batched=True).map(preprocess_func, batched=True, fn_kwargs={
    'tokenizer': tokenizer,
    'sentence1_key': 'sentence1',
    'sentence2_key': 'sentence2',
    'condition_key': 'condition',
    'similarity_key': 'label',
    'scale': scale
}, remove_columns=raw_datasets['train'].column_names)
collate_fn = DataCollatorWithPaddingForMultiRenameInputs(
    [['input_ids','input_ids_2']],
    tokenizer, padding='longest', max_length=512, return_tensors='pt')
dl = DataLoader(trans_datasets, batch_size=16, collate_fn=collate_fn)
for ba in dl:
    for key in ba.keys():
        print(key, ba[key])
    break
    print('='*20)

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


labels tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.7500, 0.7500, 0.7500, 0.5000,
        0.0000, 0.7500, 0.2500, 1.0000, 0.7500, 0.7500, 0.5000])
attention_mask tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [31]:
def scale_to_range(labels:List, scale:tuple):
    min_, max_ = scale
    return list(map(lambda x: (x - min_) / (max_ - min_), labels))
    
def get_preprocess_func(tokenizer:PreTrainedTokenizerBase, sentence1_key:str, sentence2_key:str, condition_key:str, similarity_key:str, scale:tuple):
    
    def preprocess_func(examples):
        sent1_args = (examples[sentence1_key], examples[condition_key])
        sent2_args = (examples[sentence2_key], examples[condition_key])
        sent1_res = tokenizer(*sent1_args, truncation=True)
        sent2_res = tokenizer(*sent2_args, truncation=True)
        for idx in [2,]:
            for key in sent2_res.keys():
                sent1_res[key+'_'+str(idx)] = sent2_res[key]
        sent1_res['labels'] = scale_to_range(examples[similarity_key], scale)
        return sent1_res
    return preprocess_func

def add_prefix(examples, prefix:str, columns:Optional[Union[str, List[str]]]=None):
    if isinstance(columns, str):
        examples[prefix+'_'+columns] = examples[columns]
        return examples
    elif columns is None or columns == 'all':
        columns = list(examples.keys())
    for key in columns:
        examples[prefix+'_'+key] = examples[key]
    return examples

def concat_pos_and_neg_datasets(datasets:Dataset):
    pos = datasets.shard(2,0).map(add_prefix, batched=True, remove_columns=datasets.column_names, fn_kwargs={'prefix':'pos', 'columns':None})
    neg = datasets.shard(2,1).map(add_prefix, batched=True, remove_columns=datasets.column_names, fn_kwargs={'prefix':'neg', 'columns':None})
    new_datasets = concatenate_datasets([pos, neg], axis=1)
    return new_datasets

seed = 42
transformers.enable_full_determinism(seed)
datasets.disable_caching()
datasets.disable_progress_bars()
raw_datasets = load_dataset("csv", data_files={'train': 'data/csts_train.csv', 'test': 'data/csts_test.csv', 'validation': 'data/csts_validation.csv'})

tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/sup-simcse-roberta-base')
label_unique = raw_datasets.unique('label')
all_labels = set(label_unique['train']+label_unique['validation'])
scale = (min(all_labels),max(all_labels))
trans_datasets = raw_datasets.sort(['sentence1','label'], reverse=[False, True]).map(get_preprocess_func(tokenizer, 'sentence1', 'sentence2','condition','label', scale), batched=True, remove_columns=raw_datasets['train'].column_names)

new_trans_datasets = {}
for key in trans_datasets.keys():
    new_trans_datasets[key] = concat_pos_and_neg_datasets(trans_datasets[key])
new_trans_datasets = DatasetDict(new_trans_datasets)
new_trans_datasets = new_trans_datasets.shuffle(seed)
new_trans_datasets['train']

Dataset({
    features: ['pos_input_ids', 'pos_attention_mask', 'pos_input_ids_2', 'pos_attention_mask_2', 'pos_labels', 'neg_input_ids', 'neg_attention_mask', 'neg_input_ids_2', 'neg_attention_mask_2', 'neg_labels'],
    num_rows: 5671
})