# CaviDB search engine

This dialog manager diagram represents the flow the program will follow

<img src="./diagram.png" width="400px" heigth="800px">

## Usage
For its usage execute the cells in order. The audios generated will provide you the steps you need to follow.

## Code source

In [12]:
from nemo.collections.tts.models import FastPitchModel
from nemo.collections.tts.models import HifiGanModel
from ipywebrtc import AudioRecorder, CameraStream
import torchaudio
import soundfile as sf
import random
from IPython.display import Audio
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import numpy as np
import subprocess
import ffmpeg
from word2number import w2n
import urllib
import re
import tempfile

In [8]:
### Record audio
def record_audio():
    camera = CameraStream(constraints={'audio': True,'video':False})
    recorder = AudioRecorder(stream=camera)
    display(recorder)
    return recorder

def save_recording(recorder, audio_out_path):
    with open('recording.webm', 'wb') as f:
        f.write(recorder.audio.value)
    command =  f"ffmpeg -i recording.webm -ac 1 -f wav {audio_out_path} -y -hide_banner -loglevel panic -ar 48000"
    subprocess.call(command, shell=True)

def load_audio(audio_path):
    wave, sample_rate = torchaudio.load(audio_path)
    return np.array(wave)[0], sample_rate

def play_audio(wave, sample_rate):
    print(wave.shape)
    display(Audio(data=wave, rate=sample_rate))


In [27]:
### Dialog system
class TTSModel:
    def __init__(self):
        self.spec_generator = FastPitchModel.from_pretrained("nvidia/tts_en_fastpitch")
        self.model = HifiGanModel.from_pretrained(model_name="nvidia/tts_hifigan")

    def generate_speech(self, text):
        parsed = self.spec_generator.parse(text)
        spectrogram = self.spec_generator.generate_spectrogram(tokens=parsed)
        audio = (
            self.model.convert_spectrogram_to_audio(spec=spectrogram).detach().numpy()
        )
        return audio


class ASRModel:
    def __init__(self):
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-small")
        self.model = WhisperForConditionalGeneration.from_pretrained(
            "openai/whisper-small"
        )
        self.model.config.forced_decoder_ids = None

    def recognize_speech(self, audio, sample_rate=48000):
        audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
        input_features = self.processor(
            audio, sampling_rate=16000, return_tensors="pt"
        ).input_features
        predicted_ids = self.model.generate(input_features)
        transcription = self.processor.batch_decode(
            predicted_ids, skip_special_tokens=True
        )
        return transcription[0]


class SpokenDialogSystem:
    def __init__(self, tts_model, asr_model):
        self.tts_model = tts_model
        self.asr_model = asr_model
        self.current_state = "ASK_PDB"
        self.pdb_id = None
        self.chain_id = None
        self.model_id = None

    # Inicio del sistema
    def start(self):
        start_audio = self.tts_model.generate_speech(
            "Welcome to Cavi Database audio search engine! First, execute the cell above and provide your target's Protein Data Bank identifier."
        )
        display(Audio(start_audio, rate=22050))

    # Main dialog function
    def listen_and_respond(self, audio):
        transcription = self.asr_model.recognize_speech(audio)
        dialog_act = self.understand_transcription(transcription)
        print("dialog_act", dialog_act)
        response_code = self.advance_dialog(dialog_act)
        print("response_code", response_code)
        response_text = self.generate_language(response_code)
        response_audio = self.tts_model.generate_speech(response_text)
        display(Audio(response_audio, rate=22050))

    # Language comprehension
    def understand_transcription(self, transcription):
        transcription = transcription.lower().strip()
        print(transcription)
        
        try:
            if self.current_state == "ASK_PDB":
                self.pdb_id = transcription_to_pdb(transcription)
            elif self.current_state == "ASK_CHAIN":
                self.chain_id = transcription_to_chain(transcription)
            elif self.current_state == "ASK_MODEL":
                self.model_id = transcription_to_number(transcription, force=True)
            return "ACT_INPUT"
        except Exception as e:
            print(e)
            return "ACT_UNKNOWN"

    # Dialog manager
    def advance_dialog(self, dialog_act):
        self.validate_running()    
            
        if dialog_act != "ACT_INPUT":
            return "REPEAT"
                
        if self.current_state == "ASK_PDB":
            self.current_state = "ASK_CHAIN"
        elif self.current_state == "ASK_CHAIN":
            self.current_state = "ASK_MODEL"
        else:
            self.current_state = "STATE_FINISH"

        return self.current_state

    # Language generation
    def generate_language(self, response_code):
        if response_code == "ASK_CHAIN":
            return "Please, execute the previous cell again and provide a chain identifier."
        elif response_code == "ASK_MODEL":
            return "Please, execute the previous cell again and provide a model number."
        elif response_code == "STATE_FINISH":
            return "Execute the cell above to see your result"
        elif response_code == "REPEAT":
            return f"Sorry, I didn't understand you. {self.generate_input_explanation()}. Please, repeat the answer."

        else:
            raise ValueError(f"ERROR: Unknown response code {response_code}")

    def generate_input_explanation(self):
        self.validate_running()
        
        if self.current_state == "ASK_PDB":
            return "PDB identifier must be a 4-letter alphanumeric code"
        elif self.current_state == "ASK_MODEL":
            return "Model identifier must be an integer number"
        else:
            return "Chain identifier must be a single letter"
        
    def generate_query_link(self):
        return "https://www.cavidb.org/chains?q=" + urllib.parse.quote(
            f"pdb:{self.pdb_id} chain:{self.chain_id} model:{self.model_id}"
        )
    
    def validate_running(self):
        if self.has_finished():
            raise RuntimeError("Invalid state: dialog finished")
            
    def has_finished(self):
        return self.current_state == "STATE_FINISH"

