In [None]:
!pip install datasets transformers accelerate evaluate einops sacrebleu

In [2]:
import os, re, random
import ast
import collections
import yaml
import numpy as np
import sklearn
from tqdm import tqdm
from dataclasses import asdict, dataclass, field
from typing import Union, Literal, Iterable, Tuple, Callable, Optional, List, Dict, Any
import math
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import transformers
import datasets
import evaluate
from accelerate import Accelerator
import sacrebleu

[Original Repo](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/evaluator.py#L45)

# 1) LLM BenchMark

## 1.1 Model Loading

In [47]:
# set up
model_name = 'microsoft/phi-2'
gpus = torch.cuda.device_count()
accelerator = Accelerator()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# load tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name, revision='main', trust_remote_code=True, use_fast=True)

# load model
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, revision='main', trust_remote_code=True, torch_dtype='auto').to(device)
model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/40 [10:26<?, ?it/s]
  0%|          | 0/40 [08:55<?, ?it/s]
  0%|          | 0/40 [08:10<?, ?it/s]
  0%|          | 0/40 [07:39<?, ?it/s]
  0%|          | 0/40 [07:10<?, ?it/s]
  0%|          | 0/40 [06:39<?, ?it/s]
  0%|          | 0/40 [05:43<?, ?it/s]
  0%|          | 0/40 [04:25<?, ?it/s]


