In [1]:
# Import necessary libraries
import os
import glob
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
from collections import Counter
import librosa
# For inline plots in Jupyter Notebook
%matplotlib inline


In [2]:
# Set the root directory for the preprocessed audio clips
dataset_dir = "data/100_all"

# Recursively list all .wav files in the directory
filepaths = glob.glob(os.path.join(dataset_dir, '**', '*.wav'), recursive=True)
print("Total number of audio clips:", len(filepaths))




In [3]:
# Load a sample file to inspect its properties
sample_path = filepaths[0]
waveform, sr = torchaudio.load(sample_path)
print("Sample file:", sample_path)
print("Waveform shape:", waveform.shape)  # Expected shape: [1, 16000] for 1-second mono clips
print("Sample rate:", sr)




In [4]:
# Plot the waveform of the sample file
plt.figure(figsize=(10, 4))
plt.plot(waveform.squeeze().numpy())
plt.title("Waveform of a Sample 1-Second Clip")
plt.xlabel("Sample Index")
plt.ylabel("Amplitude")
plt.show()




In [5]:
# Compute a Mel Spectrogram and convert it to decibels
mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_mels=64)
spec = mel_transform(waveform)
spec_db = torchaudio.transforms.AmplitudeToDB()(spec)

# Plot the Mel Spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(spec_db.squeeze().numpy(), origin="lower", aspect="auto", cmap="viridis")
plt.title("Mel Spectrogram (dB) of a Sample Clip")
plt.xlabel("Time")
plt.ylabel("Mel Frequency")
plt.colorbar(label="dB")
plt.show()




In [6]:
# Take a random sample of 50 files (or all if less than 50) to compute basic statistics
subset_paths = random.sample(filepaths, min(50, len(filepaths)))
durations = []         # All clips should be around 1 second, but we compute for verification
max_amplitudes = []    # Maximum absolute amplitude per clip

for path in subset_paths:
    wf, _ = torchaudio.load(path)
    durations.append(wf.shape[-1] / sr)  # Duration in seconds
    max_amplitudes.append(np.max(np.abs(wf.numpy())))

print("Average duration (should be ~1 second):", np.mean(durations))
print("Average max amplitude:", np.mean(max_amplitudes))

# Plot histogram of maximum amplitudes
plt.figure(figsize=(8, 4))
plt.hist(max_amplitudes, bins=20, edgecolor='black')
plt.title("Histogram of Max Amplitudes (Sampled Clips)")
plt.xlabel("Max Amplitude")
plt.ylabel("Count")
plt.show()






In [7]:
import librosa
import numpy as np
import matplotlib.pyplot as plt
import glob
import os

def compute_rms(waveform):
    return np.sqrt(np.mean(waveform**2))

def compute_zcr(waveform):
    # Zero crossing rate using librosa's function
    return np.mean(librosa.feature.zero_crossing_rate(waveform)[0])

# Get a list of file paths from your preprocessed directory
filepaths = glob.glob(os.path.join("data/100_all", '**', '*.wav'), recursive=True)
print("Total number of audio clips:", len(filepaths))

# Compute RMS and ZCR for a sample of clips (say 1000 clips)
sample_filepaths = np.random.choice(filepaths, min(1000, len(filepaths)), replace=False)

rms_list = []
zcr_list = []

for path in sample_filepaths:
    waveform, sr = librosa.load(path, sr=16000)
    rms_list.append(compute_rms(waveform))
    zcr_list.append(compute_zcr(waveform))

# Plot histogram for RMS
plt.figure(figsize=(8, 4))
plt.hist(rms_list, bins=30, edgecolor='black')
plt.title("Histogram of RMS Energy")
plt.xlabel("RMS Energy")
plt.ylabel("Frequency")
plt.show()

# Plot histogram for ZCR
plt.figure(figsize=(8, 4))
plt.hist(zcr_list, bins=30, edgecolor='black')
plt.title("Histogram of Zero Crossing Rate")
plt.xlabel("ZCR")
plt.ylabel("Frequency")
plt.show()








In [8]:
import os
import glob
import numpy as np
import librosa
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm  # For progress bar

def get_audio_properties(filepath):
    """Extract various properties from an audio file"""
    try:
        # Load the audio file
        waveform, sr = librosa.load(filepath, sr=None)  # Using sr=None to get the native sample rate
        
        # Calculate properties
        duration = librosa.get_duration(y=waveform, sr=sr)
        rms = compute_rms(waveform)
        zcr = compute_zcr(waveform)
        spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=waveform, sr=sr)[0])
        spectral_bandwidth = np.mean(librosa.feature.spectral_bandwidth(y=waveform, sr=sr)[0])
        
        return {
            'filepath': filepath,
            'sample_rate': sr,
            'duration': duration,
            'num_samples': len(waveform),
            'rms': rms,
            'zcr': zcr, 
            'spectral_centroid': spectral_centroid,
            'spectral_bandwidth': spectral_bandwidth
        }
    except Exception as e:
        print(f"Error processing {filepath}: {str(e)}")
        return None

# Get a list of file paths from your preprocessed directory
filepaths = glob.glob(os.path.join("data/100_all", '**', '*.wav'), recursive=True)
print(f"Total number of audio clips: {len(filepaths)}")

# Sample files to analyze (adjust the number as needed)
sample_size = min(500, len(filepaths))
sample_filepaths = np.random.choice(filepaths, sample_size, replace=False)

# Analyze audio files with progress bar
audio_properties = []
for path in tqdm(sample_filepaths, desc="Analyzing audio files"):
    props = get_audio_properties(path)
    if props:
        audio_properties.append(props)

