In [14]:
import torch
import librosa
import numpy as np
import os
import sys
import logging
import soundfile as sf

# Add project root to path to import models (if running notebook from root)
# If running from elsewhere, adjust the path as needed
# sys.path.append('.') # Uncomment if needed

from models.moth.encoder import MothEncoder
from models.bat.decoder import BatDecoder

# Basic logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("Inference")


In [16]:
# --- Configure Paths ---
ENCODER_MODEL_PATH = 'trained_models/actual_run_spec/encoder.pth' # Path to trained Moth model
DECODER_MODEL_PATH = 'trained_models/actual_run_spec/decoder.pth'   # Path to trained Bat model

# --- Select Input Audio ---
# IMPORTANT: Replace this with the actual path to YOUR input WAV file!
INPUT_WAV_PATH = '50_speakers_audio_data/Speaker_0000/Speaker_0000_00000.wav' 

# --- Output Path (Optional) ---
# Set path if you want to save the watermarked audio
OUTPUT_WATERMARKED_PATH = 'output/inference_watermarked.wav' # e.g., 'output/watermarked_audio.wav'
SAVE_WATERMARKED_AUDIO = True # Set to False if you don't want to save it

# --- Parameters (should match training) ---
SAMPLE_RATE = 16000
MAX_LENGTH = 96000 # Corresponds to 6 seconds at 16kHz


In [17]:
# --- Auto-detect device ---
def get_default_device():
    if torch.cuda.is_available():
        logger.info("CUDA detected. Using GPU.")
        return 'cuda'
    elif torch.backends.mps.is_available():
        logger.info("MPS detected. Using Apple Silicon GPU.")
        return 'mps'
    else:
        logger.info("No GPU detected. Using CPU.")
        return 'cpu'
# --------------------------

# --- Determine device ---
device_str = get_default_device()
try:
    device = torch.device(device_str)
    logger.info(f"Using device: {device}")
except RuntimeError as e:
    logger.error(f"Could not use device '{device_str}'. Error: {e}. Falling back to CPU.")
    device = torch.device('mmps')


2025-04-23 18:50:41,290 - INFO - MPS detected. Using Apple Silicon GPU.
2025-04-23 18:50:41,295 - INFO - Using device: mps


In [18]:
# --- Load Encoder (Moth) ---
logger.info(f"Loading Moth Encoder from: {ENCODER_MODEL_PATH}")
encoder = MothEncoder().to(device)
try:
    encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location=device))
    encoder.eval() # Set to evaluation mode
    logger.info("Moth Encoder loaded successfully.")
except FileNotFoundError:
    logger.error(f"Encoder model file not found at {ENCODER_MODEL_PATH}. Please train the model first or check the path.")
    # You might want to stop execution here if the model isn't found
    # raise SystemExit("Encoder model not found.") 
except Exception as e:
    logger.error(f"Error loading encoder model: {e}")
    # raise SystemExit("Error loading encoder.")
    
# --- Load Decoder (Bat) ---
logger.info(f"Loading Bat Decoder from: {DECODER_MODEL_PATH}")
decoder = BatDecoder().to(device)
try:
    decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location=device))
    decoder.eval() # Set to evaluation mode
    logger.info("Bat Decoder loaded successfully.")
except FileNotFoundError:
    logger.error(f"Decoder model file not found at {DECODER_MODEL_PATH}. Please train the model first or check the path.")
    # raise SystemExit("Decoder model not found.") 
except Exception as e:
    logger.error(f"Error loading decoder model: {e}")
    # raise SystemExit("Error loading decoder.")


2025-04-23 18:50:44,203 - INFO - Loading Moth Encoder from: trained_models/actual_run_spec/encoder.pth
2025-04-23 18:50:51,825 - INFO - Moth Encoder loaded successfully.
2025-04-23 18:50:51,827 - INFO - Loading Bat Decoder from: trained_models/actual_run_spec/decoder.pth
2025-04-23 18:50:52,847 - INFO - Bat Decoder loaded successfully.


