Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multimodal/speech_cv collection update (Audio Visual Speech Recognition) #8688

Closed
wants to merge 9 commits into from
17 changes: 15 additions & 2 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,15 @@ def __init__(
is_causal=causal_downsampling,
)
else:
self.pre_encode = nn.Linear(feat_in, d_model)

# Linear pre_encode
if feat_in:
self.pre_encode = nn.Linear(feat_in, d_model)

# Identity pre_encode
else:
self.pre_encode = None
self._feat_in = d_model

# Reduction
if reduction and reduction_factor > 1:
Expand Down Expand Up @@ -522,8 +530,13 @@ def forward_internal(

audio_signal = torch.transpose(audio_signal, 1, 2)

if isinstance(self.pre_encode, nn.Linear):
# No Pre-encoding
if self.pre_encode is None:
pass
# Linear Pre-encoding
elif isinstance(self.pre_encode, nn.Linear):
audio_signal = self.pre_encode(audio_signal)
# Other Pre-encoding
else:
audio_signal, length = self.pre_encode(x=audio_signal, lengths=length)
length = length.to(torch.int64)
Expand Down
170 changes: 170 additions & 0 deletions nemo/collections/common/parts/preprocessing/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,127 @@ def __init__(
super().__init__(data)


class AudioVideoText(_Collection):
"""List of audio-video-transcript text correspondence with preprocessing."""

OUTPUT_TYPE = collections.namedtuple(
typename='AudioVideoTextEntity',
field_names='id audio_file video_file duration text_tokens offset text_raw speaker orig_sr lang',
)

def __init__(
self,
ids: List[int],
audio_files: List[str],
video_files: List[str],
durations: List[float],
texts: List[str],
offsets: List[str],
speakers: List[Optional[int]],
orig_sampling_rates: List[Optional[int]],
token_labels: List[Optional[int]],
langs: List[Optional[str]],
parser: parsers.CharParser,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
max_number: Optional[int] = None,
do_sort_by_duration: bool = False,
index_by_file_id: bool = False,
):
"""Instantiates audio-text manifest with filters and preprocessing.

Args:
ids: List of examples positions.
audio_files: List of audio files.
video_files: List of video files.
durations: List of float durations.
texts: List of raw text transcripts.
offsets: List of duration offsets or None.
speakers: List of optional speakers ids.
orig_sampling_rates: List of original sampling rates of audio files.
langs: List of language ids, one for eadh sample, or None.
parser: Instance of `CharParser` to convert string to tokens.
min_duration: Minimum duration to keep entry with (default: None).
max_duration: Maximum duration to keep entry with (default: None).
max_number: Maximum number of samples to collect.
do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id.
index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
"""

output_type = self.OUTPUT_TYPE
data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0
if index_by_file_id:
self.mapping = {}

for id_, audio_file, video_file, duration, offset, text, speaker, orig_sr, token_labels, lang in zip(
ids,
audio_files,
video_files,
durations,
offsets,
texts,
speakers,
orig_sampling_rates,
token_labels,
langs,
):
# Duration filters.
if min_duration is not None and duration < min_duration:
duration_filtered += duration
num_filtered += 1
continue

if max_duration is not None and duration > max_duration:
duration_filtered += duration
num_filtered += 1
continue

if token_labels is not None:
text_tokens = token_labels
else:
if text != '':
if hasattr(parser, "is_aggregate") and parser.is_aggregate and isinstance(text, str):
if lang is not None:
text_tokens = parser(text, lang)
else:
raise ValueError("lang required in manifest when using aggregate tokenizers")
else:
text_tokens = parser(text)
else:
text_tokens = []

if text_tokens is None:
duration_filtered += duration
num_filtered += 1
continue

total_duration += duration

data.append(
output_type(id_, audio_file, video_file, duration, text_tokens, offset, text, speaker, orig_sr, lang)
)
if index_by_file_id:
file_id, _ = os.path.splitext(os.path.basename(audio_file))
if file_id not in self.mapping:
self.mapping[file_id] = []
self.mapping[file_id].append(len(data) - 1)

# Max number of entities filter.
if len(data) == max_number:
break

if do_sort_by_duration:
if index_by_file_id:
logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
else:
data.sort(key=lambda entity: entity.duration)

logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)

super().__init__(data)


class ASRAudioText(AudioText):
"""`AudioText` collector from asr structured json files."""

Expand Down Expand Up @@ -379,6 +500,55 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs):
)


class ASRVideoAudioText(AudioVideoText):
"""`AudioVideoText` collector from asr structured json files."""

def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs):
"""Parse lists of audio and video files, durations and transcripts texts.

Args:
manifests_files: Either single string file or list of such -
manifests to yield items from.
*args: Args to pass to `AudioVideoText` constructor.
**kwargs: Kwargs to pass to `AudioVideoText` constructor.
"""

ids, audio_files, video_files, durations, texts, offsets, = (
[],
[],
[],
[],
[],
[],
)
speakers, orig_srs, token_labels, langs = [], [], [], []
for item in manifest.item_iter(manifests_files):
ids.append(item['id'])
audio_files.append(item['audio_file'])
video_files.append(item['video_file'])
durations.append(item['duration'])
texts.append(item['text'])
offsets.append(item['offset'])
speakers.append(item['speaker'])
orig_srs.append(item['orig_sr'])
token_labels.append(item['token_labels'])
langs.append(item['lang'])
super().__init__(
ids,
audio_files,
video_files,
durations,
texts,
offsets,
speakers,
orig_srs,
token_labels,
langs,
*args,
**kwargs,
)


class SpeechLabel(_Collection):
"""List of audio-label correspondence with preprocessing."""

Expand Down