In [1]:
import os
import glob
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt

# Fix random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# ---------------------- Hyperparameters ---------------------- #
SAMPLE_RATE = 16000   # Audio sample rate
AUDIO_LEN   = 16000   # 1-second audio (16k samples)
BATCH_SIZE  = 32      # Batch size for training
LR          = 1e-3    # Learning rate
HIDDEN_DIM  = 32      # Hidden dimension for LSTM
CHANNELS    = 32      # Initial convolution channels
OUTPUT_CH   = 128     # Final conv channels for Generator
STRIDES     = [2, 4, 5, 8]  # Downsampling/upsampling strides
LSTM_LAYERS = 2       # Number of LSTM layers
NUM_WORKERS = 4       # Number of DataLoader workers (adjust as needed)

# Loss Weights for Composite Loss
lambda_L1     = 1.0
lambda_msspec = 1.0
lambda_adv    = 0.1
lambda_loud   = 0.5
lambda_loc    = 1.0  # Used for detection/localization BCE

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [2]:
class OneSecClipsDataset(Dataset):
    """
    Assumes each .wav file in root_dir is a ~1-sec clip (16k samples).
    If sample_rate != 16000, it resamples to 16k.
    """
    def __init__(self, root_dir, sample_rate=SAMPLE_RATE):
        super().__init__()
        self.filepaths = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        wav_path = self.filepaths[idx]
        waveform, sr = torchaudio.load(wav_path)

        # Convert to mono if multi-channel
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if needed
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform = resampler(waveform)

        # Ensure the clip is 1 second (pad/crop if necessary)
        if waveform.shape[1] > AUDIO_LEN:
            waveform = waveform[:, :AUDIO_LEN]
        elif waveform.shape[1] < AUDIO_LEN:
            pad_amt = AUDIO_LEN - waveform.shape[1]
            waveform = F.pad(waveform, (0, pad_amt))

        return waveform  # shape: (1, AUDIO_LEN)


def watermark_masking_augmentation(wav, p_replace_orig=0.4, p_replace_zero=0.2, p_replace_diff=0.2):
    """
    Randomly masks portions of the audio:
    - p_replace_orig: do nothing
    - p_replace_zero: replace segment with zeros
    - p_replace_diff: replace segment with random noise
    """
    T = wav.shape[1]
    window_len = int(0.1 * SAMPLE_RATE)  # 0.1 second window
    k = 5  # number of windows to apply augmentation
    for _ in range(k):
        start = random.randint(0, T - window_len)
        end = start + window_len
        choice = random.random()
        if choice < p_replace_orig:
            pass  # no-op
        elif choice < p_replace_orig + p_replace_zero:
            wav[:, start:end] = 0.0
        elif choice < p_replace_orig + p_replace_zero + p_replace_diff:
            wav[:, start:end] = 0.1 * torch.randn_like(wav[:, start:end])
        else:
            pass
    return wav
#AUGMENTATIONS removed 

# def robustness_augmentations(wav):
#     """
#     Adds small random noise for robustness.
#     """
#     return wav + 0.005 * torch.randn_like(wav)


In [3]:
def make_conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1):
    return nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)

class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.downsample = (stride != 1 or in_ch != out_ch)
        self.conv1 = make_conv1d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
        self.conv2 = make_conv1d(out_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.elu   = nn.ELU()
        if self.downsample:
            self.skip_conv = make_conv1d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0)

    def forward(self, x):
        residual = x
        out = self.elu(self.conv1(x))
        out = self.conv2(out)
        if self.downsample:
            residual = self.skip_conv(residual)
        out = self.elu(out + residual)
        return out