# Convert to DataFrame for easier analysis
df = pd.DataFrame(audio_properties)

# Print summary statistics
print("\n--- Dataset Properties Summary ---")
print(f"Number of files analyzed: {len(df)}")
print("\nSample Rate Statistics:")
sr_counts = df['sample_rate'].value_counts()
for sr, count in sr_counts.items():
    print(f"  {sr} Hz: {count} files ({count/len(df)*100:.1f}%)")

print("\nDuration Statistics:")
print(f"  Min duration: {df['duration'].min():.2f} seconds")
print(f"  Max duration: {df['duration'].max():.2f} seconds")
print(f"  Mean duration: {df['duration'].mean():.2f} seconds")
print(f"  Median duration: {df['duration'].median():.2f} seconds")

print("\nOther Audio Properties (mean values):")
print(f"  RMS energy: {df['rms'].mean():.5f}")
print(f"  Zero crossing rate: {df['zcr'].mean():.5f}")
print(f"  Spectral centroid: {df['spectral_centroid'].mean():.2f} Hz")
print(f"  Spectral bandwidth: {df['spectral_bandwidth'].mean():.2f} Hz")

# Plot histograms for key properties
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].hist(df['duration'], bins=30, edgecolor='black')
axes[0, 0].set_title("Distribution of Audio Duration")
axes[0, 0].set_xlabel("Duration (seconds)")
axes[0, 0].set_ylabel("Frequency")

axes[0, 1].hist(df['rms'], bins=30, edgecolor='black')
axes[0, 1].set_title("Distribution of RMS Energy")
axes[0, 1].set_xlabel("RMS")
axes[0, 1].set_ylabel("Frequency")

axes[1, 0].hist(df['zcr'], bins=30, edgecolor='black')
axes[1, 0].set_title("Distribution of Zero Crossing Rate")
axes[1, 0].set_xlabel("ZCR")
axes[1, 0].set_ylabel("Frequency")

axes[1, 1].hist(df['spectral_centroid'], bins=30, edgecolor='black')
axes[1, 1].set_title("Distribution of Spectral Centroid")
axes[1, 1].set_xlabel("Frequency (Hz)")
axes[1, 1].set_ylabel("Frequency")

plt.tight_layout()
plt.show()









In [9]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader

In [10]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, transform=None):

        self.root_dir = root_dir
        self.filepaths = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
        self.transform = transform

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        waveform, sample_rate = torchaudio.load(filepath)
        if self.transform:
            waveform = self.transform(waveform)
        return waveform, sample_rate


In [11]:
dataset_dir = "data/100_all"
audio_dataset = AudioDataset(root_dir=dataset_dir)

print("Total number of audio clips in dataset:", len(audio_dataset))

dataloader = DataLoader(audio_dataset, batch_size=32, shuffle=True, num_workers=16)

for batch in dataloader:
    waveforms, sample_rates = batch
    # print("Batch waveforms shape:", waveforms.shape)
    # print("Sample rates:", sample_rates)
    # break  # Just process one batch for testing




In [None]:
# # Example usage of the WatermarkGenerator and WatermarkDetector
# if __name__ == "__main__":
#     dummy_input = torch.randn(4, 1, 16000)
    
#     generator = WatermarkGenerator()
#     detector = WatermarkDetector()
#     watermark_delta = generator(dummy_input)
#     print("Watermark delta shape:", watermark_delta.shape) 
#     watermarked_audio = dummy_input + watermark_delta
    
#     # Run detector on the watermarked audio
#     detection_output = detector(watermarked_audio)
#     print("Detection output shape:", detection_output.shape)  




In [16]:
from torch.utils.data import Subset
import numpy as np

# Assume audio_dataset is your complete dataset instance (e.g., AudioDataset)
dataset_size = len(audio_dataset)
subset_size = int(0.01 * dataset_size)  # 1% of the data

# Randomly select subset_size indices without replacement
subset_indices = np.random.choice(dataset_size, subset_size, replace=False)

# Create the subset dataset
subset_dataset = Subset(audio_dataset, subset_indices)

# Create a DataLoader for the subset
subset_dataloader = DataLoader(subset_dataset, batch_size=32, shuffle=True, num_workers=4)

print(f"Using {len(subset_dataset)} samples out of {dataset_size} for testing (approx. 1%).")




In [15]:
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

def compute_all_metrics(pred_probs, target, threshold=0.5):
    pred_flat = pred_probs.view(-1).detach().cpu().numpy()
    target_flat = target.view(-1).detach().cpu().numpy()
    
    pred_labels = (pred_flat >= threshold).astype(int)
    target_labels = target_flat.astype(int)
    
    acc = accuracy_score(target_labels, pred_labels)
    roc_auc = roc_auc_score(target_labels, pred_flat)
    precision = precision_score(target_labels, pred_labels, zero_division=0)
    recall = recall_score(target_labels, pred_labels, zero_division=0)
    f1 = f1_score(target_labels, pred_labels, zero_division=0)
    cm = confusion_matrix(target_labels, pred_labels)
    
    metrics = {
        "accuracy": acc,
        "roc_auc": roc_auc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm
    }
    return metrics


In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# import random
# # Set up device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print("Using device:", device)

# #############################
# # 1. Model Architectures
# #############################