def transcription_to_number(transcription, force=False):
    if transcription.isnumeric():
        return transcription

    try:
        return w2n.word_to_num(transcription)
    except ValueError as e:
        if force:
            raise e
        return transcription


def transcription_to_pdb(transcription):
    pdb = "".join(map(transcription_to_number, re.split(r'\W', transcription)))
    if len(pdb) != 4:
        raise ValueError("PDB must be 4 letters")
    return pdb


def transcription_to_chain(transcription):
    if not transcription.isalpha():
        raise ValueError("Chain must be a letter")
    
    if len(transcription) != 1:
        raise ValueError("Chain must be a single letter")
    
    return transcription

## SDS Inteface

In [5]:
## Create model instances for being used by the SDS object
tts_model = TTSModel()
asr_model = ASRModel()

[NeMo W 2023-10-31 20:19:22 en_us_arpabet:66] apply_to_oov_word=None, This means that some of words will remain unchanged if they are not handled by any of the rules in self.parse_one_word(). This may be intended if phonemes and chars are both valid inputs, otherwise, you may see unexpected deletions in your input.
[NeMo W 2023-10-31 20:19:22 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    dataset:
      _target_: nemo.collections.tts.torch.data.TTSDataset
      manifest_filepath: /ws/LJSpeech/nvidia_ljspeech_train_clean_ngc.json
      sample_rate: 22050
      sup_data_path: /raid/LJSpeech/supplementary
      sup_data_types:
      - align_prior_matrix
      - pitch
      n_fft: 1024
      win_length: 1024
      hop_length: 256
      window: hann
      n_mels: 80
      lowfreq: 0
      highfreq: 8000
      max_duration: null
      

[NeMo I 2023-10-31 20:19:22 features:289] PADDING: 1
[NeMo I 2023-10-31 20:19:22 save_restore_connector:249] Model FastPitchModel was successfully restored from /home/franco/.cache/huggingface/hub/models--nvidia--tts_en_fastpitch/snapshots/2c8305b7b41b33fd6367f0635796dc3a7a33cbf9/tts_en_fastpitch.nemo.


[NeMo W 2023-10-31 20:19:25 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    dataset:
      _target_: nemo.collections.tts.data.datalayers.MelAudioDataset
      manifest_filepath: /home/fkreuk/data/train_finetune.txt
      min_duration: 0.75
      n_segments: 8192
    dataloader_params:
      drop_last: false
      shuffle: true
      batch_size: 64
      num_workers: 4
    
[NeMo W 2023-10-31 20:19:25 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    dataset:
      _target_: nemo.collections.tts.data.datalayers.MelAudioDataset
      manifest_filepath: /home/fkreuk/data/val_finetune.txt
      min_duration: 3
      n_segments: 66150


[NeMo I 2023-10-31 20:19:25 features:289] PADDING: 0


[NeMo W 2023-10-31 20:19:25 features:266] Using torch_stft is deprecated and has been removed. The values have been forcibly set to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True as needed.


[NeMo I 2023-10-31 20:19:25 features:289] PADDING: 0


    


[NeMo I 2023-10-31 20:19:26 save_restore_connector:249] Model HifiGanModel was successfully restored from /home/franco/.cache/huggingface/hub/models--nvidia--tts_hifigan/snapshots/3ba1fed954276287015654bf4c78060ffc9a4772/tts_hifigan.nemo.


In [28]:
## SDS start
SDS = SpokenDialogSystem(tts_model, asr_model)
SDS.start()

[NeMo W 2023-10-31 20:35:51 fastpitch:291] parse() is meant to be called in eval mode.
[NeMo W 2023-10-31 20:35:51 fastpitch:368] generate_spectrogram() is meant to be called in eval mode.


In [45]:
## User answer
recorder = record_audio()

AudioRecorder(audio=Audio(value=b'', format='webm'), stream=CameraStream(constraints={'audio': True, 'video': …

In [46]:
## SDS instructions
with tempfile.NamedTemporaryFile() as f:
    save_recording(recorder, f.name)
    recorder.close()
    wave, sr = load_audio(f.name)
    SDS.listen_and_respond(wave)


[NeMo W 2023-10-31 20:39:15 fastpitch:291] parse() is meant to be called in eval mode.
[NeMo W 2023-10-31 20:39:15 fastpitch:368] generate_spectrogram() is meant to be called in eval mode.


12
dialog_act ACT_INPUT
response_code STATE_FINISH


In [47]:
## Final result
SDS.generate_query_link()

'https://www.cavidb.org/chains?q=pdb%3A1084%20chain%3Ae%20model%3A12'