In [1]:
# ------------------- CHUNK 1: Environment Setup ------------------- #

import os
import glob
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt

# Fix random seeds if desired:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# ------------------ Hyperparameters from your instructions ------------------ #
SAMPLE_RATE = 16000
AUDIO_LEN   = 16000  # 1 second
BATCH_SIZE  = 64
LR          = 1e-3
HIDDEN_DIM  = 32   # LSTM hidden dimension
NUM_BITS    = 16   # message bits
CHANNELS    = 32   # initial conv channels
OUTPUT_CH   = 128  # final conv channels for the generator
STRIDES     = [2, 4, 5, 8]  # downsampling strides
LSTM_LAYERS = 2
NUM_WORKERS = 16
# Loss Weights
lambda_L1 = 1.0
lambda_msspec = 1.0
lambda_adv = 0.1
lambda_loud = 0.5
lambda_loc = 1.0
lambda_dec = 1.0
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

print("\nCHUNK #1 completed: Environment and hyperparameters set.")


Using device: cuda

CHUNK #1 completed: Environment and hyperparameters set.


In [2]:
# ------------------- CHUNK 2 (UPDATED): Dataset & Augmentations for 1-sec clips ------------------- #
class OneSecClipsDataset(Dataset):
    """
    Assumes each .wav file in root_dir is a ~1-sec clip (16k samples).
    If sample_rate != 16000, we'll resample to 16k.
    """
    def __init__(self, root_dir, sample_rate=16000):
        super().__init__()
        self.filepaths = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        wav_path = self.filepaths[idx]
        waveform, sr = torchaudio.load(wav_path)

        # Convert to mono if multi-channel
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if needed
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform  = resampler(waveform)

        # By assumption, each clip is ~1 second at 16 kHz
        # If there's a minor mismatch, you can pad/crop here, e.g.:
        #   if waveform.shape[1] > 16000:
        #       waveform = waveform[:, :16000]
        #   elif waveform.shape[1] < 16000:
        #       pad_len = 16000 - waveform.shape[1]
        #       waveform = F.pad(waveform, (0, pad_len))
        
        return waveform  # shape: (1, ~16000)

def watermark_masking_augmentation(wav, p_replace_orig=0.4, p_replace_zero=0.2, p_replace_diff=0.2):
    T = wav.shape[1]
    window_len = int(0.1 * 16000)  # 0.1 second if sample_rate=16k
    k = 5
    for _ in range(k):
        start = random.randint(0, T - window_len)
        end   = start + window_len
        choice = random.random()
        if choice < p_replace_orig:
            # Replace with 'original' – in a real pipeline you'd keep track 
            # of the pre-watermarked version. Here it's a placeholder no-op.
            pass
        elif choice < p_replace_orig + p_replace_zero:
            # Replace with zeros
            wav[:, start:end] = 0.0
        elif choice < p_replace_orig + p_replace_zero + p_replace_diff:
            # Replace with random noise as a placeholder
            wav[:, start:end] = 0.1 * torch.randn_like(wav[:, start:end])
        else:
            # Leave unchanged
            pass
    return wav

def robustness_augmentations(wav):
    wav = wav + 0.005 * torch.randn_like(wav)
    return wav

print("\nCHUNK #2 updated for 1-sec clips in 'data/100_all' is ready.")



CHUNK #2 updated for 1-sec clips in 'data/100_all' is ready.


In [3]:
# ------------------- CHUNK 3: Generator with Residual Blocks, LSTM, Message Embedding ------------------- #

def make_conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1):
    return nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)

class ResidualBlock(nn.Module):

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.downsample = (stride != 1 or in_ch != out_ch)
        self.conv1 = make_conv1d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
        self.conv2 = make_conv1d(out_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.elu   = nn.ELU()
        
        # If channels differ or stride>1, need a 1D conv to match shapes for skip
        if self.downsample:
            self.skip_conv = make_conv1d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0)

    def forward(self, x):
        residual = x
        out = self.elu(self.conv1(x))
        out = self.conv2(out)
        if self.downsample:
            residual = self.skip_conv(residual)
        out = self.elu(out + residual)
        return out