# # --- Original WatermarkGenerator (without message) ---
# class WatermarkGenerator(nn.Module):
#     def __init__(self):
#         super(WatermarkGenerator, self).__init__()
#         # Encoder: Downsample input from 16000 -> 4000 samples
#         self.encoder = nn.Sequential(
#             nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7), 
#             nn.ReLU(),
#             nn.Conv1d(16, 32, kernel_size=15, stride=2, padding=7),  
#             nn.ReLU(),
#             nn.Conv1d(32, 64, kernel_size=15, stride=2, padding=7),  
#             nn.ReLU()
#         )
#         # Bottleneck: LSTM to capture temporal dependencies
#         self.lstm = nn.LSTM(input_size=64, hidden_size=64, num_layers=1, batch_first=True)
#         # Decoder: Upsample back from 4000 -> 16000
#         self.decoder = nn.Sequential(
#             nn.ConvTranspose1d(64, 32, kernel_size=15, stride=2, padding=7, output_padding=1),
#             nn.ReLU(),
#             nn.ConvTranspose1d(32, 16, kernel_size=15, stride=2, padding=7, output_padding=1),
#             nn.ReLU(),
#             nn.Conv1d(16, 1, kernel_size=15, stride=1, padding=7)
#         )
    
#     def forward(self, x):
#         # x: (batch, 1, 16000)
#         encoded = self.encoder(x)  # -> (batch, 64, 4000)
#         encoded_transposed = encoded.transpose(1, 2)  # -> (batch, 4000, 64)
#         lstm_out, _ = self.lstm(encoded_transposed)      # -> (batch, 4000, 64)
#         lstm_out = lstm_out.transpose(1, 2)                # -> (batch, 64, 4000)
#         decoded = self.decoder(lstm_out)                   # -> (batch, 1, 16000)
#         watermark_delta = 0.01 * torch.tanh(decoded)
#         return watermark_delta

# # --- Extended Generator with Optional 16-bit Message Embedding ---
# class WatermarkGeneratorWithMessage(WatermarkGenerator):
#     def __init__(self, message_bits=16):
#         super(WatermarkGeneratorWithMessage, self).__init__()
#         self.use_message = True
#         self.message_bits = message_bits
#         # Create a learnable embedding table for message bits:
#         # For each bit (0 or 1) and for each bit position, produce an adjustment vector of size 64.
#         # Shape: (2, message_bits, 64)
#         self.embedding = nn.Parameter(torch.randn(2, message_bits, 64) * 0.01)
        
#     def forward(self, x, message):
#         """
#         x: (batch, 1, 16000)
#         message: (batch, message_bits) containing binary values (0 or 1)
#         """
#         encoded = self.encoder(x)            # (batch, 64, 4000)
#         encoded_transposed = encoded.transpose(1, 2)  # (batch, 4000, 64)
#         lstm_out, _ = self.lstm(encoded_transposed)      # (batch, 4000, 64)
#         # Incorporate message embedding:
#         # For each sample, average embeddings for each bit.
#         batch_size, t, feat = lstm_out.shape
#         message = message.long()  # ensure integers 0 or 1
#         message_embs = []
#         for i in range(self.message_bits):
#             # For each bit position i, select the corresponding embedding vector
#             emb_i = self.embedding[:, i, :]  # shape (2, 64)
#             # message[:, i] selects row 0 or 1 for each sample
#             message_emb_i = emb_i[message[:, i]]  # shape (batch, 64)
#             message_embs.append(message_emb_i)
#         # Average over the message bits: shape (batch, 64)
#         message_emb = torch.stack(message_embs, dim=1).mean(dim=1)
#         # Expand message_emb to add to every timestep in the LSTM output:
#         message_emb_expanded = message_emb.unsqueeze(1).expand(-1, t, -1)  # (batch, 4000, 64)
#         lstm_out = lstm_out + message_emb_expanded
#         lstm_out = lstm_out.transpose(1, 2)  # (batch, 64, 4000)
#         decoded = self.decoder(lstm_out)  # (batch, 1, 16000)
#         watermark_delta = 0.01 * torch.tanh(decoded)
#         return watermark_delta

# # --- Extended Detector to also Decode the 16-bit Message ---
# class WatermarkDetectorWithMessage(nn.Module):
#     def __init__(self, message_bits=16):
#         super(WatermarkDetectorWithMessage, self).__init__()
#         self.message_bits = message_bits
#         # Use a similar encoder as before
#         self.encoder = nn.Sequential(
#             nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7),
#             nn.ReLU(),
#             nn.Conv1d(16, 32, kernel_size=15, stride=2, padding=7),
#             nn.ReLU(),
#             nn.Conv1d(32, 64, kernel_size=15, stride=2, padding=7),
#             nn.ReLU()
#         )
#         # Upsampling head: output 1 channel for detection, and message_bits channels for message decoding.
#         self.upsample = nn.Sequential(
#             nn.ConvTranspose1d(64, 32, kernel_size=15, stride=2, padding=7, output_padding=1),
#             nn.ReLU(),
#             nn.ConvTranspose1d(32, 16, kernel_size=15, stride=2, padding=7, output_padding=1),
#             nn.ReLU(),
#             nn.Conv1d(16, 1 + message_bits, kernel_size=15, stride=1, padding=7)
#         )
        
#     def forward(self, x):
#         # x: (batch, 1, 16000)
#         encoded = self.encoder(x)   # (batch, 64, 4000)
#         out = self.upsample(encoded)  # (batch, 1+message_bits, 16000)
#         # First channel for detection probability
#         detection = torch.sigmoid(out[:, 0:1, :])
#         # The remaining channels for message decoding:
#         # We average over the time dimension to get one prediction per bit.
#         message_logits = out[:, 1:, :].mean(dim=2)  # (batch, message_bits)
#         message_prob = torch.sigmoid(message_logits)  # Probabilities for each bit
#         return detection, message_prob