PhiForCausalLM(
  (transformer): PhiModel(
    (embd): Embedding(
      (wte): Embedding(51200, 2560)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (h): ModuleList(
      (0-31): 32 x ParallelBlock(
        (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
        (mixer): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): Linear(in_features=2560, out_features=7680, bias=True)
          (out_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (inner_attn): SelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): CrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (mlp): MLP(
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
          (act): NewGELUActivation()
        )
      )
    )
  )
  (lm

In [4]:
truncation = False
vocab_size = tokenizer.vocab_size
max_length = 1024
max_batch_size = 2

# select (or create) a pad token to use
if tokenizer.pad_token:
    pass
elif tokenizer.unk_token:
    tokenizer.pad_token_id = tokenizer.unk_token_id
elif tokenizer.eos_token:
    tokenizer.pad_token_id = tokenizer.eos_token_id
else:
    if model.config.model_type == "qwen":
        tokenizer.pad_token = "<|endoftext|>"
    else:
        tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

## 1.2 Task Config

In [6]:
@dataclass
class TaskConfig(dict):
    task: str = None
    task_alias: str = None
    group: Union[str, list] = None
    group_alias: Union[str, list] = None
    dataset_path: str = None
    dataset_name: str = None
    dataset_kwargs: dict = None
    training_split: str = None
    validation_split: str = None
    test_split: str = None
    fewshot_split: str = None
    process_docs: Callable = None
    doc_to_text: Union[Callable, str] = None
    doc_to_target: Union[Callable, str] = None
    doc_to_choice: Union[Callable, str, dict, list] = None
    process_results: Union[Callable, str] = None
    use_prompt: str = None
    description: str = ""
    target_delimiter: str = " "
    fewshot_delimiter: str = "\n\n"
    fewshot_config: dict = None
    num_fewshot: int = None
    metric_list: list = None
    output_type: str = "generate_until"
    generation_kwargs: dict = None
    repeats: int = 1 # no. times instance in a dataset is inferred on. (increase for MV)
    filter_list: Union[str, list] = None
    should_decontaminate: bool = False
    doc_to_decontamination_query: str = None
    metadata: Union[str, list] = None

    def __post_init__(self) -> None:
        if self.generation_kwargs is not None:

            if "temperature" in self.generation_kwargs:
                self.generation_kwargs["temperature"] = float(
                    self.generation_kwargs["temperature"]
                )

            if "until" not in self.generation_kwargs:
                self.generation_kwargs["until"] = [self.fewshot_delimiter]
        else:
            if self.output_type == "generate_until":
                # ensure that we greedily generate in absence of explicit arguments otherwise
                self.generation_kwargs = {
                    "until": None if self.fewshot_delimiter is None else [self.fewshot_delimiter],
                    "do_sample": False,
                }

    def __getitem__(self, item):
        return getattr(self, item)

    def __setitem__(self, item, value):
        return setattr(self, item, value)

    def to_dict(self):
        """dumps the current config as a dictionary object, as a printable format.
        """
        cfg_dict = asdict(self)
        # remove null values
        for k, v in list(cfg_dict.items()):
            if v is None:
                cfg_dict.pop(k)
            elif isinstance(v, Callable):
                cfg_dict[k] = str(v)
        return cfg_dict

In [7]:
### test ###
def preprocess(text):
    text = text.strip()
    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
    text = text.replace(" [title]", ". ")
    text = re.sub("\\[.*?\\]", "", text)
    text = text.replace("  ", " ")
    return text

def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
    def _process_doc(doc):
        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
        out_doc = {
            "query": preprocess(doc["activity_label"] + ": " + ctx),
            "choices": [preprocess(ending) for ending in doc["endings"]],
            "gold": int(doc["label"]),
        }
        return out_doc
    return dataset.map(_process_doc)

config = dict(
    group=['multiple_choice'],
    task='hellaswag',
    dataset_path='hellaswag',
    dataset_name=None,
    output_type='multiple_choice',
    training_split='train',
    validation_split='validation',
    test_split=None,
    process_docs=process_docs,
    doc_to_text="{{query}}",
    doc_to_target="{{label}}",
    doc_to_choice="choices",
    metric_list=[
        dict(metric='acc', aggregation='mean', higher_is_better=True),
        dict(metric='acc_norm', aggregation='mean', higher_is_better=True),
    ],
    metadata=dict(version=1.0),
)

task_config = TaskConfig(**config)
print(task_config)

TaskConfig(task='hellaswag', task_alias=None, group=['multiple_choice'], group_alias=None, dataset_path='hellaswag', dataset_name=None, dataset_kwargs=None, training_split='train', validation_split='validation', test_split=None, fewshot_split=None, process_docs=<function process_docs at 0x7ca37036dab0>, doc_to_text='{{query}}', doc_to_target='{{label}}', doc_to_choice='choices', process_results=None, use_prompt=None, description='', target_delimiter=' ', fewshot_delimiter='\n\n', fewshot_config=None, num_fewshot=None, metric_list=[{'metric': 'acc', 'aggregation': 'mean', 'higher_is_better': True}, {'metric': 'acc_norm', 'aggregation': 'mean', 'higher_is_better': True}], output_type='multiple_choice', generation_kwargs=None, repeats=1, filter_list=None, should_decontaminate=False, doc_to_decontamination_query=None, metadata={'version': 1.0})


## 1.3 Metric

In [251]:
### helper function for metric ###
DEFAULT_METRIC_REGISTRY = {
    "loglikelihood": ["perplexity", "acc"],
    "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
    "multiple_choice": ["acc", "acc_norm"],
    "generate_until": ["exact_match"],
}

OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}

### aggregation metric ###
def get_metric_aggregation(name):
    return METRIC_AGGREGATION_REGISTRY[name]

def get_aggregation(name):
    return AGGREGATION_REGISTRY[name]

def register_aggregation(name):
    def decorate(fn):
        AGGREGATION_REGISTRY[name] = fn
        return fn
    return decorate

@register_aggregation("mean")
def mean(arr):
    return sum(arr) / len(arr)

@register_aggregation("median")
def median(arr):
    return arr[len(arr) // 2]

@register_aggregation("perplexity")
def perplexity(items):
    return math.exp(-mean(items))

def weighted_mean(items):
    a, b = zip(*items)
    return sum(a) / sum(b)

@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
    return math.exp(-weighted_mean(items))

@register_aggregation("bits_per_byte")
def bits_per_byte(items):
    return -weighted_mean(items) / math.log(2)

@register_aggregation("f1")
def f1_score(items):
    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    fscore = sklearn.metrics.f1_score(golds, preds)
    return np.max(fscore)

@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    # print(preds)
    return sklearn.metrics.matthews_corrcoef(golds, preds)

def is_non_str_iterable(obj):
    return isinstance(obj, Iterable) and not isinstance(obj, str)

def _sacreformat(refs, preds):
    if not is_non_str_iterable(refs):
        refs = list(refs)
    if not is_non_str_iterable(refs[0]):
        refs = [[ref] for ref in refs]
    refs = list(zip(*refs))
    if not is_non_str_iterable(preds):
        preds = list(preds)
    if is_non_str_iterable(preds[0]):
        assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
        preds = [pred[0] for pred in preds]

    return refs, preds

@register_aggregation("bleu")
def bleu(items):
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
    refs, preds = _sacreformat(refs, preds)
    return sacrebleu.corpus_bleu(preds, refs).score

@register_aggregation("chrf")
def chrf(items):
    """
    chrF++ is a tool for automatic evaluation of machine translation output
    based on character n-gram precision and recall enhanced with word n-grams.
    """
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
    refs, preds = _sacreformat(refs, preds)
    return sacrebleu.corpus_chrf(preds, refs).score

@register_aggregation("ter")
def ter(items):
    """Translation Error Rate is an error metric for machine translation that
    measures the number of edits required to change a system output into one
    of the references
    """
    refs = list(zip(*items))[0]
    preds = list(zip(*items))[1]
    refs, preds = _sacreformat(refs, preds)
    return sacrebleu.corpus_ter(preds, refs).score

def stderr_for_metric(metric, bootstrap_iters):
    bootstrappable = [
        median,
        matthews_corrcoef,
        f1_score,
        perplexity,
        bleu,
        chrf,
        ter,
    ]

    if metric in bootstrappable:
        return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)

    stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
    return stderr.get(metric, None)

In [9]:
def get_metric(name, hf_evaluate_metric=False):
    if not hf_evaluate_metric:
        if name in METRIC_REGISTRY:
            return METRIC_REGISTRY[name]
        else:
            print(f"Could not find registered metric '{name}")
    else:
        # import metric from hf evaluate
        metric_object = evaluate.load(name)
        return metric_object.compute

def register_metric(**args):
    def decorate(fn):
        name = args["metric"]

        for key, registry in [
            ("metric", METRIC_REGISTRY),
            ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
            ("aggregation", METRIC_AGGREGATION_REGISTRY)
        ]:

            if key in args:
                value = args[key]
                if key == "metric":
                    registry[name] = fn
                elif key == "aggregation":
                    registry[name] = AGGREGATION_REGISTRY[value]
                else:
                    registry[name] = value
        return fn
    return decorate

@register_metric(
    metric="acc",
    higher_is_better=True,
    output_type=["loglikelihood", "multiple_choice"],
    aggregation="mean",
)
def acc_fn(items):
    return items

@register_metric(
    metric="acc_norm",
    higher_is_better=True,
    output_type=["loglikelihood", "multiple_choice"],
    aggregation="mean",
)
def acc_norm_fn(items):
    return items

@register_metric(
    metric="acc_mutual_info",
    higher_is_better=True,
    output_type="multiple_choice",
    aggregation="mean",
)
def acc_mutual_info_fn(items):
    return items

exact_match = evaluate.load("exact_match")
@register_metric(
    metric="exact_match",
    higher_is_better=True,
    output_type="generate_until",
    aggregation="mean",
)
def exact_match_fn(**kwargs):
    return exact_match.compute(**kwargs)

@register_metric(
    metric="perplexity",
    higher_is_better=False,
    output_type="loglikelihood",
    aggregation="perplexity",
)
def perplexity_fn(items):
    return items

@register_metric(
    metric="word_perplexity",
    higher_is_better=False,
    output_type="loglikelihood_rolling",
    aggregation="weighted_perplexity",
)
def word_perplexity_fn(items):
    return items


@register_metric(
    metric="byte_perplexity",
    higher_is_better=False,
    output_type="loglikelihood_rolling",
    aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items):
    return items

@register_metric(
    metric="bits_per_byte",
    higher_is_better=False,
    output_type="loglikelihood_rolling",
    aggregation="bits_per_byte",
)
def bits_per_byte_fn(items):
    return items

def pop_stddev(arr):
    mu = mean(arr)
    return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))

def sample_stddev(arr):
    mu = mean(arr)
    return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))

def mean_stderr(arr):
    return sample_stddev(arr) / math.sqrt(len(arr))

@register_metric(
    metric="mcc",
    higher_is_better=True,
    output_type="multiple_choice",
    aggregation="matthews_corrcoef",
)
def mcc_fn(items):
    return items

@register_metric(
    metric="f1",
    higher_is_better=True,
    output_type="multiple_choice",
    aggregation="f1",
)
def f1_fn(items):
    return items

@register_metric(
    metric="bleu",
    higher_is_better=True,
    output_type="generate_until",
    aggregation="bleu",
)
def bleu_fn(items):
    return items

@register_metric(
    metric="chrf",
    higher_is_better=True,
    output_type="generate_until",
    aggregation="chrf",
)
def chrf_fn(items):
    return items

@register_metric(
    metric="ter",
    higher_is_better=True,
    output_type="generate_until",
    aggregation="ter",
)
def ter_fn(items):
    return items