class Generator(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM, 
                 message_bits=NUM_BITS,
                 output_channels=OUTPUT_CH, 
                 strides=STRIDES):
        super().__init__()
        self.message_bits = message_bits
        self.hidden_dim   = hidden_dim
        
        self.E = nn.Embedding(num_embeddings=(2**message_bits), embedding_dim=hidden_dim)

        # ------------------- Encoder ------------------- #
        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)

        enc_blocks = []
        ch = base_channels
        for i, st in enumerate(strides):
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        # Project encoder output to hidden_dim
        self.proj = nn.Linear(ch, hidden_dim)  # ch is 512 here

        # LSTM with input_size now matching hidden_dim
        self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=2, batch_first=True, bidirectional=False)

        self.final_conv_enc = nn.Conv1d(hidden_dim, output_channels, kernel_size=7, stride=1, padding=3)

        # ------------------- Decoder ------------------- #

        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = output_channels  # starting with 128
        for st in rev_strides:
            out_ch = in_ch // 2  # halving channels at each block
            dec_blocks.append(nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st, padding=(st//2), output_padding=0))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch  # update in_ch for next block
        self.decoder_blocks = nn.Sequential(*dec_blocks)

        # Change final_conv_dec to accept the correct number of channels from the decoder's output (in_ch)
        self.final_conv_dec = nn.Conv1d(in_ch, 1, kernel_size=7, stride=1, padding=3)



    def forward(self, s, message=None):
        B, _, T = s.shape
        x = self.init_conv(s)
        x = self.encoder_blocks(x)
        x_t = x.transpose(1, 2)  # shape: (B, T_after, ch)

        # Project to hidden_dim so we can add the message embedding
        x_t = self.proj(x_t)  # shape: (B, T_after, hidden_dim)

        if message is not None:
            e = self.E(message)  # shape: (B, hidden_dim)
            T_after = x_t.shape[1]
            e_expanded = e.unsqueeze(1).expand(-1, T_after, -1)  # (B, T_after, hidden_dim)
            x_t = x_t + e_expanded

        lstm_out, _ = self.lstm(x_t)
        lstm_out_t = lstm_out.transpose(1, 2)
        latent = self.final_conv_enc(lstm_out_t)

        x_dec = latent
        x_dec = self.decoder_blocks(x_dec)
        delta = self.final_conv_dec(x_dec)
        if delta.shape[-1] != T:
            min_len = min(delta.shape[-1], T)
            delta = delta[:, :, :min_len]
            if min_len < T:
                pad_amt = T - min_len
                delta = F.pad(delta, (0, pad_amt))
        return delta


print("\nCHUNK #3 completed: Generator model with residual blocks, LSTM, and message embedding defined.")



CHUNK #3 completed: Generator model with residual blocks, LSTM, and message embedding defined.


In [4]:
# ------------------- CHUNK 4: Detector Network ------------------- #
class Detector(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM,
                 message_bits=NUM_BITS,
                 strides=STRIDES):
        super().__init__()
        self.message_bits = message_bits

        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)

        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = ch
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st, padding=(st//2), output_padding=0))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.upsample_blocks = nn.Sequential(*dec_blocks)

        self.final_conv = nn.Conv1d(base_channels, 1 + message_bits, kernel_size=7, stride=1, padding=3)

    def forward(self, x):
        # Save original time dimension
        original_length = x.shape[-1]

        # Encoder
        x = self.init_conv(x)
        x = self.encoder_blocks(x)

        # Upsample
        x = self.upsample_blocks(x)
        out = self.final_conv(x)

        # Adjust output to match original time dimension
        if out.shape[-1] > original_length:
            out = out[:, :, :original_length]
        elif out.shape[-1] < original_length:
            pad_amt = original_length - out.shape[-1]
            out = F.pad(out, (0, pad_amt))
            
        return torch.sigmoid(out)


print("\nCHUNK #4 completed: Detector network (encoder + upsampling + classification) defined.")



CHUNK #4 completed: Detector network (encoder + upsampling + classification) defined.


In [5]:
# ------------------- CHUNK 5: Loss Functions ------------------- #
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T

# Define sample rate as a constant if not already defined
SAMPLE_RATE = 16000  # Common value, adjust to match your actual sample rate

class MultiScaleMelLoss(nn.Module):
    """
    Multi-scale mel spectrogram loss that compares spectrograms at different resolutions.
    This helps capture both fine and coarse-grained audio details.
    """
    def __init__(self, sample_rate=SAMPLE_RATE, n_ffts=[1024, 2048, 512], n_mels=80, alpha=1.0):
        super().__init__()
        self.mel_specs = nn.ModuleList([
            T.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                hop_length=n_fft // 4,
                n_mels=n_mels,
                normalized=True
            ) for n_fft in n_ffts
        ])
        self.alpha = alpha
        
    def forward(self, original, watermarked):
        # original & watermarked: shape (B,1,T)
        total_loss = 0.0
        
        # Compute loss at each scale
        for mel_spec in self.mel_specs:
            # Log mel spectrograms
            mel_orig = torch.log(mel_spec(original) + 1e-5)
            mel_wm = torch.log(mel_spec(watermarked) + 1e-5)
            
            # Compute L1 and L2 losses
            l1_loss = F.l1_loss(mel_wm, mel_orig)
            l2_loss = F.mse_loss(mel_wm, mel_orig)
            
            # Combine losses
            total_loss += l1_loss + self.alpha * l2_loss
            
        return total_loss / len(self.mel_specs)


class AdversarialLoss(nn.Module):
    """
    Adversarial loss using a discriminator network that tries to distinguish between
    original and watermarked audio.
    """
    def __init__(self):
        super().__init__()
        # Define a simple discriminator network
        self.discriminator = nn.Sequential(
            # Input: (B, 1, T)
            nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(16, 32, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 1, kernel_size=41, stride=4, padding=20),
        )
        self.optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001)
        
    def forward(self, original, watermarked, train_disc=True):
        # Train discriminator
        if train_disc:
            self.optimizer.zero_grad()
            
            # Real samples should be classified as 1
            real_output = self.discriminator(original)
            real_loss = F.binary_cross_entropy_with_logits(
                real_output, 
                torch.ones_like(real_output)
            )
            
            # Watermarked samples should be classified as 0
            fake_output = self.discriminator(watermarked.detach())
            fake_loss = F.binary_cross_entropy_with_logits(
                fake_output, 
                torch.zeros_like(fake_output)
            )
            
            # Total discriminator loss
            disc_loss = real_loss + fake_loss
            disc_loss.backward()
            self.optimizer.step()
        
        # Generator loss (trick discriminator)
        fake_output = self.discriminator(watermarked)
        gen_loss = F.binary_cross_entropy_with_logits(
            fake_output, 
            torch.ones_like(fake_output)
        )
        
        return gen_loss


class TFLoudnessLoss(nn.Module):
    """
    Time-Frequency loudness loss that ensures perceptual similarity across
    different frequency bands and time windows.
    """
    def __init__(self, n_bands=8, window_size=2048, hop_size=512):
        super().__init__()
        self.n_bands = n_bands
        self.win_size = window_size
        self.hop_size = hop_size
        
        # Perceptual weighting (approximating equal-loudness contours)
        # Higher weights for mid-frequency bands where human hearing is most sensitive
        weights = torch.ones(n_bands)
        mid_band_idx = n_bands // 3
        weights[mid_band_idx:2*mid_band_idx] = 1.5
        self.register_buffer('band_weights', weights)

    def forward(self, original, watermarked):
        # Create Hanning window
        window = torch.hann_window(self.win_size, device=original.device)
        
        # Compute STFTs
        stft_orig = torch.stft(
            original.squeeze(1), 
            n_fft=self.win_size, 
            hop_length=self.hop_size, 
            window=window, 
            return_complex=True,
            normalized=True
        )
        stft_wm = torch.stft(
            watermarked.squeeze(1), 
            n_fft=self.win_size, 
            hop_length=self.hop_size, 
            window=window, 
            return_complex=True,
            normalized=True
        )

        # Compute magnitude spectrograms
        mag_orig = stft_orig.abs()  # (B, freq, frames)
        mag_wm = stft_wm.abs()
        
        # Compute phase spectrograms
        phase_orig = stft_orig.angle()
        phase_wm = stft_wm.angle()

        # Divide frequency dimension into bands
        freq_bins = mag_orig.shape[1]
        band_size = freq_bins // self.n_bands

        # Initialize losses
        loudness_loss = 0.0
        spectral_loss = 0.0
        phase_loss = 0.0
        
        # Compute loss per band
        for b in range(self.n_bands):
            start = b * band_size
            end = freq_bins if (b == self.n_bands-1) else (start + band_size)

            # Extract band
            band_orig = mag_orig[:, start:end, :]
            band_wm = mag_wm[:, start:end, :]
            
            # Phase for this band
            phase_band_orig = phase_orig[:, start:end, :]
            phase_band_wm = phase_wm[:, start:end, :]
            
            # Compute band energy (loudness)
            energy_orig = torch.sum(band_orig**2, dim=1)  # Sum over freq, shape (B, frames)
            energy_wm = torch.sum(band_wm**2, dim=1)
            
            # Log-scale energy with stabilizing epsilon
            loud_orig = torch.log10(energy_orig + 1e-8)
            loud_wm = torch.log10(energy_wm + 1e-8)
            
            # Loudness difference (L1)
            band_loudness_diff = F.l1_loss(loud_wm, loud_orig)
            loudness_loss += self.band_weights[b] * band_loudness_diff
            
            # Spectral shape difference (L2)
            band_spectral_diff = F.mse_loss(band_wm, band_orig)
            spectral_loss += self.band_weights[b] * band_spectral_diff
            
            # Phase difference (important for transients)
            # We use circular distance for phase
            phase_diff = 1.0 - torch.cos(phase_band_wm - phase_band_orig)
            phase_loss += self.band_weights[b] * phase_diff.mean()

        # Normalize by number of bands
        loudness_loss /= self.n_bands
        spectral_loss /= self.n_bands
        phase_loss /= self.n_bands
        
        # Weight the different components
        # Phase is less important for many watermarking applications
        return loudness_loss + spectral_loss + 0.2 * phase_loss