# #############################
# # 2. Augmentation Functions
# #############################

# def apply_watermark_masking_balanced(s, s_w, num_segments=10):
#     """
#     A revised augmentation function to produce a more balanced ground truth mask.
#     For each sample, divide the audio into 'num_segments' segments and randomly 
#     set half of them to 0 (watermark dropped) and the other half remain 1.
    
#     Args:
#         s: original audio tensor, shape (batch, 1, T)
#         s_w: watermarked audio tensor (s + delta), shape (batch, 1, T)
#         num_segments: total segments to partition each sample
    
#     Returns:
#         s_w_aug: augmented watermarked audio (with some segments reverted)
#         mask: ground truth mask of same shape (1 means watermark is intact, 0 means dropped)
#     """
#     batch_size, channels, T = s_w.shape
#     s_w_aug = s_w.clone()
#     mask = torch.ones_like(s_w)  # start with all ones
#     seg_length = T // num_segments
    
#     for b in range(batch_size):
#         # Randomly select half the segments to drop watermark
#         indices = list(range(num_segments))
#         random.shuffle(indices)
#         drop_indices = indices[:num_segments//2]  # drop watermark in half the segments
#         for idx in drop_indices:
#             start = idx * seg_length
#             end = start + seg_length
#             # Option: revert segment to original audio s (i.e., watermark removed)
#             s_w_aug[b, :, start:end] = s[b, :, start:end]
#             mask[b, :, start:end] = 0.0
#     return s_w_aug, mask

# def apply_adversarial_augmentation(x, noise_std=0.005):
#     """
#     Apply a simple differentiable adversarial-like augmentation,
#     such as adding random Gaussian noise or a simple filtering.
#     This helps simulate real-world distortions.
    
#     Args:
#         x: input audio tensor, shape (batch, 1, T)
#         noise_std: standard deviation of Gaussian noise
        
#     Returns:
#         x_aug: augmented audio tensor.
#     """
#     noise = noise_std * torch.randn_like(x)
#     x_aug = x + noise
#     return x_aug

# #############################
# # 3. Training Loop
# #############################

# # Choose whether to use the message-embedding variant.
# use_message = True

# if use_message:
#     # Instantiate extended models with 16-bit messages.
#     generator = WatermarkGeneratorWithMessage(message_bits=16).to(device)
#     detector  = WatermarkDetectorWithMessage(message_bits=16).to(device)
#     # For demonstration, generate random binary messages for the batch.
#     # In practice, these could come from the user.
# else:
#     generator = WatermarkGenerator().to(device)
#     detector  = WatermarkDetector().to(device)

# # Define loss functions:
# criterion_perc = nn.L1Loss()    # Perceptual loss (original vs watermarked)
# criterion_det  = nn.BCELoss()   # Detection loss for watermark presence
# if use_message:
#     criterion_msg = nn.BCELoss()   # For message decoding; we'll treat each bit as binary

# # Loss weight factors (you can tune these)
# lambda_perc = 10.0
# lambda_det  = 1.0
# lambda_msg  = 1.0  # weight for message decoding loss

# optimizer = optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=1e-4)

# num_epochs = 10

# for epoch in range(num_epochs):
#     running_loss = 0.0
#     generator.train()
#     detector.train()
    
#     for i, (s, _) in enumerate(subset_dataloader):  # using our 1% subset DataLoader
#         s = s.to(device)  # (batch, 1, 16000)
#         optimizer.zero_grad()
        
#         # Optionally, apply adversarial augmentation to original audio:
#         s_aug = apply_adversarial_augmentation(s, noise_std=0.005)
        
#         # Generate watermark delta:
#         if use_message:
#             # Generate a random binary message for each sample, shape (batch, 16)
#             batch_size = s.size(0)
#             random_message = torch.randint(0, 2, (batch_size, 16)).float().to(device)
#             delta = generator(s_aug, random_message)
#         else:
#             delta = generator(s_aug)
        
#         s_w = s + delta  # watermarked audio
        
#         # Apply watermark masking to get a balanced ground truth mask:
#         s_w_aug, mask = apply_watermark_masking_balanced(s, s_w, num_segments=10)
#         mask = mask.to(device)
        
#         # Perceptual loss: encourage watermarked audio to be similar to original
#         loss_perc = criterion_perc(s_w, s)
        
#         # Detection loss: the detector should output 1 where watermark is present and 0 where dropped.
#         if use_message:
#             detection_output, message_output = detector(s_w_aug)
#         else:
#             detection_output = detector(s_w_aug)
        
#         loss_det = criterion_det(detection_output, mask)
        
#         if use_message:
#             # For message loss, target is the random_message (broadcasted over batch)
#             # Here we assume the detector outputs probabilities for each bit (shape: (batch, 16))
#             loss_msg = criterion_msg(message_output, random_message)
#             total_loss = lambda_perc * loss_perc + lambda_det * loss_det + lambda_msg * loss_msg
#         else:
#             total_loss = lambda_perc * loss_perc + lambda_det * loss_det
        
#         total_loss.backward()
#         optimizer.step()
#         running_loss += total_loss.item()
        
#         if (i+1) % 20 == 0:
#             print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(subset_dataloader)}], Loss: {total_loss.item():.4f}")
    
#     avg_loss = running_loss / len(subset_dataloader)
#     print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {avg_loss:.4f}")

# # Save model weights after training
# if use_message:
#     torch.save(generator.state_dict(), "watermark_generator_with_message.pth")
#     torch.save(detector.state_dict(), "watermark_detector_with_message.pth")
# else:
#     torch.save(generator.state_dict(), "watermark_generator.pth")
#     torch.save(detector.state_dict(), "watermark_detector.pth")