In [4]:
class Generator(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM, 
                 output_channels=OUTPUT_CH, 
                 strides=STRIDES):
        super().__init__()
        
        # ------------------- Encoder ------------------- #
        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)
        
        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        # Project encoder output to hidden_dim (for LSTM)
        self.proj = nn.Linear(ch, hidden_dim)

        # LSTM to process temporal sequence (optional but we keep it)
        self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, 
                            num_layers=LSTM_LAYERS, batch_first=True, bidirectional=False)

        self.final_conv_enc = nn.Conv1d(hidden_dim, output_channels, kernel_size=7, stride=1, padding=3)

        # ------------------- Decoder ------------------- #
        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = output_channels
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st,
                                                 padding=(st//2), output_padding=0))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.decoder_blocks = nn.Sequential(*dec_blocks)

        # Final convolution to produce the delta (watermark perturbation)
        self.final_conv_dec = nn.Conv1d(in_ch, 1, kernel_size=7, stride=1, padding=3)

    def forward(self, s):
        """
        s: shape (B, 1, T)
        Output: delta of shape (B, 1, T)
        """
        B, _, T = s.shape

        # Encode
        x = self.init_conv(s)
        x = self.encoder_blocks(x)  # shape (B, ch, T_enc)
        x_t = x.transpose(1, 2)     # (B, T_enc, ch)
        x_t = self.proj(x_t)        # (B, T_enc, hidden_dim)

        # LSTM
        lstm_out, _ = self.lstm(x_t)          # (B, T_enc, hidden_dim)
        lstm_out_t = lstm_out.transpose(1, 2) # (B, hidden_dim, T_enc)
        latent = self.final_conv_enc(lstm_out_t)

        # Decode
        x_dec = self.decoder_blocks(latent)
        delta = self.final_conv_dec(x_dec)

        # Match original length if needed
        if delta.shape[-1] != T:
            min_len = min(delta.shape[-1], T)
            delta = delta[:, :, :min_len]
            if min_len < T:
                pad_amt = T - min_len
                delta = F.pad(delta, (0, pad_amt))

        return delta


In [5]:
class Detector(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM,
                 strides=STRIDES):
        super().__init__()

        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)

        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = ch
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st,
                                                 padding=(st//2), output_padding=0))
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.upsample_blocks = nn.Sequential(*dec_blocks)

        # Final conv -> 1 channel for detection probability
        self.final_conv = nn.Conv1d(base_channels, 1, kernel_size=7, stride=1, padding=3)

    def forward(self, x):
        """
        x: shape (B, 1, T)
        Output: shape (B, 1, T) in [0,1] -> detection probability over time
        """
        original_length = x.shape[-1]
        x = self.init_conv(x)
        x = self.encoder_blocks(x)
        x = self.upsample_blocks(x)
        out = self.final_conv(x)

        # Clamp/pad to original length if needed
        if out.shape[-1] > original_length:
            out = out[:, :, :original_length]
        elif out.shape[-1] < original_length:
            pad_amt = original_length - out.shape[-1]
            out = F.pad(out, (0, pad_amt))

        return torch.sigmoid(out)


In [6]:
import torchaudio.transforms as T

# Simple Mel Loss
class SimpleMelLoss(nn.Module):
    def __init__(self, sample_rate=SAMPLE_RATE, n_fft=1024, n_mels=80):
        super(SimpleMelLoss, self).__init__()
        self.mel_spec = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=n_fft // 4,
            n_mels=n_mels,
            normalized=True
        )
        
    def forward(self, original, watermarked):
        mel_orig = torch.log(self.mel_spec(original) + 1e-5)
        mel_wm   = torch.log(self.mel_spec(watermarked) + 1e-5)
        return F.l1_loss(mel_orig, mel_wm)

