Skip to content

Commit

Permalink
2.2.10 (#624)
Browse files Browse the repository at this point in the history
* Bug fixes
  • Loading branch information
mmcauliffe committed Apr 26, 2023
1 parent 9012a5e commit c9c4022
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 10 deletions.
6 changes: 6 additions & 0 deletions docs/source/changelog/changelog_2.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
2.2 Changelog
*************

2.2.10
======

- Fix crash in speaker diarization
- Update alignment evaluation to export confusion counts

2.2.9
=====
- Fixed a bug in pronunciation probability training that was causing all probabilities of following silence to be 0
Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ To use the Docker image of MFA:

1. Run :code:`docker image pull mmcauliffe/montreal-forced-aligner:latest`
2. Enter the interactive docker shell via :code:`docker run -it -v /path/to/data/directory:/data mmcauliffe/montreal-forced-aligner:latest`
3. Once you are in the shell, you can run MFA commands as normal (i.e., :code:`mfa align ...`). You may need to download any pretrained models you want to use each session (i.e., :code:`mfa download acoustic english_mfa`)
3. Once you are in the shell, you can run MFA commands as normal (i.e., :code:`mfa align ...`). You may need to download any pretrained models you want to use each session (i.e., :code:`mfa model download acoustic english_mfa`)

.. important::

Expand Down
23 changes: 20 additions & 3 deletions montreal_forced_aligner/alignment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,11 @@ def export_textgrids(
f"Check {os.path.join(self.export_output_directory, 'output_errors.txt')} "
f"for more details"
)
output_textgrid_writing_errors(self.export_output_directory, error_dict)
output_textgrid_writing_errors(self.export_output_directory, error_dict)
if GLOBAL_CONFIG.current_profile.debug:
for k, v in error_dict.items():
print(k)
raise v
logger.info(f"Finished exporting TextGrids to {self.export_output_directory}!")
logger.debug(f"Exported TextGrids in a total of {time.time() - begin:.3f} seconds")

Expand Down Expand Up @@ -1279,13 +1283,21 @@ def evaluate_alignments(
output_directory,
f"{comparison_source.name}_{reference_source.name}_evaluation.csv",
)
confusion_path = os.path.join(
output_directory,
f"{comparison_source.name}_{reference_source.name}_confusions.csv",
)
else:
self._current_workflow = "evaluation"
os.makedirs(self.working_log_directory, exist_ok=True)
csv_path = os.path.join(
self.working_log_directory,
f"{comparison_source.name}_{reference_source.name}_evaluation.csv",
)
confusion_path = os.path.join(
self.working_log_directory,
f"{comparison_source.name}_{reference_source.name}_confusions.csv",
)
csv_header = [
"file",
"begin",
Expand All @@ -1306,6 +1318,7 @@ def evaluate_alignments(
score_sum = 0
phone_edit_sum = 0
phone_length_sum = 0
phone_confusions = collections.Counter()
with self.session() as session:
# Set up
logger.info("Evaluating alignments...")
Expand Down Expand Up @@ -1370,10 +1383,11 @@ def evaluate_alignments(
to_comp.append((reference_phones, comparison_phones))
with mp.Pool(GLOBAL_CONFIG.num_jobs) as pool:
gen = pool.starmap(score_func, to_comp)
for i, (score, phone_error_rate) in enumerate(gen):
for i, (score, phone_error_rate, errors) in enumerate(gen):
if score is None:
continue
u = indices[i]
phone_confusions.update(errors)
reference_phone_count = reference_phone_counts[u.id]
update_mappings.append(
{
Expand Down Expand Up @@ -1443,7 +1457,10 @@ def evaluate_alignments(
score_count += 1
score_sum += alignment_score
writer.writerow(data)

with mfa_open(confusion_path, "w") as f:
f.write("reference,hypothesis,count\n")
for k, v in sorted(phone_confusions.items(), key=lambda x: -x[1]):
f.write(f"{k[0]},{k[1]},{v}\n")
logger.info(f"Average overlap score: {score_sum/score_count}")
logger.info(f"Average phone error rate: {phone_edit_sum/phone_length_sum}")
logger.debug(f"Alignment evaluation took {time.time()-begin} seconds")
4 changes: 2 additions & 2 deletions montreal_forced_aligner/command_line/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,14 @@ def start_server() -> None:
logger.info(f"{GLOBAL_CONFIG.current_profile_name} MFA database server started!")


def stop_server(mode: str = "fast") -> None:
def stop_server(mode: str = "smart") -> None:
"""
Stop the MFA server for the current profile.
Parameters
----------
mode: str, optional
Mode to to be passed to `pg_ctl`, defaults to "fast"
Mode to to be passed to `pg_ctl`, defaults to "smart"
"""
logger = logging.getLogger("mfa")
GLOBAL_CONFIG.load()
Expand Down
2 changes: 2 additions & 0 deletions montreal_forced_aligner/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,8 +1751,10 @@ def to_tg_interval(self, file_duration=None) -> Interval:
if self.end < -1 or self.begin == 1000000:
raise CtmError(self)
end = round(self.end, 6)
begin = round(self.begin, 6)
if file_duration is not None and end > file_duration:
end = round(file_duration, 6)
assert begin < end
return Interval(round(self.begin, 6), end, self.label)


Expand Down
2 changes: 2 additions & 0 deletions montreal_forced_aligner/diarization/speaker_diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
self.ground_truth_utt2spk = {}
self.ground_truth_speakers = {}
self.single_clusters = set()
self._unknown_speaker_break_up_count = 0

@classmethod
def parse_parameters(
Expand Down Expand Up @@ -674,6 +675,7 @@ def fix_speaker_ordering(self):
session.commit()

def initialize_mfa_clustering(self):
self._unknown_speaker_break_up_count = 0

with self.session() as session:
next_speaker_id = self.get_next_primary_key(Speaker)
Expand Down
15 changes: 12 additions & 3 deletions montreal_forced_aligner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from __future__ import annotations

import collections
import functools
import itertools
import json
Expand Down Expand Up @@ -619,7 +620,7 @@ def align_phones(
ignored_phones: typing.Set[str] = None,
custom_mapping: Optional[Dict[str, str]] = None,
debug: bool = False,
) -> Tuple[float, float]:
) -> Tuple[float, float, Dict[Tuple[str, str], int]]:
"""
Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for
different phone labels to be scored as if they're the same phone
Expand All @@ -634,13 +635,17 @@ def align_phones(
Silence phone (these are ignored in the final calculation)
custom_mapping: dict[str, str], optional
Mapping of phones to treat as matches even if they have different symbols
debug: bool, optional
Flag for logging extra information about alignments
Returns
-------
float
Score based on the average amount of overlap in phone intervals
float
Phone error rate
dict[tuple[str, str], int]
Dictionary of error pairs with their counts
"""

if ignored_phones is None:
Expand All @@ -653,23 +658,26 @@ def align_phones(
)

alignments = pairwise2.align.globalcs(
ref, test, score_func, -5, -5, gap_char=["-"], one_alignment_only=True
ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True
)
overlap_count = 0
overlap_sum = 0
num_insertions = 0
num_deletions = 0
num_substitutions = 0
errors = collections.Counter()
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if sa == "-":
if sb.label not in ignored_phones:
errors[(sa, sb.label)] += 1
num_insertions += 1
else:
continue
elif sb == "-":
if sa.label not in ignored_phones:
errors[(sa.label, sb)] += 1
num_deletions += 1
else:
continue
Expand All @@ -680,6 +688,7 @@ def align_phones(
overlap_count += 1
if compare_labels(sa.label, sb.label, silence_phone, mapping=custom_mapping) > 0:
num_substitutions += 1
errors[(sa.label, sb.label)] += 1
if debug:
import logging

Expand All @@ -690,7 +699,7 @@ def align_phones(
else:
score = None
phone_error_rate = (num_insertions + num_deletions + (2 * num_substitutions)) / len(ref)
return score, phone_error_rate
return score, phone_error_rate, errors


def format_probability(probability_value: float) -> float:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_align_phones(basic_corpus_dir, basic_dict_path, temp_dir, eval_mapping_
"ɹ",
]
comparison_sequence = [CtmInterval(i, i + 1, x) for i, x in enumerate(comparison_sequence)]
score, phone_errors = align_phones(
score, phone_errors, error_counts = align_phones(
reference_sequence,
comparison_sequence,
silence_phone="sil",
Expand Down

0 comments on commit c9c4022

Please sign in to comment.