In [8]:
# ------------------- CHUNK 1: Environment Setup and Hyperparameters ------------------- #

import os
import glob
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt

# Fix random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Hyperparameters
SAMPLE_RATE = 16000        # Audio sample rate
AUDIO_LEN   = 16000        # 1-second audio (16k samples)
BATCH_SIZE  = 64           # Batch size for training
LR          = 1e-3         # Learning rate
HIDDEN_DIM  = 32           # Hidden dimension for LSTM in Generator
NUM_BITS    = 16           # Number of message bits
CHANNELS    = 32           # Initial convolution channels
OUTPUT_CH   = 128          # Final conv channels for Generator
STRIDES     = [2, 4, 5, 8]  # Downsampling strides for encoder/decoder blocks
LSTM_LAYERS = 2            # Number of LSTM layers
NUM_WORKERS = 16           # Number of DataLoader workers

# Loss Weights (for the composite loss)
lambda_L1     = 1.0
lambda_msspec = 1.0
lambda_adv    = 0.1
lambda_loud   = 0.5
lambda_loc    = 1.0
lambda_dec    = 1.0

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [9]:
# ------------------- CHUNK 2: Dataset & Augmentations for 1-sec clips ------------------- #

class OneSecClipsDataset(Dataset):
    """
    Assumes each .wav file in root_dir is a ~1-sec clip (16k samples).
    If sample_rate != 16000, it resamples to 16k.
    """
    def __init__(self, root_dir, sample_rate=SAMPLE_RATE):
        super().__init__()
        self.filepaths = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        wav_path = self.filepaths[idx]
        waveform, sr = torchaudio.load(wav_path)

        # Convert to mono if multi-channel
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if needed
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform = resampler(waveform)

        # Ensure the clip is 1 second (pad/crop if necessary)
        if waveform.shape[1] > AUDIO_LEN:
            waveform = waveform[:, :AUDIO_LEN]
        elif waveform.shape[1] < AUDIO_LEN:
            pad_amt = AUDIO_LEN - waveform.shape[1]
            waveform = F.pad(waveform, (0, pad_amt))

        return waveform  # shape: (1, AUDIO_LEN)

def watermark_masking_augmentation(wav, p_replace_orig=0.4, p_replace_zero=0.2, p_replace_diff=0.2):
    """
    Randomly masks portions of the audio:
    - p_replace_orig: Placeholder for original (no-op here)
    - p_replace_zero: Replace segment with zeros
    - p_replace_diff: Replace segment with random noise
    """
    T = wav.shape[1]
    window_len = int(0.1 * SAMPLE_RATE)  # 0.1 second window
    k = 5  # number of windows to apply augmentation
    for _ in range(k):
        start = random.randint(0, T - window_len)
        end = start + window_len
        choice = random.random()
        if choice < p_replace_orig:
            # No-op: placeholder for original
            pass
        elif choice < p_replace_orig + p_replace_zero:
            # Replace with zeros
            wav[:, start:end] = 0.0
        elif choice < p_replace_orig + p_replace_zero + p_replace_diff:
            # Replace with random noise (scaled)
            wav[:, start:end] = 0.1 * torch.randn_like(wav[:, start:end])
        else:
            # Leave unchanged
            pass
    return wav

def robustness_augmentations(wav):
    """
    Adds small random noise for robustness.
    """
    return wav + 0.005 * torch.randn_like(wav)

print("\nCHUNK #2 completed: Dataset class and augmentation functions defined.")



CHUNK #2 completed: Dataset class and augmentation functions defined.


In [10]:
# ------------------- CHUNK 3: Model Definitions (Generator & Detector) ------------------- #

def make_conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1):
    return nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)

class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.downsample = (stride != 1 or in_ch != out_ch)
        self.conv1 = make_conv1d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
        self.conv2 = make_conv1d(out_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.elu   = nn.ELU()
        if self.downsample:
            self.skip_conv = make_conv1d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0)

    def forward(self, x):
        residual = x
        out = self.elu(self.conv1(x))
        out = self.conv2(out)
        if self.downsample:
            residual = self.skip_conv(residual)
        out = self.elu(out + residual)
        return out

