In [None]:
#| default_exp metrics

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| export
import numpy as np
import torch
from rouge import Rouge
import evaluate
from fastcore.utils import *


In [None]:
#| export
import importlib
def get_cls(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

## Metrics Class

In [None]:
#| export
class Metrics:
    def __init__(self, lst_metrics_names):
        self.lst_metrics_names = lst_metrics_names
        self.metrics = {metric_name: 0 for metric_name in lst_metrics_names}

In [None]:
#| export
@patch
def prepare_sequence(self: Metrics, y_true):
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().numpy()
    elif isinstance(y_true, list):
        y_true = np.array(y_true)
    
    return y_true

In [None]:
#| export
@patch
def compute(self: Metrics, y_true, y_pred):
    y_true = self.prepare_sequence(y_true)
    y_pred = self.prepare_sequence(y_pred)

    for metric_name in self.lst_metrics_names:
        metric = get_cls('sklearn.metrics', metric_name)
        self.metrics[metric_name] = metric(y_true, y_pred)
    return self.metrics

In [None]:
m = Metrics(['accuracy_score', 'f1_score', 'precision_score', 'recall_score'])
y_true = [0, 1, 1, 0]
y_pred = [0, 1, 0, 0]
m.compute(y_true, y_pred)

{'accuracy_score': 0.75,
 'f1_score': 0.6666666666666666,
 'precision_score': 1.0,
 'recall_score': 0.5}

In [None]:
# #| export
# class Metrics:
#     def __init__(self, names: list):
#         self.metrics_names = names
#         def validate_names(names):
#             if len(names) == 0:
#                 return True
#             try:
#                 evaluate.combine(names)
#                 return True
#             except Exception as e:
#                 print(e)
#                 return False
#         is_valid = validate_names(self.metrics_names)
#         if not is_valid:
#             raise ValueError(f"Invalid metric names, the available metrics are {evaluate.list_evaluation_modules('metric')}")

In [None]:
# #| export
# @patch
# def prepare_targets_llm(self: Metrics, y_true, y_pred, tokenizer= None):

#     if hasattr(tokenizer, "pad_token_id"):
#         padding_mask = (y_true != tokenizer.pad_token_id)  # Shape: (batch_size, seq_length)
#         padding_mask_flat = padding_mask.view(-1)  # Flatten the mask
#         # Apply the mask
#         y_pred = y_pred[padding_mask_flat]
#         y_true = y_true[padding_mask_flat]

#     y_true = y_true.cpu() if isinstance(y_true, torch.Tensor) else y_true
#     y_pred = y_pred.cpu() if isinstance(y_pred, torch.Tensor) else y_pred
#     return y_true, y_pred

In [None]:
# #| export
# @patch
# def compute(self: Metrics, y_true, y_pred, tokenizer= None,  **kwargs):
#     self.metrics = evaluate.combine(self.metrics_names)
#     if tokenizer:
#         y_true, y_pred = self.prepare_targets_llm(y_true, y_pred, tokenizer)
    
#     return self.metrics.compute(predictions= y_pred, references= y_true, **kwargs)

In [None]:
# metrics = Metrics(["bleu", "rouge"])

In [None]:
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
# tokenizer.pad_token = tokenizer.eos_token

In [None]:
# tokenize_func = tokenizer.tokenize
# tokenize_func("hello world")

['hello', 'Ġworld']

In [None]:
# references = [["hello th", "hello there !"], ["foo bar foobar"]]
# predictions = ["hello there general kenobi", "foo bar foobar"]
# res = metrics.compute(y_true=references, y_pred=predictions, tokenizer=tokenize_func)


In [None]:
# type(res)

dict

In [None]:
# res['bleu'], res['rougeL']

(0.3976353643835253, 0.7222222222222222)

## Other utils

In [None]:
# #| hide

# def rouge_score(hyp_ids, ref_ids, tokenizer):
#     rouge = Rouge()
#     hyps = torch.where(hyp_ids != -100, hyp_ids, tokenizer.pad_token_id)
#     refs = torch.where(ref_ids != -100, ref_ids, tokenizer.pad_token_id)

#     hyps = tokenizer.batch_decode(hyps, skip_special_tokens=True)
#     refs = tokenizer.batch_decode(refs, skip_special_tokens=True)
    
#     batch_rouge = 0
#     for i in range(len(hyps)):
#         if len(hyps[i].strip()) == 0:
#             continue
        
#         else:
#             h = hyps[i].strip().lower()
#             r = refs[i].strip().lower()
#             try:
#                 item_rouge = rouge.get_scores(h, r)[0]['rouge-l']['f']
#             except ValueError:
#                 print("Error in calculating rouge score")
#                 item_rouge = 0

#             batch_rouge += item_rouge

#     rouge_score = batch_rouge / len(hyps)
    
#     return rouge_score

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()