<a href="https://colab.research.google.com/github/Kratosgado/audio-steganography/blob/staging/steg-ai/core_modules/Final_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install kagglehub

import kagglehub
from kagglehub.config import get_kaggle_credentials



In [None]:
kagglehub.login(get_kaggle_credentials())

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.


In [2]:
%%capture
!pip install stable-baselines3 shimmy torchvggish

In [3]:
import torch
import torchaudio
import torch.nn as nn
from scipy.fft import dct, idct
import librosa
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal
import torch.nn.functional as F
import gymnasium as gym
import random
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


In [4]:
# warnings.filterwarnings("ignore")
torch.manual_seed(42)
np.random.seed(42)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# --- UTILITY FUNCTIONS ---

In [5]:
NORM_NUM = 32768
def string_to_bits(message):
  """Convert a string to a sequence of bits."""
  return np.array([int(bit) for bit in ''.join(format(ord(char), '08b') for char in message)], dtype=np.float32)

def bits_to_string(bits):
  """Convert a sequence of bits to a string."""
  text = ''
  for bit in range(0, len(bits), 8):
    if bit + 8 <= len(bits):
      byte = ''.join(str(bit) for bit in bits[bit:bit + 8])
      text += chr(int(byte, 2))
  return text

class CustomLoggingCallback(BaseCallback):
  """ A custom callback that logs additional information from the environment. """
  def __init__(self, verbose=0):
    super(CustomLoggingCallback, self).__init__(verbose)

  def _on_step(self) -> bool:
    """This method is called after each step in the environment.   """
    # Accessing environment infos from the VecEnv wrapper
    if self.locals.get('infos'):
      for info in self.locals['infos']:
        # Ensure info is a dictionary and contains the required keys
        if isinstance(info, dict):
          self.logger.record('rollout/ep_snr', info['snr'])
          self.logger.record('rollout/ep_reward', info['reward'])
          # self.logger.record('rollout/ep_snr', info['psnr'])
          self.logger.record('rollout/ep_ber', info['ber'])
          self.logger.record('rollout/ep_extraction_accuracy', info['extraction_accuracy'])
          # self.logger.record('rollout/ep_detection_prob', info['detection_prob'])
          self.logger.record('rollout/ep_action', info['action'])
    return True

# --- HYPERPARAMETERS ---

In [10]:
class Config:
  """Class to hold all hyperparameters."""
  #Audio processing
  FRAME_SIZE = 1024 #2048
  HOP_LENGTH = 256 #512
  NON_CRITICAL_PERCENT = 0.1 # Modify bottom 10% of coeffs by magnitude
  STATE_DIM = 1024
  SAMPLE_RATE = 22050
  N_MELS = 128

  # RL Training
  EPISODES = 80000
  # EPISODES = 100000
  LEARNING_RATE_ACTOR = 3e-4
  LEARNING_RATE_CRITIC = 3e-4
  GAMMA = 0.99 # Discount factor
  GAE_LAMBDA = 0.95 # Lambda for Generalized Advantage Estimation
  PPO_EPSILON = 0.2 # Epsilon for clipping in PPO
  PPO_EPOCHS = 10 # Number of epochs for PPO update
  N_STEPS = 1024
  CLIP_RANGE=0.2
  ENT_COEF=0.01
  BATCH_SIZE = 64
  LEARNING_RATE = 3e-4

  # Environment Network Pre-training
  PRETRAIN_EPOCHS = 10
  PRE_TRAIN_LR = 1e-3
  PRE_TRAIN_BATCH_SIZE = 32

  # Reward weights
  SNR_WEIGHTS = 0.4 # Weight for SNR in reward. Tuned to be on a similar scale to detection prob.
  # DETECTION_WEIGHT = 1.0 # Weight for detection probability in reward.
  # IMPERCEPTIBILITY_WEIGHT = 0.3
  # UNDETECTABILITY_WEIGHT = 0.3
  EXTRACTION_ACCURACY_WEIGHT = 0.6

  # spread spectrum parameters
  CARRIER_FREQ_SIZE = 16
  CHIP_RATE_SIZE = 8
  SNR_DB_SIZE = 8
  MSG_SIZE = 16
  TOTAL_PARAM_SIZE = CARRIER_FREQ_SIZE + CHIP_RATE_SIZE + SNR_DB_SIZE + MSG_SIZE
cfg = Config()

# --- MODULE 1: AUDIO PREPROCESSOR ---

