<a href="https://colab.research.google.com/github/Kratosgado/audio-steganography/blob/main/steg-ai/core_modules/policy_environment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install Libraries

In [1]:
!pip install torch numpy librosa scipy

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

## Import Libraries

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import librosa
import soundfile as sf
from collections import deque
import random

## Define Reinforcement Learning Hyperparameters

In [3]:
# parameters
SAMPLE_RATE = 44100 # Audio sample rate
EMBEDDING_STEP = 100 # Embed one bit every 100 samples
# MESSAGE = "This is just a training message to everyone"
MESSAGE = "hello"
GAMMA = 0.99 # discount factor for RL
EPSILON = 1.0 # exploration rate for epsilon- greedy
EPSILON_MIN = 0.01
EPSILON_DECAY = 0.995
LEARNING_RATE = 0.001
BATCH_SIZE = 32
MEMORY_SIZE = 1000
EPISODES = 100

# Utility functions

In [4]:
NORM_16NUM = 32768

def bitToChar(bits):
  return chr(int(bits, 2))

def textToBits( message):
  return ''.join(format(ord(ch), '08b') for ch in message)

def bitsToText(bits):
  text = ''
  for bit in range(0, len(bits), 8):
    text += bitToChar(bits[bit:bit + 8])
  return text

## Policy Network (Decides how to modify audio samples)

In [5]:
# Policy Network (Decides how to modify audio samples)
class PolicyNetwork(nn.Module):
  def __init__(self, input_size, hidden_size, action_size):
    super(PolicyNetwork, self).__init__()
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc3 = nn.Linear(hidden_size, action_size)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.relu(self.fc1(x))
    x = self.relu(self.fc2(x))
    x = self.fc3(x)
    return x

## Environment Network (Simulates steganalysis feedback)