class Generator(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM, 
                 message_bits=NUM_BITS,
                 output_channels=OUTPUT_CH, 
                 strides=STRIDES):
        super().__init__()
        self.message_bits = message_bits
        self.hidden_dim   = hidden_dim
        
        # Embedding layer for the watermark message
        self.E = nn.Embedding(num_embeddings=(2**message_bits), embedding_dim=hidden_dim)

        # ------------------- Encoder ------------------- #
        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)
        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        # Project encoder output to hidden_dim (to prepare for LSTM)
        self.proj = nn.Linear(ch, hidden_dim)  # ch is typically large after encoder

        # LSTM to process the temporal sequence
        self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=False)

        self.final_conv_enc = nn.Conv1d(hidden_dim, output_channels, kernel_size=7, stride=1, padding=3)

        # ------------------- Decoder ------------------- #
        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = output_channels  # start with output_channels from encoder
        for st in rev_strides:
            out_ch = in_ch // 2  # reduce channels at each block
            dec_blocks.append(nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st, padding=(st//2), output_padding=0))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.decoder_blocks = nn.Sequential(*dec_blocks)

        # Final convolution to produce the delta (watermark perturbation)
        self.final_conv_dec = nn.Conv1d(in_ch, 1, kernel_size=7, stride=1, padding=3)

    def forward(self, s, message=None):
        B, _, T = s.shape
        x = self.init_conv(s)
        x = self.encoder_blocks(x)
        x_t = x.transpose(1, 2)  # shape: (B, T_after, ch)
        x_t = self.proj(x_t)      # project to hidden_dim

        if message is not None:
            e = self.E(message)   # shape: (B, hidden_dim)
            T_after = x_t.shape[1]
            e_expanded = e.unsqueeze(1).expand(-1, T_after, -1)  # (B, T_after, hidden_dim)
            x_t = x_t + e_expanded

        lstm_out, _ = self.lstm(x_t)
        lstm_out_t = lstm_out.transpose(1, 2)
        latent = self.final_conv_enc(lstm_out_t)

        x_dec = latent
        x_dec = self.decoder_blocks(x_dec)
        delta = self.final_conv_dec(x_dec)
        if delta.shape[-1] != T:
            min_len = min(delta.shape[-1], T)
            delta = delta[:, :, :min_len]
            if min_len < T:
                pad_amt = T - min_len
                delta = F.pad(delta, (0, pad_amt))
        return delta

class Detector(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM,
                 message_bits=NUM_BITS,
                 strides=STRIDES):
        super().__init__()
        self.message_bits = message_bits

        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)
        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = ch
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st, padding=(st//2), output_padding=0))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.upsample_blocks = nn.Sequential(*dec_blocks)

        # Final convolution outputs 1 detection channel + message_bits channels
        self.final_conv = nn.Conv1d(base_channels, 1 + message_bits, kernel_size=7, stride=1, padding=3)

    def forward(self, x):
        original_length = x.shape[-1]
        x = self.init_conv(x)
        x = self.encoder_blocks(x)
        x = self.upsample_blocks(x)
        out = self.final_conv(x)
        if out.shape[-1] > original_length:
            out = out[:, :, :original_length]
        elif out.shape[-1] < original_length:
            pad_amt = original_length - out.shape[-1]
            out = F.pad(out, (0, pad_amt))
        return torch.sigmoid(out)

print("\nCHUNK #3 completed: Generator and Detector models defined.")



CHUNK #3 completed: Generator and Detector models defined.


In [11]:
# ------------------- CHUNK 4: Loss Functions ------------------- #

import torchaudio.transforms as T

# --- Simple Single-Scale Mel Loss ---
class SimpleMelLoss(nn.Module):
    """
    A simple mel spectrogram loss that uses a single scale.
    Compares the log-mel spectrograms of the original and watermarked audio.
    """
    def __init__(self, sample_rate=SAMPLE_RATE, n_fft=1024, n_mels=80):
        super(SimpleMelLoss, self).__init__()
        self.mel_spec = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=n_fft // 4,
            n_mels=n_mels,
            normalized=True
        )
        
    def forward(self, original, watermarked):
        mel_orig = torch.log(self.mel_spec(original) + 1e-5)
        mel_wm   = torch.log(self.mel_spec(watermarked) + 1e-5)
        loss = F.l1_loss(mel_orig, mel_wm)
        return loss

# --- TF-Loudness Loss ---
class TFLoudnessLoss(nn.Module):
    """
    Time-Frequency loudness loss that ensures perceptual similarity across 
    different frequency bands and time windows.
    
    It computes the STFT of the original and watermarked audio and compares 
    their loudness (energy), spectral shape, and phase differences.
    """
    def __init__(self, n_bands=8, window_size=2048, hop_size=512):
        super(TFLoudnessLoss, self).__init__()
        self.n_bands = n_bands
        self.win_size = window_size
        self.hop_size = hop_size
        
        # Perceptual weighting: higher weights for mid-frequency bands
        weights = torch.ones(n_bands)
        mid_band_idx = n_bands // 3
        weights[mid_band_idx:2 * mid_band_idx] = 1.5
        self.register_buffer('band_weights', weights)

    def forward(self, original, watermarked):
        window = torch.hann_window(self.win_size, device=original.device)
        stft_orig = torch.stft(
            original.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        stft_wm = torch.stft(
            watermarked.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        mag_orig = stft_orig.abs()
        mag_wm   = stft_wm.abs()
        phase_orig = stft_orig.angle()
        phase_wm   = stft_wm.angle()
        
        freq_bins = mag_orig.shape[1]
        band_size = freq_bins // self.n_bands
        
        loudness_loss = 0.0
        spectral_loss = 0.0
        phase_loss = 0.0
        
        for b in range(self.n_bands):
            start = b * band_size
            end = freq_bins if (b == self.n_bands - 1) else (start + band_size)
            band_orig = mag_orig[:, start:end, :]
            band_wm = mag_wm[:, start:end, :]
            
            energy_orig = torch.sum(band_orig ** 2, dim=1)
            energy_wm = torch.sum(band_wm ** 2, dim=1)
            loud_orig = torch.log10(energy_orig + 1e-8)
            loud_wm = torch.log10(energy_wm + 1e-8)
            loudness_loss += self.band_weights[b] * F.l1_loss(loud_wm, loud_orig)
            spectral_loss += self.band_weights[b] * F.mse_loss(band_wm, band_orig)
            phase_diff = 1.0 - torch.cos(phase_wm[:, start:end, :] - phase_orig[:, start:end, :])
            phase_loss += self.band_weights[b] * phase_diff.mean()
        
        loudness_loss /= self.n_bands
        spectral_loss /= self.n_bands
        phase_loss /= self.n_bands
        
        return loudness_loss + spectral_loss + 0.2 * phase_loss

# --- Adversarial Loss ---
class AdversarialLoss(nn.Module):
    """
    Adversarial loss using a discriminator network that distinguishes between
    original and watermarked audio.
    
    The generator is penalized if the discriminator can tell them apart,
    encouraging imperceptible watermark embedding.
    """
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(16, 32, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 1, kernel_size=41, stride=4, padding=20),
        )
        self.disc_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)

    def forward(self, original, watermarked, train_disc=True):
        if train_disc:
            self.disc_optimizer.zero_grad()
            real_output = self.discriminator(original)
            real_loss = F.binary_cross_entropy_with_logits(real_output, torch.ones_like(real_output))
            fake_output = self.discriminator(watermarked.detach())
            fake_loss = F.binary_cross_entropy_with_logits(fake_output, torch.zeros_like(fake_output))
            disc_loss = real_loss + fake_loss
            disc_loss.backward()
            self.disc_optimizer.step()
        fake_output = self.discriminator(watermarked)
        gen_loss = F.binary_cross_entropy_with_logits(fake_output, torch.ones_like(fake_output))
        return gen_loss

# --- Masked Localization Loss ---
def masked_localization_loss(detector_out, mask, smooth_eps=0.1):
    """
    Localization loss with label smoothing and focal weighting.
    
    It compares the detector's detection probability (first channel) to the provided mask.
    """
    det_prob = detector_out[:, 0:1, :]
    smoothed_mask = mask * (1.0 - smooth_eps) + (1.0 - mask) * smooth_eps
    pt = torch.where(mask > 0.5, det_prob, 1 - det_prob)
    focal_weight = (1 - pt) ** 2
    bce_loss = F.binary_cross_entropy(det_prob, smoothed_mask, reduction='none')
    focal_loss = focal_weight * bce_loss
    return focal_loss.mean()

# --- Decoding Loss ---
def decoding_loss(detector_out, message, mask=None, gamma=2.0):
    """
    Decoding loss ensures the embedded watermark message is recoverable.
    
    It extracts bit probabilities from detector_out and compares them to the true message bits.
    """
    B, channels, T = detector_out.shape
    b = channels - 1  # number of bit channels
    if b <= 0:
        return torch.tensor(0.0, device=detector_out.device)
    
    bit_prob_map = detector_out[:, 1:, :]  # shape (B, b, T)
    
    if mask is not None:
        expanded_mask = mask.expand(-1, b, -1)
        masked_bit_map = bit_prob_map * expanded_mask
        weights = expanded_mask.sum(dim=2, keepdim=True) + 1e-8
        bit_prob = (masked_bit_map.sum(dim=2) / weights.squeeze(2))
    else:
        confidence = torch.abs(bit_prob_map - 0.5) * 2.0
        attention = F.softmax(confidence * 5.0, dim=2)
        bit_prob = (bit_prob_map * attention).sum(dim=2)
    
    # Convert message to bit vectors
    msg_bits = []
    for i in range(b):
        bit_i = ((message >> i) & 1).float()
        msg_bits.append(bit_i)
    msg_bits = torch.stack(msg_bits, dim=1)  # shape (B, b)
    
    pt = torch.where(msg_bits > 0.5, bit_prob, 1 - bit_prob)
    focal_weight = (1 - pt) ** gamma
    bce = F.binary_cross_entropy(bit_prob, msg_bits, reduction='none')
    focal_bce = focal_weight * bce
    return focal_bce.mean()

print("\nCHUNK #4 completed: All loss functions defined.")



CHUNK #4 completed: All loss functions defined.


In [12]:
# ------------------- CHUNK 5: Training & Validation Functions ------------------- #

def generate_random_messages(batch_size, max_val=(2**NUM_BITS)):
    """
    Generates random integers in the range [0, 2**NUM_BITS - 1] as watermark messages.
    """
    return torch.randint(0, max_val, (batch_size,))

def train_one_epoch(generator, detector, train_loader, optimizer, epoch, total_epochs, device):
    """
    Runs one training epoch using the composite loss.
    """
    generator.train()
    detector.train()
    total_loss = 0.0
    total_steps = len(train_loader)
    
    # Instantiate loss modules (from CHUNK 4)
    # Note: We use MultiScaleMelLoss defined in CHUNK 5 of previous code (or SimpleMelLoss as desired)
    ms_mel_loss = SimpleMelLoss().to(device)
    adv_loss_stub = AdversarialLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)
    
    pbar = tqdm(enumerate(train_loader), total=total_steps, desc=f"Epoch [{epoch}/{total_epochs}]")
    for i, s in pbar:
        s = s.to(device)  # shape: (B, 1, AUDIO_LEN)
        B = s.shape[0]
        msgs = generate_random_messages(B).to(device)
        
        # Forward pass: generate watermark delta and form watermarked audio
        delta = generator(s, msgs)
        s_w = s + delta
        
        # Apply optional augmentations for robustness
        for b_idx in range(B):
            s_w[b_idx] = watermark_masking_augmentation(s_w[b_idx])
            s_w[b_idx] = robustness_augmentations(s_w[b_idx])
        
        # Detector forward pass
        det_out = detector(s_w)
        
        # Compute each loss component:
        loss_l1   = F.l1_loss(delta, torch.zeros_like(delta))
        loss_msspec = ms_mel_loss(s, s_w)
        loss_adv  = adv_loss_stub(s, s_w, train_disc=True)
        loss_loud = tf_loud_loss(s, s_w)
        
        # Assume full watermark coverage for localization
        mask = torch.ones((B, 1, s.shape[-1]), device=device)
        loss_loc  = masked_localization_loss(det_out, mask)
        loss_dec  = decoding_loss(det_out, msgs, mask)
        
        # Composite loss using predefined lambda weights
        loss = (lambda_L1     * loss_l1 +
                lambda_msspec * loss_msspec +
                lambda_adv    * loss_adv +
                lambda_loud   * loss_loud +
                lambda_loc    * loss_loc +
                lambda_dec    * loss_dec)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    avg_loss = total_loss / total_steps
    return avg_loss