In [30]:
class AudioPreprocessor:
    """Handles audio loading, MDCT, and inverse MDCT with sign preservation."""

    @staticmethod
    def load_audio(path):
        """Load WAV audio file and resample to cfg.SAMPLE_RATE"""
        audio, _ = librosa.load(path, sr=cfg.SAMPLE_RATE)
        return AudioPreprocessor.normalize_audio(audio)

    @staticmethod
    def resample_audio(waveform, sr) -> np.ndarray:
        """Resample audio to cfg.SAMPLE_RATE"""
        audio = librosa.resample(waveform, orig_sr=sr, target_sr=cfg.SAMPLE_RATE)
        return AudioPreprocessor.normalize_audio(audio)

    @staticmethod
    def normalize_audio(waveform) -> np.ndarray:
        """Normalize audio to [-1, 1]"""
        return waveform / np.max(np.abs(waveform))

    @staticmethod
    def stft(waveform: np.ndarray):
        """Compute Short-Time Fourier Transform (STFT)"""
        return librosa.stft(waveform, n_fft=cfg.FRAME_SIZE, hop_length=cfg.HOP_LENGTH)

    @staticmethod
    def istft(stft_matrix, length):
        """Compute Inverse Short-Time Fourier Transform (ISTFT)"""
        return librosa.istft(stft_matrix, hop_length=cfg.HOP_LENGTH, n_fft=cfg.FRAME_SIZE, length=length)

    @staticmethod
    def compute_mdct(waveform):
        """
        Compute Modified Discrete Cosine Transform (MDCT) using STFT and DCT.
        Returns both magnitudes and phases to preserve sign information.
        """
        stft = AudioPreprocessor.stft(waveform)
        magnitudes = np.abs(stft)
        phases = np.angle(stft)
        return magnitudes, phases

    @staticmethod
    def get_non_critical_coeffs(mdct_coeffs, percentile=10):
        """Identify non-critical coefficients (lowest magnitude)"""
        magnitudes = np.abs(mdct_coeffs)
        threshold = np.percentile(magnitudes.flatten(), percentile)
        return magnitudes < threshold

    @staticmethod
    def reconstruct_audio(magnitudes, phases, length):
        """
        Reconstruct audio from magnitude and phase.
        The phase information preserves the signs even if magnitudes are modified.
        """
        # Reconstruct complex STFT
        stft = magnitudes * np.exp(1j * phases)

        # Inverse STFT to get audio
        reconstructed_audio = AudioPreprocessor.istft(stft, length)
        return reconstructed_audio

    @staticmethod
    def plot_spectrogram(waveform, sr, title):
        """Visualize audio spectrogram"""
        plt.figure(figsize=(10, 4))
        S = librosa.amplitude_to_db(np.abs(librosa.stft(waveform)), ref=np.max)
        librosa.display.specshow(S, sr=sr, x_axis='time', y_axis='log')
        plt.colorbar(format='%+2.0f dB')
        plt.title(title)
        plt.tight_layout()
        plt.show()

    @staticmethod
    def plot_mdct_coefficients(mdct_coeffs, title="MDCT Coefficients"):
        """Visualize MDCT coefficients"""
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(np.abs(mdct_coeffs), aspect='auto', origin='lower')
        plt.colorbar()
        plt.title(f"{title} - Magnitudes")
        plt.xlabel("Time Frame")
        plt.ylabel("Frequency Bin")

        plt.subplot(1, 2, 2)
        plt.imshow(np.sign(mdct_coeffs), aspect='auto', origin='lower', cmap='RdBu')
        plt.colorbar()
        plt.title(f"{title} - Signs")
        plt.xlabel("Time Frame")
        plt.ylabel("Frequency Bin")

        plt.tight_layout()
        plt.show()

    @staticmethod
    def save_audio(audio: np.ndarray, sr: int, path: str):
        """Save audio to file"""
        sf.write(path, audio, sr)

# --- MODULE 2: EMBEDDING/EXTRACTION MODULE ---
Main Sign encoding class

In [8]:
from abc import ABC, abstractmethod

class EmbeddingModule(ABC):
  def __init__(self):
    pass

  @abstractmethod
  def embed(self, *args, **kwargs) -> np.ndarray:
    pass

  @abstractmethod
  def extract(self, *args, **kwargs) -> list[int]:
    pass

  @abstractmethod
  def set_parameters(self, action):
    pass

class SignEncoding(EmbeddingModule):
  """Embeds and extracts messages from MDCT coefficients."""

  def set_parameters(self, action):
    return action[0]

  def embed(self, waveform, action, msg_bits: np.ndarray, **kwargs):
    """Embed message using sign encoding in non-critical coefficients"""
    magnitudes, phases = AudioPreprocessor.compute_mdct(waveform)
    self.mask = AudioPreprocessor.get_non_critical_coeffs(magnitudes)
    alpha = action
    # Apply mask to get non-critical coefficients
    coeffs = magnitudes.copy()
    non_critical = coeffs[self.mask]

    # Ensure we have enough coefficients for the message
    if len(non_critical) < len(msg_bits):
        raise ValueError("Message too long for available non-critical coefficients")

    # Embed message using sign encoding
    for i, bit in enumerate(msg_bits):
        sign = 1 if bit == 1 else -1
        non_critical[i] = sign * np.abs(non_critical[i]) * (1 + alpha)

    # Update coefficients
    coeffs[self.mask] = non_critical
    self.magnitudes = coeffs


    return AudioPreprocessor.reconstruct_audio(coeffs, phases, len(waveform))

  def extract(self, stego_waveform, bits_len, **kwargs):
    """Extract message from non-critical coefficients"""
    # Apply mask to get non-critical coefficients
    # magnitudes, _ = AudioPreprocessor.compute_mdct(stego_waveform)
    non_critical = self.magnitudes[self.mask]

    # Extract message from sign
    msg_bits = []
    for i in range(bits_len):
        sign = 1 if non_critical[i] >= 0 else -1
        bit = 1 if sign > 0 else 0
        msg_bits.append(bit)

    # Convert binary to string
    return msg_bits


