Skip to content

Commit

Permalink
Add: Check transformers version in FER metric and refactor internal c…
Browse files Browse the repository at this point in the history
…ode.
  • Loading branch information
Labbeti committed Dec 21, 2023
1 parent 4387f23 commit ea508c8
Showing 1 changed file with 45 additions and 64 deletions.
109 changes: 45 additions & 64 deletions src/aac_metrics/functional/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
import transformers

from torch import nn, Tensor
from tqdm import tqdm
Expand All @@ -26,12 +27,11 @@
from aac_metrics.utils.globals import _get_device


# config according to the settings on your computer, this should be default setting of shadowsocks
DEFAULT_PROXIES = {
_DEFAULT_PROXIES = {
"http": "socks5h://127.0.0.1:1080",
"https": "socks5h://127.0.0.1:1080",
}
PRETRAIN_ECHECKERS_DICT = {
_PRETRAIN_ECHECKERS_DICT = {
"echecker_clotho_audiocaps_base": (
"https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_base.ckpt",
"1a719f090af70614bbdb9f9437530b7e133c48cfa4a58d964de0d47fc974a2fa",
Expand All @@ -41,13 +41,7 @@
"90ed0ac5033ec497ec66d4f68588053813e085671136dae312097c96c504f673",
),
}

RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"])

pylog = logging.getLogger(__name__)


ERROR_NAMES = (
_ERROR_NAMES = (
"add_tail",
"repeat_event",
"repeat_adv",
Expand All @@ -56,6 +50,10 @@
"error",
)

_RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"])

pylog = logging.getLogger(__name__)


class BERTFlatClassifier(nn.Module):
def __init__(self, model_type: str, num_classes: int = 5) -> None:
Expand Down Expand Up @@ -131,18 +129,29 @@ def fer(
error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})"
raise ValueError(error_msg)

version = transformers.__version__
major, minor, _patch = map(int, version.split("."))
if major > 4 or (major == 4 and minor > 30):
raise ValueError(
f"Invalid transformers version {version} for FER metric. Please use a version < 4.31.0."
)

# Init models
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker, echecker_tokenizer, device, reset_state, verbose
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)

# Compute and apply fluency error detection penalty
probs_outs_sents = __detect_error_sents(
echecker,
echecker_tokenizer, # type: ignore
candidates,
batch_size,
device,
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
sents=candidates,
batch_size=batch_size,
device=device,
)
fer_scores = (probs_outs_sents["error"] > error_threshold).astype(float)

Expand Down Expand Up @@ -226,10 +235,10 @@ def __detect_error_sents(
# batch_logits: (bsize, num_classes=6)
# note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69
probs = logits.sigmoid().transpose(0, 1).cpu().numpy()
probs_dic: dict[str, np.ndarray] = dict(zip(ERROR_NAMES, probs))
probs_dic: dict[str, np.ndarray] = dict(zip(_ERROR_NAMES, probs))

else:
dic_lst_probs = {name: [] for name in ERROR_NAMES}
dic_lst_probs = {name: [] for name in _ERROR_NAMES}

for i in range(0, len(sents), batch_size):
batch = __infer_preprocess(
Expand Down Expand Up @@ -257,11 +266,10 @@ def __detect_error_sents(


def __check_download_resource(
remote: RemoteFileMetadata,
remote: _RemoteFileMetadata,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
) -> str:
proxies = DEFAULT_PROXIES if use_proxy and proxies is None else proxies
data_home = __get_data_home()
file_path = os.path.join(data_home, remote.filename)
if not os.path.exists(file_path):
Expand All @@ -286,10 +294,10 @@ def __infer_preprocess(


def __download(
remote: RemoteFileMetadata,
remote: _RemoteFileMetadata,
file_path: Optional[str] = None,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = DEFAULT_PROXIES,
proxies: Optional[dict[str, str]] = None,
) -> str:
data_home = __get_data_home()
file_path = __fetch_remote(remote, data_home, use_proxy, proxies)
Expand All @@ -299,8 +307,12 @@ def __download(
def __download_with_bar(
url: str,
file_path: str,
proxies: Optional[dict[str, str]] = DEFAULT_PROXIES,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
) -> str:
if use_proxy and proxies is None:
proxies = _DEFAULT_PROXIES

# Streaming, so we can iterate over the response.
response = requests.get(url, stream=True, proxies=proxies)
total_size_in_bytes = int(response.headers.get("content-length", 0))
Expand All @@ -317,31 +329,13 @@ def __download_with_bar(


def __fetch_remote(
remote: RemoteFileMetadata,
remote: _RemoteFileMetadata,
dirname: Optional[str] = None,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = DEFAULT_PROXIES,
proxies: Optional[dict[str, str]] = None,
) -> str:
"""Helper function to download a remote dataset into path
Fetch a dataset pointed by remote's url, save into path using remote's
filename and ensure its integrity based on the SHA256 Checksum of the
downloaded file.
Parameters
----------
remote : RemoteFileMetadata
Named tuple containing remote dataset meta information: url, filename
and checksum
dirname : string
Directory to save the file to.
Returns
-------
file_path: string
Full path of the created file.
"""

file_path = remote.filename if dirname is None else join(dirname, remote.filename)
proxies = None if not use_proxy else proxies
file_path = __download_with_bar(remote.url, file_path, proxies)
file_path = __download_with_bar(remote.url, file_path, use_proxy, proxies)
checksum = __sha256(file_path)
if remote.checksum != checksum:
raise IOError(
Expand All @@ -352,23 +346,10 @@ def __fetch_remote(
return file_path


def __get_data_home(data_home: Optional[str] = None) -> str: # type: ignore
"""Return the path of the scikit-learn data dir.
This folder is used by some large dataset loaders to avoid downloading the
data several times.
By default the data dir is set to a folder named 'fense_data' in the
user home folder.
Alternatively, it can be set by the 'FENSE_DATA' environment
variable or programmatically by giving an explicit folder path. The '~'
symbol is expanded to the user home folder.
If the folder does not already exist, it is automatically created.
Parameters
----------
data_home : str | None
The path to data dir.
"""
def __get_data_home(data_home: Optional[str] = None) -> str:
if data_home is None:
data_home = environ.get("FENSE_DATA", join(torch.hub.get_dir(), "fense_data"))
DEFAULT_DATA_HOME = join(torch.hub.get_dir(), "fense_data")
data_home = environ.get("FENSE_DATA", DEFAULT_DATA_HOME)

data_home: str
data_home = expanduser(data_home)
Expand All @@ -384,15 +365,15 @@ def __load_pretrain_echecker(
proxies: Optional[dict[str, str]] = None,
verbose: int = 0,
) -> BERTFlatClassifier:
if echecker_model not in PRETRAIN_ECHECKERS_DICT:
if echecker_model not in _PRETRAIN_ECHECKERS_DICT:
raise ValueError(
f"Invalid argument {echecker_model=}. (expected one of {tuple(PRETRAIN_ECHECKERS_DICT.keys())})"
f"Invalid argument {echecker_model=}. (expected one of {tuple(_PRETRAIN_ECHECKERS_DICT.keys())})"
)

device = _get_device(device)
tfmers_logging.set_verbosity_error() # suppress loading warnings
url, checksum = PRETRAIN_ECHECKERS_DICT[echecker_model]
remote = RemoteFileMetadata(
url, checksum = _PRETRAIN_ECHECKERS_DICT[echecker_model]
remote = _RemoteFileMetadata(
filename=f"{echecker_model}.ckpt", url=url, checksum=checksum
)
file_path = __check_download_resource(remote, use_proxy, proxies)
Expand Down

0 comments on commit ea508c8

Please sign in to comment.