In [1]:
import os
import glob
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import confusion_matrix, classification_report
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt

# Fix random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# ---------------------- Hyperparameters ---------------------- #
SAMPLE_RATE = 16000   # Audio sample rate
AUDIO_LEN   = 16000   # 1-second audio (16k samples)
BATCH_SIZE  = 256      # Batch size for training
LR          = 1e-3    # Learning rate

HIDDEN_DIM  = 32      # Hidden dimension for LSTM in Generator
CHANNELS    = 32      # Initial convolution channels
OUTPUT_CH   = 128     # Final conv channels for Generator
STRIDES     = [2, 4, 5, 8]  # Downsampling strides
LSTM_LAYERS = 2       # Number of LSTM layers
NUM_WORKERS = 16       # DataLoader workers

# Loss Weights for this Phase 1 (Generator-only)
lambda_L1     = 1.0
lambda_msspec = 1.0
lambda_loud   = 0.5
lambda_loc    = 1.0

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [2]:
class OneSecClipsDataset(Dataset):
    """
    Assumes each .wav file in root_dir is a ~1-sec clip (16k samples).
    If sample_rate != 16000, it resamples to 16k.
    """
    def __init__(self, root_dir, sample_rate=SAMPLE_RATE):
        super().__init__()
        self.filepaths = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        wav_path = self.filepaths[idx]
        waveform, sr = torchaudio.load(wav_path)

        # Convert to mono if multi-channel
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if needed
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform = resampler(waveform)

        # Ensure the clip is 1 second (pad/crop if necessary)
        if waveform.shape[1] > AUDIO_LEN:
            waveform = waveform[:, :AUDIO_LEN]
        elif waveform.shape[1] < AUDIO_LEN:
            pad_amt = AUDIO_LEN - waveform.shape[1]
            waveform = F.pad(waveform, (0, pad_amt))

        return waveform  # shape: (1, AUDIO_LEN)

def watermark_masking_augmentation(wav, p_replace_orig=0.4, p_replace_zero=0.2, p_replace_diff=0.2):
    """
    Randomly masks portions of the audio:
    - p_replace_orig: do nothing
    - p_replace_zero: replace segment with zeros
    - p_replace_diff: replace segment with random noise
    """
    T = wav.shape[1]
    window_len = int(0.1 * SAMPLE_RATE)
    k = 5
    for _ in range(k):
        start = random.randint(0, T - window_len)
        end = start + window_len
        choice = random.random()
        if choice < p_replace_orig:
            pass
        elif choice < p_replace_orig + p_replace_zero:
            wav[:, start:end] = 0.0
        elif choice < p_replace_orig + p_replace_zero + p_replace_diff:
            wav[:, start:end] = 0.1 * torch.randn_like(wav[:, start:end])
        else:
            pass
    return wav

def robustness_augmentations(wav):
    """
    Adds small random noise for robustness.
    """
    return wav + 0.005 * torch.randn_like(wav)


In [3]:
def make_conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1):
    return nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)

class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.downsample = (stride != 1 or in_ch != out_ch)
        self.conv1 = make_conv1d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
        self.conv2 = make_conv1d(out_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.elu   = nn.ELU()
        if self.downsample:
            self.skip_conv = make_conv1d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0)

    def forward(self, x):
        residual = x
        out = self.elu(self.conv1(x))
        out = self.conv2(out)
        if self.downsample:
            residual = self.skip_conv(residual)
        out = self.elu(out + residual)
        return out