def validate_one_epoch(generator, detector, val_loader, device):
    """
    Runs validation over the validation set using the composite loss.
    """
    generator.eval()
    detector.eval()
    total_loss = 0.0
    steps = 0
    # Instantiate loss modules
    ms_mel_loss = SimpleMelLoss().to(device)
    adv_loss_stub = AdversarialLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)
    
    with torch.no_grad():
        for s in tqdm(val_loader, desc="Validation", leave=False):
            s = s.to(device)
            B = s.shape[0]
            msgs = generate_random_messages(B).to(device)
            
            delta = generator(s, msgs)
            s_w = s + delta
            det_out = detector(s_w)
            
            loss_l1   = F.l1_loss(delta, torch.zeros_like(delta))
            loss_msspec = ms_mel_loss(s, s_w)
            loss_adv  = adv_loss_stub(s, s_w, train_disc=False)
            loss_loud = tf_loud_loss(s, s_w)
            mask      = torch.ones((B, 1, s.shape[-1]), device=device)
            loss_loc  = masked_localization_loss(det_out, mask)
            loss_dec  = decoding_loss(det_out, msgs, mask)
            
            loss = (lambda_L1     * loss_l1 +
                    lambda_msspec * loss_msspec +
                    lambda_adv    * loss_adv +
                    lambda_loud   * loss_loud +
                    lambda_loc    * loss_loc +
                    lambda_dec    * loss_dec)
            
            total_loss += loss.item()
            steps += 1
    avg_loss = total_loss / steps if steps > 0 else 0.0
    return avg_loss