# TF-Loudness Loss
class TFLoudnessLoss(nn.Module):
    def __init__(self, n_bands=8, window_size=2048, hop_size=512):
        super(TFLoudnessLoss, self).__init__()
        self.n_bands = n_bands
        self.win_size = window_size
        self.hop_size = hop_size
        
        weights = torch.ones(n_bands)
        mid_band_idx = n_bands // 3
        weights[mid_band_idx:2 * mid_band_idx] = 1.5
        self.register_buffer('band_weights', weights)

    def forward(self, original, watermarked):
        window = torch.hann_window(self.win_size, device=original.device)
        stft_orig = torch.stft(
            original.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        stft_wm = torch.stft(
            watermarked.squeeze(1), n_fft=self.win_size, hop_length=self.hop_size,
            window=window, return_complex=True, normalized=True
        )
        mag_orig = stft_orig.abs()
        mag_wm   = stft_wm.abs()
        phase_orig = stft_orig.angle()
        phase_wm   = stft_wm.angle()
        
        freq_bins = mag_orig.shape[1]
        band_size = freq_bins // self.n_bands
        
        loudness_loss = 0.0
        spectral_loss = 0.0
        phase_loss = 0.0
        
        for b in range(self.n_bands):
            start = b * band_size
            end = freq_bins if (b == self.n_bands - 1) else (start + band_size)
            band_orig = mag_orig[:, start:end, :]
            band_wm = mag_wm[:, start:end, :]
            
            energy_orig = torch.sum(band_orig ** 2, dim=1)
            energy_wm = torch.sum(band_wm ** 2, dim=1)
            loud_orig = torch.log10(energy_orig + 1e-8)
            loud_wm   = torch.log10(energy_wm + 1e-8)
            loudness_loss += self.band_weights[b] * F.l1_loss(loud_wm, loud_orig)
            spectral_loss += self.band_weights[b] * F.mse_loss(band_wm, band_orig)
            phase_diff = 1.0 - torch.cos(phase_wm[:, start:end, :] - phase_orig[:, start:end, :])
            phase_loss += self.band_weights[b] * phase_diff.mean()
        
        loudness_loss /= self.n_bands
        spectral_loss /= self.n_bands
        phase_loss /= self.n_bands
        
        return loudness_loss + spectral_loss + 0.2 * phase_loss

# Simple Adversarial Loss
class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(16, 32, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=41, stride=4, padding=20),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 1, kernel_size=41, stride=4, padding=20),
        )
        self.disc_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)

    def forward(self, original, watermarked, train_disc=True):
        if train_disc:
            self.disc_optimizer.zero_grad()
            real_output = self.discriminator(original)
            real_loss = F.binary_cross_entropy_with_logits(real_output, torch.ones_like(real_output))
            fake_output = self.discriminator(watermarked.detach())
            fake_loss = F.binary_cross_entropy_with_logits(fake_output, torch.zeros_like(fake_output))
            disc_loss = real_loss + fake_loss
            disc_loss.backward()
            self.disc_optimizer.step()

        fake_output = self.discriminator(watermarked)
        gen_loss = F.binary_cross_entropy_with_logits(fake_output, torch.ones_like(fake_output))
        return gen_loss

def masked_localization_loss(detector_out, mask, smooth_eps=0.1):
    """
    Per-sample BCE for detection. 'mask' can be 1 for watermarked, 0 for clean.
    """
    # detector_out: (B,1,T)
    # mask: (B,1,T)
    det_prob = detector_out
    smoothed_mask = mask * (1.0 - smooth_eps) + (1.0 - mask) * smooth_eps

    pt = torch.where(mask > 0.5, det_prob, 1 - det_prob)
    focal_weight = (1 - pt) ** 2
    bce_loss = F.binary_cross_entropy(det_prob, smoothed_mask, reduction='none')
    focal_loss = focal_weight * bce_loss
    return focal_loss.mean()


In [None]:
# def train_one_epoch(generator, detector, train_loader, optimizer, epoch, total_epochs, device):
#     generator.train()
#     detector.train()
#     total_loss = 0.0
#     total_steps = len(train_loader)
    
#     ms_mel_loss = SimpleMelLoss().to(device)
#     adv_loss_stub = AdversarialLoss().to(device)
#     tf_loud_loss = TFLoudnessLoss().to(device)
    
#     pbar = tqdm(enumerate(train_loader), total=total_steps, desc=f"Epoch [{epoch}/{total_epochs}]")
#     for i, batch_data in pbar:
#         s = batch_data.to(device)  # shape: (B, 1, T)
#         B = s.shape[0]
        