# #############################
# # 4. Evaluation
# #############################

# generator.eval()
# detector.eval()

# with torch.no_grad():
#     for s, _ in subset_dataloader:
#         s = s.to(device)
#         if use_message:
#             batch_size = s.size(0)
#             random_message = torch.randint(0, 2, (batch_size, 16)).float().to(device)
#             delta = generator(s, random_message)
#         else:
#             delta = generator(s)
#         s_w = s + delta
#         s_w_aug, mask = apply_watermark_masking_balanced(s, s_w, num_segments=10)
#         mask = mask.to(device)
#         if use_message:
#             detection_output, message_output = detector(s_w_aug)
#         else:
#             detection_output = detector(s_w_aug)
#         break

# # Compute detection metrics (using our previous functions)
# metrics = compute_all_metrics(detection_output, mask, threshold=0.5)
# print("Detection Metrics:")
# print(f"Accuracy: {metrics['accuracy']*100:.2f}%")
# print(f"ROC AUC Score: {metrics['roc_auc']:.4f}")
# print(f"Precision: {metrics['precision']:.4f}")
# print(f"Recall: {metrics['recall']:.4f}")
# print(f"F1 Score: {metrics['f1']:.4f}")
# print("Confusion Matrix:")
# print(metrics["confusion_matrix"])

# if use_message:
#     # For message decoding, let's compute accuracy per bit
#     # Binarize message_output with threshold 0.5
#     pred_message = (message_output >= 0.5).float()
#     # Assume random_message from evaluation above is our target message:
#     target_message = random_message
#     message_acc = (pred_message == target_message).float().mean().item()
#     print(f"Message Decoding Accuracy: {message_acc*100:.2f}%")






In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
from tqdm import tqdm 
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#############################
# 1. Model Architectures
#############################

# --- Original WatermarkGenerator (without message) ---
class WatermarkGenerator(nn.Module):
    def __init__(self):
        super(WatermarkGenerator, self).__init__()
        # Encoder: Downsample input from 16000 -> 4000 samples
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7), 
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=15, stride=2, padding=7),  
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=15, stride=2, padding=7),  
            nn.ReLU()
        )
        # Bottleneck: LSTM to capture temporal dependencies
        self.lstm = nn.LSTM(input_size=64, hidden_size=64, num_layers=1, batch_first=True)
        # Decoder: Upsample back from 4000 -> 16000
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(64, 32, kernel_size=15, stride=2, padding=7, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=15, stride=2, padding=7, output_padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 1, kernel_size=15, stride=1, padding=7)
        )
    
    def forward(self, x):
        # x: (batch, 1, 16000)
        encoded = self.encoder(x)  # -> (batch, 64, 4000)
        encoded_transposed = encoded.transpose(1, 2)  # -> (batch, 4000, 64)
        lstm_out, _ = self.lstm(encoded_transposed)      # -> (batch, 4000, 64)
        lstm_out = lstm_out.transpose(1, 2)                # -> (batch, 64, 4000)
        decoded = self.decoder(lstm_out)                   # -> (batch, 1, 16000)
        watermark_delta = 0.01 * torch.tanh(decoded)
        return watermark_delta

# --- Extended Generator with Optional 16-bit Message Embedding ---
class WatermarkGeneratorWithMessage(WatermarkGenerator):
    def __init__(self, message_bits=16):
        super(WatermarkGeneratorWithMessage, self).__init__()
        self.use_message = True
        self.message_bits = message_bits
        # Create a learnable embedding table for message bits:
        # For each bit (0 or 1) and for each bit position, produce an adjustment vector of size 64.
        # Shape: (2, message_bits, 64)
        self.embedding = nn.Parameter(torch.randn(2, message_bits, 64) * 0.01)
        
    def forward(self, x, message):
        """
        x: (batch, 1, 16000)
        message: (batch, message_bits) containing binary values (0 or 1)
        """
        encoded = self.encoder(x)            # (batch, 64, 4000)
        encoded_transposed = encoded.transpose(1, 2)  # (batch, 4000, 64)
        lstm_out, _ = self.lstm(encoded_transposed)      # (batch, 4000, 64)
        # Incorporate message embedding:
        # For each sample, average embeddings for each bit.
        batch_size, t, feat = lstm_out.shape
        message = message.long()  # ensure integers 0 or 1
        message_embs = []
        for i in range(self.message_bits):
            # For each bit position i, select the corresponding embedding vector
            emb_i = self.embedding[:, i, :]  # shape (2, 64)
            # message[:, i] selects row 0 or 1 for each sample
            message_emb_i = emb_i[message[:, i]]  # shape (batch, 64)
            message_embs.append(message_emb_i)
        # Average over the message bits: shape (batch, 64)
        message_emb = torch.stack(message_embs, dim=1).mean(dim=1)
        # Expand message_emb to add to every timestep in the LSTM output:
        message_emb_expanded = message_emb.unsqueeze(1).expand(-1, t, -1)  # (batch, 4000, 64)
        lstm_out = lstm_out + message_emb_expanded
        lstm_out = lstm_out.transpose(1, 2)  # (batch, 64, 4000)
        decoded = self.decoder(lstm_out)  # (batch, 1, 16000)
        watermark_delta = 0.01 * torch.tanh(decoded)
        return watermark_delta

