Skip to content

Commit

Permalink
Mod: Refactor internal FER code to match aac-metrics style.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 15, 2024
1 parent e1ec091 commit 0d309d0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/aac_metrics/functional/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def _fense_from_outputs(
device=fense_scores.device,
)

fense_outs_corpus = sbert_sim_outs_corpus | fer_outs_corpus | {"fense": fense_score}
fense_outs_sents = sbert_sim_outs_sents | fer_outs_sents | {"fense": fense_scores}
fense_outs_corpus = sbert_sim_outs_corpus | fer_outs_corpus | {"fense": fense_score} # type: ignore
fense_outs_sents = sbert_sim_outs_sents | fer_outs_sents | {"fense": fense_scores} # type: ignore
fense_outs = fense_outs_corpus, fense_outs_sents

return fense_outs
Expand Down
32 changes: 17 additions & 15 deletions src/aac_metrics/functional/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import hashlib
import logging
import os
import os.path as osp
import re
from collections import namedtuple
from os import environ, makedirs
from os.path import exists, expanduser, join
from typing import Mapping, Optional, TypedDict, Union

import numpy as np
Expand Down Expand Up @@ -150,7 +149,7 @@ def fer(
# Compute and apply fluency error detection penalty
probs_outs_sents = __detect_error_sents(
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
echecker_tokenizer=echecker_tokenizer, # type: ignore
sents=candidates,
batch_size=batch_size,
device=device,
Expand Down Expand Up @@ -182,7 +181,7 @@ def fer(

fer_outs = fer_outs_corpus, fer_outs_sents

return fer_outs
return fer_outs # type: ignore
else:
return fer_score

Expand Down Expand Up @@ -344,27 +343,30 @@ def __fetch_remote(
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
) -> str:
file_path = remote.filename if dirname is None else join(dirname, remote.filename)
if dirname is None:
file_path = remote.filename
else:
file_path = osp.join(dirname, remote.filename)

file_path = __download_with_bar(remote.url, file_path, use_proxy, proxies)
checksum = __sha256(file_path)
if remote.checksum != checksum:
raise IOError(
"{} has an SHA256 checksum ({}) "
"differing from expected ({}), "
"file may be corrupted.".format(file_path, checksum, remote.checksum)
raise RuntimeError(
f"{file_path} has an SHA256 checksum ({checksum}) "
f"differing from expected ({remote.checksum}), "
f"file may be corrupted."
)

return file_path


def __get_data_home(data_home: Optional[str] = None) -> str:
if data_home is None:
DEFAULT_DATA_HOME = join(torch.hub.get_dir(), "fense_data")
data_home = environ.get("FENSE_DATA", DEFAULT_DATA_HOME)
DEFAULT_DATA_HOME = osp.join(torch.hub.get_dir(), "fense_data")
data_home = os.getenv("FENSE_DATA", DEFAULT_DATA_HOME)

data_home: str
data_home = expanduser(data_home)
if not exists(data_home):
makedirs(data_home)
data_home = osp.expanduser(data_home)
os.makedirs(data_home, exist_ok=True)
return data_home


Expand Down

0 comments on commit 0d309d0

Please sign in to comment.