class Generator(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM, 
                 output_channels=OUTPUT_CH, 
                 strides=STRIDES):
        super().__init__()
        
        # ---------- Encoder ----------
        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)
        
        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        # Project encoder output to hidden_dim (for LSTM)
        self.proj = nn.Linear(ch, hidden_dim)

        # LSTM
        self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, 
                            num_layers=LSTM_LAYERS, batch_first=True, bidirectional=False)

        self.final_conv_enc = nn.Conv1d(hidden_dim, output_channels, kernel_size=7, stride=1, padding=3)

        # ---------- Decoder ----------
        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = output_channels
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st,
                                                 padding=(st//2), output_padding=0))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.decoder_blocks = nn.Sequential(*dec_blocks)

        # Final conv -> 1 channel for the delta
        self.final_conv_dec = nn.Conv1d(in_ch, 1, kernel_size=7, stride=1, padding=3)

    def forward(self, s):
        """
        s: shape (B, 1, T)
        Output: delta (B, 1, T)
        """
        B, _, T = s.shape

        # Encode
        x = self.init_conv(s)
        x = self.encoder_blocks(x)  
        x_t = x.transpose(1, 2)     # (B, T_enc, ch)
        x_t = self.proj(x_t)        # (B, T_enc, hidden_dim)

        # LSTM
        lstm_out, _ = self.lstm(x_t)           # (B, T_enc, hidden_dim)
        lstm_out_t = lstm_out.transpose(1, 2)  # (B, hidden_dim, T_enc)
        latent = self.final_conv_enc(lstm_out_t)

        # Decode
        x_dec = self.decoder_blocks(latent)
        delta = self.final_conv_dec(x_dec)

        # Adjust shape if needed
        if delta.shape[-1] != T:
            min_len = min(delta.shape[-1], T)
            delta = delta[:, :, :min_len]
            if min_len < T:
                pad_amt = T - min_len
                delta = F.pad(delta, (0, pad_amt))

        return delta


In [4]:
import torchaudio.transforms as T

class SimpleMelLoss(nn.Module):
    def __init__(self, sample_rate=SAMPLE_RATE, n_fft=1024, n_mels=80):
        super(SimpleMelLoss, self).__init__()
        self.mel_spec = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=n_fft // 4,
            n_mels=n_mels,
            normalized=True
        )
        
    def forward(self, original, watermarked):
        mel_orig = torch.log(self.mel_spec(original) + 1e-5)
        mel_wm   = torch.log(self.mel_spec(watermarked) + 1e-5)
        return F.l1_loss(mel_orig, mel_wm)

