In [None]:
import torchaudio
import sys
import os
import torch
import numpy as np
import nemo.collections.asr as nemo_asr
import soundfile as sf
from panns_inference import AudioTagging, SoundEventDetection, labels

os.environ["CUDA_VISIBLE_DEVICES"]="1"

from service_utils import FrameVAD, Append

### Load VAD Model
STEP = 0.100
WINDOW_SIZE = 0.100
CHANNELS = 1 
RATE = 16000
FRAME_LEN = STEP

CHUNK_SIZE = int(STEP * RATE)

vad = FrameVAD('checkpoints/naint_vad_BackSilenceSpeech.nemo',
               sample_rate = RATE, 
               frame_len=FRAME_LEN, 
               frame_overlap=(WINDOW_SIZE - FRAME_LEN) / 2, 
               offset=0, device='cuda')

### Load ASR Model
ASR_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=
                                                              "stt_en_conformer_ctc_large", map_location='cuda')


### Load AudioTag model
at = AudioTagging(checkpoint_path='checkpoints/audio_scene_checkpoint.pth', device='cuda')

In [None]:
def align_wav(audio_frame, n_non_speech_frames, save_name_file, sample_rate):
    global count
    torch.cuda.empty_cache()
    signal = np.frombuffer(audio_frame, dtype=np.int16)
    result = vad.transcribe(signal)
    speech = controller.buffer_for_speech(audio_frame)
    
    labels = controller.buffer_for_labels(result[1])
    if (len(labels) >= n_non_speech_frames):
        switch = 'move_window'
        labels = np.array(labels)
        
        if (labels != 'speech').sum() == n_non_speech_frames:
            switch = 'send'
            
            if controller.total_in_buffer == n_non_speech_frames:
                switch = 'refresh'
                
            else:
                duration = speech.size / sample_rate
                if duration < 25 and duration > 1:
                    sf.write(save_name_file + '.wav', speech, sample_rate)
                    ASR_predict = ASR_model.transcribe([save_name_file + '.wav'])[0]

                    au, sr = sf.read(save_name_file + '.wav')
                    if ASR_predict != '' and get_audio_tag(au):
                        print(f'Success cut - {save_name_file}, write file')
                        with open(save_name_file + '.txt', "w") as text_file:
                            text_file.write(ASR_predict)
                    else:
                        print(f'Bad audio file - {save_name_file}, remove file')
                        os.remove(save_name_file + '.wav')
                else:
                    print(f'Bad duration - {save_name_file}, skip this cut')

            count += 1
            speech = controller.buffer_for_speech(switch = switch)

        labels = controller.buffer_for_labels(switch = switch)

def get_audio_tag(audio_array):
    '''
    This function check audio on music
    '''
    audio_my = audio_array[None, :]  # (batch_size, segment_samples)
    (clipwise_output, embedding) = at.inference(audio_my)

    idx = np.argsort(clipwise_output[0])[::-1][0:5]
    idx_to_lb = {i : label for i, label in enumerate(labels)}
    
    rate_tag = {}
    for i in range(len(idx)):
        rate_tag[idx_to_lb[idx[i]]] = np.mean(clipwise_output[:, idx[i]])
    
    if ([*rate_tag][0].lower() == 'speech') and ([*rate_tag][1].lower().find('speech') != -1):
        return True
    else:
        if [*rate_tag.values()][1] < 0.1:
            return True
        else:
            return False

In [None]:
wav_file = 'example.wav'
au, sr = sf.read(wav_file, dtype='int16')

step_len = CHUNK_SIZE
for i in range(int(au.size / step_len) + n_non_speech_frames):
    frame = au[(i * step_len) : (i+1)*step_len]
    align_wav(frame, n_non_speech_frames, f'examples/example_{count}', sr)

print('Finish cuting')