@register_metric(
    metric="acc_all",
    higher_is_better=True,
    output_type="loglikelihood",
    aggregation="mean",
)
def acc_all(items):
    question_scoring_dict = {}
    preds = list(zip(*items))[0]
    docs = list(zip(*items))[1]

    for doc, pred in zip(docs, preds):
        paragraph_id = doc["idx"]["paragraph"]
        question_id = doc["idx"]["question"]
        if (paragraph_id, question_id) not in question_scoring_dict:
            question_scoring_dict[(paragraph_id, question_id)] = []

        gold_label = doc["label"] == 1

        question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
    acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
    return acc

def acc_all_stderr(items):
    # Only count as correct if all answers are labeled correctly for each question
    question_scoring_dict = {}
    preds = list(zip(*items))[0]
    docs = list(zip(*items))[1]

    for doc, pred in zip(docs, preds):
        question_id = doc["idx"]["question"]
        if question_id not in question_scoring_dict:
            question_scoring_dict[question_id] = []

        gold_label = doc["label"] == 1
        question_scoring_dict[question_id].append(gold_label == pred)

    acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
    return acc

Downloading builder script:   0%|          | 0.00/5.67k [00:00<?, ?B/s]

In [10]:
def is_higher_better(metric_name):
    return HIGHER_IS_BETTER_REGISTRY[metric_name]

def process_metric(config):
    _metric_fn_list = {}
    _metric_fn_kwargs = {}
    _aggregation_list = {}
    _higher_is_better = {}

    # if metric not defined, then take default list
    if config.metric_list is None:
        _metric_list = DEFAULT_METRIC_REGISTRY[config.output_type]
        for metric_name in _metric_list:
            _metric_fn_list[metric_name] = get_metric(metric_name)
            _metric_fn_kwargs[metric_name] = {}
            _aggregation_list[metric_name] = get_metric_aggregation(metric_name)
            _higher_is_better[metric_name] = is_higher_better(metric_name)

    else:
        for metric_config in config.metric_list:
            metric_name = metric_config["metric"]
            kwargs = {
                key: metric_config[key] for key in metric_config
                    if key not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
            }
            hf_evaluate_metric = (
                "hf_evaluate" in metric_config and metric_config["hf_evaluate"] is True
            )

            if config.process_results is not None:
                _metric_fn_list[metric_name] = None
                _metric_fn_kwargs[metric_name] = {}

            elif callable(metric_name):
                metric_fn = metric_name.__call__
                metric_name = metric_name.__name__
                _metric_fn_list[metric_name] = metric_fn
                _metric_fn_kwargs[metric_name] = kwargs

            else:
                _metric_fn_list[metric_name] = get_metric(
                    metric_name, hf_evaluate_metric)
                _metric_fn_kwargs[metric_name] = kwargs

            if "aggregation" in metric_config:
                agg_name = metric_config["aggregation"]
                if isinstance(agg_name, str):
                    _aggregation_list[metric_name] = get_aggregation(agg_name)
                elif callable(agg_name):
                    _aggregation_list[metric_name] = metric_config[
                        "aggregation"
                    ]
            else:
                INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
                metric_agg = get_metric_aggregation(metric_name)
                _aggregation_list[metric_name] = metric_agg

            if "higher_is_better" in metric_config:
                _higher_is_better[metric_name] = metric_config["higher_is_better"]
            else:
                _higher_is_better[metric_name] = is_higher_better(metric_name)

    return {
        'metric_fn_list': _metric_fn_list,
        'aggregation_list' : _aggregation_list,
        'higher_is_better' : _higher_is_better
    }

### test ###
metric_dict = process_metric(task_config)
print(metric_dict)

{'metric_fn_list': {'acc': <function acc_fn at 0x7ca378132440>, 'acc_norm': <function acc_norm_fn at 0x7ca378132710>}, 'aggregation_list': {'acc': <function mean at 0x7ca37036dfc0>, 'acc_norm': <function mean at 0x7ca37036dfc0>}, 'higher_is_better': {'acc': True, 'acc_norm': True}}


## 1.4 Filter Operation

In [11]:
class Filter:
    def __init__(self, *args, **kwargs) -> None:
        pass

    def apply(self, resps, docs):
        return resps

@dataclass
class FilterEnsemble:
    name: str
    filters: List[Filter]

    def apply(self, instances, docs):
        resps = [inst.resps for inst in instances]
        for f in self.filters:
            resps = f.apply(resps, docs)
        for inst, resp in zip(instances, resps):
            inst.filtered_resps[self.name] = resp

def get_filter(filter_name):
    if filter_name in FILTER_REGISTRY:
        return FILTER_REGISTRY[filter_name]
    else:
        return filter_name

In [12]:
### selection ###
class TakeFirstFilter(Filter):
    def __init__(self) -> None:
        pass

    def apply(self, resps, docs):
        return map(lambda r: r[0], resps)

class TakeKFilter(Filter):
    def __init__(self, *args, **kwargs) -> None:
        self.k = kwargs.pop("k")
        super().__init__(*args, **kwargs)

    def apply(self, resps, docs):
        return map(lambda r: r[: self.k], resps)

class MajorityVoteFilter(Filter):
    def __init__(self) -> None:
        pass

    def apply(self, resps, docs):
        def select_majority(resp):
            counts = collections.Counter(resp)
            vote = counts.most_common(1)[0][0]
            return vote
        return map(lambda r: [select_majority(r)], resps)

### extraction ###
class RegexFilter(Filter):
    def __init__(
        self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]"):
        """
        pass a string `regex` to run `re.compile(r"regex")` on.
        `fallback` defines the output returned if no matches for the regex are located.
        """
        self.regex_pattern = regex_pattern
        self.regex = re.compile(regex_pattern)
        self.fallback = fallback

    def apply(self, resps, docs):
        def filter_set(inst):
            filtered = []
            for resp in inst:
                match = self.regex.search(resp)
                if match:
                    match = match.group(1).strip()
                else:
                    match = self.fallback
                filtered.append(match)
            return filtered

        filtered_resps = list(map(lambda x: filter_set(x), resps))
        return filtered_resps

class WhitespaceFilter(Filter):
    def __init__(self) -> None:
        pass

    def apply(self, resps, docs):
        def filter_set(inst):
            filtered_resp = []
            for resp in inst:
                if resp.startswith(" "):
                    resp = resp[1:]

                filtered_resp.append(resp)

            return filtered_resp

        filtered_resps = [filter_set(resp) for resp in resps]
        return filtered_resps

### transformation ###
class LowercaseFilter(Filter):
    def __init__(self) -> None:
        pass

    def apply(self, resps, docs):
        def filter_set(inst):
            return [resp.lower() for resp in inst]
        return [filter_set(resp) for resp in resps]