#         # 1) Generate watermarked audio => "positive"
#         delta = generator(s)  
#         s_w   = s + delta  
        
#         # Optional data augmentations on watermarked audio
#         for b_idx in range(B):
#             s_w[b_idx] = watermark_masking_augmentation(s_w[b_idx])
        
#         # 2) "Negative" examples = the clean audio 's' (no watermark)
#         #    We can optionally do some augmentations on s if desired, 
#         #    but typically you leave it as normal audio.
        
#         # 3) Combine positives & negatives for the DETECTOR
#         detector_input = torch.cat([s_w, s], dim=0)  # shape (2B,1,T)
#         # Build ground-truth label mask => 1 for watermarked, 0 for clean
#         label_mask = torch.cat([
#             torch.ones_like(s),
#             torch.zeros_like(s)
#         ], dim=0).to(device)  # shape (2B,1,T)
        
#         det_out = detector(detector_input)  # shape (2B,1,T)
        
#         # -------------------- Compute Losses -------------------- #
#         # (A) L1
#         loss_l1     = F.l1_loss(delta, torch.zeros_like(delta))
#         # (B) Mel
#         loss_msspec = ms_mel_loss(s, s_w)
#         # (C) Adversarial
#         loss_adv    = adv_loss_stub(s, s_w, train_disc=True)
#         # (D) Loudness
#         loss_loud   = tf_loud_loss(s, s_w)
#         # (E) Detector BCE
#         loss_loc    = masked_localization_loss(det_out, label_mask)

#         # Composite
#         loss = (lambda_L1     * loss_l1 +
#                 lambda_msspec * loss_msspec +
#                 lambda_adv    * loss_adv +
#                 lambda_loud   * loss_loud +
#                 lambda_loc    * loss_loc)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item()
#         pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
#     avg_loss = total_loss / total_steps
#     return avg_loss


### <span style="color:rgb(255, 0, 128)">Logging individual losses</span>

In [7]:
def train_one_epoch(generator, detector, train_loader, optimizer, epoch, total_epochs, device):
    generator.train()
    detector.train()
    total_loss = 0.0
    total_steps = len(train_loader)
    
    ms_mel_loss = SimpleMelLoss().to(device)
    adv_loss_stub = AdversarialLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)
    
    pbar = tqdm(enumerate(train_loader), total=total_steps, desc=f"Epoch [{epoch}/{total_epochs}]")
    for i, batch_data in pbar:
        s = batch_data.to(device)  # shape: (B, 1, T)
        B = s.shape[0]
        
        # 1) Generate watermarked audio ("positive")
        delta = generator(s)  
        s_w   = s + delta  
        
        # Optional data augmentation on watermarked audio
        for b_idx in range(B):
            s_w[b_idx] = watermark_masking_augmentation(s_w[b_idx])
        
        # 2) "Negative" examples: clean audio 's'
        # 3) Combine positive and negative examples for the detector
        detector_input = torch.cat([s_w, s], dim=0)  # shape (2B, 1, T)
        label_mask = torch.cat([
            torch.ones_like(s),
            torch.zeros_like(s)
        ], dim=0).to(device)  # shape (2B, 1, T)
        
        det_out = detector(detector_input)  # shape (2B, 1, T)
        
        # -------------------- Compute Individual Losses -------------------- #
        loss_l1     = F.l1_loss(delta, torch.zeros_like(delta))
        loss_msspec = ms_mel_loss(s, s_w)
        loss_adv    = adv_loss_stub(s, s_w, train_disc=True)
        loss_loud   = tf_loud_loss(s, s_w)
        loss_loc    = masked_localization_loss(det_out, label_mask)
        
        # Composite loss
        loss = (lambda_L1     * loss_l1 +
                lambda_msspec * loss_msspec +
                lambda_adv    * loss_adv +
                lambda_loud   * loss_loud +
                lambda_loc    * loss_loc)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({
            "Total": f"{loss.item():.4f}",
            "L1": f"{loss_l1.item():.4f}",
            "Mel": f"{loss_msspec.item():.4f}",
            "Adv": f"{loss_adv.item():.4f}",
            "Loud": f"{loss_loud.item():.4f}",
            "Loc": f"{loss_loc.item():.4f}"
        })
    
    avg_loss = total_loss / total_steps
    return avg_loss