def masked_localization_loss(detector_out, mask, smooth_eps=0.1):
    """
    Localization loss with label smoothing and focal weighting to improve 
    detector performance.
    
    Args:
        detector_out: shape (B, 1+b, T). The first channel is detection => (B,1,T).
        mask: shape (B,1,T), 1=watermarked region, 0=original
        smooth_eps: Label smoothing epsilon
    """
    det_prob = detector_out[:, 0:1, :]  # shape (B,1,T)
    
    # Apply label smoothing
    smoothed_mask = mask * (1.0 - smooth_eps) + (1.0 - mask) * smooth_eps
    
    # Focal loss weighting
    # Reduce weight for easy examples, focus on hard examples
    pt = torch.where(mask > 0.5, det_prob, 1 - det_prob)
    focal_weight = (1 - pt) ** 2
    
    # BCE loss
    bce_loss = F.binary_cross_entropy(det_prob, smoothed_mask, reduction='none')
    focal_loss = focal_weight * bce_loss
    
    # Average loss
    avg_loss = focal_loss.mean()
    
    return avg_loss


def decoding_loss(detector_out, message, mask=None, gamma=2.0):
    """
    Improved decoding loss that focuses on watermarked regions and uses
    a more sophisticated bit extraction method.
    
    Args:
        detector_out: shape (B, 1+b, T)
        message: shape (B,) or (B,b) containing integers in [0..2^b-1]
        mask: shape (B,1,T), 1=watermarked region, 0=original
        gamma: Focal loss parameter
    """
    B, channels, T = detector_out.shape
    b = channels - 1
    if b <= 0:
        return torch.tensor(0.0, device=detector_out.device)

    # Extract bit probability maps
    bit_prob_map = detector_out[:, 1:, :]  # shape (B,b,T)
    
    # If mask is provided, use it to focus on watermarked regions
    if mask is not None:
        # Expand mask to match bit channels
        expanded_mask = mask.expand(-1, b, -1)  # (B,b,T)
        
        # Apply mask: zero out non-watermarked regions
        masked_bit_map = bit_prob_map * expanded_mask
        
        # Calculate weighted average over time dimension
        # This gives more weight to regions that are strongly watermarked
        weights = expanded_mask.sum(dim=2, keepdim=True) + 1e-8
        bit_prob = (masked_bit_map.sum(dim=2) / weights.squeeze(2))
    else:
        # If no mask, use attention-like mechanism to focus on relevant time steps
        # Look for time steps with highest confidence
        confidence = torch.abs(bit_prob_map - 0.5) * 2  # Scale to [0,1]
        attention = F.softmax(confidence * 5.0, dim=2)  # Sharpen and normalize
        bit_prob = (bit_prob_map * attention).sum(dim=2)  # Weighted average

    # Convert message integers to binary bits
    msg_bits = []
    for i in range(b):
        bit_i = ((message >> i) & 1).float()
        msg_bits.append(bit_i)
    msg_bits = torch.stack(msg_bits, dim=1)  # shape (B,b)
    
    # Binary cross entropy with focal loss weighting
    pt = torch.where(msg_bits > 0.5, bit_prob, 1 - bit_prob)
    focal_weight = (1 - pt) ** gamma
    
    bce = F.binary_cross_entropy(bit_prob, msg_bits, reduction='none')
    focal_bce = focal_weight * bce
    
    return focal_bce.mean()

print("\nCHUNK #5 completed: stubs for multi-scale mel, adversarial, TF-loudness, localization, and decoding losses.")



CHUNK #5 completed: stubs for multi-scale mel, adversarial, TF-loudness, localization, and decoding losses.


In [6]:
# ------------------- CHUNK 6: Training Loop ------------------- #

def generate_random_messages(batch_size, max_val=(2**NUM_BITS)):
    # random integer in [0, 2^b -1]
    return torch.randint(0, max_val, (batch_size,))

# We'll define a combined function for one training step
def train_one_epoch(
    generator,
    detector,
    train_loader,
    optimizer,
    epoch,
    total_epochs,
    device
):
    generator.train()
    detector.train()
    
    # Instantiate the various loss modules
    ms_mel_loss   = MultiScaleMelLoss().to(device)
    adv_loss_stub = AdversarialLoss().to(device)
    tf_loud_loss  = TFLoudnessLoss().to(device)

    total_steps = len(train_loader)
    pbar = tqdm(enumerate(train_loader), total=total_steps, desc=f"Epoch [{epoch}/{total_epochs}]")
    
    sum_loss = 0.0
    for i, s in pbar:
        s = s.to(device)  # shape (B,1,16000)

        B = s.shape[0]
        # generate random messages
        msgs = generate_random_messages(B).to(device)  # shape (B,)

        # forward pass
        delta = generator(s, msgs)
        s_w   = s + delta  # watermarked audio

        # optional watermark masking augmentation
        # for demonstration, do it in a loop
        for b_idx in range(B):
            s_w[b_idx] = watermark_masking_augmentation(s_w[b_idx])
            s_w[b_idx] = robustness_augmentations(s_w[b_idx])

        # detection
        det_out = detector(s_w)

        # 1) L1 on watermark
        loss_l1 = F.l1_loss(delta, torch.zeros_like(delta))  # or you can do F.l1_loss(s_w, s)

        # 2) multi-scale mel
        loss_msspec = ms_mel_loss(s, s_w)

        # 3) adversarial stub
        loss_adv = adv_loss_stub(s, s_w)

        # 4) tf-loudness
        loss_loud = tf_loud_loss(s, s_w)

        # 5) localization => we build a mask=1 for the entire clip if it's watermarked
        # In real partial scenario, you'd know which samples are watermarked
        mask = torch.ones((B,1,s.shape[-1]), device=device)
        loss_loc = masked_localization_loss(det_out, mask)

        # 6) decoding => ensure the b channels decode the same bits we embedded
        loss_dec = decoding_loss(det_out, msgs)

        # Weighted sum
        loss = (lambda_L1     * loss_l1
              + lambda_msspec * loss_msspec
              + lambda_adv    * loss_adv
              + lambda_loud   * loss_loud
              + lambda_loc    * loss_loc
              + lambda_dec    * loss_dec)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sum_loss += loss.item()
        pbar.set_postfix({
            "loss_l1": f"{loss_l1.item():.3f}",
            "loss_loc": f"{loss_loc.item():.3f}",
            "loss_dec": f"{loss_dec.item():.3f}",
            "total": f"{loss.item():.3f}"
        })
    avg_loss = sum_loss / (total_steps if total_steps>0 else 1)
    return avg_loss