# --- Extended Detector to also Decode the 16-bit Message ---
class WatermarkDetectorWithMessage(nn.Module):
    def __init__(self, message_bits=16):
        super(WatermarkDetectorWithMessage, self).__init__()
        self.message_bits = message_bits
        # Use a similar encoder as before
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=15, stride=2, padding=7),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=15, stride=2, padding=7),
            nn.ReLU()
        )
        # Upsampling head: output 1 channel for detection, and message_bits channels for message decoding.
        self.upsample = nn.Sequential(
            nn.ConvTranspose1d(64, 32, kernel_size=15, stride=2, padding=7, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=15, stride=2, padding=7, output_padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 1 + message_bits, kernel_size=15, stride=1, padding=7)
        )
        
    def forward(self, x):
        # x: (batch, 1, 16000)
        encoded = self.encoder(x)   # (batch, 64, 4000)
        out = self.upsample(encoded)  # (batch, 1+message_bits, 16000)
        # First channel for detection probability
        detection = torch.sigmoid(out[:, 0:1, :])
        # The remaining channels for message decoding:
        # We average over the time dimension to get one prediction per bit.
        message_logits = out[:, 1:, :].mean(dim=2)  # (batch, message_bits)
        message_prob = torch.sigmoid(message_logits)  # Probabilities for each bit
        return detection, message_prob

#############################
# 2. Augmentation Functions
#############################

