In [15]:
import os, random, torch, torchaudio
import numpy as np

def pick_random_wav(data_root):
    return 'harvard.wav'

# Load config
from train import load_cfg
cfg = load_cfg()

# Pick a random audio sample
wav_path = pick_random_wav(cfg["data_root"])
if not wav_path:
    raise RuntimeError("No .wav files found in dataset.")
print(f"Testing on: {wav_path}")

# Load and preprocess audio
wav, sr = torchaudio.load(wav_path)
wav = torchaudio.functional.resample(wav, sr, cfg["sample_rate"])
wav = wav.mean(dim=0, keepdim=True)
wav = wav / (wav.abs().max() + 1e-9)
num_samples = int(cfg["sample_rate"] * cfg["slice_seconds"])
if wav.shape[-1] >= num_samples:
    wav = wav[:, :num_samples]
else:
    repeat = (num_samples + wav.shape[-1] - 1) // wav.shape[-1]
    wav = wav.repeat(1, repeat)[:, :num_samples]
wav = wav.cpu()

Testing on: harvard.wav


In [16]:
import torch
from models import MothEncoder, BatDetector

device = "cpu"
moth = MothEncoder(cfg["alpha"]).to(device)
bat = BatDetector().to(device)
moth.load_state_dict(torch.load(os.path.join(cfg["exp_dir"], "moth_final.ckpt"), map_location=device))
bat.load_state_dict(torch.load(os.path.join(cfg["exp_dir"], "bat_final.ckpt"), map_location=device))
moth.eval(); bat.eval()

BatDetector(
  (features): Sequential(
    (0): Conv1d(1, 16, kernel_size=(15,), stride=(1,), padding=(7,))
    (1): ReLU()
    (2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (3): Conv1d(16, 32, kernel_size=(15,), stride=(1,), padding=(7,))
    (4): ReLU()
    (5): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (6): Conv1d(32, 64, kernel_size=(15,), stride=(1,), padding=(7,))
    (7): ReLU()
    (8): AdaptiveAvgPool1d(output_size=1)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=64, out_features=1, bias=True)
    (2): Sigmoid()
  )
)

In [None]:
from pydub import AudioSegment
from pydub.playback import play

def tensor_to_wav(tensor, sample_rate, out_path):
    arr = tensor.squeeze().cpu().numpy()
    arr = np.int16(arr * 32767)
    torchaudio.save(out_path, torch.from_numpy(arr).unsqueeze(0), sample_rate)
    return out_path

orig_path = "original_temp.wav"
tensor_to_wav(wav, cfg["sample_rate"], orig_path)
print("Playing original audio...")
play(AudioSegment.from_wav(orig_path))

Playing original audio...


Input #0, wav, from '/var/folders/n0/jlkzrfxd5kjb58lgm7cyjy2w0000gn/T/tmph128istu.wav':
  Duration: 00:00:02.00, bitrate: 256 kb/s
  Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 16000 Hz, 1 channels, s16, 256 kb/s
   1.92 M-A:  0.000 fd=   0 aq=    0KB vq=    0KB sq=    0B 




In [19]:
with torch.no_grad():
    watermarked = moth(wav.unsqueeze(0)).squeeze(0)

water_path = "watermarked_temp.wav"
tensor_to_wav(watermarked, cfg["sample_rate"], water_path)
print("Playing watermarked audio...")
play(AudioSegment.from_wav(water_path))

Playing watermarked audio...


Input #0, wav, from '/var/folders/n0/jlkzrfxd5kjb58lgm7cyjy2w0000gn/T/tmp_3u1gduu.wav':
  Duration: 00:00:02.00, bitrate: 256 kb/s
  Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 16000 Hz, 1 channels, s16, 256 kb/s
   1.91 M-A:  0.000 fd=   0 aq=    0KB vq=    0KB sq=    0B 




In [20]:
with torch.no_grad():
    prob_orig = bat(wav.unsqueeze(0)).item()
    prob_water = bat(watermarked.unsqueeze(0)).item()

print(f"Original audio prediction: {prob_orig:.3f} (watermark: {'YES' if prob_orig >= 0.5 else 'NO'})")
print(f"Watermarked audio prediction: {prob_water:.3f} (watermark: {'YES' if prob_water >= 0.5 else 'NO'})")

Original audio prediction: 0.000 (watermark: NO)
Watermarked audio prediction: 0.993 (watermark: YES)
