<a href="https://colab.research.google.com/github/Kratosgado/audio-steganography/blob/main/steg-ai/core_modules/Final_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install stable-baselines3 shimmy --quiet

In [2]:
import os
!git clone https://github.com/librosa/data.git ./audio_data
# get audio files from librose
audios_path = "./audio_data/audio"
simple_audio_files = [os.path.join(audios_path, f) for f in os.listdir(audios_path) if f.endswith(".ogg")]
complex_audio_files = [os.path.join(audios_path, f) for f in os.listdir(audios_path) if f.endswith(".ogg")]

Cloning into './audio_data'...
remote: Enumerating objects: 156, done.[K
remote: Counting objects: 100% (156/156), done.[K
remote: Compressing objects: 100% (117/117), done.[K
remote: Total 156 (delta 61), reused 125 (delta 34), pack-reused 0 (from 0)[K
Receiving objects: 100% (156/156), 14.78 MiB | 33.05 MiB/s, done.
Resolving deltas: 100% (61/61), done.


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import librosa
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal
from scipy.signal import get_window
from scipy.fft import dct, idct
import scipy.io.wavfile as wf
import scipy
import torch.nn.functional as F
from collections import deque
import gymnasium as gym
from gymnasium import spaces
import random
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback

In [4]:
# warnings.filterwarnings("ignore")
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
FRAME_SIZE = 1024 #2048
HOP_LENGTH = 256 #512
NON_CRITICAL_PERCENT = 0.1 # Modify bottom 10% of coeffs by magnitude
STATE_DIM = 1024
SAMPLE_RATE = 22050
N_MELS = 128

Using device: cuda


# --- UTILITY FUNCTIONS ---

In [5]:
NORM_NUM = 32768
def textToBits( message):
  return ''.join(format(ord(ch), '08b') for ch in message)


def bitToChar(bits):
  return chr(int(bits, 2))

def textFromBits(bits):
  text = ''
  for bit in range(0, len(bits), 8):
    text += bitToChar(bits[bit:bit + 8])
  return text

# Function to generate a random message (for the RL environment)
def generate_random_message(length=50):
    """Generates a random string message."""
    import string
    letters = string.ascii_letters + string.digits + string.punctuation + " "
    return ''.join(random.choice(letters) for i in range(length))
def extract_state(audio_segment, message_context, sr: int | float=44100):
  # Audio features
  mfcc = librosa.feature.mfcc(y=audio_segment, sr=sr)
  spectral_centroid = librosa.feature.spectral_centroid(y=audio_segment, sr=sr)

  # Temporal features
  rms = librosa.feature.rms(y=audio_segment)
  zcr = librosa.feature.zero_crossing_rate(y=audio_segment)

  # Normalize and flatten features
  return np.concatenate([mfcc.mean(axis=1),spectral_centroid.flatten(), rms.flatten(),zcr.flatten(),message_context ])

class CustomLoggingCallback(BaseCallback):
  """ A custom callback that logs additional information from the environment. """
  def __init__(self, verbose=0):
    super(CustomLoggingCallback, self).__init__(verbose)

  def _on_step(self) -> bool:
    """This method is called after each step in the environment.   """
    # Accessing environment infos from the VecEnv wrapper
    if self.locals.get('infos'):
      for info in self.locals['infos']:
        # Ensure info is a dictionary and contains the required keys
        if isinstance(info, dict) and 'snr' in info and 'detection_prob' in info:
          self.logger.record('rollout/ep_snr', info['snr'])
          self.logger.record('rollout/ep_reward', info['reward'])
          self.logger.record('rollout/ep_snr', info['psnr'])
          self.logger.record('rollout/ep_ber', info['ber'])
          self.logger.record('rollout/ep_extraction_accuracy', info['extraction_accuracy'])
          self.logger.record('rollout/ep_detection_prob', info['detection_prob'])
          self.logger.record('rollout/ep_action', info['action'])
    return True

# --- HYPERPARAMETERS ---

In [6]:
class Config:
  """Class to hold all hyperparameters."""
  #Audio processing
  FRAME_SIZE = 1024 #2048
  HOP_LENGTH = 256 #512
  NON_CRITICAL_PERCENT = 0.1 # Modify bottom 10% of coeffs by magnitude
  STATE_DIM = 1024
  SAMPLE_RATE = 22050
  N_MELS = 128

  # RL Training
  EPISODES = 1000
  LEARNING_RATE_ACTOR = 3e-4
  LEARNING_RATE_CRITIC = 3e-4
  GAMMA = 0.99 # Discount factor
  GAE_LAMBDA = 0.95 # Lambda for Generalized Advantage Estimation
  PPO_EPSILON = 0.2 # Epsilon for clipping in PPO
  PPO_EPOCHS = 10 # Number of epochs for PPO update
  BATCH_SIZE = 32

  # Environment Network Pre-training
  PRETRAIN_EPOCHS = 10
  PRE_TRAIN_LR = 1e-3
  PRE_TRAIN_BATCH_SIZE = 32

  # Reward weights
  SNR_WEIGHTS = 0.02 # Weight for SNR in reward. Tuned to be on a similar scale to detection prob.
  DETECTION_WEIGHT = 1.0 # Weight for detection probability in reward.
  IMPERCEPTIBILITY_WEIGHT = 0.3
  UNDETECTABILITY_WEIGHT = 0.4
  EXTRACTION_ACCURACY_WEIGHT = 0.3
cfg = Config()

# --- MODULE 1: AUDIO PREPROCESSOR ---

In [7]:
class AudioPreprocessor:
  """Handles audio loading, MDCT, and inverse MDCT."""
  def __init__(self, audio_path= None, audio_data=None, frame_size=cfg.FRAME_SIZE, hop_length=cfg.HOP_LENGTH, sr=cfg.SAMPLE_RATE):
    if audio_path:
      self.audio, self.sr = librosa.load(audio_path, sr=sr, mono=True)
    elif audio_data is not None:
      self.audio = audio_data
      self.sr = sr
    else:
      raise ValueError("Either audio_path or audio_data must be provided")

    # Normalize audio
    self.audio = self.audio / np.max(np.abs(self.audio))
    self.frame_size = frame_size
    self.hop_length = hop_length
    self.window = get_window('hann', self.frame_size)
    # return self.audio, self.sr

  def load_audio(self, path):
    """Load WAV audio file"""
    audio, _ = librosa.load(path, sr=self.sr)
    return audio

  def stft(self, audio):
    """Compute Short-Time Fourier Transform (STFT)"""
    return librosa.stft(audio, n_fft=self.frame_size, hop_length=self.hop_length)

  @staticmethod
  def istft(stft_matrix):
    """Compute Inverse Short-Time Fourier Transform (ISTFT)"""
    return librosa.istft(stft_matrix, hop_length=cfg.HOP_LENGTH)

  def compute_mdct(self):
    """Compute Modified Discrete Cosine Transform (MDCT) using STFT and DCT from librosa"""
    stft = self.stft(self.audio)
    magnitudes = np.abs(stft)
    phases = np.angle(stft)
    return magnitudes, phases

  def get_non_critical_coeffs(self, magnitudes, percentile=10):
    """Identify non-critical coefficients (lowest magnitude)"""
    threshold = np.percentile(np.abs(magnitudes.flatten()), percentile)
    mask = np.abs(magnitudes) < threshold
    return mask

  @staticmethod
  def reconstruct_audio(magnitudes, phases):
    """Reconstruct audio from magnitude/stft matrix and phase/non_critical_coeffs"""
    stft = magnitudes * np.exp(1j * phases)
    reconstructed_audio = AudioPreprocessor.istft(stft)
    return reconstructed_audio

  def plot_spectrogram(self, title):
    """Visualize audio spectrogram"""
    plt.figure(figsize=(10, 4))
    S = librosa.amplitude_to_db(np.abs(librosa.stft(self.audio)), ref=np.max)
    librosa.display.specshow(S, sr=self.sr, x_axis='time', y_axis='log')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    plt.show()

  def save_audio(self, audio: np.ndarray, sr: int, path: str):
    wf.write(path,sr, (audio * NORM_NUM).astype(np.int16))

# --- MODULE 5: EMBEDDING/EXTRACTION MODULE ---
Main Sign encoding class

In [8]:
from abc import ABC, abstractmethod

class EmbeddingModule(ABC):
  def __init__(self):
    pass

  @abstractmethod
  def embed(self, *args, **kwargs) -> np.ndarray:
    pass

  @abstractmethod
  def extract(self, *args, **kwargs) -> list[int]:
    pass

  @abstractmethod
  def set_parameters(self, action):
    pass

class SignEncoding(EmbeddingModule):
  """Embeds and extracts messages from MDCT coefficients."""
  def __init__(self, steganalysis_net=None):
    """Initialize the steganography system"""
    self.steganalysis_net = steganalysis_net
    if self.steganalysis_net:
        self.steganalysis_net.to(device)
        self.steganalysis_net.eval()

  def set_parameters(self, action):
    self.alpha = action[0]

  def embed(self, magnitudes,phases, mask, msg_bits: np.ndarray, **kwargs):
    """Embed message using sign encoding in non-critical coefficients"""
    # Apply mask to get non-critical coefficients
    coeffs = magnitudes.copy()
    non_critical = coeffs[mask]

    # Ensure we have enough coefficients for the message
    if len(non_critical) < len(msg_bits):
        raise ValueError("Message too long for available non-critical coefficients")

    # Embed message using sign encoding
    for i, bit in enumerate(msg_bits):
        sign = 1 if bit == 1 else -1
        non_critical[i] = sign * np.abs(non_critical[i]) * (1 + self.alpha)

    # Update coefficients
    coeffs[mask] = non_critical
    self.magnitudes = coeffs

    return AudioPreprocessor.reconstruct_audio(coeffs, phases)

  def extract(self, magnitudes, mask, message_length, **kwargs):
    """Extract message from non-critical coefficients"""
    # Apply mask to get non-critical coefficients
    non_critical = self.magnitudes[mask]

    # Extract message from sign
    msg_bits = []
    for i in range(message_length * 8):
        sign = 1 if non_critical[i] >= 0 else -1
        bit = 1 if sign > 0 else 0
        msg_bits.append(bit)

    # Convert binary to string
    return msg_bits
    # chars = [chr(int(bin_str[i:i+8], 2)) for i in range(0, len(bin_str), 8)]

  def detect(self, audio):
    """Compute detection probability using steganalysis network"""
    if not self.steganalysis_net:
      raise ValueError("Steganalysis network not initialized")

    # Convert to tensor
    audio_tensor = torch.as_tensor(audio, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():
      prob = self.steganalysis_net(audio_tensor).item()
    return prob

### Spread Spectrum

In [24]:
class SpreadSpectrum(EmbeddingModule):
  def __init__(self):
    """
    Initialize the Spread Spectrum steganography system

    Parameters:
      carrier_freq (int): Carrier frequency in Hz for embedding
      chip_rate (int): How many samples per bit (spreading factor)
      snr (int): Desired Signal to noise ratio in dB for embedding
    """
    # Parameters for Gold sequence generation (example values)
    self.taps1 = [5, 2]
    self.taps2 = [5, 4, 2, 1] # Must be a preferred pair with taps1
    self.seed1 = 0b11111
    self.seed2 = 0b10101

  def set_parameters(self, action):
    """
    Parameters(action):
      carrier_freq (int): Carrier frequency in Hz for embedding
      chip_rate (int): How many samples per bit (spreading factor)
      snr (int): Desired Signal to noise ratio in dB for embedding
    """
    self.carrier_freq, self.chip_rate, self.snr_db = action;
    self.chip_rate = int(self.chip_rate)

  def _generate_m_sequence(self, taps, length, initial_state):
      """Generate an m-sequence."""
      lfsr = initial_state
      seq = np.zeros(length)
      # mask = sum(1 << (t - 1) for t in taps) # Mask is not used in this implementation

      for i in range(length):
          seq[i] = 1 if (lfsr & 1) else -1  # Convert to bipolar (-1, 1)
          feedback = 0
          for tap in taps:
              feedback ^= (lfsr >> (tap - 1)) & 1
          # Adjust the shift based on the number of bits in the initial state
          lfsr = (lfsr >> 1) | (feedback << (len(bin(initial_state)) - 3)) # Assuming initial_state is not 0

      return seq

  def _generate_spreading_code(self, length):
      """Generate a Gold sequence."""
      # The lengths of the m-sequences must be the same
      # and derived from the same primitive polynomial.
      # The taps define the primitive polynomials.
      # The seeds are the initial states of the LFSRs.

      m_seq1 = self._generate_m_sequence(self.taps1, length, self.seed1)
      m_seq2 = self._generate_m_sequence(self.taps2, length, self.seed2)

      # XOR the two m-sequences
      gold_seq = m_seq1 * m_seq2 # For bipolar sequences, XOR is multiplication

      return gold_seq


  def _calculate_required_bandwidth(self, message_length):
    """Calculate the required bandwidth based on message length."""
    # each bit is spread over chip_rate samples
    return message_length * self.chip_rate

  def embed(self, original_audio, msg_bits: np.ndarray, **kwargs):
    """
    Embed a message into an audio file.

    Parameters:
      original_audio (str): Path to the audio file.
      message (str): The message to be embedded.
      output_path (str): Path to save the embedded audio file.
    """
    msg_bits = msg_bits * 2 - 1 # convert to bipoloar (-1, 1)

    # generate spreading codes
    code_length  = len(msg_bits) * self.chip_rate
    spreading_code = self._generate_spreading_code(code_length)

    # create the spread message signal
    spread_message = np.repeat(msg_bits, self.chip_rate) * spreading_code

    # create carrier signal (sine wave at carrier frequency)
    t = np.arange(len(spread_message)) / cfg.SAMPLE_RATE
    carrier = np.sin(2 * np.pi * self.carrier_freq * t)

    # modulate the message onto the carrier
    modulated = spread_message * carrier

    # adjust the signal power based on desired snr
    signal_power = np.var(original_audio)
    message_power = np.var(modulated)
    desired_message_power = signal_power/ (10 ** (self.snr_db / 10))
    scaling_factor = np.sqrt(desired_message_power / message_power)
    modulated = modulated * scaling_factor

    # pad or truncate the modulated signal to match audio length
    if len(modulated) < len(original_audio):
      modulated = np.pad(modulated, (0, len(original_audio) - len(modulated)), 'constant')
    else:
      modulated = modulated[:len(original_audio)]

    # embed the message into the audio
    stego_audio = original_audio + modulated
    return stego_audio

  def extract(self, stego_audio, message_length, original_audio_path=None, **kwargs):
    """
        Extract a hidden message from a stego audio file.

        Parameters:
            stego_audio (nd.ndarray): Path to the stego audio file
            message_length (int): Length of the hidden message in bits
            original_audio_path (str): Optional path to original audio for comparison
        """
    # if original audio is provided, subtract it to get just the message
    if original_audio_path:
      y_original, _ = librosa.load(original_audio_path, sr=cfg.SAMPLE_RATE)
      y_diff = stego_audio - y_original
    else:
      y_diff = stego_audio

    # generate the same spreading code used in embedding
    code_length = message_length * 8 * self.chip_rate # 8 bits per character
    # spreading_code = self._generate_spreading_code(code_length, seed=42) # Old simplified code
    spreading_code = self._generate_spreading_code(code_length)


    # create carrier signal
    t = np.arange(len(spreading_code)) / cfg.SAMPLE_RATE
    carrier = np.sin(2 * np.pi * self.carrier_freq * t)

    # pad or truncate the carrier to match the difference signal
    if len(carrier) < len(y_diff):
      carrier = np.pad(carrier, (0, len(y_diff) - len(carrier)), 'constant')
    else:
      carrier = carrier[:len(y_diff)]

    # demodulate the signal
    demodulated = y_diff * carrier

    # correlate with spreading code to extract bits
    extracted_bits: list[int] = []
    for i in range(message_length * 8):
      start = i * self.chip_rate
      end = start + self.chip_rate
      if end > len(demodulated):
        break
      segment = demodulated[start:end]
      code_segment = spreading_code[start:end]

      # calculate correlation
      correlation = np.sum(segment * code_segment)
      extracted_bits.append(1 if correlation > 0 else 0)

    # convert bits to string
    return extracted_bits

# --- MODULE 2: RL ACTOR-CRITIC NETWORKS ---

In [10]:
class PolicyNetwork(nn.Module):
  """Actor Network: Decides on the modification scale."""
  def __init__(self, input_dim, hidden_dim=256):
    super(PolicyNetwork, self).__init__()
    self.fc = nn.Sequential(
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU()
    )
    self.mean_layer = nn.Linear(hidden_dim, 1)
    self.log_std_layer = nn.Linear(hidden_dim, 1)

  def forward(self, x):
    features = self.fc(x)
    mean = self.mean_layer(features)
    log_std = self.log_std_layer(features)
    log_std = torch.clamp(log_std, min=-20, max=2)  # Constrain for stability
    return mean, log_std

  def sample_action(self, state):
    """Sample action from policy distribution"""
    state_tensor = torch.as_tensor(state, dtype=torch.float32).to(device)
    mean, log_std = self.forward(state_tensor)
    std = torch.exp(log_std)
    normal_dist = torch.distributions.Normal(mean, std)
    action = normal_dist.sample()
    return action.cpu().detach().numpy().flatten()

# 7. Feature Extractor for PPO
class CustomFeatureExtractor(BaseFeaturesExtractor):
  def __init__(self, observation_space: spaces.Box, features_dim = 128):
    super().__init__(observation_space, features_dim)
    self.net = nn.Sequential(
        nn.Linear(observation_space.shape[0], 256),
        nn.ReLU(),
        nn.Linear(256, features_dim),
        nn.ReLU()
    )

  def forward(self, observations: torch.Tensor) -> torch.Tensor:
    return self.net(observations)

# --- MODULE 3: ENVIRONMENT NETWORK (STEGANALYZER) ---

In [18]:
class SteganalysisCNN(nn.Module):
  """A 1D CNN that acts as a steganalysis tool."""
  def __init__(self, input_channels=1):
    super(SteganalysisCNN, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv1d(input_channels, 32, kernel_size=5, stride=2, padding=2),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size=2),
        nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size=2),
        nn.Flatten()
    )

    # Calculate output dimension after convolutions
    test_input = torch.randn(1, input_channels, 500)
    conv_output_dim = self.conv(test_input).shape[-1]

    self.fc = nn.Sequential(
        nn.Linear(conv_output_dim, 64),
        nn.ReLU(),
        nn.Linear(64, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    x = self.conv(x)
    return self.fc(x)

 #5. Custom Gym Environment for RL Training
class AudioStegoEnv(gym.Env):
  def __init__(self, audio_path, msg_bits: np.ndarray, method="sign-encoding"):
    super(AudioStegoEnv, self).__init__()

    # Initialize audio and message
    self.preprocessor = AudioPreprocessor(audio_path=audio_path)
    self.original_audio = self.preprocessor.audio.copy()
    self.msg_bits = msg_bits

    self.action_space = spaces.Box(low=0.0, high=0.1, shape=(1,), dtype=np.float32)

    # Define the action space as a Tuple of Box and Discrete spaces
    if method == "spread-spectrum":
      self.action_space = spaces.Box(low=np.array([5000, 50, 10], dtype=np.float32), high=np.array([15000, 200, 30], dtype=np.float32), dtype=np.float32) # continuous actions

    self.observation_space = spaces.Box(low=-1.0, high=1.0,
                                        shape=(N_MELS,), dtype=np.float32)
    self.embedder: EmbeddingModule = SignEncoding() if method == "sign-encoding" else SpreadSpectrum()

    # Initialize state
    self.reset()

  def reset(self, seed=None, options=None):
    """Reset environment to initial state"""
    super().reset(seed=seed) # Call the parent class reset with seed
    self.message_length = len(self.msg_bits) // 8  # Store message length in characters

    self.current_audio = self.original_audio.copy()
    self.magnitudes, self.phases = self.preprocessor.compute_mdct() # move
    # Compute initial features
    self.mask = self.preprocessor.get_non_critical_coeffs(self.magnitudes)
    self.current_step = 0
    return self._get_obs(), {} # Return observation and info dictionary

  def _get_obs(self):
    """Extract MFCC features as observation"""
    mfcc = librosa.feature.mfcc(y=self.current_audio, sr=SAMPLE_RATE,
                                n_mfcc=N_MELS, n_fft=FRAME_SIZE,
                                hop_length=HOP_LENGTH)
    return np.mean(mfcc, axis=1)

  def step(self, action):
    """Execute one embedding step"""
    self.embedder.set_parameters(action)

    # Embed message
    self.current_audio = self.embedder.embed(magnitudes = self.magnitudes,phases=self.phases, mask = self.mask, msg_bits = self.msg_bits, original_audio=self.original_audio)

    # Compute rewards
    snr = self._calculate_snr()
    psnr = self._calculate_psnr()
    detection_prob = self._simulate_detection()
    accuracy, ber = self._calculate_accuracy_ber()
    reward = self._calculate_reward(snr,psnr, detection_prob, accuracy)

    # Update state
    self.current_step += 1
    done = self.current_step >= 10  # Train for 10 steps
    info = {
        "snr": snr,
        "psnr": psnr,
        "ber": ber,
        "detection_prob": detection_prob,
        "extraction_accuracy": accuracy,
        "reward": reward,
        "action": action
    }
    return self._get_obs(), reward, done,False, info

  def _calculate_snr(self):
    """Calculate Signal-to-Noise Ratio"""
    # Ensure both audio arrays have the same length before calculating noise
    min_len = min(len(self.current_audio), len(self.original_audio))
    current_audio_trimmed = self.current_audio[:min_len]
    original_audio_trimmed = self.original_audio[:min_len]

    noise = current_audio_trimmed - original_audio_trimmed
    signal_power = np.mean(original_audio_trimmed ** 2)
    noise_power = np.mean(noise ** 2)

    if noise_power == 0:
        return 100  # High SNR if no noise

    return 10 * np.log10(signal_power / noise_power)

  def _calculate_psnr(self):
    if len(self.current_audio) < len(self.original_audio):
        stego_audio_padded = np.pad(self.current_audio, (0, len(self.original_audio) - len(self.current_audio)), 'constant')
    else:
        stego_audio_padded = self.current_audio[:len(self.original_audio)]

    psnr = 10 * np.log10(np.max(self.original_audio ** 2) / np.mean((self.original_audio - stego_audio_padded) ** 2)) if np.mean((self.original_audio - stego_audio_padded) ** 2) > 0 else float('inf')
    return psnr


  def _simulate_detection(self):
    """Simulate steganalysis detection"""
    # not real
    snr = self._calculate_snr()
    return 1 / (1 + np.exp(0.5 * (snr - 50)))  # Logistic function

  def _calculate_accuracy_ber(self):
    """Calculate the accuracy of extracted bits"""
    # Use the same mask as embedding for extraction simulation
    extracted_bits = self.embedder.extract(stego_audio =self.current_audio,magnitudes=self.magnitudes, mask=self.mask,message_length= self.message_length)

    min_len_bits = min(len(extracted_bits), len(self.msg_bits))
    ber = np.mean([self.msg_bits[i] != extracted_bits[i] for i in range(min_len_bits)]) if min_len_bits > 0 else 1.0 # BER is 1 if no bits to compare
    return 1.0 - ber, ber

  def _calculate_reward(self, snr,psnr, detection_prob, extraction_accuracy):
    """Calculate reward balancing SNR, detectability, and extraction accuracy"""
    # Target: SNR > 50 dB, detection_prob < 0.1, extraction_accuracy close to 1.0
    snr_reward = min(snr / 50, 1.0)
    detect_reward = 1.0 - min(detection_prob, 1.0)

    # Weighted combination
    return (
        cfg.SNR_WEIGHTS * snr_reward +
        cfg.DETECTION_WEIGHT * detect_reward +
        cfg.EXTRACTION_ACCURACY_WEIGHT * extraction_accuracy
    )

# -- MAIN FRAMEWORK --

In [19]:
class RLAudioSteganography:
  """Main framework class"""
  def __init__(self, cfg: Config) -> None:
    self.cfg = cfg
    self.embedded_mask = None # Store the mask used during embedding
    self.original_magnitudes = None # Store original magnitudes

  def Initialize_components(self, audio_path, method="sign-encoding"):
    """Initialize components"""
    self.preprocessor = AudioPreprocessor(audio_path=audio_path)
    self.original_magnitudes, self.phases = self.preprocessor.compute_mdct()
    self.mask = self.preprocessor.get_non_critical_coeffs(self.original_magnitudes)
    self.steganalysis_net = SteganalysisCNN()
    # self.steganalysis_net = SteganalysisCNN(input_channels=len(self.preprocessor.audio))
    self.steganalysis_net.to(device)
    self.method = method
    self.embedder: EmbeddingModule = SignEncoding() if method == "sign-encoding" else SpreadSpectrum()

  def string_to_bits(self, message):
    """Convert a string to a sequence of bits."""
    return np.array([int(bit) for bit in ''.join(format(ord(char), '08b') for char in message)], dtype=np.float32)

  def bits_to_string(self, bits):
    """Convert a sequence of bits to a string."""
    text = ''
    for bit in range(0, len(bits), 8):
      if bit + 8 <= len(bits):
        byte = ''.join(str(bit) for bit in bits[bit:bit + 8])
        text += chr(int(byte, 2))
    return text

  def train_ppo(self, audio_path, message, total_timesteps=10000)-> PPO:
    """Train PPO agent for audio steganography"""
    env = make_vec_env(lambda: AudioStegoEnv(audio_path, self.string_to_bits(message), method=self.method), n_envs=4)

    policy_kwargs = dict(
        features_extractor_class=CustomFeatureExtractor,
        features_extractor_kwargs=dict(features_dim=128),
    )

    model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1,
                device=device, learning_rate=cfg.LEARNING_RATE_ACTOR, n_steps=cfg.EPISODES, batch_size=cfg.BATCH_SIZE)
    # Instantiate the custom callback
    custom_callback = CustomLoggingCallback()
    model.learn(total_timesteps, custom_callback)

    return model

  def embed_message(self, audio_path, message, output_path, model):
    """Embed a message into an audio file using trained policy"""
    # Re-initialize preprocessor and compute magnitudes/phases/mask for the specific audio being embedded into
    self.preprocessor = AudioPreprocessor(audio_path=audio_path)
    self.original_magnitudes, self.phases = self.preprocessor.compute_mdct()
    self.mask = self.preprocessor.get_non_critical_coeffs(self.original_magnitudes)
    self.embedded_mask = self.mask # Store the mask used for embedding

    # Steganalysis network is initialized in Initialize_components, ensure it's done before embedding
    # If embed_message is called standalone, ensure Initialize_components is called first with the correct audio_path

    msg_bits = self.string_to_bits(message)
    env = AudioStegoEnv(audio_path, msg_bits) # Use the correct audio_path and message
    obs, info = env.reset()
    action, _ = model.predict(obs, deterministic=True)
    self.embedder.set_parameters(action)
    # Convert message string to bits for embedding
    message_bits = self.string_to_bits(message)
    # stego_audio = self.embedder.embed(self.original_magnitudes, self.embedded_mask, message_bits)
    stego_audio = self.embedder.embed(
        magnitudes = self.original_magnitudes,
        phases=self.phases,
        mask = self.mask,
        msg_bits = msg_bits,
        original_audio=self.preprocessor.audio.copy())
    self.preprocessor.save_audio(stego_audio,SAMPLE_RATE, output_path)
    print(f"Stego audio saved to {output_path}")


  def extract_message(self, stego_audio_path, msg_length):
    """Extract a message from a stego audio file"""
    # Load the stego audio
    preprocessor_stego = AudioPreprocessor(audio_path=stego_audio_path)
    magnitudes_stego, phases_stego = preprocessor_stego.compute_mdct()

    # Use the stored mask from embedding for extraction
    if self.embedded_mask is None:
        raise ValueError("Embedding must be performed before extraction to get the mask.")

    # extracted_message = self.embedder.extract(magnitudes_stego, self.embedded_mask, msg_length)
    extracted_bits = self.embedder.extract(
        stego_audio =preprocessor_stego.audio.copy(),
        magnitudes=magnitudes_stego,
        mask=self.mask,
        message_length= msg_length)
    return self.bits_to_string(extracted_bits)

  def plot_training_history(self):
    """Plot training metrics"""
    # plt.figure(figsize=(12, 8))
    # plt.subplot(2, 2, 1)
    # plt.plot(self.training_history['rewards'], label='Reward')
    # plt.title('Training Rewards')
    # plt.xlabel('Episode')
    # plt.ylabel('Average Reward')
    # plt.legend()
    # plt.tight_layout()
    # plt.show()

# --- MAIN EXECUTION SCRIPT ---

In [25]:
if __name__ == "__main__":
    # Initialize the framework
    cfg = Config()
    framework = RLAudioSteganography(cfg)

    # --- Training Phase ---
    # Load sample audio for training
    training_audio_path = librosa.ex('trumpet')
    message_to_embed = "There are things that we do not wish to know. and this is a secret I want to send to you okay"

    # Initialize components with the training audio
    framework.Initialize_components(training_audio_path, method="spread-spectrum")

    # Train PPO agent
    print("Training PPO agent...")
    model = framework.train_ppo(training_audio_path, message_to_embed, total_timesteps=10000)

    # Save the trained model
    model_save_path = "ppo_audio_stego_model"
    model.save(model_save_path)
    print(f"Trained model saved to {model_save_path}")

    # --- Embedding with the Trained Model on a New Audio ---
    print("\nEmbedding message in a new audio file using the trained model...")

    # Load the saved model
    loaded_model = PPO.load(model_save_path)
    print(f"Loaded model from {model_save_path}")

    # Define a new audio file and output path
    new_audio_path = librosa.ex('trumpet', hq=True) # Using a different trumpet example
    new_output_path = "stego_new_audio.wav"

    # Initialize components with the *new* audio
    framework.Initialize_components(new_audio_path, method="spread-spectrum")

    # Embed message using the loaded model
    framework.embed_message(new_audio_path, message_to_embed, new_output_path, loaded_model)

    # Extract message from the new stego audio
    extracted_message = framework.extract_message(new_output_path, len(message_to_embed))
    print(f"\nOriginal Message: {message_to_embed}")
    print(f"Extracted Message from new audio: {extracted_message}")

    # --- Evaluation (Optional) ---
    # You can add evaluation steps here similar to the commented out code
    # to check SNR and detection probability on the new stego audio.

Training PPO agent...
Using cuda device




---------------------------------------------------
| rollout/                  |                     |
|    ep_action              | [5000.   50.   10.] |
|    ep_ber                 | 0.047               |
|    ep_detection_prob      | 1                   |
|    ep_extraction_accuracy | 0.953               |
|    ep_len_mean            | 10                  |
|    ep_rew_mean            | 2.92                |
|    ep_reward              | 0.292               |
|    ep_snr                 | 34                  |
| time/                     |                     |
|    fps                    | 5                   |
|    iterations             | 1                   |
|    time_elapsed           | 715                 |
|    total_timesteps        | 4000                |
---------------------------------------------------
---------------------------------------------------
| rollout/                  |                     |
|    ep_action              | [5000.   50.   10.] |
|    ep_ber 

Downloading file 'sorohanro_-_solo-trumpet-06.hq.ogg' from 'https://librosa.org/data/audio/sorohanro_-_solo-trumpet-06.hq.ogg' to '/root/.cache/librosa'.


Trained model saved to ppo_audio_stego_model

Embedding message in a new audio file using the trained model...
Loaded model from ppo_audio_stego_model
Stego audio saved to stego_new_audio.wav

Original Message: There are things that we do not wish to know. and this is a secret I want to send to you okay
Extracted Message from new audio: Thõrc bs thbNsw pHat ÷e(do0fot wuch to know. and 4hhs(is a(secret I(want to s¥nd to you okay


In [26]:
cfg = Config()
framework_test = RLAudioSteganography(cfg)

audio = simple_audio_files[0]
message = """
    Embed a message into an audio file.

    Parameters:
      audio_path (str): Path to the audio file.
      message (str): The message to be embedded.
      output_path (str): Path to save the embedded audio file.
    """
output_path = "test.wav"

framework_test.Initialize_components(audio_path=simple_audio_files[0], method="spread-spectrum")
# Load the saved model
model_test = PPO.load(model_save_path)
print(f"Loaded model from {model_save_path}")
# Embed message using the loaded model
framework_test.embed_message(audio, message, output_path, model_test)

# Extract message from the new stego audio
extracted_message = framework_test.extract_message(output_path, len(message))
print(f"\nOriginal Message: {message}")
print(f"Extracted Message from new audio: {extracted_message}")



Loaded model from ppo_audio_stego_model
Stego audio saved to test.wav

Original Message: 
    Embed a message into an audio file.

    Parameters:
      audio_path (str): Path to the audio file.
      message (str): The message to be embedded.
      output_path (str): Path to save the embedded audio file.
    
Extracted Message from new audio: 
    Embed a message into an audiï file.* (  Parameteòs:
      audio[path (sôr): Path tk the audio file,
      meswagd (str): The message to bm embedded.
      o_tput_path (str): P`th to savebthe embedded audio file.
    
