Skip to content

Commit

Permalink
Add silence detection to inference (#36)
Browse files Browse the repository at this point in the history
* update README

* update README

* add fp16

* inference modification

* add stacked mels

* triple

* add silence detection
  • Loading branch information
loubbrad committed May 17, 2024
1 parent 0d7badb commit 014e757
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 17 deletions.
13 changes: 9 additions & 4 deletions amt/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ def __init__(
dtype=torch.bfloat16,
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
self.dtype = dtype
self.cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer(
"k_cache", torch.zeros(self.cache_shape, dtype=dtype)
)
self.register_buffer(
"v_cache", torch.zeros(self.cache_shape, dtype=dtype)
)

def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val, v_val: [B, H, L, D]
Expand Down Expand Up @@ -118,7 +123,7 @@ def forward(
class CrossAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
assert n_state % n_head == 0, "n_head does not evenly divide n_state"

self.n_head = n_head
self.d_head = n_state // n_head
Expand Down
183 changes: 173 additions & 10 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import signal
import time
import copy
import random
import logging
import traceback
Expand All @@ -9,15 +10,17 @@
import torch.multiprocessing as multiprocessing
import torch._dynamo.config
import torch._inductor.config
import numpy as np

from torch.multiprocessing import Queue
from tqdm import tqdm
from functools import wraps
from torch.cuda import is_bf16_supported
from librosa.effects import _signal_to_frame_nonsilent

from amt.inference.model import AmtEncoderDecoder
from amt.tokenizer import AmtTokenizer
from amt.audio import AudioTransform
from amt.audio import AudioTransform, SAMPLE_RATE
from amt.data import get_wav_mid_segments

torch._inductor.config.coordinate_descent_tuning = True
Expand Down Expand Up @@ -78,8 +81,8 @@ def recalculate_tok_ids(

# Mask out tok_ids larger than 30ms from original tok_id
tok_ids_expanded = tok_ids.unsqueeze(1)
mask_c = col_indices <= tok_ids_expanded + 3
mask_d = col_indices >= tok_ids_expanded - 3
mask_c = col_indices <= tok_ids_expanded + 2
mask_d = col_indices >= tok_ids_expanded - 2
beam_mask = mask_c & mask_d

# Don't mask out the original tok_id (required for non-onset/vel toks)
Expand Down Expand Up @@ -218,8 +221,8 @@ def process_segments(
),
)

logits[:, 389] *= 1.2
next_tok_ids = torch.argmax(logits, dim=-1)
# logits[:, 389] *= 1.05
# next_tok_ids = torch.argmax(logits, dim=-1)

next_tok_ids = recalculate_tok_ids(
logits=logits,
Expand Down Expand Up @@ -429,13 +432,141 @@ def _truncate_seq(
if len(_mid_dict.note_msgs) == 0:
return [tokenizer.bos_tok]
else:
# The end_ms - 1 is a workaround to get rid of the off msgs
res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1)

if res[-1] == tokenizer.eos_tok:
res.pop()
return res


# This is a sloppy implementation
def process_silent_intervals(
seq: list, intervals: list, tokenizer: AmtTokenizer
):
def adjust_onset(_onset: int):
# Adjusts the onset according to the silence intervals
for start, end in intervals:
if start <= _onset <= end:
return start

return _onset

if len(intervals) == 0:
return seq

res = []
logger = logging.getLogger(__name__)
active_notes = {pitch: False for pitch in range(0, 127)}
active_notes["pedal"] = False

for tok_1, tok_2, tok_3 in zip(
seq,
seq[1:] + [tokenizer.pad_tok],
seq[2:] + [tokenizer.pad_tok, tokenizer.pad_tok],
):
if isinstance(tok_1, tuple) is False:
res.append(tok_1)
continue
elif tok_1[0] == "prev":
res.append(tok_1)
active_notes[tok_1[1]] = True
continue
elif tok_1[0] in {"onset", "vel"}:
continue

if tok_1[0] == "pedal":
note_type = "on" if tok_1[1] == 1 else "off"
note_val = "pedal"
elif tok_1[0] in {"on", "off"}:
note_type = tok_1[0]
note_val = tok_1[1]

if note_type == "on":
# Check that the rest of the tokens are valid
if isinstance(tok_2, tuple) is False:
logger.debug(f"Invalid token sequence {tok_1}, {tok_2}")
continue
if note_val != "pedal" and isinstance(tok_3, tuple) is False:
logger.debug(
f"Invalid token sequence {tok_1}, {tok_2}, {tok_3}"
)
continue

# Don't add on if note is already on
if active_notes[note_val] is True:
continue

# Calculate adjusted onset and add if conditions are met
onset = tok_2[1]
onset_adj = adjust_onset(onset)
if onset != onset_adj:
continue
else:
active_notes[note_val] = True
res.append(tok_1)
res.append(tok_2)
if note_val != "pedal":
res.append(tok_3)

elif note_type == "off":
# Check that the rest of the tokens are valid
if isinstance(tok_2, tuple) is False and tok_2[0] != "onset":
logger.debug(f"Invalid token sequence {tok_1}, {tok_2}")
continue

# Don't add on if note is not on
if active_notes[note_val] is False:
continue

# Add note with adjusted offset
offset = tok_2[1]
offset_adj = adjust_onset(offset)
if offset != offset_adj:
logger.debug(
f"Adjusted offset of {tok_1}, {tok_2} -> {offset_adj}"
)
res.append(tok_1)
res.append(("onset", tokenizer._quantize_onset(offset_adj)))
active_notes[note_val] = False

return res


def get_silent_intervals(wav: torch.Tensor):
FRAME_LEN = 2048
HOP_LEN = 512
MIN_WINDOW_S = 5
MIN_WINDOW_STEPS = (SAMPLE_RATE // HOP_LEN) * MIN_WINDOW_S + 1
MS_PER_HOP = int((HOP_LEN * 1e3) / SAMPLE_RATE)

non_silent = _signal_to_frame_nonsilent(
wav.numpy(),
frame_length=FRAME_LEN,
hop_length=HOP_LEN,
top_db=30,
ref=np.max,
)
non_silent = np.concatenate(([True], non_silent, [True]))

edges = np.diff(non_silent.astype(int))
starts = np.where(edges == -1)[0]
ends = np.where(edges == 1)[0]

# Calculate lengths
lengths = ends - starts

# Filter intervals by minimum length
valid = lengths > MIN_WINDOW_STEPS
silent_intervals = [
(start * MS_PER_HOP, (end - 1) * MS_PER_HOP)
for start, end, vl in zip(starts, ends, valid)
if vl
]

return silent_intervals


def transcribe_file(
file_path,
gpu_task_queue: Queue,
Expand Down Expand Up @@ -463,7 +594,10 @@ def transcribe_file(
init_idx = len(seq)

# Add to gpu queue and wait for results
gpu_task_queue.put(((audio_segments.pop(0), seq), pid))
curr_audio_segment = audio_segments.pop(0)
silent_intervals = get_silent_intervals(curr_audio_segment)
input_seq = copy.deepcopy(seq)
gpu_task_queue.put(((curr_audio_segment, seq), pid))
while True:
try:
gpu_result = result_queue.get(timeout=0.1)
Expand All @@ -476,18 +610,47 @@ def transcribe_file(
else:
result_queue.put(gpu_result)

if len(silent_intervals) > 0:
logger.debug(
f"Seen silent intervals in segment {idx}: {silent_intervals}"
)

seq_raw = seq
seq = process_silent_intervals(
seq, intervals=silent_intervals, tokenizer=tokenizer
)

if len(seq) != len(seq_raw):
logger.info(
f"Removed tokens ({len(seq_raw)} -> {len(seq)}) "
f"in segment {idx} according to silence in intervals: "
f"{silent_intervals}",
)

try:
next_seq = _truncate_seq(
seq,
CHUNK_LEN_MS,
LEN_MS - CHUNK_LEN_MS,
)
except Exception as e:
logger.info(
f"Skipping segment {idx} (failed to transcribe): {file_path}"
)
logger.info(f"Failed to reconcile segment {idx}: {file_path}")
logger.debug(traceback.format_exc())
seq = [tokenizer.bos_tok]

try:
seq = _truncate_seq(
input_seq,
CHUNK_LEN_MS - 2,
CHUNK_LEN_MS,
)
except Exception as e:
seq = [tokenizer.bos_tok]
logger.info(
f"Failed to recover prompt, proceeding with default: {seq}"
)
else:
logger.info(f"Proceeding with prompt: {seq}")

else:
if seq[-1] == tokenizer.eos_tok:
logger.info(f"Seen eos_tok at segment {idx}: {file_path}")
Expand Down
2 changes: 1 addition & 1 deletion amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def msg_mixup(src: list):
raise Exception

random.shuffle(res) # Only includes prev toks
res.append(self.bos_tok) # Beggining of sequence
res.append(self.bos_tok) # Beginning of sequence

buffer = defaultdict(lambda: defaultdict(list))
for tok_1, tok_2, tok_3 in zip(
Expand Down
4 changes: 2 additions & 2 deletions config/models/medium-triple.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"n_audio_ctx": 1500,
"n_audio_state": 768,
"n_audio_head": 12,
"n_audio_layer": 4,
"n_audio_layer": 6,
"n_text_ctx": 4096,
"n_text_state": 768,
"n_text_head": 12,
"n_text_layer": 4
"n_text_layer": 6
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ torch >= 2.2
torchaudio
accelerate
psutil
librosa
mido
tqdm
orjson
Expand Down

0 comments on commit 014e757

Please sign in to comment.