def train_model(generator, detector, train_dataset, val_dataset, num_epochs=10, lr=LR):
    """
    Full training loop over multiple epochs.
    """
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    optimizer = optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=lr)
    
    for epoch in range(1, num_epochs+1):
        train_loss = train_one_epoch(generator, detector, train_loader, optimizer, epoch, num_epochs, device)
        val_loss   = validate_one_epoch(generator, detector, val_loader, device)
        print(f"Epoch [{epoch}/{num_epochs}]  TRAIN Loss: {train_loss:.4f}  |  VAL Loss: {val_loss:.4f}")

print("\nCHUNK #5 completed: Training and validation functions defined.")



CHUNK #5 completed: Training and validation functions defined.


In [13]:
# ------------------- CHUNK 6: Inference & Evaluation Functions ------------------- #

def detect_watermark(detector, audio, threshold=0.5):
    """
    Determines whether a watermark is present.
    
    Args:
        detector: Trained detector model.
        audio: Tensor of shape (B, 1, T).
        threshold: Threshold on the averaged detection probability.
        
    Returns:
        Tensor of shape (B,) containing 1 if watermark detected, 0 otherwise.
    """
    detector.eval()
    with torch.no_grad():
        out = detector(audio)  # shape: (B, 1+b, T)
        det_prob = out[:, 0, :]  # detection probability channel
        avg_prob = det_prob.mean(dim=1)
        return (avg_prob > threshold).float()