In [6]:
# Environment Network (Simulates steganalysis feedback)
class EnvironmentNetwork(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(EnvironmentNetwork, self).__init__()
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc3 = nn.Linear(hidden_size, output_size)
    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.relu(self.fc1(x))
    x = self.relu(self.fc2(x))
    x = self.sigmoid(self.fc3(x))  # OUTPUT: detection probability
    return x

## RL Agent to make use of the Policy Network

In [7]:
# RL Agent
class RLAgent:
  def __init__(self, state_size, action_size) -> None:
    self.state_size = state_size
    self.action_size = action_size
    self.memory = deque(maxlen = MEMORY_SIZE)
    self.policy_net = PolicyNetwork(state_size, 128, action_size).float()
    self.target_net = PolicyNetwork(state_size, 128, action_size).float()
    self.target_net.load_state_dict(self.policy_net.state_dict())
    self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
    self.epsilon = EPSILON

  def select_action(self, state):
    if random.random() < self.epsilon:
      return random.randrange(self.action_size)
    state = torch.FloatTensor(state).unsqueeze(0)
    with torch.no_grad():
      q_values = self.policy_net(state)
    return q_values.argmax().item()

  def store_transition(self, state, action, reward, next_state, done):
    self.memory.append((state, action, reward, next_state, done))

  def train(self):
    if len(self.memory) < BATCH_SIZE:
      return
    batch = random.sample(self.memory, BATCH_SIZE)
    states, actions, rewards, next_states, dones = zip(*batch)

    states = torch.FloatTensor(states)
    actions = torch.LongTensor(actions)
    rewards = torch.FloatTensor(rewards)
    next_states = torch.FloatTensor(next_states)
    dones = torch.FloatTensor(dones)

    current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_q = self.target_net(next_states).max(1)[0]
    target_q = rewards + GAMMA * next_q * (1 - dones)

    loss = nn.MSELoss()(current_q, target_q.detach())
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    self.epsilon = max(EPSILON_MIN, self.epsilon * EPSILON_DECAY)

  def update_target_network(self):
    self.target_net.load_state_dict(self.policy_net.state_dict())

## Audio Steg Environment placeholder

In [25]:
# Audio Steganography Environment placeholder
# we will need a trained steganalysis environment
class AudioStegEnvironment:
  def __init__(self, audio_path, message):
    self.audio_path = audio_path
    self.audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)
    self.original_audio = self.audio.copy()
    self.original_signals = (self.audio * NORM_16NUM).astype(np.int16) # convert to 16-bit
    self.signals = self.original_signals # convert to 16-bit
    self.message = [int(b) for b in message]
    self.pos = 0 # current position in message
    self.sample_idx = 0 # current sample index
    self.env_net = EnvironmentNetwork(input_size=10, hidden_size=64, output_size=1).float()
    self.state_size = 10 # example: 10 samples around current position
    self.action_size = 2 # actions: modify sample (0: no change, 1: change)

  def reset(self):
    self.pos = 0
    self.sample_idx = 0
    self.audio = self.original_audio.copy()
    self.audio = (self.audio * NORM_16NUM).astype(np.int16) # convert to 16-bit
    return self.get_state()

  def get_state(self):
    # state: 10 samples around current index
    start = max(0, self.sample_idx - 5)
    end = min(len(self.audio), self.sample_idx + 5)
    state = np.zeros(10)
    state[:end - start] = self.audio[start:end]
    return state

  def step(self, action):
    # action: 0 (no change), 1 (modify sample to embed bit)
    done = False
    reward = 0

    if self.pos >= len(self.message) or self.sample_idx >= len(self.audio):
      done = True
      return self.get_state(), reward, done

    if action == 1: # modify sample to embed bit
      target_bit = self.message[self.pos]
      # simple lsb-like modification (real-world would be more sophisticated)
      self.audio[self.sample_idx] = self.modify_sample(self.audio[self.sample_idx], target_bit)

    # simulate steganalysis with environment network
    state = torch.FloatTensor(self.get_state()).unsqueeze(0)
    detection_prob = self.env_net(state).item()
    reward = 1 - detection_prob # reward: high if undetectable

    # Local Signal-Noise Ratio (only around modified region)
    if action == 1 and self.sample_idx < len(self.original_signals):
      window_size = 10
      start = max(0, self.sample_idx - window_size)
      end = min(len(self.audio), self.sample_idx + window_size)
      original_window = self.original_signals[start:end]
      noise_power = np.mean((self.signals[start:end] - original_window) ** 2 + 1e-10)
      signal_power = np.mean(original_window**2)
      snr =  (10 * np.log10(signal_power / noise_power)) if int(noise_power) != 0 else 100
      reward += 0.1 * snr  # Add SNR contribution to reward
      print(f"noise power: {noise_power} -- signal_power: {signal_power} -- snr: {snr} -- reward: {reward}")

      self.pos += 1  # move to next message bit

    self.sample_idx += EMBEDDING_STEP
    next_state = self.get_state()
    return next_state, reward, done

  def modify_sample(self, sample, bit):
      # Use LSB for now
      return (sample & ~1) | bit

  def decode_message(self):
        # Decode from embedded audio
        bits = [(self.audio[i] & 1) for i in range(0, len(self.message) * EMBEDDING_STEP, EMBEDDING_STEP)]
        return ''.join(map(str, bits))

## Main training loop

In [26]:
# Main Training Loop
def main():
  env = AudioStegEnvironment("drive/MyDrive/Colab Notebooks/steganography/input.wav", textToBits(MESSAGE))
  agent = RLAgent(state_size=10, action_size=2)

  for episode in range(EPISODES):
    state = env.reset()
    total_reward = 0
    done = False

    while not done:
      action = agent.select_action(state)
      next_state, reward, done = env.step(action)
      agent.store_transition(state, action, reward, next_state, done)
      agent.train()
      state = next_state
      total_reward += reward

    agent.update_target_network()
    decoded_msg = env.decode_message()
    print(f"Episode {episode + 1}, Total Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.2f}, Decoded: {decoded_msg}")

    # save modified audio
    sf.write(f"stego_audio_episode_{episode + 1}.wav", env.audio.astype(np.float32)/ NORM_16NUM, SAMPLE_RATE)

  # saving model
  torch.save(agent.policy_net.state_dict(), "policy_net.pth")
if __name__ == "__main__":
  main()

noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.538310825824738
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.548089861869812
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.548089861869812
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.538310825824738
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.548089861869812
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.538310825824738
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.538310825824738
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.538310825824738
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.528194785118103
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.548089861869812
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.548089861869812
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.538310825824738
noise power

## Functions to be used after training

In [27]:
def extract_message(stego_audio, embedding_step, message_length):
    extracted = []
    for i in range(0, len(stego_audio), embedding_step):
        if len(extracted) >= message_length:
            break
        sample = stego_audio[i]
        bit = 1 if sample > 0 else 0  # Simplified; assumes original sample ≈ 0
        extracted.append(bit)
    return extracted

In [30]:

# loading model
# agent.policy_net.load_state_dict(torch.load("policy_net.pth"))

# using trained model
def inference(audio_path, message):
    env = AudioStegEnvironment(audio_path, textToBits(message))
    agent = RLAgent(state_size=10, action_size=2)
    agent.policy_net.load_state_dict(torch.load("policy_net.pth"))
    agent.policy_net.eval()
    agent.epsilon = 0  # No exploration
    state = env.reset()
    done = False
    while not done:
      state_tensor = torch.FloatTensor(state).unsqueeze(0)
      with torch.no_grad():
        action = agent.policy_net(state_tensor).argmax().item()
      next_state, _, done = env.step(action)
      state = next_state
    sf.write("stego_audio_final.wav", env.audio.astype(np.float32)/ NORM_16NUM, SAMPLE_RATE)
    decoded_msg = env.decode_message()
    print(f"Decoded message: {textToBits(decoded_msg)}")

In [31]:
inference("drive/MyDrive/Colab Notebooks/steganography/input.wav", MESSAGE)

noise power: 1e-10 -- signal_power: 0.1 -- snr: 100 -- reward: 10.52780145406723
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.531883209943771
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.531883209943771
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.521273255348206
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.531883209943771
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.521273255348206
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.521273255348206
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.521273255348206
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.521273255348206
noise power: 1e-10 -- signal_power: 0.05 -- snr: 100 -- reward: 10.531883209943771
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.531883209943771
noise power: 1e-10 -- signal_power: 0.0 -- snr: 100 -- reward: 10.521273255348206
noise power: 1

# Off limits: experimenting

In [14]:
def lsb_encode(audio: np.ndarray, message: str):
  message_bits = textToBits(message)

  # ensure the message fits within the audio data
  if len(message_bits) > len(audio) * 8:
    raise ValueError("Message is too long to encode in audio file")

  stego_audio = audio.copy()
  for i, bit in enumerate(message_bits):
    stego_audio[i] = (stego_audio[i] & ~1) | int(bit)
  sf.write('stego_audio.wav', stego_audio.astype(np.float32)/ norm_num, sr)
  return stego_audio

def lsb_decode(audio: np.ndarray, size: int):
  bits = ''
  for sample in range(size * 8):
    bits += str(audio[sample] & 1)

  return bitsToText(bits)

In [15]:
norm_num = 32768
audio_file ="drive/MyDrive/Colab Notebooks/steganography/input.wav"
y, sr = librosa.load(audio_file, sr=None, mono=True)
original_signals = (y * norm_num).astype(np.int16)

In [16]:
secret = 'what is the meaning of apostacy: this issue has been long discussed'
secret_bits = textToBits(secret)
stego_signal = lsb_encode(original_signals, secret)

In [20]:
window_size = 10
start = 80000
end = 800010
original_window = original_signals[start:end]
noise_power = np.mean((stego_signal[start:end] - original_window) ** 2 + 1e-10)
signal_power = np.mean(original_window**2)
print(f"noise power: {noise_power}")
print(f"signal power: {signal_power}")
snr = 10 * np.log10(signal_power / noise_power) if noise_power != 0 else 100
reward = 0.1 * snr  # Add SNR contribution to reward
print(f"snr: {snr}")
print(f"reward: {reward}")

noise power: 9.999999999999995e-11
signal power: 1498.8913278982236
snr: 131.75770146921008
reward: 13.175770146921009