class UppercaseFilter(Filter):
    def __init__(self) -> None:
        pass

    def apply(self, resps, docs):
        def filter_set(inst):
            return [resp.upper() for resp in inst]

        return [filter_set(resp) for resp in resps]

class MapFilter(Filter):
    def __init__(self, mapping_dict: dict = {}, default_value=None) -> None:
        self.mapping_dict = mapping_dict
        self.default_value = default_value

    def apply(self, resps, docs):
        def filter_set(inst):
            return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
        return [filter_set(resp) for resp in resps]


FILTER_REGISTRY = {
    "take_first": TakeFirstFilter,
    "regex": RegexFilter,
    "majority_vote": MajorityVoteFilter,
    "take_first_k": TakeKFilter,
    "remove_whitespace": WhitespaceFilter,
    "lowercase": LowercaseFilter,
    "uppercase": UppercaseFilter,
    "map": MapFilter,
}

In [13]:
### filter operation ###
def build_filter_ensemble(filter_name, components):
    filters = []
    for function, kwargs in components:
        if kwargs is None:
            f = get_filter(function)()
        else:
            f = get_filter(function)(**kwargs)
        filters.append(f)
    return FilterEnsemble(name=filter_name, filters=filters)

def create_filter_ops(config):
    if config.filter_list is not None:
        _filters = []
        for filter_config in config.filter_list:
            for filter_pipeline in filter_config:
                filter_name = filter_config["name"]
                filter_functions = filter_config["filter"]
                components = []
                for function in filter_functions:
                    kwargs = {key: function[key] for key in function if key != "function"}
                    components.append([function["function"], kwargs])
                filter_pipeline = build_filter_ensemble(filter_name, components)
            _filters.append(filter_pipeline)
    else:
        _filters = [build_filter_ensemble("none", [["take_first", None]])]

    return _filters

### test ###
filter_ops = create_filter_ops(task_config)
print(filter_ops)

[FilterEnsemble(name='none', filters=[<__main__.TakeFirstFilter object at 0x7ca3703badd0>])]


## 1.8 Data Loading

In [14]:
### download dataset ###
def download(
    DATASET_PATH, DATASET_NAME, data_dir=None, cache_dir=None,
    download_mode=None, dataset_kwargs=None):

    dataset = datasets.load_dataset(
        path=DATASET_PATH,
        name=DATASET_NAME,
        data_dir=data_dir,
        cache_dir=cache_dir,
        download_mode=download_mode,
        **dataset_kwargs if dataset_kwargs is not None else {},
    )
    return dataset

### test ###
dataset = download(task_config.dataset_path, task_config.dataset_name)
print(dataset)

if task_config.process_docs is not None:
    docs = task_config.process_docs(dataset[task_config.training_split])
else:
    docs = dataset[task_config.training_split]
train_docs = list(docs)
features = list(docs.features.keys())
print(features)
print(train_docs[0])

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/4.36k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.53k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.1M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.04M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.14M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
        num_rows: 39905
    })
    test: Dataset({
        features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
        num_rows: 10003
    })
    validation: Dataset({
        features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
        num_rows: 10042
    })
})


Map:   0%|          | 0/39905 [00:00<?, ? examples/s]

['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label', 'query', 'choices', 'gold']
{'ind': 4, 'activity_label': 'Removing ice from car', 'ctx_a': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.', 'ctx_b': 'then', 'ctx': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then', 'endings': [', the man adds wax to the windshield and cuts it.', ', a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.', ', the man puts on a christmas coat, knitted with netting.', ', the man continues removing the snow on his car.'], 'source_id': 'activitynet~v_-1IBHYS3L-Y', 'split': 'train', 'split_type': 'indomain', 'label': '3', 'query': 'Removing ice from car: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. 

## 1.8 Data Processing (Instance)

In [205]:
@dataclass
class Instance:
    request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"]
    doc: dict
    arguments: tuple
    idx: int
    metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None))
    resps: list = field(default_factory=list)
    filtered_resps: dict = field(default_factory=dict)
    task_name: str = None
    doc_id: str = None
    repeats: str = None

    def __post_init__(self) -> None:
        self.task_name, self.doc_id, self.repeats = self.metadata

    @property
    def args(self):
        """Returns (string,) where `string` is the string to calculate loglikelihood over"""
        return (self.arguments if isinstance(self.arguments, tuple) else (self.arguments,))

def regex_replace(text, pattern, replacement):
    return re.sub(pattern, replacement, text)

def apply_template(template: str, doc: dict) -> str:
    placeholders = re.findall(r'{{(.*?)}}', template)
    for placeholder in placeholders:
        if placeholder in doc:
            template = template.replace(f'{{{{{placeholder}}}}}', str(doc[placeholder]))

    return regex_replace(template, r"\s+", " ").strip()

def doc_to_text(doc, config, features, prompt=None):
    """extract the key column as mentioned in config"""

    if prompt is not None:
        doc_to_text = prompt
    else:
        doc_to_text = config.doc_to_text

    if isinstance(doc_to_text, int):
        return doc_to_text
    elif isinstance(doc_to_text, str):
        if doc_to_text in features:
            return doc[doc_to_text]
        else:
            text_string = apply_template(doc_to_text, doc)
            if text_string.isdigit() and config.doc_to_choice is not None:
                return ast.literal_eval(text_string)
            else:
                return text_string

    elif callable(doc_to_text):
        return doc_to_text(doc)

    elif hasattr(doc_to_text, "apply"):
        applied_prompt = doc_to_text.apply(doc)
        if len(applied_prompt) == 2:
            return applied_prompt[0]
        else:
            return config.fewshot_delimiter
    else:
        print(type(doc_to_text))
        raise TypeError

def doc_to_target(doc, config, prompt, features):
        if prompt is not None:
            doc_to_target = prompt
        else:
            doc_to_target = config.doc_to_target

        if isinstance(doc_to_target, int):
            return doc_to_target
        elif isinstance(doc_to_target, str):
            if doc_to_target in features:
                return doc[doc_to_target]
            else:
                target_string = apply_template(doc_to_target, doc)
                if target_string.isdigit() and config.doc_to_choice is not None:
                    return ast.literal_eval(target_string)
                elif (
                    len(target_string) >= 2
                    and (target_string[0] == "[")
                    and (target_string[-1] == "]")
                ):
                    try:
                        return ast.literal_eval(target_string)
                    except (SyntaxError, ValueError):
                        return target_string
                else:
                    return target_string

        elif isinstance(doc_to_target, list):
            return doc_to_target
        elif callable(doc_to_target):
            return doc_to_target(doc)

        # Used when applying a Promptsource template
        elif hasattr(doc_to_target, "apply"):
            applied_prompt = doc_to_target.apply(doc)
            if len(applied_prompt) == 2:
                return applied_prompt[1]
            else:
                return config.fewshot_delimiter
        else:
            raise TypeError

