In [17]:
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
from howl.client import HowlClient

import logging
import time
from typing import Callable
from howl.context import InferenceContext
from howl.settings import SETTINGS
from howl.settings import HowlSettings
import howl.model as howl_model
from howl.workspace import Workspace
from pathlib import Path
import numpy as np
import pyaudio
import howl.data.transform as transform
import torch
import os
from howl.context import InferenceContext
from howl.model.inference import FrameInferenceEngine, InferenceEngine
from howl.model import ConfusionMatrix, ConvertedStaticModel, RegisteredModel
from howl.data.transform.operator import ZmuvTransform, batchify, compose
from howl.utils import logging_utils
from howl.utils.args_utils import ArgOption, ArgumentParserBuilder
import wave
import torchaudio
import torch.nn.functional as F
from howl.utils import audio_utils

from torch.utils.mobile_optimizer import optimize_for_mobile


In [18]:
"""Load a pretrained model using the provided name"""
path = "/media/nontawat/Windows/wakeword_ws/howl_/howl/workspaces/hey-ff-res8-reproduce-neg"
# path = '/media/nontawat/Windows/wakeword_ws/howl_/howl/workspaces/hey-ff-res8-reproduce-neg'
device = "cpu"
workspace = Workspace(Path(path), delete_existing=False)
# Load model settings
settings = workspace.load_settings()

# Set up context
use_frame = settings.training.objective == "frame"
ctx = InferenceContext(
    vocab=settings.training.vocab, token_type=settings.training.token_type, use_blank=not use_frame
)

# Load models
zmuv_transform = transform.ZmuvTransform()
model = howl_model.RegisteredModel.find_registered_class("res8")(ctx.num_labels).eval()

# Load pretrained weights
zmuv_transform.load_state_dict(
    torch.load(str(workspace.path / "zmuv.pt.bin"), map_location=torch.device(device))
)
workspace.load_model(model, best=True)

# Load engine
model.streaming()
if use_frame:
    engine = FrameInferenceEngine(
        int(settings.training.max_window_size_seconds * 1000),
        int(settings.training.eval_stride_size_seconds * 1000),
        model,
        zmuv_transform,
        ctx,
    )
else:
    engine = InferenceEngine(model, zmuv_transform, ctx)

model = engine.model
zmuv = engine.zmuv
std_audio_transform = engine.std

In [19]:
print(isinstance(model, torch.nn.Module))
print(isinstance(zmuv, torch.nn.Module))
print(isinstance(std_audio_transform, torch.nn.Module))

True
True
True


In [23]:
class HowlModelFullPipeline(torch.nn.Module):
    def __init__(self, model, zmuv, std_audio_transform):
        super().__init__()
        self.model = model
        self.zmuv = zmuv
        self.std = std_audio_transform
        
    def forward(self, frame):
        lengths = torch.tensor([frame.size(-1)])
        transformed_lengths = self.std.compute_lengths(lengths)
        transformed_frame = engine.zmuv(engine.std(frame.unsqueeze(0)))
        prediction = self.model(transformed_frame, transformed_lengths).softmax(-1)[0]
        prediction *= engine.inference_weights
        prediction = prediction / prediction.sum()
        return prediction

In [24]:
full_model = HowlModelFullPipeline(model, zmuv, std_audio_transform)

In [14]:
MAX_WINDOWS_SIZE = 0.5
SAMPLE_RATE = 16000

FRAME_SIZE = int(MAX_WINDOWS_SIZE * SAMPLE_RATE)

random_tensor = torch.rand((FRAME_SIZE)).type(torch.float32)
print("random tensor shape", random_tensor.shape)

random tensor shape torch.Size([8000])


In [15]:
full_model(random_tensor) # test forward

tensor([0.0023, 0.0079, 0.0220, 0.9678], grad_fn=<DivBackward0>)

In [16]:
# traced_full = torch.jit.trace(full_model, [random_tensor])
scripted_full = torch.jit.script(full_model)

RuntimeError: Can't redefine method: __streaming_state_getter on class: __torch__.howl.model.cnn.___torch_mangle_44.Res8 (of Python compilation unit at: 0x4089b50)

In [33]:
traced_full(random_tensor)

tensor([0.0019, 0.0064, 0.0192, 0.9725], grad_fn=<DivBackward0>)

In [34]:

traced_script_module_optimized = optimize_for_mobile(traced_full)
traced_script_module_optimized._save_for_lite_interpreter("hey_ff_traced_full.ptl")

In [35]:
traced_script_module_optimized(random_tensor)

  return forward_call(*input, **kwargs)


tensor([0.0019, 0.0064, 0.0192, 0.9725])