### Spread Spectrum

In [34]:
class SpreadSpectrum(EmbeddingModule):
    def __init__(self):
        """Initialize the Spread Spectrum steganography system"""
        # Parameters for Gold sequence generation (example values)
        self.taps1 = [5, 2]
        self.taps2 = [5, 4, 2, 1]  # Must be a preferred pair with taps1
        self.seed1 = 0b11111
        self.seed2 = 0b10101

        self.action_ranges = {
            "carrier_freq": (5000, 15000),
            "chip_rate": (3, 200),
            "snr": (20, 150),
        }

    def _scale_action(self, normalized_val, low, high):
        """Scale from [-1, 1] to [low, high]"""
        return int(low + (normalized_val) * (high - low))

    def set_parameters(self, action):
        """
        Parameters(action):
          carrier_freq (int): Carrier frequency in Hz for embedding
          chip_rate (int): How many samples per bit (spreading factor)
          snr (int): Desired Signal to noise ratio in dB for embedding
        """
        # Scale normalized actions to original ranges
        carrier_freq = self._scale_action(
            action[0], *self.action_ranges["carrier_freq"]
        )
        chip_rate = self._scale_action(action[1], *self.action_ranges["chip_rate"])
        snr_db = self._scale_action(action[2], *self.action_ranges["snr"])
        return [carrier_freq, chip_rate, snr_db]

    def _generate_m_sequence(self, taps, length, initial_state):
        """Generate an m-sequence."""
        lfsr = initial_state
        seq = np.zeros(length)
        # mask = sum(1 << (t - 1) for t in taps) # Mask is not used in this implementation

        for i in range(length):
            seq[i] = 1 if (lfsr & 1) else -1  # Convert to bipolar (-1, 1)
            feedback = 0
            for tap in taps:
                feedback ^= (lfsr >> (tap - 1)) & 1
            # Adjust the shift based on the number of bits in the initial state
            lfsr = (lfsr >> 1) | (
                feedback << (len(bin(initial_state)) - 3)
            )  # Assuming initial_state is not 0

        return seq

    def _generate_spreading_code(self, length):
        """Generate a Gold sequence."""
        # The lengths of the m-sequences must be the same
        # and derived from the same primitive polynomial.
        # The taps define the primitive polynomials.
        # The seeds are the initial states of the LFSRs.

        m_seq1 = self._generate_m_sequence(self.taps1, length, self.seed1)
        m_seq2 = self._generate_m_sequence(self.taps2, length, self.seed2)

        # XOR the two m-sequences
        gold_seq = m_seq1 * m_seq2  # For bipolar sequences, XOR is multiplication

        return gold_seq

    def _int_to_bits(self, decimal_value, num_bits):
        """Convert a decimal value to a binary string of a fixed number of bits."""
        binary_string = bin(decimal_value)[2:].zfill(num_bits)
        return [int(bit) for bit in binary_string]

    def _bits_to_int(self, binary_string):
        """Convert a binary string to a decimal value."""
        return int(binary_string, 2)

    def _embed_bits_lsb(self, audio, bits_to_embed, start_sample=0):
        """Embed a sequence of bits into the least significant bits of audio samples."""
        audio_int = (audio * NORM_NUM).astype(np.int16)

        if len(bits_to_embed) > len(audio_int):
            raise ValueError("Not enough audio samples to embed all bits.")

        # Create a copy to modify
        audio_int_modified = audio_int.copy()

        # Embed the bits starting from start_sample
        for i, bit in enumerate(bits_to_embed):
            if start_sample + i < len(audio_int_modified):
                # Replace the LSB of each 16-bit sample
                sample_index = start_sample + i
                # Clear the LSB and set it to the new bit
                audio_int_modified[sample_index] = (
                    audio_int_modified[sample_index] & 0xFE
                ) | bit

        return audio_int_modified.astype(np.float32) / NORM_NUM

    def _extract_bits_lsb(self, audio, num_bits_to_extract, start_sample=0):
        """Extract a sequence of bits from the least significant bits of audio samples."""
        audio_int = (audio * NORM_NUM).astype(np.int16)

        if num_bits_to_extract + start_sample > len(audio_int):
            raise ValueError("Cannot extract more bits than available LSBs.")

        extracted_bits = []
        for i in range(num_bits_to_extract):
            sample_index = start_sample + i
            # Extract the LSB
            bit = audio_int[sample_index] & 1
            extracted_bits.append(bit)

        return extracted_bits

    def embed(self, waveform, msg_bits: np.ndarray, action, **kwargs):
        """
        Embed a message into an audio file using spread spectrum with LSB hiding of parameters.

        Parameters:
          waveform (ndarray): Original audio data
          msg_bits (ndarray): Message bits to embed
          action (tuple): Parameters for embedding (carrier_freq, chip_rate, snr_db)
        """
        # Set parameters from action
        carrier_freq, chip_rate, snr_db = action

        # print(f"Embedding parameters: carrier_freq={carrier_freq}, chip_rate={chip_rate}, snr_db={snr_db}, message_length={message_length}")

        msg_bits_bipolar = msg_bits * 2 - 1 # Convert message bits to bipolar (-1, 1)

        # Generate spreading codes
        code_length = len(msg_bits_bipolar) * chip_rate
        spreading_code = self._generate_spreading_code(code_length)

        # Create the spread message signal
        spread_message = np.repeat(msg_bits_bipolar, chip_rate) * spreading_code

        # Create carrier signal (sine wave at carrier frequency)
        t = np.arange(len(spread_message)) / cfg.SAMPLE_RATE
        carrier = np.sin(2 * np.pi * carrier_freq * t)

        modulated = spread_message * carrier  # Modulate the message onto the carrier

        # Adjust the signal power based on desired SNR
        signal_power = np.var(waveform)
        message_power = np.var(modulated)
        # Add a small epsilon to avoid division by zero if message_power is 0
        epsilon = 1e-8
        desired_message_power = signal_power / (10 ** (snr_db / 10))
        scaling_factor = np.sqrt(desired_message_power / (message_power + epsilon))
        modulated = modulated * scaling_factor

        stego_waveform = waveform.copy()
        # Pad or truncate the modulated signal to match audio length
        if len(modulated) < len(stego_waveform):
            modulated = np.pad(
                modulated, (0, len(stego_waveform) - len(modulated)), "constant"
            )
        else:
            modulated = modulated[: len(stego_waveform)]

        # Add the modulated signal to the audio (starting after parameter bits)
        stego_waveform = stego_waveform + modulated

        # Calculate message length in bits
        message_length = len(msg_bits)

        # Convert parameters to binary strings for LSB embedding
        carrier_freq_bits = self._int_to_bits(carrier_freq, 16)
        chip_rate_bits = self._int_to_bits(chip_rate, 8)
        snr_db_bits = self._int_to_bits(snr_db, 8)
        message_length_bits = self._int_to_bits(message_length, 16)

        # Concatenate all parameter bits
        param_bits = (
            carrier_freq_bits + chip_rate_bits + snr_db_bits + message_length_bits
        )
        stego_waveform = self._embed_bits_lsb(
            stego_waveform, param_bits, start_sample=0
        )

        return stego_waveform

    def extract_param_bits(self, bits, start=None, end=None) -> int:
        """Extract parameter bits from a list of bits."""
        if start is None:
            extracted = bits[:end]
        if end is None:
            extracted = bits[start:]
        extracted = bits[start:end]
        to_str = "".join(str(bit) for bit in extracted)
        return self._bits_to_int(to_str)

    def extract(self, stego_waveform, original_audio_path=None, **kwargs):
        """
        Extract a hidden message from a stego audio file.

        Parameters:
            stego_waveform (ndarray): Stego audio data
            message_length (int): Optional length of the hidden message in bits
            original_audio_path (str): Optional path to original audio for comparison
        """
        extracted_param_bits = self._extract_bits_lsb(
            stego_waveform, cfg.TOTAL_PARAM_SIZE, start_sample=0
        )

        start, end = 0, cfg.CARRIER_FREQ_SIZE
        carrier_freq = self.extract_param_bits(extracted_param_bits, start, end)
        start, end = cfg.CARRIER_FREQ_SIZE, cfg.CARRIER_FREQ_SIZE + cfg.CHIP_RATE_SIZE
        chip_rate = self.extract_param_bits(extracted_param_bits, start, end)
        start, end = (
            cfg.CARRIER_FREQ_SIZE + cfg.CHIP_RATE_SIZE,
            cfg.CARRIER_FREQ_SIZE + cfg.CHIP_RATE_SIZE + cfg.SNR_DB_SIZE,
        )
        snr_db = self.extract_param_bits(extracted_param_bits, start, end)
        start, end = cfg.CARRIER_FREQ_SIZE + cfg.CHIP_RATE_SIZE + cfg.SNR_DB_SIZE, None
        message_length = self.extract_param_bits(extracted_param_bits, start, end)

        # print(f"Extracted parameters: carrier_freq={carrier_freq}, chip_rate={chip_rate}, snr_db={snr_db}, message_length={message_length}")

        # If original audio is provided, subtract it to get just the message
        # if original_audio_path:
        #     y_original, _ = librosa.load(original_audio_path, sr=cfg.SAMPLE_RATE)
        #     # Ensure lengths match before subtraction
        #     min_len = min(len(stego_waveform), len(y_original))
        #     y_diff = stego_waveform[:min_len] - y_original[:min_len]
        # else:
        y_diff = stego_waveform

        # Generate the same spreading code used in embedding
        code_length = message_length * chip_rate
        spreading_code = self._generate_spreading_code(code_length)

        # Create carrier signal
        t = np.arange(len(spreading_code)) / cfg.SAMPLE_RATE
        carrier = np.sin(2 * np.pi * carrier_freq * t)

        # Pad or truncate the carrier to match the difference signal
        if len(carrier) < len(y_diff):
            carrier = np.pad(carrier, (0, len(y_diff) - len(carrier)), "constant")
        else:
            carrier = carrier[: len(y_diff)]

        # Demodulate the signal
        demodulated = y_diff * carrier

        # Correlate with spreading code to extract bits
        extracted_bits = []
        for i in range(message_length):
            start = i * chip_rate
            end = start + chip_rate
            if end > len(demodulated):
                break
            segment = demodulated[start:end]
            code_segment = spreading_code[start:end]

            # Calculate correlation
            correlation = np.sum(segment * code_segment)
            extracted_bits.append(1 if correlation > 0 else 0)

        return extracted_bits

