# Dual-Attention Whisper Demo

This notebook demonstrates how to use the Dual-Attention Whisper model for noise-robust speech recognition.

## 1. Setup

In [None]:
import sys
sys.path.insert(0, '../src')

import torch
from transformers import WhisperProcessor
from model.dual_whisper import create_dual_attention_whisper
import librosa
import IPython.display as ipd
import numpy as np

## 2. Load Pre-trained Model

In [None]:
# Load model
model_name = "openai/whisper-small"

print("Loading model...")
model = create_dual_attention_whisper(
    model_name=model_name,
    freeze_encoder=True
)

processor = WhisperProcessor.from_pretrained(model_name)
print("Model loaded!")

## 3. Load and Play Audio

In [None]:
# Load your audio file
audio_path = "path/to/your/audio.wav"

audio, sr = librosa.load(audio_path, sr=16000)

# Play audio
print(f"Audio duration: {len(audio)/sr:.2f} seconds")
ipd.Audio(audio, rate=sr)

## 4. Transcribe Audio

In [None]:
# Extract features
input_features = processor.feature_extractor(
    audio,
    sampling_rate=16000,
    return_tensors="pt"
).input_features

# Generate transcription
model.eval()
with torch.no_grad():
    generated_ids = model.generate(
        input_features,
        language="az",  # Change to your language
        task="transcribe",
        max_length=448,
        num_beams=5
    )

# Decode
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

print("\nTranscription:")
print(transcription)

## 5. Compare with Standard Whisper

In [None]:
from transformers import WhisperForConditionalGeneration

# Load standard Whisper
standard_model = WhisperForConditionalGeneration.from_pretrained(model_name)
standard_model.eval()

# Generate with standard Whisper
with torch.no_grad():
    standard_ids = standard_model.generate(
        input_features,
        language="az",
        task="transcribe",
        max_length=448,
        num_beams=5
    )

standard_transcription = processor.batch_decode(standard_ids, skip_special_tokens=True)[0]

print("Standard Whisper:")
print(standard_transcription)
print("\nDual-Attention Whisper:")
print(transcription)

## 6. Training Example

In [None]:
from data.dataset import NoisyAudioDataset
from data.collator import DataCollatorSpeechSeq2SeqWithPadding
from training.metrics import compute_metrics
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

# Load datasets
train_dataset = NoisyAudioDataset(
    data_path="../data/processed/train.json",
    processor=processor,
    language="az"
)

eval_dataset = NoisyAudioDataset(
    data_path="../data/processed/eval.json",
    processor=processor,
    language="az"
)

# Data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id
)

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="../outputs",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    learning_rate=5e-6,
    warmup_steps=500,
    max_steps=5000,
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    predict_with_generate=True,
    fp16=True,
    report_to="tensorboard",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

# Trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics(processor),
    processing_class=processor.feature_extractor,
)

# Start training
# trainer.train()

## 7. Visualize Attention Weights (Advanced)

In [None]:
import matplotlib.pyplot as plt

# Generate with attention output
outputs = model.generate(
    input_features,
    language="az",
    task="transcribe",
    output_attentions=True,
    return_dict_in_generate=True,
    max_length=50
)

# Note: Extracting and visualizing dual attention weights requires
# additional code to access the primary vs secondary attention patterns
# This is left as an exercise for advanced users

print("Attention visualization would go here")
print("You can access cross-attention weights from the decoder outputs")