In [1]:
import os
import time
import numpy as np
import sounddevice as sd
import soundfile as sf
import librosa
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Model
from sentence_transformers import SentenceTransformer
from tensorflow import keras

  from .autonotebook import tqdm as notebook_tqdm





In [2]:
SAMPLE_RATE = 16000          
DURATION = 5.0                
TEMP_WAV = "user_record.wav"   
USE_GPU = torch.cuda.is_available()   
DEVICE = "cuda" if USE_GPU else "cpu"      
FUSION_MODEL_PATH = "fusion_model.h5"   

In [3]:
ASR_MODEL = "facebook/wav2vec2-large-960h-lv60-self"
EMBED_MODEL = "facebook/wav2vec2-base-960h"
SENTENCE_BERT = "all-MiniLM-L6-v2" 
POOL_AUDIO = True

In [4]:
def record_audio(filename=TEMP_WAV, duration=DURATION, sr=SAMPLE_RATE):
    print(f"[record] Recording {duration}s of audio (sr={sr})... Speak now.")
    audio = sd.rec(int(duration * sr), samplerate=sr, channels=1, dtype="float32")
    sd.wait()
    audio = audio.squeeze()
    sf.write(filename, audio, sr, subtype='PCM_16')
    print(f"[record] Saved to {filename}")
    return filename

In [5]:
print("Loading models...")
asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL)
asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL).to(DEVICE)
asr_model.eval()

embed_processor = Wav2Vec2Processor.from_pretrained(EMBED_MODEL)
embed_model = Wav2Vec2Model.from_pretrained(EMBED_MODEL).to(DEVICE)
embed_model.eval()

sbert = SentenceTransformer(SENTENCE_BERT)

Loading models...


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from tensorflow.keras.models import model_from_json


with open("fusion_model_config.json", "r") as f:
    model_json = f.read()
fusion_model = model_from_json(model_json)


fusion_model.load_weights("fusion_model.weights.h5")


fusion_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

  saveable.load_own_variables(weights_store.get(inner_path))
  saveable.load_own_variables(weights_store.get(inner_path))
  saveable.load_own_variables(weights_store.get(inner_path))


In [None]:
import torch
import librosa
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

def transcribe_wav(wav_path, asr_model, asr_processor, beam_width=10):

    speech, sr = librosa.load(wav_path, sr=SAMPLE_RATE)
    
    speech = speech / max(1e-5, abs(speech).max())

    input_values = asr_processor(speech, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding="longest").input_values.to(DEVICE)

    with torch.no_grad():
        logits = asr_model(input_values).logits

    if hasattr(asr_processor, "batch_decode"):
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = asr_processor.batch_decode(predicted_ids)[0]
    else:
        from ctcdecode import CTCBeamDecoder
        decoder = CTCBeamDecoder(
            asr_processor.tokenizer.get_vocab(),
            beam_width=beam_width,
            blank_id=asr_processor.tokenizer.pad_token_id,
            log_probs_input=True
        )
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        beam_results, _, _, out_lens = decoder.decode(log_probs.cpu())
        transcription = "".join([asr_processor.tokenizer.decode(beam_results[0][0][:out_lens[0][0]])])

    transcription = transcription.lower().strip()
    transcription = " ".join(transcription.split()) 

    return transcription


In [8]:
def get_audio_embedding(wav_path, pool=True):
    """Return a 1D numpy vector embedding from wav2vec2 (mean pooled if pool=True)."""
    speech, sr = librosa.load(wav_path, sr=SAMPLE_RATE)
    input_values = embed_processor(speech, sampling_rate=sr, return_tensors="pt", padding="longest").input_values.to(DEVICE)
    with torch.no_grad():
        outputs = embed_model(input_values)
        last_hidden = outputs.last_hidden_state 
        if pool:
            emb = last_hidden.mean(dim=1).squeeze().cpu().numpy()
        else:
            emb = last_hidden.squeeze().cpu().numpy()  
    return emb