In [19]:
def load_and_preprocess_audio(audio_path, sample_rate=SAMPLE_RATE, max_length=MAX_LENGTH):
    """Load and preprocess a single audio file."""
    try:
        logger.info(f"Loading audio file: {audio_path}")
        audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True)
        
        # Normalize audio
        if np.max(np.abs(audio)) > 0:
            audio_norm = audio / np.max(np.abs(audio))
        else:
            audio_norm = audio # Avoid division by zero for silent audio
            
        original_length = len(audio_norm)
        logger.info(f'Original audio length: {original_length} samples')
        
        # Pad or truncate
        if original_length < max_length:
            padded_audio = np.pad(audio_norm, (0, max_length - original_length))
        else:
            padded_audio = audio_norm[:max_length]
        
        # Convert to tensor and add batch/channel dimensions
        audio_tensor = torch.from_numpy(padded_audio).float().unsqueeze(0).unsqueeze(0)
        
        return audio_tensor, audio_norm # Return tensor for model and original normalized numpy for saving/comparison
    except FileNotFoundError:
        logger.error(f"Input audio file not found: {audio_path}")
        return None, None
    except Exception as e:
        logger.error(f"Error loading or processing audio {audio_path}: {e}")
        return None, None

# Load the specified audio file
original_audio_tensor, original_audio_np = load_and_preprocess_audio(INPUT_WAV_PATH)

if original_audio_tensor is not None:
    original_audio_tensor = original_audio_tensor.to(device)
    logger.info(f"Audio loaded and preprocessed successfully. Tensor shape: {original_audio_tensor.shape}")
else:
    logger.error("Failed to load audio. Cannot proceed with inference.")
    # Handle error, maybe raise an exception or exit
    # raise SystemExit("Audio loading failed.")


2025-04-23 18:50:56,596 - INFO - Loading audio file: 50_speakers_audio_data/Speaker_0000/Speaker_0000_00000.wav
2025-04-23 18:50:56,617 - INFO - Original audio length: 960512 samples
2025-04-23 18:50:56,764 - INFO - Audio loaded and preprocessed successfully. Tensor shape: torch.Size([1, 1, 96000])


In [20]:
watermarked_audio_tensor = None
watermarked_audio_np = None

if original_audio_tensor is not None and 'encoder' in locals() and encoder is not None:
    logger.info("Applying Moth encoder to generate watermark...")
    with torch.no_grad(): # No need to track gradients for inference
        perturbation = encoder(original_audio_tensor)
        watermarked_audio_tensor = original_audio_tensor + perturbation
        
        # Clamp values to avoid potential clipping issues, although perturbation is small
        watermarked_audio_tensor = torch.clamp(watermarked_audio_tensor, -1.0, 1.0) 
        
    # Convert watermarked tensor back to numpy for saving/comparison 
    # Remove batch and channel dimensions, convert to numpy
    watermarked_audio_np = watermarked_audio_tensor.squeeze(0).squeeze(0).cpu().numpy()
    logger.info("Watermark generated successfully.")
else:
    logger.warning("Skipping watermarking because original audio or encoder failed to load.")


2025-04-23 18:51:00,717 - INFO - Applying Moth encoder to generate watermark...
2025-04-23 18:51:01,040 - INFO - Watermark generated successfully.


In [21]:
from IPython.display import Audio, display

# Play the original audio (ensure the variable name is correct: original_audio_np)
if original_audio_np is not None:
    print("Playing original audio...")
    display(Audio(original_audio_np, rate=SAMPLE_RATE))
else:
    print("original_audio_np is not available.")

# Play the watermarked audio (ensure it has been generated)
if watermarked_audio_np is not None:
    print("Playing watermarked audio...")
    display(Audio(watermarked_audio_np, rate=SAMPLE_RATE))
else:
    print("watermarked_audio_np is not available.")


Playing original audio...


Playing watermarked audio...


In [22]:
if SAVE_WATERMARKED_AUDIO and watermarked_audio_np is not None:
    try:
        # Create output directory if it doesn't exist
        output_dir = os.path.dirname(OUTPUT_WATERMARKED_PATH)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        logger.info(f"Saving watermarked audio to: {OUTPUT_WATERMARKED_PATH}")
        # Use the numpy array (ensure it's 1D)
        sf.write(OUTPUT_WATERMARKED_PATH, watermarked_audio_np, SAMPLE_RATE)
        logger.info("Watermarked audio saved successfully.")
    except Exception as e:
        logger.error(f"Could not save watermarked audio: {e}")