In [None]:
# def validate_one_epoch(generator, detector, val_loader, device):
#     generator.eval()
#     detector.eval()
#     total_loss = 0.0
#     steps = 0
    
#     ms_mel_loss = SimpleMelLoss().to(device)
#     adv_loss_stub = AdversarialLoss().to(device)
#     tf_loud_loss = TFLoudnessLoss().to(device)
    
#     with torch.no_grad():
#         for batch_data in tqdm(val_loader, desc="Validation", leave=False):
#             s = batch_data.to(device)  # (B,1,T)
#             B = s.shape[0]

#             # Watermarked => positive
#             delta = generator(s)
#             s_w   = s + delta

#             # Combine with negative => clean
#             detector_input = torch.cat([s_w, s], dim=0)  # shape (2B,1,T)
#             label_mask = torch.cat([
#                 torch.ones_like(s),
#                 torch.zeros_like(s)
#             ], dim=0).to(device)

#             det_out = detector(detector_input)

#             # (A) L1
#             loss_l1   = F.l1_loss(delta, torch.zeros_like(delta))
#             # (B) Mel
#             loss_msspec = ms_mel_loss(s, s_w)
#             # (C) Adv
#             loss_adv  = adv_loss_stub(s, s_w, train_disc=False)
#             # (D) Loud
#             loss_loud = tf_loud_loss(s, s_w)
#             # (E) Detector BCE
#             loss_loc  = masked_localization_loss(det_out, label_mask)

#             # Composite
#             loss = (lambda_L1     * loss_l1 +
#                     lambda_msspec * loss_msspec +
#                     lambda_adv    * loss_adv +
#                     lambda_loud   * loss_loud +
#                     lambda_loc    * loss_loc)
            
#             total_loss += loss.item()
#             steps += 1
    
#     avg_loss = total_loss / steps if steps > 0 else 0.0
#     return avg_loss


# def train_model(generator, detector, train_dataset, val_dataset, num_epochs=10, lr=LR):
#     train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
#     val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
#     optimizer = optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=lr)
    
#     for epoch in range(1, num_epochs+1):
#         train_loss = train_one_epoch(generator, detector, train_loader, optimizer, epoch, num_epochs, device)
#         val_loss   = validate_one_epoch(generator, detector, val_loader, device)
#         print(f"Epoch [{epoch}/{num_epochs}]  TRAIN Loss: {train_loss:.4f}  |  VAL Loss: {val_loss:.4f}")


### <span style="color:rgb(255, 0, 128)">Logging individual losses</span>