class TFLoudnessLoss(nn.Module):
    def __init__(self, n_bands=8, window_size=2048, hop_size=512):
        super(TFLoudnessLoss, self).__init__()
        self.n_bands = n_bands
        self.win_size = window_size
        self.hop_size = hop_size
        
        weights = torch.ones(n_bands)
        mid_band_idx = n_bands // 3
        weights[mid_band_idx:2 * mid_band_idx] = 1.5
        self.register_buffer('band_weights', weights)

    def forward(self, original, watermarked):
        window = torch.hann_window(self.win_size, device=original.device)
        stft_orig = torch.stft(
            original.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        stft_wm = torch.stft(
            watermarked.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        mag_orig = stft_orig.abs()
        mag_wm   = stft_wm.abs()
        phase_orig = stft_orig.angle()
        phase_wm   = stft_wm.angle()
        
        freq_bins = mag_orig.shape[1]
        band_size = freq_bins // self.n_bands
        
        loudness_loss = 0.0
        spectral_loss = 0.0
        phase_loss = 0.0
        
        for b in range(self.n_bands):
            start = b * band_size
            end = freq_bins if (b == self.n_bands - 1) else (start + band_size)
            band_orig = mag_orig[:, start:end, :]
            band_wm = mag_wm[:, start:end, :]
            
            energy_orig = torch.sum(band_orig ** 2, dim=1)
            energy_wm = torch.sum(band_wm ** 2, dim=1)
            loud_orig = torch.log10(energy_orig + 1e-8)
            loud_wm   = torch.log10(energy_wm + 1e-8)
            loudness_loss += self.band_weights[b] * F.l1_loss(loud_wm, loud_orig)
            spectral_loss += self.band_weights[b] * F.mse_loss(band_wm, band_orig)
            phase_diff = 1.0 - torch.cos(phase_wm[:, start:end, :] - phase_orig[:, start:end, :])
            phase_loss += self.band_weights[b] * phase_diff.mean()
        
        loudness_loss /= self.n_bands
        spectral_loss /= self.n_bands
        phase_loss /= self.n_bands
        
        return loudness_loss + spectral_loss + 0.2 * phase_loss


In [None]:
def train_generator(
    generator,
    train_ds,
    val_ds,
    device,
    num_epochs=10,
    lr=1e-3,
    lambda_L1=1.0,
    lambda_msspec=1.0,
    lambda_loud=0.5,
    use_robustness=False
):
    """
    Trains only the Generator (no adversarial or detector),
    with a separate validation set.
    """
    # Create DataLoaders
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    optimizer = optim.Adam(generator.parameters(), lr=lr)
    generator.train()

    ms_mel_loss_fn = SimpleMelLoss().to(device)
    tf_loud_loss_fn = TFLoudnessLoss().to(device)

    for epoch in range(1, num_epochs+1):
        # ------------------- TRAINING LOOP ------------------- #
        generator.train()
        train_pbar = tqdm(train_loader, desc=f"[Epoch {epoch}/{num_epochs} - Train]", leave=True)
        
        running_l1, running_mel, running_loud, running_total = 0.0, 0.0, 0.0, 0.0
        num_batches = len(train_loader)

        for batch_data in train_pbar:
            s = batch_data.to(device)  # (B,1,T)
            optimizer.zero_grad()
            
            # Forward pass in full precision (no autocast)
            delta = generator(s)
            s_w = s + delta

            if use_robustness:
                for i in range(s_w.shape[0]):
                    s_w[i] = watermark_masking_augmentation(s_w[i])
                    s_w[i] = robustness_augmentations(s_w[i])

            loss_l1 = F.l1_loss(delta, torch.zeros_like(delta))
            loss_mel = ms_mel_loss_fn(s, s_w)
            loss_loud = tf_loud_loss_fn(s, s_w)

            loss_total = (lambda_L1 * loss_l1 +
                          lambda_msspec * loss_mel +
                          lambda_loud * loss_loud)

            # Standard backward pass
            loss_total.backward()
            optimizer.step()
        
            running_l1    += loss_l1.item()
            running_mel   += loss_mel.item()
            running_loud  += loss_loud.item()
            running_total += loss_total.item()

            train_pbar.set_postfix({
                "L1": f"{loss_l1.item():.4f}",
                "Mel": f"{loss_mel.item():.4f}",
                "Loud": f"{loss_loud.item():.4f}",
                "Total": f"{loss_total.item():.4f}"
            })

        avg_l1    = running_l1 / num_batches
        avg_mel   = running_mel / num_batches
        avg_loud  = running_loud / num_batches
        avg_total = running_total / num_batches

        # ------------------- VALIDATION LOOP ------------------- #
        generator.eval()
        val_pbar = tqdm(val_loader, desc="[Validation]", leave=False)
        
        val_l1, val_mel, val_loud, val_total = 0.0, 0.0, 0.0, 0.0
        val_steps = 0

        with torch.no_grad():
            for batch_data in val_pbar:
                s_val = batch_data.to(device)
                delta_val = generator(s_val)
                s_w_val   = s_val + delta_val

                l1_val   = F.l1_loss(delta_val, torch.zeros_like(delta_val))
                mel_val  = ms_mel_loss_fn(s_val, s_w_val)
                loud_val = tf_loud_loss_fn(s_val, s_w_val)
                total_val = (lambda_L1 * l1_val +
                             lambda_msspec * mel_val +
                             lambda_loud * loud_val)

                val_l1    += l1_val.item()
                val_mel   += mel_val.item()
                val_loud  += loud_val.item()
                val_total += total_val.item()
                val_steps += 1

        if val_steps > 0:
            val_l1    /= val_steps
            val_mel   /= val_steps
            val_loud  /= val_steps
            val_total /= val_steps

        print(f"\nEpoch [{epoch}/{num_epochs}] Summary:")
        print(f"  Train => L1: {avg_l1:.4f}, Mel: {avg_mel:.4f}, Loud: {avg_loud:.4f}, Total: {avg_total:.4f}")
        print(f"  Valid => L1: {val_l1:.4f}, Mel: {val_mel:.4f}, Loud: {val_loud:.4f}, Total: {val_total:.4f}\n")

    print("Generator-only training with validation complete.")


In [6]:
def evaluate_si_snr_torch(original, reconstructed, eps=1e-8):
    if original.dim() == 3:
        original = original.squeeze(1)
    if reconstructed.dim() == 3:
        reconstructed = reconstructed.squeeze(1)
    
    original_zm = original - original.mean(dim=1, keepdim=True)
    recon_zm    = reconstructed - reconstructed.mean(dim=1, keepdim=True)
    
    dot = (original_zm * recon_zm).sum(dim=1, keepdim=True)
    norm_sq = (original_zm ** 2).sum(dim=1, keepdim=True) + eps
    alpha = dot / norm_sq
    
    s_target = alpha * original_zm
    e_noise = recon_zm - s_target
    si_snr_val = 10 * torch.log10((s_target ** 2).sum(dim=1) / ((e_noise ** 2).sum(dim=1) + eps))
    return si_snr_val


def run_evaluation_generator_only(generator, dataset, device, batch_size=16):
    """
    Evaluate how much the generator changes the audio 
    by measuring SI-SNR or any custom metric.
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    generator.eval()

    si_snr_vals = []

    with torch.no_grad():
        for s in loader:
            s = s.to(device)  # (B,1,T)
            delta = generator(s)
            s_w = s + delta

            # measure SI-SNR
            vals = evaluate_si_snr_torch(s, s_w).cpu().numpy()
            si_snr_vals.extend(vals)
    
    avg_si_snr = float(np.mean(si_snr_vals)) if si_snr_vals else 0.0
    return avg_si_snr


In [7]:

################################################################################
# 1) DETECTOR ARCHITECTURE
################################################################################

class Detector(nn.Module):
    def __init__(self, in_channels=1, base_channels=32, strides=[2,4,5,8]):
        super().__init__()
        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)
        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = ch
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(nn.ConvTranspose1d(
                in_ch, out_ch,
                kernel_size=2*st, stride=st,
                padding=(st//2), output_padding=0
            ))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.decoder_blocks = nn.Sequential(*dec_blocks)

        self.final_conv = nn.Conv1d(in_ch, 1, kernel_size=7, stride=1, padding=3)

    def forward(self, x):
        original_length = x.shape[-1]
        x = self.init_conv(x)
        x = self.encoder_blocks(x)
        x = self.decoder_blocks(x)
        out = self.final_conv(x)
        if out.shape[-1] > original_length:
            out = out[:, :, :original_length]
        elif out.shape[-1] < original_length:
            pad_amt = original_length - out.shape[-1]
            out = F.pad(out, (0, pad_amt))
        return torch.sigmoid(out)

In [8]:
class SimpleMelLoss(nn.Module): 
    def __init__(self, sample_rate=SAMPLE_RATE, n_fft=1024, n_mels=80):
        super().__init__()
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=n_fft // 4,
            n_mels=n_mels,
            normalized=True
        )

    def forward(self, original, watermarked):
        mel_orig = torch.log(self.mel_spec(original) + 1e-5)
        mel_wm   = torch.log(self.mel_spec(watermarked) + 1e-5)
        return F.l1_loss(mel_orig, mel_wm)

class TFLoudnessLoss(nn.Module):
    def __init__(self, n_bands=8, window_size=2048, hop_size=512):
        super().__init__()
        self.n_bands = n_bands
        self.win_size = window_size
        self.hop_size = hop_size
        
        weights = torch.ones(n_bands)
        mid_band_idx = n_bands // 3
        weights[mid_band_idx:2 * mid_band_idx] = 1.5
        self.register_buffer('band_weights', weights)

    def forward(self, original, watermarked):
        window = torch.hann_window(self.win_size, device=original.device)
        stft_orig = torch.stft(
            original.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        stft_wm = torch.stft(
            watermarked.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        mag_orig = stft_orig.abs()
        mag_wm   = stft_wm.abs()
        phase_orig = stft_orig.angle()
        phase_wm   = stft_wm.angle()
        
        freq_bins = mag_orig.shape[1]
        band_size = freq_bins // self.n_bands
        
        loudness_loss = 0.0
        spectral_loss = 0.0
        phase_loss = 0.0
        
        for b in range(self.n_bands):
            start = b * band_size
            end = freq_bins if (b == self.n_bands - 1) else (start + band_size)
            band_orig = mag_orig[:, start:end, :]
            band_wm   = mag_wm[:, start:end, :]

            energy_orig = torch.sum(band_orig ** 2, dim=1)
            energy_wm   = torch.sum(band_wm ** 2, dim=1)
            loud_orig   = torch.log10(energy_orig + 1e-8)
            loud_wm     = torch.log10(energy_wm + 1e-8)
            loudness_loss += self.band_weights[b] * F.l1_loss(loud_wm, loud_orig)
            spectral_loss += self.band_weights[b] * F.mse_loss(band_wm, band_orig)
            phase_diff     = 1.0 - torch.cos(phase_wm[:, start:end, :] - phase_orig[:, start:end, :])
            phase_loss    += self.band_weights[b] * phase_diff.mean()
        
        loudness_loss /= self.n_bands
        spectral_loss /= self.n_bands
        phase_loss    /= self.n_bands
        
        return loudness_loss + spectral_loss + 0.2 * phase_loss

In [9]:
################################################################################
# 2) DETECTION LOSS: Masked Localization (focal BCE)
################################################################################

def masked_localization_loss(detector_out, mask, smooth_eps=0.1):
    """
    A binary cross-entropy style detection loss with focal weighting + label smoothing.
    detector_out: (B,1,T) in [0,1]
    mask: (B,1,T) => 1 for watermarked, 0 for clean
    smooth_eps: label smoothing factor
    """
    det_prob = detector_out  # shape (B,1,T)
    smoothed_mask = mask * (1 - smooth_eps) + (1 - mask) * smooth_eps

    # Focal weighting
    pt = torch.where(mask > 0.5, det_prob, 1 - det_prob)  # probability of the "correct" label
    focal_weight = (1 - pt) ** 2

    bce_loss = F.binary_cross_entropy(det_prob, smoothed_mask, reduction='none')
    focal_loss = focal_weight * bce_loss

    return focal_loss.mean()

In [10]:
import torch.optim.lr_scheduler as lr_sched

def train_generator_detector(
    generator,
    detector,
    train_dataset,
    val_dataset,
    device,
    num_epochs=10,
    batch_size=64,
    lr=1e-3,
    lambda_L1=1.0,
    lambda_msspec=1.0,
    lambda_loud=0.5,
    lambda_loc=1.0,
    use_robustness=False
):
    """
    Train the (unfrozen) Generator + new Detector with a Cosine LR Scheduler 
    and validation each epoch.
    """
    # 1) Optionally compile for speed
    generator = torch.compile(generator)
    detector  = torch.compile(detector)

    # 2) Create DataLoaders with pin_memory
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )

    # 3) Prepare losses + optimizer + scheduler
    ms_mel_loss = SimpleMelLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)
    optimizer = optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=lr)

    # Here we set T_max = num_epochs, so it decays over the entire training.
    # After 'num_epochs' calls to scheduler.step(), LR is near 'eta_min'.
    scheduler = lr_sched.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs,  # or total_iters if you want a per-iteration approach
        eta_min=1e-5       # final LR
    )

    generator.train()
    detector.train()

    for epoch in range(1, num_epochs + 1):
        #######################
        #      TRAIN LOOP
        #######################
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]", leave=True)
        train_total_loss = 0.0
        train_steps = 0

        for batch_data in train_pbar:
            s = batch_data.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            delta = generator(s)
            s_w   = s + delta

            if use_robustness:
                for b_idx in range(s_w.shape[0]):
                    s_w[b_idx] = watermark_masking_augmentation(s_w[b_idx])
                    s_w[b_idx] = robustness_augmentations(s_w[b_idx])

            # Create label mask: 1 for watermarked, 0 for clean
            combined = torch.cat([s_w, s], dim=0)
            label_mask = torch.cat([
                torch.ones_like(s),
                torch.zeros_like(s)
            ], dim=0).to(device)

            det_out = detector(combined)

            # Losses
            loss_l1     = F.l1_loss(delta, torch.zeros_like(delta))
            loss_msspec = ms_mel_loss(s, s_w)
            loss_loud   = tf_loud_loss(s, s_w)
            loss_loc    = masked_localization_loss(det_out, label_mask, smooth_eps=0.1)

            loss = (lambda_L1     * loss_l1 +
                    lambda_msspec * loss_msspec +
                    lambda_loud   * loss_loud +
                    lambda_loc    * loss_loc)

            loss.backward()
            optimizer.step()

            train_total_loss += loss.item()
            train_steps += 1
            train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        # End of train loop
        train_avg_loss = train_total_loss / train_steps

        #######################
        #     VALIDATION LOOP
        #######################
        generator.eval()
        detector.eval()
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Val]", leave=False)
        val_total_loss = 0.0
        val_steps = 0

        with torch.no_grad():
            for batch_data in val_pbar:
                s = batch_data.to(device)
                delta = generator(s)
                s_w   = s + delta

                combined = torch.cat([s_w, s], dim=0)
                label_mask = torch.cat([
                    torch.ones_like(s),
                    torch.zeros_like(s)
                ], dim=0).to(device)

                det_out = detector(combined)

                loss_l1     = F.l1_loss(delta, torch.zeros_like(delta))
                loss_msspec = ms_mel_loss(s, s_w)
                loss_loud   = tf_loud_loss(s, s_w)
                loss_loc    = masked_localization_loss(det_out, label_mask, smooth_eps=0.1)

                val_loss = (lambda_L1     * loss_l1 +
                            lambda_msspec * loss_msspec +
                            lambda_loud   * loss_loud +
                            lambda_loc    * loss_loc)

                val_total_loss += val_loss.item()
                val_steps += 1

        val_avg_loss = val_total_loss / val_steps if val_steps>0 else 0.0

        # Step the LR scheduler once after each epoch
        scheduler.step()

        generator.train()
        detector.train()
        
        print(f"\nEpoch [{epoch}/{num_epochs}] Summary:")
        print(f"  Train Avg Loss: {train_avg_loss:.4f}")
        print(f"  Val   Avg Loss: {val_avg_loss:.4f}")
        current_lr = scheduler.get_last_lr()[0]
        print(f"  Current LR: {current_lr:.6f}\n")

    print("Finished training (Generator + Detector) with Cosine LR + validation each epoch.")
    return generator, detector


In [11]:
def evaluate_detector(
    generator,
    detector,
    test_dataset,
    device,
    threshold=0.5,
    batch_size=16
):
    """
    Evaluate the detector on a test set.
    For each clip: watermarked => label=1, clean => label=0.
    Computes confusion matrix, classification report, TPR, FPR, and accuracy.
    """
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )

    generator.eval()
    detector.eval()

    y_true = []
    y_score = []

    with torch.no_grad():
        for s in test_loader:
            s = s.to(device)
            B = s.shape[0]

            delta = generator(s)
            s_w   = s + delta

            combined = torch.cat([s_w, s], dim=0)
            labels = np.concatenate([np.ones(B), np.zeros(B)], axis=0)

            det_out = detector(combined)  # (2B,1,T)
            scores = det_out.mean(dim=2).squeeze(1).cpu().numpy()  # shape (2B,)

            y_true.append(labels)
            y_score.append(scores)

    y_true = np.concatenate(y_true, axis=0)
    y_score = np.concatenate(y_score, axis=0)

    y_pred = (y_score >= threshold).astype(int)

    cm = confusion_matrix(y_true, y_pred)
    clf_report = classification_report(y_true, y_pred, target_names=["Clean(0)", "Watermarked(1)"])

    tp = np.sum((y_pred == 1) & (y_true == 1))
    fp = np.sum((y_pred == 1) & (y_true == 0))
    tn = np.sum((y_pred == 0) & (y_true == 0))
    fn = np.sum((y_pred == 0) & (y_true == 1))

    tpr = tp / (tp + fn + 1e-8)
    fpr = fp / (fp + tn + 1e-8)
    acc = (tp + tn) / (tp + tn + fp + fn + 1e-8)

    print(f"\n--- Detector Evaluation (threshold={threshold}) ---")
    print(f"Accuracy: {acc:.3f},  TPR: {tpr:.3f},  FPR: {fpr:.3f}")
    print("Confusion Matrix:\n", cm)
    print("Classification Report:\n", clf_report)

    return {
        "accuracy": acc,
        "TPR": tpr,
        "FPR": fpr,
        "confusion_matrix": cm,
        "classification_report": clf_report
    }


In [12]:
if __name__ == "__main__":
    ############################################################################
    # PHASE 2 MAIN
    ############################################################################
    data_root = "data/200_speech_only"
    full_dataset = OneSecClipsDataset(root_dir=data_root, sample_rate=SAMPLE_RATE)
    n = len(full_dataset)
    # n_train = int(0.8 * n)
    # n_val   = int(0.1 * n)
    # n_test  = n - n_train - n_val
    # train_ds, val_ds, test_ds = random_split(full_dataset, [n_train, n_val, n_test])
    subset_size = 1000
    subset_indices = list(range(min(subset_size, len(full_dataset))))
    subset_dataset = torch.utils.data.Subset(full_dataset, subset_indices)

    n = len(subset_dataset)
    n_train = int(0.8 * n)
    n_val   = int(0.1 * n)
    n_test  = n - n_train - n_val
    train_ds, val_ds, test_ds = random_split(subset_dataset, [n_train, n_val, n_test])
    print(f"Subset Dataset Size: {len(subset_dataset)}")
    print(f"Train/Val/Test Split: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}")
    print(f"Dataset => Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

    # 1) Load Phase 1 Generator
    generator = Generator().to(device)
    gen_sd = torch.load("generator.pth", map_location=device)

    # Fix any "_orig_mod." prefix
    fixed_state_dict = {}
    for k, v in gen_sd.items():
        if k.startswith("_orig_mod."):
            new_k = k.replace("_orig_mod.", "")
            fixed_state_dict[new_k] = v
        else:
            fixed_state_dict[k] = v

    generator.load_state_dict(fixed_state_dict, strict=True)
    generator.eval()
    print("Loaded Phase 1 Generator. Ready for Phase 2 training.")

    # 2) Create new Detector
    detector = Detector(in_channels=1, base_channels=32, strides=[2,4,5,8]).to(device)
    detector.eval()
    print("Created new Detector instance.")

    # 3) Train them together with validation
    num_epochs_phase2 = 10
    batch_size_phase2 = 32
    gen2, det2 = train_generator_detector(
        generator=generator,
        detector=detector,
        train_dataset=train_ds,
        val_dataset=val_ds,
        device=device,
        num_epochs=num_epochs_phase2,
        batch_size=batch_size_phase2,
        lr=1e-3,
        lambda_L1=1.0,
        lambda_msspec=1.0,
        lambda_loud=0.5,
        lambda_loc=2.0,
        use_robustness=False
    )

    # 4) Save the newly fine-tuned Generator + Detector
    torch.save(gen2.state_dict(), "generator_phase2.pth")
    torch.save(det2.state_dict(), "detector_phase2.pth")
    print("Saved generator_phase2.pth and detector_phase2.pth.")

    # 5) Evaluate on the Test set (10% from above)
    metrics = evaluate_detector(
        generator=gen2,
        detector=det2,
        test_dataset=test_ds,
        device=device,
        threshold=0.5,
        batch_size=16
    )
    print("\nFinal Test Metrics:", metrics)


Subset Dataset Size: 1000
Train/Val/Test Split: 800/100/100
Dataset => Train: 800, Val: 100, Test: 100


  gen_sd = torch.load("generator.pth", map_location=device)


Loaded Phase 1 Generator. Ready for Phase 2 training.
Created new Detector instance.


Epoch 1/10 [Train]: 100%|██████████| 25/25 [00:22<00:00,  1.12it/s, loss=0.3583]
Epoch 1/10 [Val]:  50%|█████     | 2/4 [00:01<00:01,  1.22it/s]W0314 18:09:01.137000 3354 torch/_dynamo/convert_frame.py:844] [1/8] torch._dynamo hit config.cache_size_limit (8)
W0314 18:09:01.137000 3354 torch/_dynamo/convert_frame.py:844] [1/8]    function: 'forward' (/tmp/ipykernel_3354/1045307349.py:14)
W0314 18:09:01.137000 3354 torch/_dynamo/convert_frame.py:844] [1/8]    last reason: 1/0: GLOBAL_STATE changed: grad_mode 
W0314 18:09:01.137000 3354 torch/_dynamo/convert_frame.py:844] [1/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0314 18:09:01.137000 3354 torch/_dynamo/convert_frame.py:844] [1/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
                                                               


Epoch [1/10] Summary:
  Train Avg Loss: 0.3872
  Val   Avg Loss: 0.3572
  Current LR: 0.000976



Epoch 2/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.64it/s, loss=0.3616]
                                                               


Epoch [2/10] Summary:
  Train Avg Loss: 0.3564
  Val   Avg Loss: 0.3569
  Current LR: 0.000905



Epoch 3/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.71it/s, loss=0.3674]
                                                               


Epoch [3/10] Summary:
  Train Avg Loss: 0.3534
  Val   Avg Loss: 0.3713
  Current LR: 0.000796



Epoch 4/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.74it/s, loss=0.3611]
                                                               


Epoch [4/10] Summary:
  Train Avg Loss: 0.3575
  Val   Avg Loss: 0.3650
  Current LR: 0.000658



Epoch 5/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.66it/s, loss=0.3483]
                                                               


Epoch [5/10] Summary:
  Train Avg Loss: 0.3533
  Val   Avg Loss: 0.3548
  Current LR: 0.000505



Epoch 6/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.76it/s, loss=0.3491]
                                                               


Epoch [6/10] Summary:
  Train Avg Loss: 0.3492
  Val   Avg Loss: 0.3486
  Current LR: 0.000352



Epoch 7/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.73it/s, loss=0.3479]
                                                               


Epoch [7/10] Summary:
  Train Avg Loss: 0.3490
  Val   Avg Loss: 0.3485
  Current LR: 0.000214



Epoch 8/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.68it/s, loss=0.3493]
                                                               


Epoch [8/10] Summary:
  Train Avg Loss: 0.3485
  Val   Avg Loss: 0.3491
  Current LR: 0.000105



Epoch 9/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.76it/s, loss=0.3493]
                                                               


Epoch [9/10] Summary:
  Train Avg Loss: 0.3488
  Val   Avg Loss: 0.3486
  Current LR: 0.000034



Epoch 10/10 [Train]: 100%|██████████| 25/25 [00:06<00:00,  3.69it/s, loss=0.3479]
                                                                


Epoch [10/10] Summary:
  Train Avg Loss: 0.3482
  Val   Avg Loss: 0.3482
  Current LR: 0.000010

Finished training (Generator + Detector) with Cosine LR + validation each epoch.
Saved generator_phase2.pth and detector_phase2.pth.





--- Detector Evaluation (threshold=0.5) ---
Accuracy: 0.500,  TPR: 0.030,  FPR: 0.030
Confusion Matrix:
 [[97  3]
 [97  3]]
Classification Report:
                 precision    recall  f1-score   support

      Clean(0)       0.50      0.97      0.66       100
Watermarked(1)       0.50      0.03      0.06       100

      accuracy                           0.50       200
     macro avg       0.50      0.50      0.36       200
  weighted avg       0.50      0.50      0.36       200


Final Test Metrics: {'accuracy': 0.499999999975, 'TPR': 0.029999999997000003, 'FPR': 0.029999999997000003, 'confusion_matrix': array([[97,  3],
       [97,  3]]), 'classification_report': '                precision    recall  f1-score   support\n\n      Clean(0)       0.50      0.97      0.66       100\nWatermarked(1)       0.50      0.03      0.06       100\n\n      accuracy                           0.50       200\n     macro avg       0.50      0.50      0.36       200\n  weighted avg       0.50      0.