def validate_one_epoch(
    generator,
    detector,
    val_loader,
    device
):
    generator.eval()
    detector.eval()

    ms_mel_loss   = MultiScaleMelLoss().to(device)
    adv_loss_stub = AdversarialLoss().to(device)
    tf_loud_loss  = TFLoudnessLoss().to(device)

    sum_loss = 0.0
    steps = 0
    with torch.no_grad():
        for s in tqdm(val_loader, desc="Validation", leave=False):
            s = s.to(device)
            B = s.shape[0]
            msgs = generate_random_messages(B).to(device)

            delta = generator(s, msgs)
            s_w   = s + delta
            det_out = detector(s_w)

            # same losses
            loss_l1   = F.l1_loss(delta, torch.zeros_like(delta))
            loss_msspec = ms_mel_loss(s, s_w)
            loss_adv = adv_loss_stub(s, s_w)
            loss_loud = tf_loud_loss(s, s_w)
            mask      = torch.ones((B,1,s.shape[-1]), device=device)
            loss_loc  = masked_localization_loss(det_out, mask)
            loss_dec  = decoding_loss(det_out, msgs)

            loss = (lambda_L1     * loss_l1
                  + lambda_msspec * loss_msspec
                  + lambda_adv    * loss_adv
                  + lambda_loud   * loss_loud
                  + lambda_loc    * loss_loc
                  + lambda_dec    * loss_dec)

            sum_loss += loss.item()
            steps += 1
    return sum_loss / (steps if steps>0 else 1)

def train_model(
    generator,
    detector,
    train_dataset,
    val_dataset,
    num_epochs=10,
    lr=LR
):
    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    optimizer = optim.Adam(
        list(generator.parameters()) + list(detector.parameters()),
        lr=lr
    )

    for epoch in range(1, num_epochs+1):
        train_loss = train_one_epoch(generator, detector, train_loader, optimizer, epoch, num_epochs, device)
        val_loss   = validate_one_epoch(generator, detector, val_loader, device)
        print(f"Epoch [{epoch}/{num_epochs}]  TRAIN Loss: {train_loss:.4f}  |  VAL Loss: {val_loss:.4f}")

print("\nCHUNK #6 completed: training functions defined (multi-objective).")



CHUNK #6 completed: training functions defined (multi-objective).


In [7]:
# ------------------- CHUNK 7: Inference & Evaluation ------------------- #

def detect_watermark(detector, audio, threshold=0.5):
    """
    audio: shape (B,1,T)
    Returns boolean indicating if watermark is present (avg>threshold)
    """
    detector.eval()
    with torch.no_grad():
        out = detector(audio)  # shape (B, 1+b, T)
        det_prob = out[:, 0, :]  # (B, T)
        avg_prob = det_prob.mean(dim=1)  # (B,)
        return (avg_prob > threshold).float()

def localize_watermark(detector, audio, threshold=0.5):
    """
    Return a mask of shape (B,T) indicating where watermark is detected
    """
    with torch.no_grad():
        out = detector(audio)  # shape (B, 1+b, T)
        det_prob = out[:, 0, :]  # (B, T)
        return (det_prob > threshold).float()

def decode_message(detector, audio):
    """
    audio: shape (B,1,T)
    We average bit predictions across time -> get (B,b).
    Convert to 0/1 by threshold=0.5
    Return integer or bit array
    """
    with torch.no_grad():
        out = detector(audio)
        bit_prob_map = out[:, 1:, :]  # shape (B,b,T)
        bit_prob = bit_prob_map.mean(dim=2)  # (B,b)
        bits_decoded = (bit_prob>0.5).int()  # shape (B,b)
        # convert to integer
        # bits_decoded[:,0] => least significant or whichever order you prefer
        B, b = bits_decoded.shape
        msg_int = torch.zeros(B, dtype=torch.long, device=bits_decoded.device)
        for i in range(b):
            msg_int |= (bits_decoded[:, i] << i)
        return msg_int

# Placeholders for objective metrics
def evaluate_si_snr(original, reconstructed):
    # Real SI-SNR calculation is more involved. We'll do a placeholder
    return random.random()

def evaluate_pesq(original, reconstructed, sr=SAMPLE_RATE):
    # Typically use an external library (pesq package). We'll do a placeholder
    return 4.5 - random.random()

def run_evaluation(generator, detector, test_dataset):
    print("Running final evaluation on test dataset (placeholder).")
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    si_snr_list  = []
    pesq_list    = []
    
    generator.eval()
    detector.eval()
    with torch.no_grad():
        for s in test_loader:
            s = s.to(device)
            B = s.shape[0]
            msgs = generate_random_messages(B).to(device)
            delta = generator(s, msgs)
            s_w   = s + delta

            # Evaluate SI-SNR, PESQ, etc.
            for i in range(B):
                si_snr_list.append(evaluate_si_snr(s[i], s_w[i]))
                pesq_list.append(evaluate_pesq(s[i], s_w[i]))

    avg_si_snr = sum(si_snr_list)/len(si_snr_list) if si_snr_list else 0.0
    avg_pesq   = sum(pesq_list)/len(pesq_list)     if pesq_list else 0.0
    print(f"Average SI-SNR: {avg_si_snr:.3f}")
    print(f"Average PESQ:   {avg_pesq:.3f}")

print("\nCHUNK #7 completed: Inference (detection, localization, decode) & placeholder evaluation done.")



CHUNK #7 completed: Inference (detection, localization, decode) & placeholder evaluation done.


In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader, Subset

# Assumed to be defined elsewhere:
# - set_seed(seed)
# - OneSecClipsDataset(root_dir, sample_rate)
# - Generator(), Detector()
# - generate_random_messages(batch_size)
# - device, LR

# For reproducibility
set_seed(42)

# --- Simple Single-Scale Mel Loss ---
class SimpleMelLoss(nn.Module):
    """
    A simple mel spectrogram loss that uses a single scale.
    """
    def __init__(self, sample_rate=16000, n_fft=1024, n_mels=80):
        super(SimpleMelLoss, self).__init__()
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=n_fft // 4,
            n_mels=n_mels,
            normalized=True
        )
        
    def forward(self, original, watermarked):
        # original and watermarked: shape (B, 1, T)
        mel_orig = torch.log(self.mel_spec(original) + 1e-5)
        mel_wm   = torch.log(self.mel_spec(watermarked) + 1e-5)
        loss = F.l1_loss(mel_orig, mel_wm)
        return loss

