In [1]:
import os
import glob
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

In [3]:
SAMPLE_RATE = 16000
AUDIO_LEN   = 16000  # 1 second
BATCH_SIZE  = 64
LR          = 1e-3
HIDDEN_DIM  = 32   # LSTM hidden dimension
NUM_BITS    = 16   # message bits
CHANNELS    = 32   # initial conv channels
OUTPUT_CH   = 128  # final conv channels for the generator
STRIDES     = [2, 4, 5, 8]  # downsampling strides
LSTM_LAYERS = 2
NUM_WORKERS = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

print("\nCHUNK #1: Environment and hyperparameters set.")

Using device: cuda

CHUNK #1: Environment and hyperparameters set.


In [4]:
class OneSecClipsDataset(Dataset):
    """
    Assumes each .wav file in root_dir is ~1 second at 16kHz (16000 samples).
    If sample_rate != 16000, we resample to 16k.
    """
    def __init__(self, root_dir, sample_rate=16000):
        super().__init__()
        self.filepaths = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        wav_path = self.filepaths[idx]
        waveform, sr = torchaudio.load(wav_path)

        # Convert to mono if multi-channel
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if needed
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform  = resampler(waveform)

        # Ensure exactly 16000 samples (pad or crop)
        if waveform.shape[1] > AUDIO_LEN:
            waveform = waveform[:, :AUDIO_LEN]
        elif waveform.shape[1] < AUDIO_LEN:
            pad_len = AUDIO_LEN - waveform.shape[1]
            waveform = F.pad(waveform, (0, pad_len))

        return waveform  # shape: (1, 16000)

###########################################
# Optional Simple Augmentations (example) #
###########################################
def watermark_masking_augmentation(wav, p_replace_zero=0.1, p_replace_noise=0.1):
    """
    Simple example of random zero or noise segments to simulate partial corruption.
    """
    T = wav.shape[1]
    window_len = int(0.1 * 16000)  # 0.1 second
    k = 2  # do it a couple times
    for _ in range(k):
        start = random.randint(0, T - window_len)
        end   = start + window_len
        choice = random.random()
        if choice < p_replace_zero:
            wav[:, start:end] = 0.0
        elif choice < p_replace_zero + p_replace_noise:
            wav[:, start:end] = 0.1 * torch.randn_like(wav[:, start:end])
        # else do nothing
    return wav

def robustness_augmentations(wav):
    """
    Add small background noise.
    """
    wav = wav + 0.005 * torch.randn_like(wav)
    return wav

print("\nCHUNK #2: Dataset and example augmentations ready.")


CHUNK #2: Dataset and example augmentations ready.


In [5]:
def make_conv1d(in_ch, out_ch, kernel_size=3, stride=1, padding=1):
    return nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)

class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.downsample = (stride != 1 or in_ch != out_ch)
        self.conv1 = make_conv1d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
        self.conv2 = make_conv1d(out_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.elu   = nn.ELU()
        
        if self.downsample:
            self.skip_conv = make_conv1d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0)

    def forward(self, x):
        residual = x
        out = self.elu(self.conv1(x))
        out = self.conv2(out)
        if self.downsample:
            residual = self.skip_conv(residual)
        out = self.elu(out + residual)
        return out