# --- MODULE 3: RL ACTOR-CRITIC NETWORKS ---

In [12]:
class TransferPolicyNetwork(nn.Module):
  def __init__(self, pretrained_feature_extractor):
    super().__init__()
    self.feature_extractor = pretrained_feature_extractor
    for param in self.feature_extractor.parameters():
      param.requires_grad = False

    self.policy_head = nn.Sequential(
        nn.Linear(self.feature_extractor.combiner.out_features, 256),
        nn.ReLU(),
        nn.Linear(256, 1)
    )

  def forward(self, observations):
    features = self.feature_extractor(observations)
    return self.policy_head(features)

# Example usage with VGGish (audio feature extractor)
# vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')
# policy_net = TransferPolicyNetwork(vggish).to(device)

# 7. Feature Extractor for PPO
class AudioFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict):
        super().__init__(observation_space, features_dim=256)

        # Feature-specific processing
        self.mfcc_net = nn.Sequential(
            nn.Linear(observation_space['mfcc'].shape[0], 128),
            nn.ReLU()
        )
        self.sc_net = nn.Linear(1, 32)
        self.rms_net = nn.Linear(1, 32)
        self.zcr_net = nn.Linear(1, 32)
        self.msg_net = nn.Linear(1, 32)

        # Combined processing
        self.combined = nn.Sequential(
            nn.Linear(128+32*4, 256),
            nn.ReLU()
        )

    def forward(self, observations: dict[str, torch.Tensor]) -> torch.Tensor:
        mfcc = self.mfcc_net(observations["mfcc"])
        sc = self.sc_net(observations["spectral_centroid"])
        rms = self.rms_net(observations["rms"])
        zcr = self.zcr_net(observations["zcr"])
        msg = self.msg_net(observations["bits_len"])

        combined = torch.cat([mfcc, sc, rms, zcr, msg], dim=1)
        return self.combined(combined)