# --- TF-Loudness Loss ---
class TFLoudnessLoss(nn.Module):
    """
    Time-Frequency loudness loss that ensures perceptual similarity across
    different frequency bands and time windows.
    
    It computes the STFT of the original and watermarked audio and compares their
    loudness, spectral shape, and phase differences.
    """
    def __init__(self, n_bands=8, window_size=2048, hop_size=512):
        super(TFLoudnessLoss, self).__init__()
        self.n_bands = n_bands
        self.win_size = window_size
        self.hop_size = hop_size
        
        # Perceptual weighting: higher weights for mid-frequency bands
        weights = torch.ones(n_bands)
        mid_band_idx = n_bands // 3
        weights[mid_band_idx:2 * mid_band_idx] = 1.5
        self.register_buffer('band_weights', weights)

    def forward(self, original, watermarked):
        window = torch.hann_window(self.win_size, device=original.device)
        stft_orig = torch.stft(
            original.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        stft_wm = torch.stft(
            watermarked.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        mag_orig = stft_orig.abs()
        mag_wm   = stft_wm.abs()
        phase_orig = stft_orig.angle()
        phase_wm   = stft_wm.angle()
        
        freq_bins = mag_orig.shape[1]
        band_size = freq_bins // self.n_bands
        
        loudness_loss = 0.0
        spectral_loss = 0.0
        phase_loss    = 0.0
        
        for b in range(self.n_bands):
            start = b * band_size
            end = freq_bins if (b == self.n_bands - 1) else (start + band_size)
            band_orig = mag_orig[:, start:end, :]
            band_wm   = mag_wm[:, start:end, :]
            
            energy_orig = torch.sum(band_orig ** 2, dim=1)
            energy_wm   = torch.sum(band_wm ** 2, dim=1)
            loud_orig   = torch.log10(energy_orig + 1e-8)
            loud_wm     = torch.log10(energy_wm + 1e-8)
            loudness_loss += self.band_weights[b] * F.l1_loss(loud_wm, loud_orig)
            
            spectral_loss += self.band_weights[b] * F.mse_loss(band_wm, band_orig)
            
            phase_diff = 1.0 - torch.cos(phase_wm[:, start:end, :] - phase_orig[:, start:end, :])
            phase_loss += self.band_weights[b] * phase_diff.mean()
        
        loudness_loss /= self.n_bands
        spectral_loss /= self.n_bands
        phase_loss    /= self.n_bands
        
        return loudness_loss + spectral_loss + 0.2 * phase_loss

# --- Adversarial Loss ---
class AdversarialLoss(nn.Module):
    """
    Adversarial loss using a discriminator network that distinguishes between
    original and watermarked audio.
    
    The generator is penalized if the discriminator can tell the difference,
    forcing it to produce watermarked audio that is as realistic as the original.
    """
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(16, 32, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 1, kernel_size=41, stride=4, padding=20),
        )
        self.disc_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)

    def forward(self, original, watermarked, train_disc=True):
        if train_disc:
            self.disc_optimizer.zero_grad()
            real_output = self.discriminator(original)
            real_loss = F.binary_cross_entropy_with_logits(real_output, torch.ones_like(real_output))
            fake_output = self.discriminator(watermarked.detach())
            fake_loss = F.binary_cross_entropy_with_logits(fake_output, torch.zeros_like(fake_output))
            disc_loss = real_loss + fake_loss
            disc_loss.backward()
            self.disc_optimizer.step()
        fake_output = self.discriminator(watermarked)
        gen_loss = F.binary_cross_entropy_with_logits(fake_output, torch.ones_like(fake_output))
        return gen_loss

# --- Masked Localization Loss ---
def masked_localization_loss(detector_out, mask, smooth_eps=0.1):
    """
    Localization loss with label smoothing and focal loss weighting.
    
    It uses the detector's first output channel (detection probability per sample)
    and compares it to a mask indicating where the watermark is present.
    """
    det_prob = detector_out[:, 0:1, :]  # detection channel
    smoothed_mask = mask * (1.0 - smooth_eps) + (1.0 - mask) * smooth_eps
    pt = torch.where(mask > 0.5, det_prob, 1 - det_prob)
    focal_weight = (1 - pt) ** 2
    bce_loss = F.binary_cross_entropy(det_prob, smoothed_mask, reduction='none')
    focal_loss = focal_weight * bce_loss
    return focal_loss.mean()

# --- Decoding Loss ---
def decoding_loss(detector_out, message, mask=None, gamma=2.0):
    """
    Decoding loss that ensures the watermark message embedded by the generator
    can be accurately recovered from the detector's output.
    
    Args:
        detector_out: Tensor of shape (B, 1+b, T) where the first channel is the detection
                      probability and the remaining b channels contain bit probabilities.
        message: Tensor of shape (B,) or (B, b) containing the true watermark message (as integers or bits).
        mask: Optional tensor of shape (B, 1, T) that indicates watermarked regions.
        gamma: Focal loss parameter for weighting hard examples.
    """
    B, channels, T = detector_out.shape
    b = channels - 1  # number of bit channels
    if b <= 0:
        return torch.tensor(0.0, device=detector_out.device)
    
    bit_prob_map = detector_out[:, 1:, :]  # shape (B, b, T)
    
    # If a mask is provided, focus on watermarked regions.
    if mask is not None:
        expanded_mask = mask.expand(-1, b, -1)  # shape (B, b, T)
        masked_bit_map = bit_prob_map * expanded_mask
        weights = expanded_mask.sum(dim=2, keepdim=True) + 1e-8
        bit_prob = (masked_bit_map.sum(dim=2) / weights.squeeze(2))
    else:
        # Otherwise, use an attention-like mechanism to focus on time steps with high confidence.
        confidence = torch.abs(bit_prob_map - 0.5) * 2.0
        attention = F.softmax(confidence * 5.0, dim=2)
        bit_prob = (bit_prob_map * attention).sum(dim=2)
    
    # Convert message integers to binary bit vectors.
    msg_bits = []
    for i in range(b):
        bit_i = ((message >> i) & 1).float()
        msg_bits.append(bit_i)
    msg_bits = torch.stack(msg_bits, dim=1)  # shape (B, b)
    
    # Compute focal weighted binary cross-entropy loss.
    pt = torch.where(msg_bits > 0.5, bit_prob, 1 - bit_prob)
    focal_weight = (1 - pt) ** gamma
    bce = F.binary_cross_entropy(bit_prob, msg_bits, reduction='none')
    focal_bce = focal_weight * bce
    return focal_bce.mean()

# --- Training Loop with All Losses and Learning Rate Scheduler ---
def train_one_epoch_full(generator, detector, loader, optimizer, device,
                         lambda_tf_loud=0.5, lambda_adv=0.1, lambda_loc=1.0, lambda_dec=1.0):
    generator.train()
    detector.train()
    total_loss = 0.0
    
    # Instantiate loss modules.
    simple_mel_loss = SimpleMelLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)
    adv_loss_module = AdversarialLoss().to(device)
    
    for s in loader:
        s = s.to(device)
        B = s.shape[0]
        msgs = generate_random_messages(B).to(device)
        
        # Generate watermarked audio.
        delta = generator(s, msgs)
        s_w = s + delta
        
        # Compute primary losses.
        loss_l1  = F.l1_loss(delta, torch.zeros_like(delta))
        loss_mel = simple_mel_loss(s, s_w)
        loss_tf  = tf_loud_loss(s, s_w)
        loss_adv = adv_loss_module(s, s_w, train_disc=True)
        
        # Compute localization loss using the detector network.
        det_out = detector(s_w)
        mask = torch.ones((B, 1, s.shape[-1]), device=device)  # assume full watermark coverage
        loss_loc = masked_localization_loss(det_out, mask)
        
        # Compute decoding loss to ensure watermark message recoverability.
        loss_dec = decoding_loss(det_out, msgs, mask)
        
        # Total loss is the weighted sum of all components.
        loss = loss_l1 + loss_mel + lambda_tf_loud * loss_tf + lambda_adv * loss_adv + lambda_loc * loss_loc + lambda_dec * loss_dec
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(loader)
    return avg_loss