In [8]:
def validate_one_epoch(generator, detector, val_loader, device):
    generator.eval()
    detector.eval()
    total_loss = 0.0
    steps = 0

    ms_mel_loss = SimpleMelLoss().to(device)
    adv_loss_stub = AdversarialLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)
    
    # For averaging individual losses
    total_loss_l1 = 0.0
    total_loss_msspec = 0.0
    total_loss_adv = 0.0
    total_loss_loud = 0.0
    total_loss_loc = 0.0

    pbar = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for batch_data in pbar:
            s = batch_data.to(device)  # (B, 1, T)
            B = s.shape[0]

            # Generate watermarked audio ("positive")
            delta = generator(s)
            s_w   = s + delta

            # Combine watermarked (positive) and clean (negative) examples
            detector_input = torch.cat([s_w, s], dim=0)  # shape (2B, 1, T)
            label_mask = torch.cat([
                torch.ones_like(s),
                torch.zeros_like(s)
            ], dim=0).to(device)

            det_out = detector(detector_input)

            # Compute individual losses
            loss_l1     = F.l1_loss(delta, torch.zeros_like(delta))
            loss_msspec = ms_mel_loss(s, s_w)
            loss_adv    = adv_loss_stub(s, s_w, train_disc=False)
            loss_loud   = tf_loud_loss(s, s_w)
            loss_loc    = masked_localization_loss(det_out, label_mask)

            # Composite loss
            loss = (lambda_L1     * loss_l1 +
                    lambda_msspec * loss_msspec +
                    lambda_adv    * loss_adv +
                    lambda_loud   * loss_loud +
                    lambda_loc    * loss_loc)

            total_loss += loss.item()
            total_loss_l1 += loss_l1.item()
            total_loss_msspec += loss_msspec.item()
            total_loss_adv += loss_adv.item()
            total_loss_loud += loss_loud.item()
            total_loss_loc += loss_loc.item()
            steps += 1

            pbar.set_postfix({
                "Total": f"{loss.item():.4f}",
                "L1": f"{loss_l1.item():.4f}",
                "Mel": f"{loss_msspec.item():.4f}",
                "Adv": f"{loss_adv.item():.4f}",
                "Loud": f"{loss_loud.item():.4f}",
                "Loc": f"{loss_loc.item():.4f}"
            })

    avg_loss = total_loss / steps if steps > 0 else 0.0
    print(f"Validation - Total: {avg_loss:.4f}, L1: {total_loss_l1/steps:.4f}, Mel: {total_loss_msspec/steps:.4f}, "
          f"Adv: {total_loss_adv/steps:.4f}, Loud: {total_loss_loud/steps:.4f}, Loc: {total_loss_loc/steps:.4f}")
    return avg_loss


def train_model(generator, detector, train_dataset, val_dataset, num_epochs=10, lr=LR):
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    optimizer = optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=lr)
    
    for epoch in range(1, num_epochs+1):
        train_loss = train_one_epoch(generator, detector, train_loader, optimizer, epoch, num_epochs, device)
        val_loss   = validate_one_epoch(generator, detector, val_loader, device)
        print(f"Epoch [{epoch}/{num_epochs}]  TRAIN Loss: {train_loss:.4f}  |  VAL Loss: {val_loss:.4f}")

In [9]:
import numpy as np

def compute_classification_metrics(y_true, y_score, threshold=0.5):
    """
    y_true: 1D numpy array of 0 or 1 ground-truth labels
    y_score: 1D numpy array of predicted probabilities in [0,1]
    threshold: decision threshold for classification
    Returns: dict with TPR, FPR, ACC
    """
    y_pred = (y_score >= threshold).astype(int)

    tp = np.sum((y_pred == 1) & (y_true == 1))
    fp = np.sum((y_pred == 1) & (y_true == 0))
    tn = np.sum((y_pred == 0) & (y_true == 0))
    fn = np.sum((y_pred == 0) & (y_true == 1))

    tpr = tp / (tp + fn + 1e-8)  # recall / sensitivity
    fpr = fp / (fp + tn + 1e-8)
    acc = (tp + tn) / (tp + tn + fp + fn + 1e-8)

    return {
        'TPR': tpr,
        'FPR': fpr,
        'Accuracy': acc
    }

def compute_auc(y_true, y_score):
    """
    Simple AUC calculation by sweeping thresholds from 0 to 1.
    If you have scikit-learn, you could use: 
        from sklearn.metrics import roc_auc_score
        return roc_auc_score(y_true, y_score)
    """
    thresholds = np.linspace(0, 1, 50)
    tprs, fprs = [], []
    for thr in thresholds:
        metrics = compute_classification_metrics(y_true, y_score, threshold=thr)
        tprs.append(metrics['TPR'])
        fprs.append(metrics['FPR'])
    fprs = np.array(fprs)
    tprs = np.array(tprs)
    order = np.argsort(fprs)
    fprs = fprs[order]
    tprs = tprs[order]
    auc_value = np.trapz(tprs, fprs)  # trapezoid rule
    return auc_value