class Generator(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM, 
                 message_bits=NUM_BITS,
                 output_channels=OUTPUT_CH, 
                 strides=STRIDES):
        super().__init__()
        self.message_bits = message_bits
        self.hidden_dim   = hidden_dim
        
        # Embedding for integer messages
        self.E = nn.Embedding(num_embeddings=(2**message_bits), embedding_dim=hidden_dim)

        # --------- Encoder --------- #
        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)

        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        # Project encoder output to hidden_dim
        self.proj = nn.Linear(ch, hidden_dim)  # ch after all enc blocks

        # LSTM
        self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=2,
                            batch_first=True, bidirectional=False)

        self.final_conv_enc = nn.Conv1d(hidden_dim, output_channels, kernel_size=7, stride=1, padding=3)

        # --------- Decoder --------- #
        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = output_channels  # start with 128
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(
                nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st,
                                   padding=(st//2), output_padding=0)
            )
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.decoder_blocks = nn.Sequential(*dec_blocks)

        self.final_conv_dec = nn.Conv1d(in_ch, 1, kernel_size=7, stride=1, padding=3)

    def forward(self, s, message=None):
        B, _, T = s.shape
        x = self.init_conv(s)
        x = self.encoder_blocks(x)
        x_t = x.transpose(1, 2)  # shape (B, T_after, ch)

        x_t = self.proj(x_t)  # (B, T_after, hidden_dim)

        if message is not None:
            e = self.E(message)  # (B, hidden_dim)
            T_after = x_t.shape[1]
            e_expanded = e.unsqueeze(1).expand(-1, T_after, -1)
            x_t = x_t + e_expanded

        lstm_out, _ = self.lstm(x_t)
        lstm_out_t = lstm_out.transpose(1, 2)
        latent = self.final_conv_enc(lstm_out_t)

        x_dec = latent
        x_dec = self.decoder_blocks(x_dec)
        delta = self.final_conv_dec(x_dec)

        # Adjust final length if needed
        if delta.shape[-1] != T:
            min_len = min(delta.shape[-1], T)
            delta = delta[:, :, :min_len]
            if min_len < T:
                pad_amt = T - min_len
                delta = F.pad(delta, (0, pad_amt))

        return delta

print("\nCHUNK #3: Generator with residual blocks, LSTM, message embedding defined.")


CHUNK #3: Generator with residual blocks, LSTM, message embedding defined.


In [6]:
class Detector(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 base_channels=CHANNELS,
                 hidden_dim=HIDDEN_DIM,
                 message_bits=NUM_BITS,
                 strides=STRIDES):
        super().__init__()
        self.message_bits = message_bits

        self.init_conv = nn.Conv1d(in_channels, base_channels, kernel_size=7, stride=1, padding=3)

        enc_blocks = []
        ch = base_channels
        for st in strides:
            out_ch = ch * 2
            enc_blocks.append(ResidualBlock(ch, out_ch, stride=st))
            ch = out_ch
        self.encoder_blocks = nn.Sequential(*enc_blocks)

        dec_blocks = []
        rev_strides = list(reversed(strides))
        in_ch = ch
        for st in rev_strides:
            out_ch = in_ch // 2
            dec_blocks.append(
                nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2*st, stride=st,
                                   padding=(st//2), output_padding=0)
            )
            dec_blocks.append(ResidualBlock(out_ch, out_ch, stride=1))
            in_ch = out_ch
        self.upsample_blocks = nn.Sequential(*dec_blocks)

        # final: 1 channel for detection + message_bits channels for bit decoding
        self.final_conv = nn.Conv1d(base_channels, 1 + message_bits, kernel_size=7, stride=1, padding=3)

    def forward(self, x):
        original_length = x.shape[-1]

        # Encoder
        x = self.init_conv(x)
        x = self.encoder_blocks(x)

        # Upsample
        x = self.upsample_blocks(x)
        out = self.final_conv(x)

        # Adjust length if needed
        if out.shape[-1] > original_length:
            out = out[:, :, :original_length]
        elif out.shape[-1] < original_length:
            pad_amt = original_length - out.shape[-1]
            out = F.pad(out, (0, pad_amt))
            
        return out  # shape (B, 1+bits, T)

print("\nCHUNK #4: Detector network defined.")


CHUNK #4: Detector network defined.


In [7]:
def generate_random_messages(batch_size, num_bits=NUM_BITS):
    """
    Generates random integers in [0, 2^num_bits - 1].
    """
    max_val = 2 ** num_bits
    return torch.randint(0, max_val, (batch_size,))

def int_to_bit_tensor(msgs_int, num_bits=NUM_BITS):
    """
    Convert integer messages => bit vectors of shape (B, num_bits), each in {0,1}.
    bits[0] = LSB, bits[num_bits-1] = MSB by default here.
    """
    B = msgs_int.shape[0]
    bit_list = []
    for bit_idx in range(num_bits):
        bit_i = ((msgs_int >> bit_idx) & 1).float()  # shape (B,)
        bit_list.append(bit_i)
    # Stack along dimension=1 => shape (B,num_bits)
    bits = torch.stack(bit_list, dim=1)
    return bits

print("\nCHUNK #5: Minimal bit helpers defined (no custom spectral or adversarial losses).")


CHUNK #5: Minimal bit helpers defined (no custom spectral or adversarial losses).


In [8]:
def train_one_epoch(
    generator,
    detector,
    train_loader,
    optimizer,
    epoch,
    total_epochs,
    device,
    num_bits=NUM_BITS
):
    generator.train()
    detector.train()
    
    l1_loss_fn = nn.L1Loss().to(device)
    bce_loss_fn = nn.BCEWithLogitsLoss().to(device)
    
    total_steps = len(train_loader)
    sum_loss = 0.0
    
    pbar = tqdm(enumerate(train_loader), total=total_steps, desc=f"Epoch [{epoch}/{total_epochs}]")
    
    for i, s in pbar:
        s = s.to(device)  # (B,1,T)
        B = s.shape[0]
        
        # (Optional) augment: e.g. we might do nothing on original or do partial
        # s = watermark_masking_augmentation(s)  # typically you do this on watermarked audio, but you can adapt

        # 1) Generate random messages (int)
        msgs_int = generate_random_messages(B, num_bits=num_bits).to(device)
        
        # 2) Generator forward => get delta, then s_w
        delta = generator(s, msgs_int)  # (B,1,T)
        s_w = s + delta  # watermarked
        
        # (Optional) augment the watermarked signal (simulate channel noise, etc.)
        # for b_idx in range(B):
        #     s_w[b_idx] = robustness_augmentations(s_w[b_idx])
        
        # 3) Detector forward => shape (B, 1+num_bits, T)
        det_out = detector(s_w)
        
        # *** Distortion Loss (L1) ***
        loss_dist = l1_loss_fn(s_w, s)

        # *** Detection Loss (BCE) ***
        # We assume everything in the batch is watermarked => label=1
        # det_out[:,0,:] => detection channel => (B, T)
        detection_map = det_out[:, 0, :]        # (B, T)
        detection_avg = detection_map.mean(dim=1)  # (B,)
        label_det = torch.ones(B, device=device)  # all 1
        loss_det = bce_loss_fn(detection_avg, label_det)

        # *** Bit Decoding Loss (BCE) ***
        # det_out[:,1:,:] => (B, num_bits, T)
        bit_map = det_out[:, 1:, :]  # (B, num_bits, T)
        bit_avg = bit_map.mean(dim=2)  # average over time => (B, num_bits)
        bit_labels = int_to_bit_tensor(msgs_int, num_bits=num_bits).to(device)  # (B, num_bits)
        loss_bits = bce_loss_fn(bit_avg, bit_labels)

        # Weighted sum
        lambda_l1   = 1.0
        lambda_det  = 1.0
        lambda_bits = 1.0
        
        loss = lambda_l1*loss_dist + lambda_det*loss_det + lambda_bits*loss_bits

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        sum_loss += loss.item()
        pbar.set_postfix({
            "dist": f"{loss_dist.item():.3f}",
            "det": f"{loss_det.item():.3f}",
            "bits": f"{loss_bits.item():.3f}",
            "total": f"{loss.item():.3f}"
        })
    
    avg_loss = sum_loss / (total_steps if total_steps > 0 else 1)
    return avg_loss


def validate_one_epoch(
    generator,
    detector,
    val_loader,
    device,
    num_bits=NUM_BITS
):
    generator.eval()
    detector.eval()

    l1_loss_fn = nn.L1Loss().to(device)
    bce_loss_fn = nn.BCEWithLogitsLoss().to(device)

    sum_loss = 0.0
    steps = 0

    with torch.no_grad():
        for s in tqdm(val_loader, desc="Validation", leave=False):
            s = s.to(device)
            B = s.shape[0]

            msgs_int = generate_random_messages(B, num_bits=num_bits).to(device)

            delta = generator(s, msgs_int)
            s_w   = s + delta
            det_out = detector(s_w)

            # Distortion
            loss_dist = l1_loss_fn(s_w, s)

            # Detection (all are watermarked => label=1)
            detection_map = det_out[:, 0, :]
            detection_avg = detection_map.mean(dim=1)
            label_det = torch.ones(B, device=device)
            loss_det = bce_loss_fn(detection_avg, label_det)

            # Bit decoding
            bit_map = det_out[:, 1:, :]
            bit_avg = bit_map.mean(dim=2)
            bit_labels = int_to_bit_tensor(msgs_int, num_bits=num_bits).to(device)
            loss_bits = bce_loss_fn(bit_avg, bit_labels)

            loss = loss_dist + loss_det + loss_bits
            sum_loss += loss.item()
            steps += 1

    return sum_loss / max(1, steps)

def train_model(
    generator,
    detector,
    train_dataset,
    val_dataset,
    num_epochs=10,
    lr=1e-3,
    num_bits=NUM_BITS
):
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    optimizer = optim.Adam(
        list(generator.parameters()) + list(detector.parameters()),
        lr=lr
    )

    for epoch in range(1, num_epochs+1):
        train_loss = train_one_epoch(generator, detector, train_loader, optimizer, epoch, num_epochs, device, num_bits)
        val_loss   = validate_one_epoch(generator, detector, val_loader, device, num_bits)
        print(f"Epoch [{epoch}/{num_epochs}]  TRAIN Loss: {train_loss:.4f}  |  VAL Loss: {val_loss:.4f}")

print("\nCHUNK #6: Updated training code with L1 + BCE complete.")


CHUNK #6: Updated training code with L1 + BCE complete.


In [9]:
def detect_watermark(detector, audio, threshold=0.5):
    """
    audio: (B,1,T)
    Returns True/False if average detection probability > threshold
    """
    detector.eval()
    with torch.no_grad():
        out = detector(audio)  # (B, 1+bits, T)
        det_prob = out[:, 0, :]  # detection channel => shape (B, T)
        avg_prob = det_prob.mean(dim=1)  # (B,)
        return (torch.sigmoid(avg_prob) > threshold).float()

def decode_message(detector, audio, num_bits=NUM_BITS):
    """
    audio: (B,1,T)
    Return integer messages predicted by the detector.
    We average bit channels over time, threshold at 0.5, convert bits => int.
    """
    detector.eval()
    with torch.no_grad():
        out = detector(audio)  # (B, 1+bits, T)
        bit_map = out[:, 1:, :]  # (B, num_bits, T)
        bit_avg = bit_map.mean(dim=2)  # (B, num_bits)
        bit_prob = torch.sigmoid(bit_avg)  # convert logits => prob
        bit_pred = (bit_prob > 0.5).int()

        # convert bits => integer
        B, b = bit_pred.shape
        msg_int = torch.zeros(B, dtype=torch.long, device=bit_pred.device)
        for i in range(b):
            msg_int |= (bit_pred[:, i] << i)
        return msg_int  # shape (B,)

print("\nCHUNK #7: Simple inference utilities (detection, decoding) complete.")



CHUNK #7: Simple inference utilities (detection, decoding) complete.


In [10]:
#################################################
if __name__ == "__main__":
    # 1) Load dataset
    data_root = "data/100_all"  # Change this to your data folder
    full_dataset = OneSecClipsDataset(root_dir=data_root, sample_rate=16000)

    # 2) Split
    n = len(full_dataset)
    n_train = int(0.8 * n)
    n_val   = int(0.1 * n)
    n_test  = n - n_train - n_val
    train_ds, val_ds, test_ds = random_split(full_dataset, [n_train, n_val, n_test])

    # 3) Instantiate models
    generator = Generator().to(device)
    detector  = Detector().to(device)

    # 4) Train
    train_model(generator, detector, train_ds, val_ds, num_epochs=10, lr=LR, num_bits=NUM_BITS)

    # 5) Save
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(detector.state_dict(), "detector.pth")

    # 6) Example: quick check on test sample
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=True)
    for audio in test_loader:
        audio = audio.to(device)
        # embed random message
        msg_int = generate_random_messages(1, num_bits=NUM_BITS).to(device)
        delta = generator(audio, msg_int)
        s_w   = audio + delta
        # detect
        detection_label = detect_watermark(detector, s_w, threshold=0.5)
        pred_msg        = decode_message(detector, s_w, num_bits=NUM_BITS)
        print(f"True msg: {msg_int.item()},  Predicted msg: {pred_msg.item()},  Detected? {detection_label.item()>0.5}")
        break  # just do 1 sample

KeyboardInterrupt: 