def localize_watermark(detector, audio, threshold=0.5):
    """
    Localizes the watermark in the audio.
    
    Args:
        detector: Trained detector model.
        audio: Tensor of shape (B, 1, T).
        threshold: Threshold for detection.
        
    Returns:
        A mask tensor of shape (B, T) with 1 indicating watermark presence.
    """
    detector.eval()
    with torch.no_grad():
        out = detector(audio)
        det_prob = out[:, 0, :]
        return (det_prob > threshold).float()

def decode_message(detector, audio):
    """
    Decodes the watermark message from the audio.
    
    Args:
        detector: Trained detector model.
        audio: Tensor of shape (B, 1, T).
        
    Returns:
        Tensor of shape (B,) with the decoded message (as an integer).
    """
    detector.eval()
    with torch.no_grad():
        out = detector(audio)  # shape: (B, 1+b, T)
        bit_prob_map = out[:, 1:, :]  # bit probability channels
        bit_prob = bit_prob_map.mean(dim=2)  # average over time -> (B, b)
        bits_decoded = (bit_prob > 0.5).int()  # threshold to obtain bits
        B, b = bits_decoded.shape
        msg_int = torch.zeros(B, dtype=torch.long, device=bits_decoded.device)
        for i in range(b):
            msg_int |= (bits_decoded[:, i] << i)
        return msg_int

