In [5]:
import torch
import torchaudio
import torch.nn.functional as F
import time
import whisper
from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor
from utils.config_loader import ConfigLoader
from models.models import BiFormer

# === –≠–º–æ—Ü–∏–∏ ===
LABEL_TO_EMOTION = {
    0: 'üò† Anger',
    1: 'ü§¢ Disgust',
    2: 'üò® Fear',
    3: 'üòÑ Joy/Happiness',
    4: 'üòê Neutral',
    5: 'üò¢ Sadness',
    6: 'üò≤ Surprise/Enthusiasm'
}

# === –ö–æ–Ω—Ñ–∏–≥ –∏ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ ===
config = ConfigLoader("checkpoints/config_copy.toml")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(config.emb_device)
device_name = torch.cuda.get_device_name(0) if device.type == "cuda" else "CPU"

# === –≠–∫—Å—Ç—Ä–∞–∫—Ç–æ—Ä—ã
audio_extractor = PretrainedAudioEmbeddingExtractor(config)
text_extractor = PretrainedTextEmbeddingExtractor(config)

# === –ú–æ–¥–µ–ª—å
model = BiFormer(
    audio_dim=config.audio_embedding_dim,
    text_dim=config.text_embedding_dim,
    seg_len=config.max_tokens,
    hidden_dim=config.hidden_dim,
    hidden_dim_gated=config.hidden_dim_gated,
    num_transformer_heads=config.num_transformer_heads,
    num_graph_heads=config.num_graph_heads,
    positional_encoding=config.positional_encoding,
    dropout=config.dropout,
    mode=config.mode,
    device=config.emb_device,
    tr_layer_number=config.tr_layer_number,
    out_features=config.out_features,
    num_classes=7
).to(device)

# === –ó–∞–≥—Ä—É–∑–∫–∞ –≤–µ—Å–æ–≤
checkpoint = torch.load("checkpoints/best_model_dev_0_5895_epoch_8.pt", map_location=device)
model.load_state_dict(checkpoint)
model.eval()

# === Whisper (–∏–∑ –∫–æ–Ω—Ñ–∏–≥–∞)
whisper_model_size = config.whisper_model
whisper_model = whisper.load_model(whisper_model_size, device=device)

total_start = time.time()

# === –ó–∞–≥—Ä—É–∑–∫–∞ –∞—É–¥–∏–æ
audio_path = "E:/MELD/wavs/test/dia0_utt0.wav"
load_start = time.time()
waveform, sr = torchaudio.load(audio_path)
load_end = time.time()
load_duration = (load_end - load_start)

waveform = waveform.mean(dim=0)
duration_sec = waveform.shape[-1] / sr
print(f"üìº Audio duration: {duration_sec:.2f} seconds")

if sr != 16000:
    resampler = torchaudio.transforms.Resample(sr, 16000)
    waveform = resampler(waveform)

waveform = waveform.to(device)

# === –¢–µ–∫—Å—Ç (–µ—Å–ª–∏ –ø—É—Å—Ç ‚Äî Whisper)
text = ""  # –ú–æ–∂–Ω–æ –¥–æ–±–∞–≤–∏—Ç—å –≤—Ä—É—á–Ω—É—é
if not text.strip():
    audio_np = waveform.squeeze().cpu().numpy()

    whisper_start = time.time()
    whisper_result = whisper_model.transcribe(audio_np, fp16=False)
    whisper_end = time.time()
    whisper_duration = (whisper_end - whisper_start)

    text = whisper_result.get("text", "").strip()
    print(f"üó£Ô∏è Whisper transcription: {text}")
    print(f"üïì Whisper transcription time: {whisper_duration:.2f} seconds")


# === –ü—Ä–µ–¥–æ–±—Ä–∞–±–æ—Ç–∫–∞ (—ç–º–±–µ–¥–¥–∏–Ω–≥–∏)
with torch.no_grad():
    prep_start = time.time()

    _, audio_emb = audio_extractor.extract(waveform, 16000)
    _, text_emb = text_extractor.extract([text])

    prep_end = time.time()
    prep_duration = (prep_end - prep_start)

# === –ò–Ω—Ñ–µ—Ä–µ–Ω—Å
with torch.no_grad():
    infer_start = time.time()

    logits = model(audio_emb, text_emb)
    
    probs = F.softmax(logits, dim=-1)[0]
    pred_idx = torch.argmax(probs).item()
    emotion_label = LABEL_TO_EMOTION.get(pred_idx, f"Unknown ({pred_idx})")
    
    infer_end = time.time()
    infer_duration = (infer_end - infer_start)
    
    total_end = time.time()
    total_duration_sec = total_end - total_start
    
    print(f"üóÇÔ∏è Audio file load time: {load_duration:.2f} seconds")
    print(f"üõ†Ô∏è Emb prep time: {prep_duration:.2f} seconds")
    print(f"‚ö° Model time: {infer_duration:.2f} seconds")
    print(f"‚åõ Total processing time: {total_duration_sec:.2f} seconds")
    
    print(f"\nüè∑Ô∏è Predicted emotion: {emotion_label} ({pred_idx}) ‚Äî probability: {probs[pred_idx]:.2f}")
    print(f"üíª Inference device: {device} ‚Äî {device_name}")

flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn i

üìº Audio duration: 2.26 seconds
üó£Ô∏è Whisper transcription: Why do all your coffee mugs have numbers on the bottom?
üïì Whisper transcription time: 0.17 seconds
üóÇÔ∏è Audio file load time: 0.00 seconds
üõ†Ô∏è Emb prep time: 0.08 seconds
‚ö° Model time: 0.01 seconds
‚åõ Total processing time: 0.26 seconds

üè∑Ô∏è Predicted emotion: üòê Neutral (4) ‚Äî probability: 0.79
üíª Inference device: cuda ‚Äî NVIDIA GeForce RTX 4080