def doc_to_choice(doc, config, prompt, features):
        if prompt is not None:
            doc_to_choice = prompt
        elif config.doc_to_choice is None:
            pass
        else:
            doc_to_choice = config.doc_to_choice

        if isinstance(doc_to_choice, str):
            if doc_to_choice in features:
                return doc[doc_to_choice]
            else:
                return ast.literal_eval(apply_template(doc_to_choice, doc))

        elif isinstance(doc_to_choice, list):
            return doc_to_choice
        elif isinstance(doc_to_choice, dict):
            return list(doc_to_choice.values())
        elif callable(doc_to_choice):
            return doc_to_choice(doc)
        elif hasattr(doc_to_choice, "get_answer_choices_list"):
            return doc_to_choice.get_answer_choices_list(doc)
        else:
            raise TypeError

In [206]:
def fewshot_context(doc, num_fewshot, config, features, prompt=None, description=None):

        description = description if description else ""
        if num_fewshot == 0:
            labeled_examples = ""
        # else:
        #     if config.training_split is not None:
        #         fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
        #     else:
        #         if self._fewshot_docs is None:
        #             self._fewshot_docs = list(
        #                 self.validation_docs()
        #                     if self.has_validation_docs() else self.test_docs()
        #             )

        #         fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
        #         fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]

        #     labeled_examples = (
        #         "\n\n".join(
        #             [
        #                 self.doc_to_text(doc) + self.doc_to_target(doc)
        #                 for doc in fewshotex
        #             ]
        #         )
        #         + "\n\n"
        #     )

        # extract the target text
        example = doc_to_text(doc, config, features, prompt)
        return description + labeled_examples + example

def construct_requests(doc, ctx, config, prompt, features, multiple_input=0, metric_dict={}, **kwargs):

        if config.output_type == "loglikelihood":
            arguments = (ctx, doc_to_target(doc, config, prompt, features))
        elif config.output_type == "loglikelihood_rolling":
            arguments = (doc_to_target(doc, config, prompt, features),)
        elif config.output_type == "multiple_choice":
            # return list of choices
            choices = doc_to_choice(doc, config, prompt, features)
            target_delimiter = config.target_delimiter

            if multiple_input:
                # If there are multiple inputs, choices are placed in the ctx
                cont = doc_to_target(doc, config, prompt, features)
                arguments = [(ctx, f"{target_delimiter}{cont}") for ctx in choices]
            else:
                # Otherwise they are placed in the continuation
                # pair prompt and each choice
                arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]

            # arguments = [(prompt, answer1), (prompt, answer2)]
            request_list = [
                Instance(
                    request_type="loglikelihood",
                    doc=doc,
                    arguments=arg,
                    idx=i,
                    **kwargs,
                )
                for i, arg in enumerate(arguments)
            ]

            if "acc_mutual_info" in metric_dict.get('_metric_fn_list', {}).keys():
                # if we are calculating multiple choice accuracy
                # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
                # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
                request_list.extend(
                    [
                        Instance(
                            request_type="loglikelihood",
                            doc=doc,
                            arguments=("", "{}".format(choice)),
                            idx=i,
                            **kwargs,
                        )
                        for i, choice in enumerate(choices)
                    ]
                )
            return request_list

        elif config.output_type == "generate_until":
            arguments = (ctx, config.generation_kwargs)

        return Instance(request_type=config.output_type, doc=doc, arguments=arguments, idx=0, **kwargs)

def build_all_requests(docs, config, features, metric_dict={}, prompt=None):
    """
    Build a set of Instances for a task
    :param docs: batch of data with multiple key columns
    """

    instances = []
    for doc_id, doc in enumerate(docs):

        # few shot is give multiple labeled example as prompt
        # if it's zero, then labeled example is empty
        fewshot_ctx = fewshot_context(
            doc,
            0 if config.num_fewshot is None else config.num_fewshot,
            config,
            features,
            prompt
        )

        # each sentence is transformed into instance class
        inst = construct_requests(
            doc=doc,
            ctx=fewshot_ctx,
            config=config,
            prompt=prompt,
            features=features,
            metric_dict=metric_dict,
            metadata=(config["task"], doc_id, config.repeats),
        )

        if not isinstance(inst, list):
            inst = [inst]
        instances.extend(inst)

    return instances

### test###
instances = build_all_requests(train_docs[:10], task_config, features, metric_dict, prompt=None)
for ins in instances[:3]:
    attributes = vars(ins)
    for attr_name, attr_value in attributes.items():
        print(f"{attr_name}: {attr_value}")
    print('\n')

request_type: loglikelihood
doc: {'ind': 4, 'activity_label': 'Removing ice from car', 'ctx_a': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.', 'ctx_b': 'then', 'ctx': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then', 'endings': [', the man adds wax to the windshield and cuts it.', ', a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.', ', the man puts on a christmas coat, knitted with netting.', ', the man continues removing the snow on his car.'], 'source_id': 'activitynet~v_-1IBHYS3L-Y', 'split': 'train', 'split_type': 'indomain', 'label': '3', 'query': 'Removing ice from car: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. Then', 'choices': [', the man adds wax to the windshield and cuts it.', ', a person board a ski lift, 

## 1.9 Model Forward

### 1.9.1 Batching Collator

In [207]:
### collator for batching ###
class Collator:
    """A class for reordering and batching elements of an array."""

    def __init__(self, arr, sort_fn, group_fn=lambda x: x[1], grouping=False,):

        self.grouping = grouping
        self.fn = sort_fn
        self.group_fn = lambda x: group_fn(x[1])  # first index are enumerated indices
        self.reorder_indices: List = []
        self.size = len(arr)
        self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr))  # [indices, (arr)]
        if self.grouping is True:
            self.group_by_index()

    def group_by_index(self) -> None:
        self.arr_with_indices = self.group(
            self.arr_with_indices, fn=self.group_fn, values=False)

    def get_batched(self, n=1, batch_fn=None):
        """Generates and yields batches from the reordered array"""
        if self.grouping:
            for (key, values) in self.arr_with_indices.items():
                values = self._reorder(values)
                batch = self.get_chunks(values, n=n, fn=batch_fn)
                yield from batch
        else:
            values = self._reorder(self.arr_with_indices)
            batch = self.get_chunks(values, n=n, fn=batch_fn)
            yield from batch

    def _reorder(self, arr):
        """Reorders the elements in the array based on the sorting function"""
        arr = sorted(arr, key=lambda x: self.fn(x[1]))
        self.reorder_indices.extend([x[0] for x in arr])
        yield from [x[1] for x in arr]

    def get_original(self, newarr: List) -> List:
        """Restores the original order of elements from the reordered list."""
        res = [None] * self.size
        cov = [False] * self.size
        for ind, v in zip(self.reorder_indices, newarr):
            res[ind] = v
            cov[ind] = True

        assert all(cov)
        return res

    def __len__(self):
        return self.size

    @staticmethod
    def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
        """Groups elements of an iterable based on a provided function."""
        res = collections.defaultdict(list)
        for ob in arr:
            try:
                hashable_dict = tuple(
                    (key, tuple(value)
                        if isinstance(value, collections.abc.Iterable) else value,)
                            for key, value in sorted(fn(ob).items())
                )
                res[hashable_dict].append(ob)
            except TypeError:
                res[fn(ob)].append(ob)
        if not values:
            return res
        return res.values()

    @staticmethod
    def get_chunks(_iter, n: int = 0, fn=None):
        """Divides an iterable into chunks of specified size or based on a given function"""
        arr = []
        _iter = tuple(_iter)
        for i, x in enumerate(_iter):
            arr.append(x)
            if len(arr) == (fn(i, _iter) if fn else n):
                yield arr
                arr = []

        if arr:
            yield arr