2025-04-23 18:51:31,382 - INFO - Saving watermarked audio to: output/inference_watermarked.wav
2025-04-23 18:51:31,387 - INFO - Watermarked audio saved successfully.


In [23]:
original_prediction = None
watermarked_prediction = None

if 'decoder' in locals() and decoder is not None:
    # --- Test Original Audio ---
    if original_audio_tensor is not None:
        logger.info("Running Bat decoder on ORIGINAL audio...")
        with torch.no_grad():
            pred_orig_tensor = decoder(original_audio_tensor)
            original_prediction = pred_orig_tensor.item() # Get the scalar probability
        logger.info(f"Original audio prediction score: {original_prediction:.4f}")
    else:
        logger.warning("Cannot test original audio as it failed to load.")
        
    # --- Test Watermarked Audio ---
    if watermarked_audio_tensor is not None:
        logger.info("Running Bat decoder on WATERMARKED audio...")
        with torch.no_grad():
            pred_wm_tensor = decoder(watermarked_audio_tensor)
            watermarked_prediction = pred_wm_tensor.item() # Get the scalar probability
        logger.info(f"Watermarked audio prediction score: {watermarked_prediction:.4f}")
    else:
        logger.warning("Cannot test watermarked audio as it was not generated (likely due to previous errors).")
else:
    logger.warning("Decoder not loaded. Skipping detection step.")


2025-04-23 18:51:34,877 - INFO - Running Bat decoder on ORIGINAL audio...
2025-04-23 18:51:35,051 - INFO - Original audio prediction score: 0.5000
2025-04-23 18:51:35,051 - INFO - Running Bat decoder on WATERMARKED audio...
2025-04-23 18:51:35,059 - INFO - Watermarked audio prediction score: 0.5000


In [27]:
from pesq import pesq

# Check if both audio arrays are available and have the correct shape
if original_audio_np is not None and watermarked_audio_np is not None:
    # PESQ requires a 1-D numpy array for each audio signal and the sample rate
    # Make sure the audio lengths match. If they don't, truncate to the shorter one.
    min_len = min(len(original_audio_np), len(watermarked_audio_np))
    ref_audio = original_audio_np[:min_len]
    degraded_audio = watermarked_audio_np[:min_len]
    
    # Compute PESQ scores for wideband (mode="wb")
    pesq_score_original = pesq(SAMPLE_RATE, ref_audio, ref_audio, mode='wb')
    pesq_score_watermarked = pesq(SAMPLE_RATE, ref_audio, degraded_audio, mode='wb')
    
    print("\n--- PESQ Scores ---")
    print(f"PESQ (Original vs. Original): {pesq_score_original:.2f}")
    print(f"PESQ (Original vs. Watermarked): {pesq_score_watermarked:.2f}")
else:
    print("Required audio data (original or watermarked) is missing.")



--- PESQ Scores ---
PESQ (Original vs. Original): 4.64
PESQ (Original vs. Watermarked): 4.64


In [26]:
DETECTION_THRESHOLD = 0.5 # Standard threshold for binary classification

print("\n--- Inference Results ---")

if original_prediction is not None:
    print(f"Original Audio Prediction: {original_prediction:.4f}")
    if original_prediction < DETECTION_THRESHOLD:
        print("  -> Correctly identified as NOT watermarked (Score < {DETECTION_THRESHOLD}).")
    else:
        print(f"  -> INCORRECTLY identified as watermarked (Score >= {DETECTION_THRESHOLD}) - False Positive!")
else:
    print("Original Audio: Detection could not be performed.")

if watermarked_prediction is not None:
    print(f"\nWatermarked Audio Prediction: {watermarked_prediction:.4f}")
    if watermarked_prediction >= DETECTION_THRESHOLD:
        print(f"  -> Correctly identified as WATERMARKED (Score >= {DETECTION_THRESHOLD}).")
    else:
        print(f"  -> INCORRECTLY identified as not watermarked (Score < {DETECTION_THRESHOLD}) - False Negative!")
else:
    print("Watermarked Audio: Detection could not be performed.")

print("-----------------------")



--- Inference Results ---
Original Audio Prediction: 0.5000
  -> INCORRECTLY identified as watermarked (Score >= 0.5) - False Positive!

Watermarked Audio Prediction: 0.5000
  -> Correctly identified as WATERMARKED (Score >= 0.5).
-----------------------
