Создадим стриминговую модель (и заодно сохраним её в формате jit).

In [None]:
from src.configs import DistillTaskConfig, StreamingTaskConfig
from src.model import CRNN, StreamingCRNN

from tqdm.notebook import tqdm

import torch
import wandb
import torchaudio

distill_config = DistillTaskConfig(
    hidden_size=16,
    bottleneck_size=8,
    cnn_out_channels=4,
    distill_w=0.,
    attn_distill_w=0.,
    melspec_win_length=400,
    melspec_hop_length=160,
    num_epochs=100,
    use_scheduler=True,
    temperature=10
)

model = CRNN(distill_config)
_ = wandb.restore('triple-100-epochs-tiny-colab.pt', run_path="broccoliman/kws/1fhnqrzr")
model.load_state_dict(torch.load('triple-100-epochs-tiny-colab.pt'))

model = torch.quantization.quantize_dynamic(model, dtype=torch.float16)

conf = StreamingTaskConfig()

st_model = StreamingCRNN(
    model, 
    max_window_length=conf.max_window_length,
    streaming_step_size=conf.streaming_step_size,
    share_hidden_states=conf.share_hidden_states,
    device=conf.device
)

torch.jit.save(torch.jit.script(st_model), 'kws.pt')

Сделаем вавку с ключевым словом посередине.

In [None]:
noise_path_1 = 'speech_commands/_background_noise_/exercise_bike.wav'
noise_path_2 = 'speech_commands/_background_noise_/running_tap.wav'
command_path = "speech_commands/sheila/dc269564_nohash_1.wav"

noise_1 = torchaudio.load(noise_path_1)[0].squeeze()
noise_2 = torchaudio.load(noise_path_2)[0].squeeze()
command = torchaudio.load(command_path)[0].squeeze()

concat_wav = torch.cat([noise_1[-160000:], command, noise_2[160000:320000]])

Нарисуем предсказания.

In [None]:
import seaborn as sns

predictions = []

for frame in tqdm(concat_wav):
    predictions.append(st_model(torch.tensor([frame]))[1].item())

sns.lineplot(range(len(predictions)), predictions)