def apply_watermark_masking_balanced(s, s_w, num_segments=10):
    batch_size, channels, T = s_w.shape
    s_w_aug = s_w.clone()
    mask = torch.ones_like(s_w)  # start with all ones
    seg_length = T // num_segments
    
    for b in range(batch_size):
        # Randomly select half the segments to drop watermark
        indices = list(range(num_segments))
        random.shuffle(indices)
        drop_indices = indices[:num_segments//2]  # drop watermark in half the segments
        for idx in drop_indices:
            start = idx * seg_length
            end = start + seg_length
            # Option: revert segment to original audio s (i.e., watermark removed)
            s_w_aug[b, :, start:end] = s[b, :, start:end]
            mask[b, :, start:end] = 0.0
    return s_w_aug, mask

def apply_adversarial_augmentation(x, noise_std=0.005):
    """
    Apply a simple differentiable adversarial-like augmentation,
    such as adding random Gaussian noise or a simple filtering.
    This helps simulate real-world distortions.
    """
    noise = noise_std * torch.randn_like(x)
    x_aug = x + noise
    return x_aug

#############################
# 3. Training Loop
#############################

# Choose whether to use the message-embedding variant.
use_message = True

if use_message:
    generator = WatermarkGeneratorWithMessage(message_bits=16).to(device)
    detector  = WatermarkDetectorWithMessage(message_bits=16).to(device)
else:
    generator = WatermarkGenerator().to(device)
    detector  = WatermarkDetector().to(device)

# Define loss functions:
criterion_perc = nn.L1Loss()    # Perceptual loss (original vs watermarked)
criterion_det  = nn.BCELoss()   # Detection loss for watermark presence
if use_message:
    criterion_msg = nn.BCELoss()   # For message decoding; we'll treat each bit as binary

# Loss weight factors (you can tune these)
lambda_perc = 10.0
lambda_det  = 1.0
lambda_msg  = 1.0  # weight for message decoding loss

optimizer = optim.Adam(list(generator.parameters()) + list(detector.parameters()), lr=1e-4)

num_epochs = 10



for epoch in range(num_epochs):
    running_loss = 0.0
    generator.train()
    detector.train()
    
    # Create a tqdm progress bar for the current epoch
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch [{epoch+1}/{num_epochs}]")
    for i, (s, _) in pbar:
        s = s.to(device)  # (batch, 1, 16000)
        optimizer.zero_grad()
        
        # Optionally, apply adversarial augmentation to original audio:
        s_aug = apply_adversarial_augmentation(s, noise_std=0.005)
        
        # Generate watermark delta:
        if use_message:
            # Generate a random binary message for each sample, shape (batch, 16)
            batch_size = s.size(0)
            random_message = torch.randint(0, 2, (batch_size, 16)).float().to(device)
            delta = generator(s_aug, random_message)
        else:
            delta = generator(s_aug)
        
        s_w = s + delta  # watermarked audio
        
        # Apply watermark masking to get a balanced ground truth mask:
        s_w_aug, mask = apply_watermark_masking_balanced(s, s_w, num_segments=10)
        mask = mask.to(device)
        
        # Perceptual loss: encourage watermarked audio to be similar to original
        loss_perc = criterion_perc(s_w, s)
        
        # Detection loss: the detector should output 1 where watermark is present and 0 where dropped.
        if use_message:
            detection_output, message_output = detector(s_w_aug)
        else:
            detection_output = detector(s_w_aug)
        loss_det = criterion_det(detection_output, mask)
        
        if use_message:
            # For message loss, target is the random_message
            loss_msg = criterion_msg(message_output, random_message)
            total_loss = lambda_perc * loss_perc + lambda_det * loss_det + lambda_msg * loss_msg
        else:
            total_loss = lambda_perc * loss_perc + lambda_det * loss_det
        
        total_loss.backward()
        optimizer.step()
        running_loss += total_loss.item()
        
        # Update the progress bar with the current loss value
        pbar.set_postfix(loss=total_loss.item())
    
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {avg_loss:.4f}")


# Save model weights after training
if use_message:
    torch.save(generator.state_dict(), "watermark_generator_with_message.pth")
    torch.save(detector.state_dict(), "watermark_detector_with_message.pth")
else:
    torch.save(generator.state_dict(), "watermark_generator.pth")
    torch.save(detector.state_dict(), "watermark_detector.pth")

#############################
# 4. Evaluation
#############################

generator.eval()
detector.eval()

with torch.no_grad():
    for s, _ in dataloader:
        s = s.to(device)
        if use_message:
            batch_size = s.size(0)
            random_message = torch.randint(0, 2, (batch_size, 16)).float().to(device)
            delta = generator(s, random_message)
        else:
            delta = generator(s)
        s_w = s + delta
        s_w_aug, mask = apply_watermark_masking_balanced(s, s_w, num_segments=10)
        mask = mask.to(device)
        if use_message:
            detection_output, message_output = detector(s_w_aug)
        else:
            detection_output = detector(s_w_aug)
        break

# Compute detection metrics (using our previous functions)
metrics = compute_all_metrics(detection_output, mask, threshold=0.5)
print("Detection Metrics:")
print(f"Accuracy: {metrics['accuracy']*100:.2f}%")
print(f"ROC AUC Score: {metrics['roc_auc']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1']:.4f}")
print("Confusion Matrix:")
print(metrics["confusion_matrix"])

if use_message:
    pred_message = (message_output >= 0.5).float()
    # Assume random_message from evaluation above is our target message:
    target_message = random_message
    message_acc = (pred_message == target_message).float().mean().item()
    print(f"Message Decoding Accuracy: {message_acc*100:.2f}%")






In [21]:
from torch.utils.data import random_split, DataLoader

# Assuming 'audio_dataset' is an instance of your AudioDataset class
dataset_size = len(audio_dataset)
test_size = int(0.1 * dataset_size)   # For example, use 10% of the data for testing
train_size = dataset_size - test_size

# Split the dataset randomly into training and testing sets
train_dataset, test_dataset = random_split(audio_dataset, [train_size, test_size])
print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")

# Create a DataLoader for the test dataset
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

print("Test DataLoader created.")




In [22]:
use_message = True


In [24]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

# Define SI-SNR computation (for imperceptibility evaluation)
def compute_si_snr(s, s_hat, eps=1e-8):
    """
    Compute the Scale-Invariant Signal-to-Noise Ratio (SI-SNR) between original s and watermarked s_hat.
    s, s_hat: tensors of shape (batch, 1, T)
    Returns: average SI-SNR in dB.
    """
    # Flatten each audio sample to (batch, T)
    s = s.view(s.size(0), -1)
    s_hat = s_hat.view(s_hat.size(0), -1)
    # Compute the scaling factor
    scale = torch.sum(s_hat * s, dim=1, keepdim=True) / (torch.sum(s * s, dim=1, keepdim=True) + eps)
    s_target = scale * s
    e = s_hat - s_target
    si_snr = 10 * torch.log10(torch.sum(s_target ** 2, dim=1) / (torch.sum(e ** 2, dim=1) + eps))
    return si_snr.mean().item()

# ============================
# Comprehensive Evaluation Code
# ============================

# Assume your trained models are saved or currently loaded and in evaluation mode.
generator.eval()
detector.eval()

# Choose your evaluation DataLoader. Here we use 'test_dataloader'
# It can be your holdout set or the subset you reserved for evaluation.
si_snr_total = 0.0
batch_count = 0

# Containers to aggregate detection predictions and targets over the test set
all_detection_preds = []
all_detection_targets = []
if use_message:
    all_message_preds = []
    all_message_targets = []

# Iterate over the evaluation dataset
with torch.no_grad():
    for s, _ in test_dataloader:  # Replace 'test_dataloader' with your evaluation loader
        s = s.to(device)  # (batch, 1, 16000)
        if use_message:
            batch_size = s.size(0)
            # Generate a random binary message for evaluation; shape: (batch, 16)
            message = torch.randint(0, 2, (batch_size, 16)).float().to(device)
            delta = generator(s, message)
        else:
            delta = generator(s)
        s_w = s + delta
        
        # Apply balanced watermark masking to get a balanced ground truth mask
        s_w_aug, mask = apply_watermark_masking_balanced(s, s_w, num_segments=10)
        mask = mask.to(device)
        
        # Obtain detector output; if using message embedding, get both detection and message outputs.
        if use_message:
            detection_output, message_output = detector(s_w_aug)
        else:
            detection_output = detector(s_w_aug)
        
        # Compute SI-SNR for this batch (comparing original vs. watermarked audio)
        batch_si_snr = compute_si_snr(s, s_w)
        si_snr_total += batch_si_snr
        batch_count += 1
        
        # Collect detection outputs and ground truth masks
        all_detection_preds.append(detection_output)
        all_detection_targets.append(mask)
        
        # Collect message outputs and targets if using message embedding
        if use_message:
            all_message_preds.append(message_output)
            all_message_targets.append(message)
        
        break  # Process one batch for this example

# Concatenate outputs across batches
all_detection_preds = torch.cat(all_detection_preds, dim=0)
all_detection_targets = torch.cat(all_detection_targets, dim=0)
detection_metrics = compute_all_metrics(all_detection_preds, all_detection_targets, threshold=0.5)
average_si_snr = si_snr_total / batch_count

print(f"Average SI-SNR over evaluation set: {average_si_snr:.2f} dB")
print("Detection Metrics:")
print(f"Accuracy: {detection_metrics['accuracy']*100:.2f}%")
print(f"ROC AUC Score: {detection_metrics['roc_auc']:.4f}")
print(f"Precision: {detection_metrics['precision']:.4f}")
print(f"Recall: {detection_metrics['recall']:.4f}")
print(f"F1 Score: {detection_metrics['f1']:.4f}")
print("Confusion Matrix:")
print(detection_metrics["confusion_matrix"])

if use_message:
    # Concatenate message predictions and targets and compute accuracy
    all_message_preds = torch.cat(all_message_preds, dim=0)
    all_message_targets = torch.cat(all_message_targets, dim=0)
    pred_message = (all_message_preds >= 0.5).float()
    message_acc = (pred_message == all_message_targets).float().mean().item()
    print(f"Message Decoding Accuracy: {message_acc*100:.2f}%")





In [30]:
import torch
import torchaudio
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#############################
# Utility Functions for Segmentation & Reassembly
#############################

def segment_audio(waveform, segment_length=16000):
    segments = []
    total_samples = waveform.shape[1]
    for start in range(0, total_samples, segment_length):
        end = start + segment_length
        segment = waveform[:, start:end]
        if segment.shape[1] < segment_length:
            pad_amount = segment_length - segment.shape[1]
            segment = F.pad(segment, (0, pad_amount))
        segments.append(segment)
    return segments

def reassemble_audio(segments):
    return torch.cat(segments, dim=1)

use_message = True  # Set to True to use the message embedding variant

if use_message:
    generator = WatermarkGeneratorWithMessage(message_bits=16).to(device)
    detector  = WatermarkDetectorWithMessage(message_bits=16).to(device)
    # Optionally, load saved weights:
    # generator.load_state_dict(torch.load("watermark_generator_with_message.pth"))
    # detector.load_state_dict(torch.load("watermark_detector_with_message.pth"))
else:
    generator = WatermarkGenerator().to(device)
    detector  = WatermarkDetector().to(device)
    # Optionally, load saved weights:
    # generator.load_state_dict(torch.load("watermark_generator.pth"))
    # detector.load_state_dict(torch.load("watermark_detector.pth"))

generator.eval()
detector.eval()

#############################
# 2. Load an Audio File and Preprocess
#############################
audio_filepath = "file_example_WAV_1MG.wav"  # Update with your file path
waveform, sample_rate = torchaudio.load(audio_filepath)
print("Original audio shape:", waveform.shape, "Sample Rate:", sample_rate)

# Convert to mono if stereo (average channels)
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)
    print("Converted to mono:", waveform.shape)

