Skip to content

Commit

Permalink
Mod: Rename flat_list to flat_list_of_list and unflat_list to unflat_…
Browse files Browse the repository at this point in the history
…list_of_list.
  • Loading branch information
LABBE Etienne committed May 17, 2024
1 parent 53a4cfd commit 1e8bcec
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 22 deletions.
10 changes: 7 additions & 3 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from transformers.models.auto.tokenization_auto import AutoTokenizer

from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.collections import duplicate_list, flat_list, unflat_list
from aac_metrics.utils.collections import (
duplicate_list,
flat_list_of_list,
unflat_list_of_list,
)
from aac_metrics.utils.globals import _get_device

DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL
Expand Down Expand Up @@ -102,7 +106,7 @@ def bert_score_mrefs(
)

device = _get_device(device)
flat_mrefs, sizes = flat_list(mult_references)
flat_mrefs, sizes = flat_list_of_list(mult_references)
duplicated_cands = duplicate_list(candidates, sizes)
assert len(duplicated_cands) == len(flat_mrefs)

Expand Down Expand Up @@ -134,7 +138,7 @@ def bert_score_mrefs(
sents_scores = {k: [v] for k, v in sents_scores.items()}

# sents_scores keys: "precision", "recall", "f1"
sents_scores = {k: unflat_list(v, sizes) for k, v in sents_scores.items()} # type: ignore
sents_scores = {k: unflat_list_of_list(v, sizes) for k, v in sents_scores.items()} # type: ignore

if not return_all_scores:
sents_scores = {"f1": sents_scores["f1"]}
Expand Down
4 changes: 2 additions & 2 deletions src/aac_metrics/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
KeyMode = Literal["intersect", "same", "union"]


def flat_list(lst: list[list[T]]) -> tuple[list[T], list[int]]:
def flat_list_of_list(lst: list[list[T]]) -> tuple[list[T], list[int]]:
"""Return a flat version of the input list of sublists with each sublist size."""
flatten_lst = [element for sublst in lst for element in sublst]
sizes = [len(sents) for sents in lst]
return flatten_lst, sizes


def unflat_list(flatten_lst: list[T], sizes: list[int]) -> list[list[T]]:
def unflat_list_of_list(flatten_lst: list[T], sizes: list[int]) -> list[list[T]]:
"""Unflat a list to a list of sublists of given sizes."""
lst = []
start = 0
Expand Down
14 changes: 4 additions & 10 deletions src/aac_metrics/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@
import subprocess
import tempfile
import time

from pathlib import Path
from typing import Any, Hashable, Iterable, Optional, Union

from aac_metrics.utils.checks import check_java_path, is_mono_sents
from aac_metrics.utils.collections import flat_list, unflat_list
from aac_metrics.utils.globals import (
_get_cache_path,
_get_java_path,
_get_tmp_path,
)

from aac_metrics.utils.collections import flat_list_of_list, unflat_list_of_list
from aac_metrics.utils.globals import _get_cache_path, _get_java_path, _get_tmp_path

pylog = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,7 +243,7 @@ def preprocess_mult_sents(
:param verbose: The verbose level. defaults to 0.
:returns: The multiple sentences processed by the tokenizer.
"""
flatten_sents, sizes = flat_list(mult_sentences)
flatten_sents, sizes = flat_list_of_list(mult_sentences)
flatten_sents = preprocess_mono_sents(
sentences=flatten_sents,
cache_path=cache_path,
Expand All @@ -259,5 +253,5 @@ def preprocess_mult_sents(
normalize_apostrophe=normalize_apostrophe,
verbose=verbose,
)
mult_sentences = unflat_list(flatten_sents, sizes)
mult_sentences = unflat_list_of_list(flatten_sents, sizes)
return mult_sentences
13 changes: 6 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

import random
import unittest

from unittest import TestCase

from aac_metrics.utils.checks import (
MAX_JAVA_MAJOR_VERSION,
MIN_JAVA_MAJOR_VERSION,
_check_java_version,
is_mono_sents,
is_mult_sents,
_check_java_version,
MIN_JAVA_MAJOR_VERSION,
MAX_JAVA_MAJOR_VERSION,
)
from aac_metrics.utils.collections import flat_list, unflat_list
from aac_metrics.utils.collections import flat_list_of_list, unflat_list_of_list


class TestUtils(TestCase):
Expand All @@ -29,13 +28,13 @@ def test_misc_functions_1(self) -> None:

self.assertTrue(is_mult_sents(lst))

flatten, sizes = flat_list(lst)
flatten, sizes = flat_list_of_list(lst)

self.assertTrue(is_mono_sents(flatten))
self.assertEqual(len(lst), len(sizes))
self.assertEqual(len(flatten), sum(sizes))

unflat = unflat_list(flatten, sizes)
unflat = unflat_list_of_list(flatten, sizes)

self.assertTrue(is_mult_sents(unflat))
self.assertEqual(len(lst), len(unflat))
Expand Down

0 comments on commit 1e8bcec

Please sign in to comment.