def overfit_on_small_subset_full_scheduler():
    data_root = "data/100_all"
    full_dataset = OneSecClipsDataset(root_dir=data_root, sample_rate=16000)
    small_dataset = Subset(full_dataset, list(range(64)))
    loader = DataLoader(small_dataset, batch_size=8, shuffle=True, num_workers=0)
    
    generator = Generator().to(device)
    detector  = Detector().to(device)
    
    optimizer = optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    
    num_epochs = 100
    for epoch in range(1, num_epochs + 1):
        loss = train_one_epoch_full(generator, detector, loader, optimizer, device,
                                    lambda_tf_loud=0.5, lambda_adv=0.1, lambda_loc=1.0, lambda_dec=1.0)
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch}/{num_epochs} - Loss: {loss:.4f} - LR: {current_lr:.6f}")

if __name__ == "__main__":
    overfit_on_small_subset_full_scheduler()


Epoch 1/100 - Loss: 2.2188 - LR: 0.000100
Epoch 2/100 - Loss: 1.6791 - LR: 0.000100
Epoch 3/100 - Loss: 1.4266 - LR: 0.000100
Epoch 4/100 - Loss: 1.2096 - LR: 0.000100
Epoch 5/100 - Loss: 1.0645 - LR: 0.000100
Epoch 6/100 - Loss: 0.9874 - LR: 0.000100
Epoch 7/100 - Loss: 0.9197 - LR: 0.000100
Epoch 8/100 - Loss: 0.8505 - LR: 0.000100
Epoch 9/100 - Loss: 0.7706 - LR: 0.000100
Epoch 10/100 - Loss: 0.7296 - LR: 0.000050
Epoch 11/100 - Loss: 0.6766 - LR: 0.000050
Epoch 12/100 - Loss: 0.6244 - LR: 0.000050
Epoch 13/100 - Loss: 0.5213 - LR: 0.000050
Epoch 14/100 - Loss: 0.5402 - LR: 0.000050
Epoch 15/100 - Loss: 0.5312 - LR: 0.000050
Epoch 16/100 - Loss: 0.6248 - LR: 0.000050
Epoch 17/100 - Loss: 0.6335 - LR: 0.000050
Epoch 18/100 - Loss: 0.6383 - LR: 0.000050
Epoch 19/100 - Loss: 0.6269 - LR: 0.000050
Epoch 20/100 - Loss: 0.6169 - LR: 0.000025
Epoch 21/100 - Loss: 0.6049 - LR: 0.000025
Epoch 22/100 - Loss: 0.5926 - LR: 0.000025
Epoch 23/100 - Loss: 0.5817 - LR: 0.000025
Epoch 24/100 - Loss:

In [None]:
if __name__ == "__main__":
    # Create dataset
    data_root = "data/100_all"
    full_dataset = OneSecClipsDataset(root_dir=data_root, sample_rate=16000)


    # Train/Val split
    n = len(full_dataset)
    n_train = int(0.8*n)
    n_val   = int(0.1*n)
    n_test  = n - n_train - n_val
    train_ds, val_ds, test_ds = random_split(full_dataset, [n_train, n_val, n_test])

    # Instantiate models
    generator = Generator().to(device)
    detector  = Detector().to(device)

    # Train
    train_model(generator, detector, train_ds, val_ds, num_epochs=10, lr=LR)

    torch.save(generator.state_dict(), "generator.pth")
    torch.save(detector.state_dict(), "detector.pth")  
    # Evaluate
    run_evaluation(generator, detector, test_ds)


Epoch [1/10]:   6%|▌         | 280/4508 [02:11<33:06,  2.13it/s, loss_l1=0.001, loss_loc=0.000, loss_dec=0.175, total=10.007]


KeyboardInterrupt: 

In [None]:
# # Save the state dictionary of the generator and detector
# torch.save(generator.state_dict(), "generator.pth")
# torch.save(detector.state_dict(), "detector.pth")


In [None]:
# from torch.utils.data import Subset

                                                            
# if __name__ == "__main__":
#     # Create dataset
#     data_root = "data/100_all"
#     full_dataset = OneSecClipsDataset(root_dir=data_root, sample_rate=16000)


#     # Train/Val split
#     n = len(full_dataset)
#     n_train = int(0.8*n)
#     n_val   = int(0.1*n)
#     n_test  = n - n_train - n_val
#     train_ds, val_ds, test_ds = random_split(full_dataset, [n_train, n_val, n_test])

#     # Instantiate models
#     generator = Generator().to(device)
#     detector  = Detector().to(device)

# # Create a subset of 100 samples (adjust as needed)
#     small_subset_indices = list(range(1000))
#     small_train_ds = Subset(full_dataset, small_subset_indices)
#     small_train_loader = DataLoader(small_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

#     # Now train using the small subset
#     train_model(generator, detector, small_train_ds, val_ds, num_epochs=10, lr=LR)


In [None]:
import torch
import torch.nn.functional as F
import numpy as np

try:
    from pesq import pesq
except ImportError:
    pesq = None
    print("WARNING: pesq library not found. PESQ will be unavailable.")

######################################################################
# 1) SI-SNR
######################################################################
def evaluate_si_snr_torch(original, reconstructed, eps=1e-8):
    """
    Torch-based SI-SNR for a batch:
      original, reconstructed: shape (B, 1, T) or (B, T).
    Returns: Tensor of shape (B,) with SI-SNR(dB) per sample in the batch.
    """
    if original.dim() == 3:
        # shape (B,1,T) => (B,T)
        original = original.squeeze(1)
    if reconstructed.dim() == 3:
        reconstructed = reconstructed.squeeze(1)

    # Zero-mean along time
    original_zm = original - torch.mean(original, dim=1, keepdim=True)
    recon_zm    = reconstructed - torch.mean(reconstructed, dim=1, keepdim=True)

    # Project recon onto original
    dot = torch.sum(original_zm * recon_zm, dim=1, keepdim=True)
    norm_sq = torch.sum(original_zm**2, dim=1, keepdim=True) + eps
    alpha = dot / norm_sq

    s_target = alpha * original_zm           # shape (B,T)
    e_noise  = recon_zm - s_target
    num  = torch.sum(s_target**2, dim=1) + eps
    den  = torch.sum(e_noise**2,  dim=1) + eps
    si_snr_val = 10 * torch.log10(num / den)
    return si_snr_val

