# TD-SpeakerBeam Demo Notebook

This notebook demonstrates how to use the TD-SpeakerBeam model for target speech extraction.

In [None]:
import sys
sys.path.append('../src')

import torch
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio

from models.td_speakerbeam import TimeDomainSpeakerBeam

## Load Pre-trained Model

Load a pre-trained TD-SpeakerBeam model.

In [None]:
# Load model (replace with actual model path)
model_path = '../example/model.pth'
try:
    model = TimeDomainSpeakerBeam.from_pretrained(model_path)
    model.eval()
    print("Model loaded successfully!")
except FileNotFoundError:
    print("Model file not found. Please train a model first or provide a valid model path.")
    # Create a dummy model for demonstration
    model = TimeDomainSpeakerBeam(
        i_adapt_layer=7,
        adapt_layer_type='mul',
        adapt_enroll_dim=128,
        n_filters=512,
        kernel_size=16,
        stride=8
    )
    print("Created dummy model for demonstration.")

## Load Audio Files

Load mixture and enrollment audio files.

In [None]:
# Load audio files (replace with actual file paths)
try:
    mixture, sr = sf.read('../example/mixture.wav')
    enrollment, _ = sf.read('../example/enrollment.wav')
    print(f"Loaded mixture: {mixture.shape}, enrollment: {enrollment.shape}")
    print(f"Sample rate: {sr}")
except FileNotFoundError:
    print("Audio files not found. Creating dummy signals.")
    sr = 8000
    duration = 3.0
    t = np.linspace(0, duration, int(sr * duration))
    mixture = 0.5 * np.sin(2 * np.pi * 440 * t) + 0.3 * np.sin(2 * np.pi * 880 * t)
    enrollment = 0.5 * np.sin(2 * np.pi * 440 * t)
    print(f"Created dummy signals: mixture: {mixture.shape}, enrollment: {enrollment.shape}")

## Perform Target Speech Extraction

In [None]:
# Convert to tensors
mixture_tensor = torch.from_numpy(mixture).float().unsqueeze(0)
enrollment_tensor = torch.from_numpy(enrollment).float().unsqueeze(0)

# Perform extraction
with torch.no_grad():
    extracted = model(mixture_tensor, enrollment_tensor)
    
extracted_audio = extracted.squeeze().numpy()
print(f"Extracted audio shape: {extracted_audio.shape}")

## Visualize Results

In [None]:
# Plot waveforms
fig, axes = plt.subplots(3, 1, figsize=(12, 8))

time = np.arange(len(mixture)) / sr

axes[0].plot(time, mixture)
axes[0].set_title('Mixture')
axes[0].set_ylabel('Amplitude')

axes[1].plot(time[:len(enrollment)], enrollment)
axes[1].set_title('Enrollment')
axes[1].set_ylabel('Amplitude')

axes[2].plot(time, extracted_audio)
axes[2].set_title('Extracted Target Speech')
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('Amplitude')

plt.tight_layout()
plt.show()

## Audio Playback

In [None]:
print("Original mixture:")
display(Audio(mixture, rate=sr))

print("Enrollment:")
display(Audio(enrollment, rate=sr))

print("Extracted target speech:")
display(Audio(extracted_audio, rate=sr))