In [9]:
def get_text_embedding(text):
    """Return Sentence-BERT embedding (numpy 1D vector)."""
    emb = sbert.encode([text], convert_to_numpy=True, show_progress_bar=False)[0]
    return emb

In [10]:
import pickle
with open('audio_scaler.pkl', 'rb') as f:
    audio_scaler  = pickle.load(f)

with open('text_scaler.pkl', 'rb') as f:
    text_scaler  = pickle.load(f)


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [11]:
try:
    from googletrans import Translator 
    HAS_GOOGLETRANS = True
except Exception:
    HAS_GOOGLETRANS = False
translator = Translator() if HAS_GOOGLETRANS else None

In [None]:
def infer_once(record_duration=DURATION, translate_to=None):

    wav_path = record_audio(duration=record_duration)
    SEQ_LEN = 16 
    FEATURE_DIM = 48
    
    print("[step] Transcribing audio...")
    try:
        text = transcribe_wav(wav_path, asr_model, asr_processor)
    except Exception as e:
        print("[error] Transcription failed:", e)
        text = ""
    print("Transcribed text:", repr(text))

    translated = None
    text_for_embedding = text
    if translator:
        try:
            translated = translator.translate(text, dest='en').text
            print("Translated (en) for embedding:", translated)
            text_for_embedding = translated
        except Exception as e:
            print("[warning] Translation failed:", e)

    print("[step] Extracting audio embedding...")
    audio_emb = get_audio_embedding(wav_path, pool=POOL_AUDIO)
    print("Raw audio emb shape:", audio_emb.shape)
    audio_emb_seq = audio_emb.reshape(1, SEQ_LEN, FEATURE_DIM)

    if len(text_for_embedding.strip()) == 0:
        text_emb = np.zeros((sbert.get_sentence_embedding_dimension(),))
    else:
        text_emb = get_text_embedding(text_for_embedding)
    print("Text emb shape:", text_emb.shape)
    text_emb_seq = text_emb.reshape(1, -1)

    global audio_scaler, text_scaler
    if audio_scaler is not None:
        audio_emb_seq = audio_scaler.transform(audio_emb_seq.reshape(1, -1)).reshape(1, SEQ_LEN, FEATURE_DIM)
    if text_scaler is not None:
        text_emb_seq = text_scaler.transform(text_emb_seq)

    if fusion_model is None:
        print("[info] No fusion model available to predict emotion.")
        return {
            "text": text,
            "translated": translated,
            "audio_emb": audio_emb,
            "text_emb": text_emb,
            "prediction": None
        }

    print("[step] Predicting emotion...")
    pred_prob = fusion_model.predict([audio_emb_seq, text_emb_seq], verbose=0)[0]
    pred_idx = np.argmax(pred_prob)

    try:
        label_classes = fusion_model.class_names
    except Exception:
        label_classes = ['anger', 'joy', 'neutral', 'sadness']

    pred_label = label_classes[pred_idx] if pred_idx < len(label_classes) else str(pred_idx)
    print(f"Predicted emotion: {pred_label} (prob={pred_prob[pred_idx]:.3f})")

    return {
        "text": text,
        "translated": translated,
        "audio_emb": audio_emb,
        "text_emb": text_emb,
        "prediction": {
            "label": pred_label,
            "prob": float(pred_prob[pred_idx]),
            "probs": pred_prob.tolist()
        }
    }


In [40]:
if __name__ == "__main__":
    result = infer_once(record_duration=DURATION, translate_to='hi')
    text = result['text'].lower().strip()

[record] Recording 5.0s of audio (sr=16000)... Speak now.
[record] Saved to user_record.wav
[step] Transcribing audio...
Transcribed text: 'i'
Translated (en) for embedding: i
[step] Extracting audio embedding...
Raw audio emb shape: (768,)
Text emb shape: (384,)
[step] Predicting emotion...
Predicted emotion: neutral (prob=0.527)