######################################################################
# 2) PESQ
######################################################################
def evaluate_pesq(original, reconstructed, sr=16000):
    """
    Single-sample PESQ in wide-band mode (for sr=16000).
    - original, reconstructed: shape (T,) or (1,T) as torch Tensors or numpy arrays
    - returns float PESQ or 0.0 if error/no library
    """
    if pesq is None:
        print("PESQ library not available.")
        return 0.0

    # Convert to CPU numpy
    if isinstance(original, torch.Tensor):
        original = original.squeeze().detach().cpu().numpy()
    if isinstance(reconstructed, torch.Tensor):
        reconstructed = reconstructed.squeeze().detach().cpu().numpy()

    # Ensure valid sample rate for PESQ
    assert sr in [8000, 16000], "PESQ only supports 8k or 16k sample rates."

    # Normalize if needed
    if np.max(np.abs(original)) > 1.0:
        original = original / np.max(np.abs(original))
    if np.max(np.abs(reconstructed)) > 1.0:
        reconstructed = reconstructed / np.max(np.abs(reconstructed))

    mode = 'wb' if sr == 16000 else 'nb'
    try:
        score = pesq(sr, original, reconstructed, mode)
    except Exception as e:
        print(f"PESQ error: {e}")
        score = 0.0
    return score

######################################################################
# 3) Multi-bit message generation
######################################################################
def generate_random_messages(batch_size, num_bits=16, as_bits=True):
    """
    If as_bits=True, return shape (B, num_bits) of {0,1}.
    If as_bits=False, return shape (B,) of integers in [0..2^num_bits-1].
    """
    if as_bits:
        return torch.randint(0, 2, (batch_size, num_bits), dtype=torch.float)
    else:
        max_val = 2**num_bits
        return torch.randint(0, max_val, (batch_size,), dtype=torch.long)

######################################################################
# 4) Decoding from Detector
######################################################################
def decode_bits_from_detector(detector_out):
    """
    If the detector outputs shape (B, 1+b, T):
      - channel 0 = detection probability per sample
      - channels [1..b] = bit probability per sample
    We'll average across time => shape (B,b)
    Then threshold at 0.5 => bits
    Returns a float tensor shape (B,b) in {0,1}.
    """
    B, channels, T = detector_out.shape
    b = channels - 1
    if b <= 0:
        return torch.zeros((B,0), device=detector_out.device)

    bit_prob_map = detector_out[:, 1:, :]  # (B,b,T)
    bit_prob = bit_prob_map.mean(dim=2)    # (B,b)
    bit_pred = (bit_prob > 0.5).float()    # (B,b)
    return bit_pred

######################################################################
# 5) Evaluate the entire pipeline
######################################################################
def run_evaluation(generator, detector, test_dataset, device,
                   batch_size=16, num_workers=4, sr=16000, 
                   num_bits=16, as_bits=True, 
                   compute_pesq_score=True, compute_si_snr_score=True):
    """
    Evaluate on test_dataset:
      1) We generate random messages (bit vector or integer).
      2) Generator => watermarked
      3) Detector => decode bits
      4) Compare with ground truth
      5) Compute SI-SNR & PESQ

    generator, detector: your trained models
    test_dataset: dataset of shape (B,1,T)
    device: "cuda" or "cpu"
    batch_size: ...
    sr: sample rate for PESQ
    num_bits: used to generate messages
    as_bits: if True => shape (B,b), else => shape (B,)
    compute_pesq_score, compute_si_snr_score: toggles

    Returns a dict with average SI-SNR, PESQ, and bit accuracy.
    """
    from torch.utils.data import DataLoader

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    generator.eval()
    detector.eval()

    si_snr_vals  = []
    pesq_vals    = []
    bit_acc_vals = []

    with torch.no_grad():
        for s in test_loader:
            s = s.to(device)  # (B,1,T)
            B = s.shape[0]

            # 1) Generate random messages
            msgs = generate_random_messages(B, num_bits=num_bits, as_bits=as_bits).to(device)

            # 2) Generator forward
            delta = generator(s, msgs)
            s_w   = s + delta

            # 3) Detector forward => shape (B,1+b,T)
            det_out = detector(s_w)
            # decode bits
            bit_pred = decode_bits_from_detector(det_out)  # (B,b)

            # ground truth bits
            if as_bits:
                bit_true = msgs  # shape (B,b)
            else:
                # if we have an integer, we must convert to bits
                # shape (B,b)
                bit_true = []
                for i in range(num_bits):
                    bit_i = ((msgs >> i) & 1).float()
                    bit_true.append(bit_i)
                bit_true = torch.stack(bit_true, dim=1)

            # compute bitwise accuracy per sample
            # shape (B,b)
            matches = (bit_pred == bit_true).float()
            sample_acc = matches.mean(dim=1)  # (B,) average bits => per-sample accuracy
            bit_acc_vals.extend(sample_acc.cpu().numpy())

            # 4) Audio quality metrics
            if compute_si_snr_score:
                si_snr_batch = evaluate_si_snr_torch(s, s_w)  # shape (B,)
                si_snr_vals.extend(si_snr_batch.cpu().numpy())

            if compute_pesq_score:
                # PESQ is per sample
                for i in range(B):
                    pesq_score = evaluate_pesq(s[i], s_w[i], sr=sr)
                    pesq_vals.append(pesq_score)

    # Aggregate
    avg_si_snr = float(np.mean(si_snr_vals)) if len(si_snr_vals)>0 else 0.0
    avg_pesq   = float(np.mean(pesq_vals))   if len(pesq_vals)>0 else 0.0
    avg_bit_acc= float(np.mean(bit_acc_vals))if len(bit_acc_vals)>0 else 0.0

    print("\n--- Evaluation Results ---")
    if compute_si_snr_score:
        print(f"SI-SNR:  {avg_si_snr:.3f} dB")
    if compute_pesq_score:
        print(f"PESQ:    {avg_pesq:.3f}")
    print(f"Bit Accuracy: {avg_bit_acc*100:.2f}%")

    return {
        "si_snr": avg_si_snr,
        "pesq": avg_pesq,
        "bit_accuracy": avg_bit_acc
    }


In [None]:
results = run_evaluation(
    generator, detector, 
    test_dataset=test_ds, 
    device=device,
    batch_size=16,
    num_workers=4,
    sr=16000,
    num_bits=16,
    as_bits=False,   # or True if your generator uses bit vectors
    compute_pesq_score=True,
    compute_si_snr_score=True
)


In [None]:
print(watermarked_audio.abs().max().item())


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Set up a device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def test_losses():
    # Create dummy input data
    batch_size = 2
    audio_length = 16000  # 1 second at 16kHz
    
    # Original and watermarked audio
    original = torch.randn(batch_size, 1, audio_length, device=device)
    watermarked = original + 0.01 * torch.randn(batch_size, 1, audio_length, device=device)
    
    # Detector output (1 detection channel + 8 bit channels)
    num_bits = 8
    detector_out = torch.rand(batch_size, 1 + num_bits, audio_length, device=device)
    detector_out[:, 0, :] = torch.sigmoid(torch.randn(batch_size, audio_length, device=device))  # detection probs
    
    # Binary mask and message
    mask = torch.ones(batch_size, 1, audio_length, device=device)
    message = torch.randint(0, 2**num_bits - 1, (batch_size,), device=device)
    
    # Test each loss
    print("Testing losses...")
    
    try:
        # Test MultiScaleMelLoss
        mel_loss = MultiScaleMelLoss().to(device)
        loss_value = mel_loss(original, watermarked)
        print(f"✓ MultiScaleMelLoss: {loss_value.item():.6f}")
    except Exception as e:
        print(f"✗ MultiScaleMelLoss error: {str(e)}")
    
    try:
        # Test AdversarialLoss
        adv_loss = AdversarialLoss().to(device)
        loss_value = adv_loss(original, watermarked)
        print(f"✓ AdversarialLoss: {loss_value.item():.6f}")
    except Exception as e:
        print(f"✗ AdversarialLoss error: {str(e)}")
    
    try:
        # Test TFLoudnessLoss
        tf_loss = TFLoudnessLoss().to(device)
        loss_value = tf_loss(original, watermarked)
        print(f"✓ TFLoudnessLoss: {loss_value.item():.6f}")
    except Exception as e:
        print(f"✗ TFLoudnessLoss error: {str(e)}")
    
    try:
        # Test masked_localization_loss
        loss_value = masked_localization_loss(detector_out, mask)
        print(f"✓ masked_localization_loss: {loss_value.item():.6f}")
    except Exception as e:
        print(f"✗ masked_localization_loss error: {str(e)}")
    
    try:
        # Test decoding_loss
        loss_value = decoding_loss(detector_out, message, mask)
        print(f"✓ decoding_loss: {loss_value.item():.6f}")
    except Exception as e:
        print(f"✗ decoding_loss error: {str(e)}")