### 1.9.2 Helper Function

In [208]:
batch_size = 8
max_length = 1024

In [209]:
def tok_encode(string, tokenizer, left_truncate_len=None, add_special_tokens=False):
    """
    :param add_special_tokens: False if AutoModelForCausalLM
    """
    encoding = tokenizer.encode(string, add_special_tokens=add_special_tokens)

    # left-truncate the encoded context to be at most `left_truncate_len` tokens long
    if left_truncate_len:
        encoding = encoding[-left_truncate_len:]
    return encoding

def _encode_pair(context, continuation, tokenizer):
    n_spaces = len(context) - len(context.rstrip())
    if n_spaces > 0:
        continuation = context[-n_spaces:] + continuation
        context = context[:-n_spaces]

    whole_enc = tok_encode(context + continuation, tokenizer, add_special_tokens=False)
    context_enc = tok_encode(context, tokenizer, add_special_tokens=False)

    context_enc_len = len(context_enc)
    continuation_enc = whole_enc[context_enc_len:]
    return context_enc, continuation_enc

def pad_and_concat(
    max_length, tensors,padding_side="right",):
    """padding a list of tensors given the maximum tensor length in the batch"""

    for i, tensor in enumerate(tensors):
        if len(tensor.shape) == 2:
            tensor = tensor.squeeze(0)
        tensor_len = tensor.shape[0]

        if tensor_len < max_length:
            zeros = torch.zeros(max_length - tensor_len, dtype=torch.long, device=tensor.device)

            if padding_side == "right":
                tensors[i] = torch.cat([tensor, zeros],dim=0,).unsqueeze(0)
            else:
                zeros = torch.zeros(max_length - tensor_len, dtype=torch.long, device=tensor.device)
                tensors[i] = torch.cat([zeros, tensor,], dim=0).unsqueeze(0)

        else:
            tensors[i] = tensor.unsqueeze(0)
    return torch.cat(tensors, dim=0)

def _model_call(inps, model, attn_mask=None, labels=None, is_encoder_decoder=False):
    with torch.no_grad():
        if not is_encoder_decoder:
            return model(inps).logits

def _select_cont_toks(logits, contlen=None, inplen=None, is_encoder_decoder=False):
        if not is_encoder_decoder:
            logits = logits[inplen - contlen : inplen]
        return logits

def make_disjoint_window(pair):
    """output from get_rolling_token_windows and makes the context not overlap with continuation"""
    a, b = pair
    return a[:len(a) - (len(b) - 1)], b

def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):

    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
    yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len

        yield (
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
        )
        predicted += window_pred_len

### 1.9.3 Forward: Log Likelihood

In [210]:
def _loglikelihood_tokens(
        requests, model, disable_tqdm=False, override_bs=None, is_encoder_decoder=False):

    res = []
    def _collate(x):
        """Defines the key for the sorted method"""
        toks = x[1] + x[2]
        return -len(toks), tuple(toks)

    re_ord = Collator(requests, sort_fn=_collate)

    # automatic (variable) batch size detection for vectorization
    # pull longest context sample from request
    n_reordered_requests = len(re_ord)
    chunks = re_ord.get_batched(n=batch_size, batch_fn=None)
    pbar = tqdm(total=len(requests), disable=disable_tqdm)

    for chunk in chunks:
        inps = []
        cont_toks_list = []
        inplens = []
        conts = []
        encoder_attns = []
        padding_len_inp = None
        padding_len_cont = None

        # extract encoded prompt and answer
        for _, context_enc, continuation_enc in chunk:
            # when too long to fit in context, truncate from the left
            if not is_encoder_decoder:
                # (seq_len)
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(max_length + 1):][:-1],
                    dtype=torch.long,
                    device=device)
                (inplen,) = inp.shape

            padding_len_inp = (
                max(padding_len_inp, inplen) if padding_len_inp is not None else inplen)

            # [m, inp_length]
            inps.append(inp)
            cont_toks_list.append(continuation_enc)
            inplens.append(inplen)

        # create encoder attn mask and batched conts, if seq2seq
        call_kwargs = {}
        if not is_encoder_decoder:
            batched_inps = pad_and_concat(padding_len_inp, inps, padding_side="right")

        # [m, seq_len, vocab]
        multi_logits = F.log_softmax(
            _model_call(batched_inps, model, **call_kwargs), dim=-1)

        for (cache_key, _, _), logits, inplen, cont_toks in zip(
            chunk, multi_logits, inplens, cont_toks_list):

            # slice to original seq length, take only logits in the continuation
            contlen = len(cont_toks)
            ctx_len = (
                inplen + (logits.shape[0] - padding_len_inp)
                    if not is_encoder_decoder else None
                )

            logits = _select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
            # [1, seq, vocab]
            logits = logits.unsqueeze(0)

            # Check if per-token argmax is exactly equal to continuation
            greedy_tokens = logits.argmax(dim=-1)
             # [1, seq]
            cont_toks = torch.tensor(cont_toks, dtype=torch.long, device=device).unsqueeze(0)
            max_equal = (greedy_tokens == cont_toks).all()

            # Obtain log-probs at the corresponding continuation token indices
            # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
            logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)

            # Answer: (log prob of answer token, is-exact-match)
            answer = (float(logits.sum()), bool(max_equal))
            res.append(answer)
            pbar.update(1)

    pbar.close()
    return re_ord.get_original(res)

