In [None]:
#| default_exp metrics

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| export
from fastcore.utils import *

In [None]:
#| export
import numpy as np
import torch
from rouge import Rouge
import evaluate

## Metrics Class

In [None]:
#| export
class Metrics:
    def __init__(self, names: list):
        self.metrics_names = names
        def validate_names(names):
            if len(names) == 0:
                return True
            try:
                evaluate.combine(names)
                return True
            except Exception as e:
                print(e)
                return False
        is_valid = validate_names(self.metrics_names)
        if not is_valid:
            raise ValueError(f"Invalid metric names, the available metrics are {evaluate.list_evaluation_modules('metric')}")

In [None]:
#| export
@patch
def compute(self:Metrics, y_true, y_pred, **kwargs):
    self.metrics = evaluate.combine(self.metrics_names)
    return self.metrics.compute(predictions= y_pred, references= y_true, **kwargs)

In [None]:
metrics = Metrics(["bleu", "rouge"])

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
tokenize_func = tokenizer.tokenize
tokenize_func("hello world")

['hello', 'Ġworld']

In [None]:
references = [["hello th", "hello there !"], ["foo bar foobar"]]
predictions = ["hello there general kenobi", "foo bar foobar"]
res = metrics.compute(y_true=references, y_pred=predictions, tokenizer=tokenize_func)


In [None]:
res['bleu'], res['rougeL']

(0.3976353643835253, 0.7222222222222222)

## Other utils

In [None]:
#| hide

def rouge_score(hyp_ids, ref_ids, tokenizer):
    rouge = Rouge()
    hyps = torch.where(hyp_ids != -100, hyp_ids, tokenizer.pad_token_id)
    refs = torch.where(ref_ids != -100, ref_ids, tokenizer.pad_token_id)

    hyps = tokenizer.batch_decode(hyps, skip_special_tokens=True)
    refs = tokenizer.batch_decode(refs, skip_special_tokens=True)
    
    batch_rouge = 0
    for i in range(len(hyps)):
        if len(hyps[i].strip()) == 0:
            continue
        
        else:
            h = hyps[i].strip().lower()
            r = refs[i].strip().lower()
            try:
                item_rouge = rouge.get_scores(h, r)[0]['rouge-l']['f']
            except ValueError:
                print("Error in calculating rouge score")
                item_rouge = 0

            batch_rouge += item_rouge

    rouge_score = batch_rouge / len(hyps)
    
    return rouge_score

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()