if __name__ == "__main__":
    test_losses()

Testing losses...
✓ MultiScaleMelLoss: 0.006972
✓ AdversarialLoss: 0.705468
✓ TFLoudnessLoss: 0.000844
✓ masked_localization_loss: 0.324282
✓ decoding_loss: 0.172981


In [None]:
def test_gradients():
    # Original and watermarked audio with requires_grad=True
    original = torch.randn(2, 1, 16000, device=device, requires_grad=True)
    watermarked = original + 0.01 * torch.randn(2, 1, 16000, device=device, requires_grad=True)
    
    # Test gradient flow for each loss
    losses_to_test = [
        ("MultiScaleMelLoss", lambda: MultiScaleMelLoss().to(device)(original, watermarked)),
        ("AdversarialLoss", lambda: AdversarialLoss().to(device)(original, watermarked)),
        ("TFLoudnessLoss", lambda: TFLoudnessLoss().to(device)(original, watermarked))
    ]
    
    print("\nTesting gradient flow...")
    
    for loss_name, loss_fn in losses_to_test:
        try:
            # Forward pass
            loss = loss_fn()
            
            # Backward pass
            loss.backward()
            
            # Check if gradients exist
            if watermarked.grad is not None:
                grad_norm = watermarked.grad.norm().item()
                print(f"✓ {loss_name} gradients: {grad_norm:.6f}")
            else:
                print(f"✗ {loss_name} no gradients")
                
            # Reset gradients for next test
            original.grad = None
            watermarked.grad = None
            
        except Exception as e:
            print(f"✗ {loss_name} gradient error: {str(e)}")
            original.grad = None
            watermarked.grad = None

In [None]:
def test_loss_behavior():
    print("\nTesting loss behavior...")
    
    # Create audio with varying levels of difference
    original = torch.randn(1, 1, 16000, device=device)
    
    # Test with increasing levels of perturbation
    perturbation_levels = [0.001, 0.01, 0.1, 0.5]
    
    for level in perturbation_levels:
        # Create watermarked version with controlled perturbation
        noise = torch.randn(1, 1, 16000, device=device)
        watermarked = original + level * noise
        
        # Compute losses
        mel_loss = MultiScaleMelLoss().to(device)(original, watermarked).item()
        tf_loss = TFLoudnessLoss().to(device)(original, watermarked).item()
        
        print(f"Perturbation {level:.3f}: Mel Loss = {mel_loss:.6f}, TF Loss = {tf_loss:.6f}")

In [None]:
def test_mini_training():
    print("\nTesting mini training loop...")
    
    # Create simple models for testing
    class SimpleGenerator(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv1d(1, 1, 3, padding=1)
            
        def forward(self, audio, message):
            # Ignore message for simplicity
            return self.conv(audio)
    
    class SimpleDetector(nn.Module):
        def __init__(self, num_bits=8):
            super().__init__()
            self.conv = nn.Conv1d(1, 1 + num_bits, 3, padding=1)
            
        def forward(self, audio):
            return self.conv(audio)
    
    # Instantiate models
    generator = SimpleGenerator().to(device)
    detector = SimpleDetector(num_bits=8).to(device)
    optimizer = torch.optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=0.001)
    
    # Loss functions
    ms_mel_loss = MultiScaleMelLoss().to(device)
    adv_loss = AdversarialLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)
    
    # Mini batch
    original = torch.randn(2, 1, 16000, device=device)
    message = torch.randint(0, 255, (2,), device=device)
    
    # Training step
    try:
        # Forward pass
        delta = generator(original, message)
        watermarked = original + delta
        det_out = detector(watermarked)
        
        # Compute losses
        loss_l1 = F.l1_loss(delta, torch.zeros_like(delta))
        loss_msspec = ms_mel_loss(original, watermarked)
        loss_adv = adv_loss(original, watermarked)
        loss_loud = tf_loud_loss(original, watermarked)
        mask = torch.ones((2, 1, 16000), device=device)
        loss_loc = masked_localization_loss(det_out, mask)
        loss_dec = decoding_loss(det_out, message, mask)
        
        # Total loss
        lambda_weights = [1.0, 1.0, 0.1, 0.5, 1.0, 1.0]
        loss = (
            lambda_weights[0] * loss_l1 +
            lambda_weights[1] * loss_msspec +
            lambda_weights[2] * loss_adv +
            lambda_weights[3] * loss_loud +
            lambda_weights[4] * loss_loc +
            lambda_weights[5] * loss_dec
        )
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"✓ Mini training successful with total loss: {loss.item():.6f}")
        print(f"  L1: {loss_l1.item():.4f}, Mel: {loss_msspec.item():.4f}, Adv: {loss_adv.item():.4f}")
        print(f"  Loud: {loss_loud.item():.4f}, Loc: {loss_loc.item():.4f}, Dec: {loss_dec.item():.4f}")
        
    except Exception as e:
        print(f"✗ Mini training failed: {str(e)}")

# Run all tests
test_losses()
test_gradients()
test_loss_behavior()
test_mini_training()

Testing losses...
✓ MultiScaleMelLoss: 0.006790
✓ AdversarialLoss: 0.700679
✓ TFLoudnessLoss: 0.000801
✓ masked_localization_loss: 0.328260
✓ decoding_loss: 0.173435

Testing gradient flow...
✗ MultiScaleMelLoss no gradients
✗ AdversarialLoss gradient error: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
✗ TFLoudnessLoss gradient error: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after ca

  if watermarked.grad is not None:
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [30,0,0], thread: [96,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [30,0,0], thread: [99,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [30,0,0], thread: [101,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [30,0,0], thread: [102,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [30,0,0], thread: [104,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [30,0,0], thread: [105,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [30,0,0], thread: [109,0,0] Ass

Perturbation 0.500: Mel Loss = 0.555924, TF Loss = 0.195192

Testing mini training loop...
✗ Mini training failed: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