def loglikelihood(requests, model, tokenizer):
    new_reqs = []

    # encode (prompt, answer) pair
    for context, continuation in [req.args for req in requests]:
        if context == "":
            context_enc, continuation_enc = (
                [tokenizer.eos_token_id],
                tok_encode(continuation, tokenizer),
            )
        else:
            context_enc, continuation_enc = _encode_pair(context, continuation, tokenizer)
        new_reqs.append(((context, continuation), context_enc, continuation_enc))

    return _loglikelihood_tokens(new_reqs, model)

### 1.9.4 LogLikelihood Rolling

In [211]:
### forward method: loglikelihood_rolling ###
def loglikelihood_rolling(self, requests, model, tokenizer):
    loglikelihoods = []
    adaptive_batch_size = batch_size

    for (string,) in tqdm([req.args for req in requests]):
        rolling_token_windows = list(
            map(
                make_disjoint_window,
                get_rolling_token_windows(
                    token_list=tok_encode(string, tokenizer),
                    prefix_token=tokenizer.eos_token_id,
                    max_seq_len=max_length,
                    context_len=1,
                ),
            )
        )

        rolling_token_windows = [(None,) + x for x in rolling_token_windows]
        pad_amnt = 0

        string_nll = _loglikelihood_tokens(
            rolling_token_windows,
            model=model,
            disable_tqdm=True,
            override_bs=adaptive_batch_size,
        )

        string_nll = [x[0] for x in string_nll]
        string_nll = sum(string_nll)
        loglikelihoods.append(string_nll)

    return loglikelihoods

### 1.9.6 Testing

In [212]:
### test ###
requests = collections.defaultdict(list)

# group into each req_type. e.g. loglikelihood
for instance in instances:
    reqtype = instance.request_type
    requests[reqtype].append(instance)

# create `K` copies of each request `req` based off `K = req.repeats`
# the repeat is useful when you want to have majority voting on same data
for reqtype, reqs in requests.items():
    cloned_reqs = []
    for req in reqs:
        cloned_reqs.extend([req] * req.repeats)

    # run requests through model
    if reqtype == 'loglikelihood':
        # output: (prob, bool)
        resps = loglikelihood(cloned_reqs, model, tokenizer)
    elif reqtype == 'loglikelihood_rolling':
        resps = loglikelihood_rolling(cloned_reqs, model, tokenizer)

    # put responses from model into a list of length K for each request.
    for x, req in zip(resps, cloned_reqs):
        req.resps.append(x)

print('\n')
print('Results: ', requests['loglikelihood'][0])
print('Request Response: ', requests['loglikelihood'][0].resps)

100%|██████████| 40/40 [00:00<00:00, 50.73it/s]