# --- MODULE 4: ENVIRONMENT NETWORK (STEGANALYZER) ---

### Transfer Steganalysis

In [13]:
# Function to generate a random message (for the RL environment)
def generate_random_message(length=50):
    """Generates a random string message."""
    import string
    letters = string.ascii_letters + string.digits + string.punctuation + " "
    return ''.join(random.choice(letters) for i in range(length))

def extract_audio_features(waveform, bits_len):
    features = {}
    features['mfcc'] = librosa.feature.mfcc(
        y=waveform, sr=cfg.SAMPLE_RATE,
        n_mfcc=cfg.N_MELS,
        n_fft=cfg.FRAME_SIZE, hop_length=cfg.HOP_LENGTH
    ).mean(axis=1)

    features['spectral_centroid'] = librosa.feature.spectral_centroid(
        y=waveform, sr=cfg.SAMPLE_RATE
    ).mean(keepdims=True)

    features['rms'] = librosa.feature.rms(y=waveform).mean(keepdims=True)
    features['zcr'] = librosa.feature.zero_crossing_rate(y=waveform).mean(keepdims=True)
    features["bits_len"] = bits_len/1000

    return features

In [14]:
class TransferSteganalysis(nn.Module):
  def __init__(self, pretrained_cnn):
    super().__init__()
    self.features = pretrained_cnn.features[:8] # first 8 layers
    self.classifier = nn.Sequential(
        nn.AdaptiveAvgPool1d(output_size=1),
        nn.Flatten(),
        nn.Linear(512, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    x = self.features(x)
    return self.classifier(x)

# Example with pre-trained audio CNN
# pretrained = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
# steganalysis_net = TransferSteganalysis(pretrained).to(device)

### Custom Environment

In [18]:
 #5. Custom Gym Environment for RL Training
class AudioStegoEnv(gym.Env):
  def __init__(self, dataset, method="sign-encoding"):
    super(AudioStegoEnv, self).__init__()
    self.dataset = dataset

    self.action_space = spaces.Box(low=0.0, high=0.1, shape=(1,), dtype=np.float32)
    if method == "spread-spectrum":
        self.action_space = spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32)

    # Define observation space as a dictionary for MultiInputPolicy
    mfcc_dim = cfg.N_MELS
    spectral_centroid_dim = 1 # mean over frames
    rms_dim = 1 # mean over frames
    zcr_dim = 1 # mean over frames
    msg_dim = 1

    self.observation_space = spaces.Dict({
            "mfcc": spaces.Box(low=-np.inf, high=np.inf, shape=(cfg.N_MELS,)),
            "spectral_centroid": spaces.Box(low=0, high=np.inf, shape=(1,)),
            "rms": spaces.Box(low=0, high=np.inf, shape=(1,)),
            "zcr": spaces.Box(low=0, high=np.inf, shape=(1,)),
            "bits_len": spaces.Box(low=0, high=np.inf, shape=(1,)),
        })

    self.embedder: EmbeddingModule = SignEncoding() if method == "sign-encoding" else SpreadSpectrum()

    # Initialize state
    # self.reset()

  def reset_with_audio(self, audio_path, msg):
    self.audio_path = audio_path
    self.original_waveform = AudioPreprocessor.load_audio(audio_path)
    self.msg_bits = string_to_bits(msg)
    self.bits_len = len(self.msg_bits)
    self.current_waveform = self.original_waveform.copy()
    self.current_step = 0
    return self._get_obs(), {} # Return observation and info dictionary

  def reset(self, seed=None, options=None):
    """Reset environment to initial state"""
    super().reset(seed=seed)
    msg_len = int(random.random()*100 + 50)
    self.msg = generate_random_message(msg_len)
    self.msg_bits = string_to_bits(self.msg)
    self.bits_len = len(self.msg_bits)
    tensor, sr, _ = random.choice(self.dataset)
    self.original_waveform = AudioPreprocessor.resample_audio(tensor[0].numpy(), sr=sr)

    self.current_waveform = self.original_waveform.copy()
    # self.current_step = 0
    return self._get_obs(), {} # Return observation and info dictionary

  def _get_obs(self):
    """Extract features and concatenate them as observation"""
    return extract_audio_features(self.current_waveform, self.bits_len)



  def step(self, action):
    """Execute one embedding step"""
    action = self.embedder.set_parameters(action)

    # Embed message
    self.current_waveform = self.embedder.embed(msg_bits = self.msg_bits, waveform=self.original_waveform, action=action)

    # Compute rewards
    snr = self._calculate_snr()
    # psnr = self._calculate_psnr()
    # detection_prob = self._simulate_detection()
    accuracy, ber = self._calculate_accuracy_ber()
    reward = self._calculate_reward(snr,ber, 0, accuracy)

    # Update state
    # self.current_step += 1
    # done = self.current_step >= 10
    info = {
        "snr": snr,
        # "psnr": psnr,
        "ber": ber,
        # "detection_prob": detection_prob,
        "extraction_accuracy": accuracy,
        "reward": reward,
        "action": action
    }
    return self._get_obs(), reward, True, None, info

  def _calculate_snr(self):
    """Calculate Signal-to-Noise Ratio"""
    # Ensure both audio arrays have the same length before calculating noise
    min_len = min(len(self.current_waveform), len(self.original_waveform))
    current_audio_trimmed = self.current_waveform[:min_len]
    original_audio_trimmed = self.original_waveform[:min_len]

    noise = current_audio_trimmed - original_audio_trimmed
    signal_power = np.mean(original_audio_trimmed ** 2)
    noise_power = np.mean(noise ** 2)

    if noise_power == 0:
        return 100  # High SNR if no noise

    return 10 * np.log10(signal_power / noise_power)

  # def _calculate_psnr(self):
  #   if len(self.current_waveform) < len(self.original_waveform):
  #       stego_waveform_padded = np.pad(self.current_waveform, (0, len(self.original_waveform) - len(self.current_waveform)), 'constant')
  #   else:
  #       stego_waveform_padded = self.current_waveform[:len(self.original_waveform)]

  #   psnr = 10 * np.log10(np.max(self.original_waveform ** 2) / np.mean((self.original_waveform - stego_waveform_padded) ** 2)) if np.mean((self.original_waveform - stego_waveform_padded) ** 2) > 0 else float('inf')
  #   return psnr


  def _simulate_detection(self):
    pass
  #   """Simulate steganalysis detection"""
  #   # not real
  #   snr = self._calculate_snr()
  #   return 1 / (1 + np.exp(0.5 * (snr - 50)))  # Logistic function

  def _calculate_accuracy_ber(self):
    """Calculate the accuracy of extracted bits"""
    extracted_bits = self.embedder.extract(stego_waveform =self.current_waveform, bits_len=self.bits_len)
    bits_len = len(extracted_bits)

    min_len_bits = min(bits_len, self.bits_len)
    ber = np.mean([self.msg_bits[i] != extracted_bits[i] for i in range(min_len_bits)]) if min_len_bits > 0 else 1.0 # BER is 1 if no bits to compare
    extracted_str = bits_to_string(extracted_bits)
    print(f"Original message: {self.msg}")
    print(f"Extracted message: {extracted_str}")
    str_len = len(extracted_str)
    bits_len = len(self.msg)
    min_len_str = min(str_len, bits_len)
    str_ber = np.mean([self.msg[i] != extracted_str[i] for i in range(min_len_str)]) if min_len_str > 0 else 1.0 # BER is 1 if no bits to compare

    return 1.0 - str_ber, ber

  def _calculate_reward(self, snr,ber, detection_prob, extraction_accuracy):
    """Calculate reward balancing SNR, detectability, and extraction accuracy"""
    # Target: SNR > 50 dB, detection_prob < 0.1, extraction_accuracy close to 1.0
    snr_reward = min(snr / 50, 1.0)
    # detect_reward = 1.0 - min(detection_prob, 1.0)

    # Weighted combination
    return (
        cfg.SNR_WEIGHTS * snr_reward - ber +
        # cfg.DETECTION_WEIGHT * detect_reward +
        cfg.EXTRACTION_ACCURACY_WEIGHT * extraction_accuracy
    )

# -- Prepare sample audios --

In [16]:
import os
import IPython
import matplotlib.pyplot as plt

sample_dir = "_assets"
os.makedirs(sample_dir, exist_ok=True)

def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
  waveform = waveform.numpy()
  figure, ax = plt.subplots()
  ax.specgram(waveform[0], Fs=sample_rate)
  figure.suptitle(title)
  figure.tight_layout()

dataset = torchaudio.datasets.YESNO(sample_dir, download=True)
sample_path = "_assets/waves_yesno/0_0_0_0_1_1_1_1.wav"

100%|██████████| 4.49M/4.49M [00:00<00:00, 6.33MB/s]
  tar.extract(file_, to_path)


In [33]:
# i = 8
# waveform, sr, label = dataset[i]
# processor = AudioPreprocessor()
# waveform = AudioPreprocessor.resample_audio(waveform[0].numpy(), sr)
# features = extract_audio_features(waveform, 500)
# print(waveform.shape)
# mod = waveform.copy()[:-98]
# mod.shape

(137592,)


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


(137494,)

In [36]:

# env = AudioStegoEnv(dataset, method="spread-spectrum")
# # env = AudioStegoEnv(dataset, method="sign-encoding")
# obs, info = env.reset()
# obs, reward, done, truncated, info = env.step([0.5, 0.2, 0])
# info

Original message: 2 )m+igg<6>1&J<:{3LME?S$X\&5]lkPl{M@YuCW}j;\mogr7<Fo!`TdkS#/Xk,_/#rq^f`wK*!=-"-'y LD^}Ocm1Z58
Extracted message: 2 )m+igg<6>1&J<:{3LME?S$X\&5]lkPl{M@YuCW}j;\mogr7<Fo!`TdkS#/Xk
]o{1Ã1Cx?Aí=*#y LD^}Ocm1Z58


{'snr': np.float32(24.790276),
 'ber': np.float64(0.06048387096774194),
 'extraction_accuracy': np.float64(0.8064516129032258),
 'reward': np.float64(0.6217093035098045),
 'action': [10000, 42, 20]}

# -- MAIN FRAMEWORK --

In [None]:
class RLAudioSteganography:
  """Main framework class"""
  def __init__(self, cfg: Config, method="sign-encoding") -> None:
    self.cfg = cfg
    self.method = method
    self.embedder: EmbeddingModule = SignEncoding() if method == "sign-encoding" else SpreadSpectrum()


  def Initialize_components(self, audio_path, method="sign-encoding"):
    """Initialize components"""
    # self.steganalysis_net = SteganalysisCNN()
    # self.steganalysis_net = SteganalysisCNN(input_channels=len(self.preprocessor.audio))
    # self.steganalysis_net.to(device)

  def train_ppo(self, total_timesteps=10000)-> PPO:
  # def train_ppo(self, total_timesteps=25000)-> PPO:
    """Train PPO agent for audio steganography"""
    env = make_vec_env(lambda: AudioStegoEnv(dataset, method=self.method), n_envs=4)

    policy_kwargs = dict(
        features_extractor_class=AudioFeatureExtractor,
        features_extractor_kwargs=dict(),
        net_arch=dict(pi=[256, 256], vf=[256, 256])
    )
    model = PPO(
        "MultiInputPolicy",
        env,
        policy_kwargs=policy_kwargs,
        verbose=1,
        device=device,
        learning_rate=cfg.LEARNING_RATE,
      )
    # Instantiate the custom callback
    custom_callback = CustomLoggingCallback()
    model.learn(cfg.EPISODES, custom_callback)

    return model

  def embed_message(self, audio_path, message, output_path, model):
    """Embed a message into an audio file using trained policy"""
    # Re-initialize preprocessor and compute magnitudes/phases/mask for the specific audio being embedded into
    waveform = AudioPreprocessor.load_audio(audio_path)

    msg_bits = string_to_bits(message)
    env = AudioStegoEnv([],self.method) # Use the correct audio_path and message
    obs, info = env.reset_with_audio(audio_path, message)
    action, _ = model.predict(obs, deterministic=True)
    action = self.embedder.set_parameters(action[0])
    # Convert message string to bits for embedding
    message_bits = string_to_bits(message)
    stego_waveform = self.embedder.embed( msg_bits = msg_bits, action = action, waveform=waveform)
    AudioPreprocessor.save_audio(stego_waveform,cfg.SAMPLE_RATE, output_path)
    print(f"Stego audio saved to {output_path}")


  def extract_message(self, stego_audio_path, msg_length):
    """Extract a message from a stego audio file"""
    # Load the stego audio
    waveform = AudioPreprocessor.load_audio(stego_audio_path)

    extracted_bits = self.embedder.extract(stego_waveform=waveform, message_length= msg_length)
    return bits_to_string(extracted_bits)

# --- MAIN EXECUTION SCRIPT ---

In [None]:
# Initialize the framework
cfg = Config()
framework = RLAudioSteganography(cfg, method="spread-spectrum")

# --- Training Phase ---

# Initialize components with the training audio
# framework.Initialize_components(sample_path, method="spread-spectrum")

# Train PPO agent
print("Training PPO agent...")
model = framework.train_ppo()

# Save the trained model
model_save_path = "_assets/ppo_audio_stego_model"
model.save(model_save_path)
print(f"Trained model saved to {model_save_path}")

Training PPO agent...
Using cuda device
-----------------------------------------------
| rollout/                  |                 |
|    ep_action              | [12663, 10, 20] |
|    ep_ber                 | 0.00714         |
|    ep_extraction_accuracy | 0.986           |
|    ep_len_mean            | 1               |
|    ep_rew_mean            | 0.311           |
|    ep_reward              | 0.857           |
|    ep_snr                 | 34.06507        |
| time/                     |                 |
|    fps                    | 4               |
|    iterations             | 1               |
|    time_elapsed           | 1976            |
|    total_timesteps        | 8192            |
-----------------------------------------------
-----------------------------------------------
| rollout/                  |                 |
|    ep_action              | [5000, 10, 140] |
|    ep_ber                 | 0.498           |
|    ep_extraction_accuracy | 0.0127          |


In [None]:
# auto download model to local storage
from google.colab import files
files.download(f"{model_save_path}.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
model_name = f"kratosgado/Audio_Steganography_PPO_model/pyTorch/no_steganalysis"
model_files = "_assets"
kagglehub.model_upload(model_name, model_files, 'Apache 2.0')

Uploading Model https://www.kaggle.com/models/kratosgado/Audio_Steganography_PPO_model/pyTorch/no_steganalysis ...
More than 50 files detected, creating a zip archive...
Starting upload for file /tmp/tmpwlk6f1ku/archive.zip


Uploading: 100%|██████████| 14.8M/14.8M [00:00<00:00, 17.5MB/s]

Upload successful: /tmp/tmpwlk6f1ku/archive.zip (14MB)





Your model instance version has been created.
Files are being processed...
See at: https://www.kaggle.com/models/kratosgado/Audio_Steganography_PPO_model/pyTorch/no_steganalysis


In [None]:
cfg = Config()
framework_test = RLAudioSteganography(cfg, method="spread-spectrum")

message = """
    Embed a message into an audio file.

    Parameters:
      audio_path (str): Path to the audio file.
      message (str): The message to be embedded.
      output_path (str): Path to save the embedded audio file.
    """
output_path = "test.wav"

# framework_test.Initialize_components(audio_path=sample_path, method="spread-spectrum")
# Load the saved model
model_test = PPO.load(model_save_path)
print(f"Loaded model from {model_save_path}")
# Embed message using the loaded model
framework_test.embed_message(sample_path, message, output_path, model_test)

# Extract message from the new stego audio
extracted_message = framework_test.extract_message(output_path, len(message))
print(f"\nOriginal Message: {message}")
print(f"Extracted Message: {extracted_message}")

Loaded model from _assets/ppo_audio_stego_model


  import aifc
  import audioop
  import sunau


Stego audio saved to test.wav

Original Message: 
    Embed a message into an audio file.

    Parameters:
      audio_path (str): Path to the audio file.
      message (str): The message to be embedded.
      output_path (str): Path to save the embedded audio file.
    
Extracted Message: º    Embed a message into an audio file.

    Parameters:
      audio_path (str): Path to the audio file.
      messafe (str): The message to be embedded.
      output_path (str): Path to save the embedded audio file.
    