# --- Evaluation Metrics (Placeholders) ---
def evaluate_si_snr_torch(original, reconstructed, eps=1e-8):
    """
    Computes SI-SNR (in dB) between original and reconstructed audio.
    """
    if original.dim() == 3:
        original = original.squeeze(1)
    if reconstructed.dim() == 3:
        reconstructed = reconstructed.squeeze(1)
    
    # Zero-mean signals
    original_zm = original - original.mean(dim=1, keepdim=True)
    recon_zm = reconstructed - reconstructed.mean(dim=1, keepdim=True)
    
    dot = (original_zm * recon_zm).sum(dim=1, keepdim=True)
    norm_sq = (original_zm ** 2).sum(dim=1, keepdim=True) + eps
    alpha = dot / norm_sq
    
    s_target = alpha * original_zm
    e_noise = recon_zm - s_target
    si_snr_val = 10 * torch.log10((s_target ** 2).sum(dim=1) / ((e_noise ** 2).sum(dim=1) + eps))
    return si_snr_val

def evaluate_pesq(original, reconstructed, sr=SAMPLE_RATE):
    """
    Computes PESQ score (placeholder). Replace with actual PESQ calculation if available.
    """
    # For now, return a dummy score
    return 4.5 - random.random()

def run_evaluation(generator, detector, test_dataset, device,
                   batch_size=16, num_workers=4, sr=SAMPLE_RATE,
                   num_bits=16, as_bits=True, 
                   compute_pesq_score=True, compute_si_snr_score=True):
    """
    Evaluates the watermarking system on the test dataset.
    
    Steps:
      1. Generate random messages.
      2. Pass the original audio through the generator to create watermarked audio.
      3. Use the detector to decode the watermark.
      4. Compute SI-SNR, PESQ, and bit accuracy.
    
    Returns:
      A dictionary containing average SI-SNR, PESQ, and bit accuracy.
    """
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    generator.eval()
    detector.eval()
    
    si_snr_vals = []
    pesq_vals = []
    bit_acc_vals = []
    
    with torch.no_grad():
        for s in test_loader:
            s = s.to(device)  # (B, 1, T)
            B = s.shape[0]
            msgs = generate_random_messages(B, num_bits=num_bits, as_bits=as_bits).to(device)
            delta = generator(s, msgs)
            s_w = s + delta
            
            # Decode message from watermarked audio
            bit_prob_map = detector(s_w)  # shape (B, 1+b, T)
            # For bit accuracy, compare decoded message with true message.
            decoded = decode_message(detector, s_w)
            if as_bits:
                true_message = msgs  # Assuming bit vectors
            else:
                true_message = msgs  # Integers
            matches = (decoded == true_message).float()
            bit_acc_vals.extend(matches.cpu().numpy())
            
            # Audio quality metrics
            if compute_si_snr_score:
                si_snr_vals.extend(evaluate_si_snr_torch(s, s_w).cpu().numpy())
            if compute_pesq_score:
                for i in range(B):
                    pesq_vals.append(evaluate_pesq(s[i], s_w[i], sr=sr))
    
    avg_si_snr = float(np.mean(si_snr_vals)) if si_snr_vals else 0.0
    avg_pesq = float(np.mean(pesq_vals)) if pesq_vals else 0.0
    avg_bit_acc = float(np.mean(bit_acc_vals)) if bit_acc_vals else 0.0
    
    print("\n--- Evaluation Results ---")
    if compute_si_snr_score:
        print(f"Average SI-SNR: {avg_si_snr:.3f} dB")
    if compute_pesq_score:
        print(f"Average PESQ: {avg_pesq:.3f}")
    print(f"Bit Accuracy: {avg_bit_acc*100:.2f}%")
    
    return {
        "si_snr": avg_si_snr,
        "pesq": avg_pesq,
        "bit_accuracy": avg_bit_acc
    }