def evaluate_detector(generator, detector, dataset, device, batch_size=16):
    """
    Creates watermarked (+) and clean (–) examples for each sample in dataset,
    passes them to the detector, collects predictions, and computes TPR, FPR, Accuracy, AUC.
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    generator.eval()
    detector.eval()

    y_true_all = []
    y_score_all = []

    with torch.no_grad():
        for s in loader:
            s = s.to(device)  # shape (B,1,T)
            B = s.shape[0]

            # Watermarked
            delta = generator(s)
            s_w = s + delta

            # Detector input => (2B,1,T)
            # first B => watermarked => label=1, second B => clean => label=0
            combined = torch.cat([s_w, s], dim=0)
            out = detector(combined)  # shape (2B,1,T)

            # Average detection probability over time => single score per sample
            scores = out.mean(dim=2).squeeze(1).cpu().numpy()  # shape (2B,)

            # Ground truth labels
            gt = np.concatenate([
                np.ones(B),   # watermarked => 1
                np.zeros(B)   # clean => 0
            ], axis=0)

            y_true_all.append(gt)
            y_score_all.append(scores)

    y_true_all = np.concatenate(y_true_all, axis=0)
    y_score_all = np.concatenate(y_score_all, axis=0)

    # TPR, FPR, Accuracy at threshold=0.5
    metrics_05 = compute_classification_metrics(y_true_all, y_score_all, threshold=0.5)
    # AUC
    auc_value  = compute_auc(y_true_all, y_score_all)

    print("Detection Metrics @ threshold=0.5:")
    print(f"  TPR:      {metrics_05['TPR']:.3f}")
    print(f"  FPR:      {metrics_05['FPR']:.3f}")
    print(f"  Accuracy: {metrics_05['Accuracy']:.3f}")
    print(f"AUC: {auc_value:.3f}")

    return {
        "TPR": metrics_05['TPR'],
        "FPR": metrics_05['FPR'],
        "Accuracy": metrics_05['Accuracy'],
        "AUC": auc_value
    }


In [10]:
def evaluate_si_snr_torch(original, reconstructed, eps=1e-8):
    if original.dim() == 3:
        original = original.squeeze(1)
    if reconstructed.dim() == 3:
        reconstructed = reconstructed.squeeze(1)
    
    original_zm = original - original.mean(dim=1, keepdim=True)
    recon_zm    = reconstructed - reconstructed.mean(dim=1, keepdim=True)
    
    dot = (original_zm * recon_zm).sum(dim=1, keepdim=True)
    norm_sq = (original_zm ** 2).sum(dim=1, keepdim=True) + eps
    alpha = dot / norm_sq
    
    s_target = alpha * original_zm
    e_noise = recon_zm - s_target
    si_snr_val = 10 * torch.log10((s_target ** 2).sum(dim=1) / ((e_noise ** 2).sum(dim=1) + eps))
    return si_snr_val

# def evaluate_pesq(original, reconstructed, sr=SAMPLE_RATE):
#     """
#     Placeholder: returns a dummy PESQ-like score in [1.0..4.5].
#     If you have the real PESQ library, replace with actual call.
#     """
#     return 4.5 - random.random()

def run_evaluation(generator, detector, dataset, device, 
                   batch_size=16, compute_pesq_score=True, compute_si_snr_score=True):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    generator.eval()
    detector.eval()
    
    si_snr_vals = []
    pesq_vals = []
    
    with torch.no_grad():
        for s in loader:
            s = s.to(device)
            B = s.shape[0]
            
            # Generate watermarked audio
            delta = generator(s)
            s_w = s + delta
            
            # Audio quality
            if compute_si_snr_score:
                si_snr_vals.extend(evaluate_si_snr_torch(s, s_w).cpu().numpy())
            # if compute_pesq_score:
            #     for i in range(B):
            #         pesq_vals.append(evaluate_pesq(s[i], s_w[i]))

    avg_si_snr = float(np.mean(si_snr_vals)) if si_snr_vals else 0.0
    # avg_pesq = float(np.mean(pesq_vals)) if pesq_vals else 0.0
    
    print("\n--- Audio Quality Results ---")
    if compute_si_snr_score:
        print(f"Average SI-SNR: {avg_si_snr:.3f} dB")
    # if compute_pesq_score:
    #     print(f"Average PESQ:  {avg_pesq:.3f}")
    
    return {
        "si_snr": avg_si_snr,
        # "pesq": avg_pesq
    }