Results:  Instance(request_type='loglikelihood', doc={'ind': 4, 'activity_label': 'Removing ice from car', 'ctx_a': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.', 'ctx_b': 'then', 'ctx': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then', 'endings': [', the man adds wax to the windshield and cuts it.', ', a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.', ', the man puts on a christmas coat, knitted with netting.', ', the man continues removing the snow on his car.'], 'source_id': 'activitynet~v_-1IBHYS3L-Y', 'split': 'train', 'split_type': 'indomain', 'label': '3', 'query': 'Removing ice from car: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. Then', 'choices': [', the man adds wax to the windshield and cuts it.', ', a per




## 1.10 Post-Processing (TBC)

In [240]:
# TODO: run filter apply before post-processing
def process_results(
        doc, results, config, metric_dict, multiple_input=False, multiple_target=False):
    if callable(config.process_results):
        return config.process_results(doc, results)

    result_dict = {}
    use_metric = list(metric_dict['metric_fn_list'].keys())

    if task_config.output_type == "loglikelihood":
        results = results[0]
        ll, is_greedy = results
        return {
            **({"perplexity": ll} if "perplexity" in use_metric else {}),
            **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
        }

    elif task_config.output_type == "loglikelihood_rolling":
        (loglikelihood,) = results
        _words = self.count_words(self.doc_to_target(doc))
        _bytes = self.count_bytes(self.doc_to_target(doc))
        return {
            **(
                {"word_perplexity": (loglikelihood, _words)}
                if "word_perplexity" in use_metric
                else {}
            ),
            **(
                {"byte_perplexity": (loglikelihood, _bytes)}
                if "byte_perplexity" in use_metric
                else {}
            ),
            **(
                {"bits_per_byte": (loglikelihood, _bytes)}
                if "bits_per_byte" in use_metric
                else {}
            ),
        }
    elif task_config.output_type == "multiple_choice":
        lls, is_greedy = zip(*results)

        # the 4 choices for each doc
        choices = doc_to_choice(doc, config, None, list(docs.features.keys()))
        completion_len = np.array([float(len(i)) for i in choices])

        if (2 * len(choices) == len(lls) and
            "acc_mutual_info" in metric_dict['metric_fn_list'].keys()):
            # mutual info.
            lls_unconditional = lls[1::2]
            lls = lls[::2]

        # find the index of the choice with highest prob
        pred = np.argmax(lls)
        pred_norm = np.argmax(lls / completion_len)

        if multiple_input:
            gold = doc_to_text(doc, config, list(docs.features.keys()))
        else:
            gold = doc_to_target(doc, config, None, list(docs.features.keys()))

        ### check error in golden set ###
        gold_index_error = False
        if isinstance(gold, list):
            gold = [i if i < len(choices) else -100 for i in gold]
            if -100 in gold:
                gold_index_error = True
        else:
            if isinstance(gold, int):
                gold = gold if gold < len(choices) else -100
            elif isinstance(gold, str):
                gold = choices.index(gold) if gold in choices else -100

            if gold == -100:
                gold_index_error = True

        if gold_index_error:
            eval_logger.warning(
                f"Label index was not in within range of available choices,"
                f"Sample:\n\n{doc}\n\n"
            )

        # get final metric
        if multiple_target:
            acc = 1.0 if pred in gold else 0.0
            acc_norm = 1.0 if pred_norm in gold else 0.0
            exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
        else:
            acc = 1.0 if pred == gold else 0.0
            acc_norm = 1.0 if pred_norm == gold else 0.0
            exact_match = int(is_greedy[gold]) if gold != -100 else 0

        result_dict = {
            **({"acc": acc} if "acc" in use_metric else {}),
            **({"f1": (gold, pred)} if "f1" in use_metric else {}),
            **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
            **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
            **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
        }

        if "acc_mutual_info" in use_metric:
            lls_mutual_info = [
                ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
            ]
            acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
            result_dict["acc_mutual_info"] = acc_mutual_info

    # elif self.OUTPUT_TYPE == "generate_until":
    #     gold = self.doc_to_target(doc)
    #     result = results[0]
    #     if self.config.doc_to_choice is not None:
    #         # If you set doc_to_choice,
    #         # it assumes that doc_to_target returns a number.
    #         choices = self.doc_to_choice(doc)
    #         gold = choices[gold]
    #     # we expect multiple_targets to be a list.
    #     elif self.multiple_target:
    #         gold = list(gold)
    #     elif type(gold) != type(result):
    #         # cast gold to the same type as result
    #         gold = type(result)(gold)

    #     for metric in self._metric_fn_list.keys():
    #         if self.multiple_target:
    #             # in the case where we have multiple targets,
    #             # return true if any are true
    #             # TODO: this may break for multipLe_target, non zero-or-1 metrics
    #             scores = []
    #             if not isinstance(gold, list):
    #                 # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
    #                 # print(gold)
    #                 gold = [gold]
    #             for gold_option in gold:
    #                 try:
    #                     result_score = self._metric_fn_list[metric](
    #                         references=[gold_option],
    #                         predictions=[result],
    #                         **self._metric_fn_kwargs[metric],
    #                     )
    #                 except (
    #                     TypeError
    #                 ):  # TODO: this is hacky and I don't want to do it
    #                     result_score = self._metric_fn_list[metric](
    #                         [gold_option, result]
    #                     )
    #                 if isinstance(result_score, dict):
    #                     # TODO: this handles the case where HF evaluate returns a dict.
    #                     result_score = result_score[metric]
    #                 scores.append(result_score)
    #             if any(scores):
    #                 result_score = 1.0
    #             else:
    #                 result_score = 0.0
    #         else:
    #             try:
    #                 result_score = self._metric_fn_list[metric](
    #                     references=[gold],
    #                     predictions=[result],
    #                     **self._metric_fn_kwargs[metric],
    #                 )
    #             except TypeError:  # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
    #                 result_score = self._metric_fn_list[metric]([gold, result])
    #             if isinstance(result_score, dict):
    #                 # TODO: this handles the case where HF evaluate returns a dict.
    #                 result_score = result_score[metric]
    #         result_dict[metric] = result_score
    # else:
    #     raise ValueError(
    #         f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
    #         "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
    #     )

    return result_dict

In [244]:
vals = collections.defaultdict(list)
for doc_id, doc in enumerate(train_docs[:10]):

    # get all the results fro the id (e.g. 4 choices per doc)
    requests = list(filter(lambda x: x.doc_id == doc_id, instances))
    requests.sort(key=lambda x: x.idx)
    metrics = process_results(doc, [req.resps[0] for req in requests], task_config, metric_dict)
    for metric, value in metrics.items():
        vals[('metric')].append(value)

print(vals)

defaultdict(<class 'list'>, {'metric': [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]})


In [254]:
### Aggregate results over all datapoints ###
# agg_fn = metric_dict['aggregation_list']

# stderr = stderr_for_metric(
#     metric=agg_fn,
#     bootstrap_iters=0


# )
# stderr

## 1.10 Benchmark

In [None]:
rnd = random.Random(100)

def eval(config):
    task_config = TaskConfig(**config)
    metric_dict = process_metric(task_config)
    filter_ops = create_filter_ops(task_config)
    dataset = download(task_config.dataset_path, task_config.dataset_name)

    # TODO: add use-prompt
    # if task_config.use_prompt:
    #     prompt = get_prompt(config.use_prompt, config['dataset_path'], config['dataset_name'])

    ### processed documents ###
    if task_config.training_split is not None:
        if task_config.process_docs is not None:
            docs = task_config.process_docs(dataset[task_config.training_split])
        else:
            docs = dataset[task_config.training_split]
        train_docs = list(docs)

    if task_config.test_split is not None:
        if task_config.process_docs is not None:
            docs = task_config.process_docs(dataset[task_config.test_split])
        else:
            docs = dataset[task_config.test_split]
        test_docs = list(docs)
        test_samples = rnd.sample(test_docs, 5)

    elif task_config.validation_split is not None:
        if task_config.process_docs is not None:
            docs = task_config.process_docs(dataset[task_config.validation_split])
        else:
            docs = dataset[task_config.validation_split]
        test_docs = list(docs)
        test_samples = rnd.sample(test_docs, 5)

    features = list(docs.features.keys())

    # TODO: add few-shot
    # if task_config.fewshot_split is not None:
    #     few_shot = dataset[task_config.fewshot_split]

    ### processing instances ###
    instances = build_all_requests(test_docs[:100], task_config, features, metric_dict, prompt=None)

    ### get log prob of each tokens ###
    requests = collections.defaultdict(list)
    # group into each req_type. e.g. loglikelihood
    for instance in instances:
        reqtype = instance.request_type
        requests[reqtype].append(instance)

    # create `K` copies of each request `req` based off `K = req.repeats`
    # the repeat is useful when you want to have majority voting on same data
    for reqtype, reqs in requests.items():
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)

        # run requests through model
        if reqtype == 'loglikelihood':
            # output: (prob, bool)
            resps = loglikelihood(requests[reqtype], model, tokenizer)
        elif reqtype == 'loglikelihood_rolling':
            resps = loglikelihood_rolling(requests[reqtype], model, tokenizer)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

### 1.5.3 HellaSwag

In [None]:
def preprocess(text):
    text = text.strip()
    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
    text = text.replace(" [title]", ". ")
    text = re.sub("\\[.*?\\]", "", text)
    text = text.replace("  ", " ")
    return text

def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
    def _process_doc(doc):
        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
        out_doc = {
            "query": preprocess(doc["activity_label"] + ": " + ctx),
            "choices": [preprocess(ending) for ending in doc["endings"]],
            "gold": int(doc["label"]),
        }
        return out_doc
    return dataset.map(_process_doc)

config = dict(
    group=['multiple_choice'],
    task='hellaswag',
    dataset_path='hellaswag',
    dataset_name=None,
    output_type='multiple_choice',
    training_split='train',
    validation_split='validation',
    test_split=None,
    process_docs=process_docs,
    doc_to_text="{{query}}",
    doc_to_target="{{label}}",
    doc_to_choice="choices",
    metric_list=[
        dict(metric='acc', aggregation='mean', higher_is_better=True),
        dict(metric='acc_norm', aggregation='mean', higher_is_better=True),
    ],
    metadata=dict(version=1.0),
)

In [None]:
# eval(config)