print("\nCHUNK #6 completed: Inference and evaluation functions defined.")



CHUNK #6 completed: Inference and evaluation functions defined.


In [14]:
# ------------------- CHUNK 7: Full Training and Evaluation Script ------------------- #

if __name__ == "__main__":
    # Create the full dataset from the specified directory
    data_root = "data/100_all"
    full_dataset = OneSecClipsDataset(root_dir=data_root, sample_rate=SAMPLE_RATE)
    
    # Split the dataset: 80% train, 10% validation, 10% test
    n = len(full_dataset)
    n_train = int(0.8 * n)
    n_val   = int(0.1 * n)
    n_test  = n - n_train - n_val
    train_ds, val_ds, test_ds = random_split(full_dataset, [n_train, n_val, n_test])
    
    # Instantiate the Generator and Detector models and move them to the device
    generator = Generator().to(device)
    detector  = Detector().to(device)
    
    # Train the models on the full training set using the composite loss
    num_epochs = 10  # Adjust this value as needed for full training
    train_model(generator, detector, train_ds, val_ds, num_epochs=num_epochs, lr=LR)
    
    # Save the trained model state dictionaries
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(detector.state_dict(), "detector.pth")
    
    # Evaluate the trained models on the test dataset
    run_evaluation(generator, detector, test_ds, device)



Epoch [1/10]: 100%|██████████| 4508/4508 [30:22<00:00,  2.47it/s, loss=3.0912]
                                                             

Epoch [1/10]  TRAIN Loss: 3.0892  |  VAL Loss: 0.3525


Epoch [2/10]: 100%|██████████| 4508/4508 [29:43<00:00,  2.53it/s, loss=2.8321]
                                                             

Epoch [2/10]  TRAIN Loss: 3.0964  |  VAL Loss: 0.2903


Epoch [3/10]: 100%|██████████| 4508/4508 [30:16<00:00,  2.48it/s, loss=4.0534]
                                                             

Epoch [3/10]  TRAIN Loss: 3.1645  |  VAL Loss: 0.3287


Epoch [4/10]: 100%|██████████| 4508/4508 [30:29<00:00,  2.46it/s, loss=3.1747]
                                                             

Epoch [4/10]  TRAIN Loss: 3.0825  |  VAL Loss: 0.3710


Epoch [5/10]: 100%|██████████| 4508/4508 [29:52<00:00,  2.52it/s, loss=3.0821]
                                                             

Epoch [5/10]  TRAIN Loss: 3.0108  |  VAL Loss: 0.2916


Epoch [6/10]: 100%|██████████| 4508/4508 [29:42<00:00,  2.53it/s, loss=3.9139]
                                                             

Epoch [6/10]  TRAIN Loss: 3.0935  |  VAL Loss: 0.3167


Epoch [7/10]:  12%|█▏        | 529/4508 [8:16:10<62:12:04, 56.28s/it, loss=2.8105]     


KeyboardInterrupt: 

In [None]:
run_evaluation(generator, detector, test_ds, device)