# Resample to 16 kHz if needed.
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)
    sample_rate = 16000
    print("Resampled audio shape:", waveform.shape)

# IMPORTANT: Do not crop/pad the waveform to 1 second.
# We want to segment the entire waveform.
total_length = waveform.shape[1]
print(f"Total audio length (in samples): {total_length}, which is {total_length/16000:.2f} seconds.")

# ---------------------------
# 3. Segment the Audio into 1-second Clips
# ---------------------------
segments = segment_audio(waveform, segment_length=16000)
print(f"Audio segmented into {len(segments)} segments.")

#############################
# 4. Watermark Each Segment and Reassemble
#############################
watermarked_segments = []
for segment in segments:
    # Add batch dimension: shape becomes (1, 1, 16000)
    segment_batch = segment.unsqueeze(0).to(device)
    if use_message:
        # Define a custom 16-bit message for each segment.
        # For example, here we use the same custom message for all segments.
        custom_message_list = [1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0]
        custom_message = torch.tensor(custom_message_list, dtype=torch.float).unsqueeze(0).to(device)
        delta = generator(segment_batch, custom_message)
    else:
        delta = generator(segment_batch)
    watermarked_segment = segment_batch + delta
    watermarked_segments.append(watermarked_segment.squeeze(0).cpu().detach())

# Reassemble the watermarked segments
watermarked_audio_full = reassemble_audio(watermarked_segments)
print("Full watermarked audio shape (samples):", watermarked_audio_full.shape)
print("Full watermarked audio duration (seconds):", watermarked_audio_full.shape[1]/16000)

# Save the full watermarked audio to disk
torchaudio.save("watermarked_audio.wav", watermarked_audio_full, sample_rate)
print("Watermarked audio saved as 'watermarked_audio.wav'.")

#############################
# 5. Run the Detector on the Full Watermarked Audio (Segment-wise Processing)
#############################
# Here, we'll process each segment with the detector and then reassemble the detector outputs.
detector_outputs = []
decoded_messages = []  # Only if using message embedding

for segment in watermarked_segments:
    segment_batch = segment.unsqueeze(0).to(device)  # (1, 1, 16000)
    if use_message:
        detection_output, message_output = detector(segment_batch)
        decoded_messages.append(message_output.cpu().detach())
    else:
        detection_output = detector(segment_batch)
    detector_outputs.append(detection_output.cpu().detach())

# Reassemble detector outputs along the time dimension
detector_full_output = reassemble_audio([seg.squeeze(0) for seg in detector_outputs])
print("Full detector output shape:", detector_full_output.shape)

#############################
# 6. Visualization for the First Segment
#############################
watermarked_np = watermarked_segments[0].squeeze(0).detach().cpu().numpy()
detection_np   = detector_outputs[0].squeeze(0).squeeze(0).detach().cpu().numpy()

plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(watermarked_np)
plt.title("Watermarked Audio Waveform (Segment 1)")
plt.xlabel("Sample Index")
plt.ylabel("Amplitude")

plt.subplot(2, 1, 2)
plt.plot(detection_np)
plt.title("Detector Output (Probability) (Segment 1)")
plt.xlabel("Sample Index")
plt.ylabel("Probability")
plt.tight_layout()
plt.show()

#############################
# 7. (If Using Message Embedding) Evaluate Message Decoding for the First Segment
#############################
if use_message:
    decoded_message = (decoded_messages[0] >= 0.5).float()
    print("Custom input message:")
    print(custom_message[0].cpu().numpy())
    print("Decoded message from detector (Segment 1):")
    print(decoded_message[0].cpu().numpy())