In [11]:
def train_generator(generator, dataset, device, num_epochs=10, lr=1e-3,
                    lambda_L1=1.0, lambda_msspec=1.0, lambda_loud=0.5):
    """
    Trains ONLY the Generator, ignoring adversarial or detection losses.
    Minimizes a combination of L1, MelSpectrogram, and Loudness losses
    so that watermarked audio remains close to original.

    Args:
        generator: your Generator model
        dataset: a torch Dataset or Subset (1-sec audio clips)
        device: 'cuda' or 'cpu'
        num_epochs: how many epochs to run
        lr: learning rate
        lambda_L1, lambda_msspec, lambda_loud: weighting factors
    """
    loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
    gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

    ms_mel_loss = SimpleMelLoss().to(device)
    tf_loud_loss = TFLoudnessLoss().to(device)

    generator.train()
    for epoch in range(1, num_epochs+1):
        total_loss = 0.0
        for batch_idx, s in enumerate(loader):
            s = s.to(device)  # (B,1,T)
            delta = generator(s)      # (B,1,T)
            s_w = s + delta           # watermarked

            # Possibly do robustness/augmentations if you want the generator to
            # learn to produce robust watermarks from the start:
            # for i in range(s_w.shape[0]):
            #     s_w[i] = watermark_masking_augmentation(s_w[i])
            #     s_w[i] = robustness_augmentations(s_w[i])

            # --- 1) L1 on delta
            loss_l1 = F.l1_loss(delta, torch.zeros_like(delta))

            # --- 2) Mel Spectrogram
            loss_msspec = ms_mel_loss(s, s_w)

            # --- 3) Loudness
            loss_loud = tf_loud_loss(s, s_w)

            loss = (lambda_L1 * loss_l1 +
                    lambda_msspec * loss_msspec +
                    lambda_loud * loss_loud)

            gen_optimizer.zero_grad()
            loss.backward()
            gen_optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"[Gen-Only Epoch {epoch}/{num_epochs}]  Loss: {avg_loss:.4f}")
    
    print("Generator pre-training complete.")


In [None]:
if __name__ == "__main__":
    data_root = "data/100_all"
    full_dataset = OneSecClipsDataset(root_dir=data_root, sample_rate=SAMPLE_RATE)

    # Pick a subset for demonstration
    subset_size = 1000
    subset_indices = list(range(min(subset_size, len(full_dataset))))
    subset_dataset = torch.utils.data.Subset(full_dataset, subset_indices)

    # Split: 80% train, 10% val, 10% test
    n = len(subset_dataset)
    n_train = int(0.8 * n)
    n_val   = int(0.1 * n)
    n_test  = n - n_train - n_val
    train_ds, val_ds, test_ds = random_split(subset_dataset, [n_train, n_val, n_test])

    # Instantiate models
    generator = Generator().to(device)
    detector  = Detector().to(device)

    # Train
    num_epochs = 10
    train_model(generator, detector, train_ds, val_ds, num_epochs=num_epochs, lr=LR)

    # Save models
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(detector.state_dict(),  "detector.pth")

    # Evaluate classification (watermarked vs. clean)
    print("\n--- Detector Classification Metrics ---")
    detection_metrics = evaluate_detector(generator, detector, test_ds, device, batch_size=16)
    print(detection_metrics)

    # Evaluate audio quality (SI-SNR, PESQ, etc.)
    print("\n--- Audio Quality Metrics ---")
    run_evaluation(generator, detector, test_ds, device,
                   batch_size=16, compute_pesq_score=True, compute_si_snr_score=True)
