In [1]:
# Library imports
import torch
import math

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler

from IPython import display

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import T5Tokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BertModel, BertConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Determine what device to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type == 'cuda':
    device_name = torch.cuda.get_device_name(device)

    print(f'CUDA on {device_name}')
else:
    print(f'CPU')

CUDA on NVIDIA RTX 500 Ada Generation Laptop GPU


In [3]:
# Configuration parameters

SECRET_LENGTH = 16
TOKENS_PER_BIT = 1
INJECT_SECRET_IN_SYSTEM_PROMPT = False 	# Whether to insert the secret as binary text in the system prompt
MASK_TOKENS = False						# Whether to mask a subset of tokens basd on the encoded secret bit. A simple form of encoding

STEGO_LENGTH = SECRET_LENGTH * TOKENS_PER_BIT

# Training parameters

MAX_EPOCH = 200                         # Maximum number of training epochs

LR_INITIAL = 0.1						# Initial learning rate
LR_GAMMA = 0.98	    					# Learning rate decay factor per epoch

RECOVERY_LOSS_ALPHA_START = 0.95		# Weighting factor for recovery loss vs semantic loss at the start of training (0.0 to 1.0)
RECOVERY_LOSS_ALPHA_END = 0.6			# Weighting factor for recovery loss vs semantic loss at the end of training (0.0 to 1.0)
RECOVERY_LOSS_STEPS = MAX_EPOCH // 2	# Number of steps over which to linearly anneal the recovery loss weight


In [4]:
# Prompt templates

INPUT_PROMPT = 'Tell me a fact about {0}.'
INPUT_SUBJECTS = [
    'traveling',
    'sports',
    'music',
    'technology',
    'health and wellness',
    'education',
    'history',
    'science',
    'art and culture'
]

In [5]:
# Cleanup previous run
# This allows rerunning the entire notebook without memory leaks

import gc

# Delete models, clean up memory

if 'generator_tokenizer' in globals():
    del generator_tokenizer # type: ignore
    
if 'generator' in globals():
    del generator # type: ignore
    
if 'decoder' in globals():
    del decoder # type: ignore
    
if 'semantic_anchor' in globals():
    del semantic_anchor # type: ignore

gc.collect()

torch._C._cuda_clearCublasWorkspaces()
torch.cuda.empty_cache()


In [6]:
# Check if everything's cleaned
allocated = torch.cuda.memory_allocated() / (1024 * 1024)
reserved = torch.cuda.memory_reserved() / (1024 * 1024)

print(f'Allocated: {allocated:.1f} MB; Reserved: {reserved:.1f} MB')

Allocated: 0.0 MB; Reserved: 0.0 MB


In [7]:
# Stego Encoder Model Definition

class StegoEncoder(nn.Module):
    def __init__(self, base_lm, prefix_encoder = None):
        super().__init__()
        
        self.base_lm = base_lm
        self.prefix_encoder = prefix_encoder
        self.vocab_size = base_lm.config.vocab_size
        
        # Secret encoder to map secret bit to vocab size
        self.logit_bias = nn.Parameter(torch.rand(self.vocab_size))

        # Freeze base model
        for p in self.base_lm.parameters():
            p.requires_grad = False

    def generate_input_embeddings(self, input_ids, secret_bits = None):
        # Token embeddings based on input tokens
        token_embeds = self.base_lm.get_input_embeddings()(input_ids) # (B, T, D)
        if self.prefix_encoder is None or secret_bits is None:
            return token_embeds
        
        # Prefix embeddings based on secret bits
        prefix_embeds = self.prefix_encoder(secret_bits)

        # Concatenate prefix and token embeddings
        input_embeds = torch.cat([prefix_embeds, token_embeds], dim=1)  # (B, T+P, D)
        
        return input_embeds

    def mask_indices(self, logits, bit):
        # Create a mask to mask out subsets of tokens
        logit_mask = torch.full((self.vocab_size, ), True)
        
        # bit = 0: mask even tokens
        # bit = 1: mask odd tokens
        masked_indices = list(range(bit.item(), self.vocab_size, 2))
        logit_mask[masked_indices] = False

        # Mask out tokens we don't allow
        logits[:, ~logit_mask] = -math.inf

    def forward(self, secret_bit = None, temperature=1.0, **kwargs):
        out = self.base_lm(
            **kwargs,
            return_dict=True,
            use_cache=True,
            eos_token_id=None, # prevent early stop
        )

        # Select logits for the last token only (this is the next prediction)
        logits = out.logits[:, -1, :]   # [B, V]

        # Encode secret bit based on learned logit bias
        if secret_bit is not None:
            batch_bits = secret_bit.unsqueeze(1)  # [B, 1]

            bit = (batch_bits * 2 - 1).expand(-1, self.vocab_size)  # [B, V]
            bias = self.logit_bias.expand(bit.shape[0], -1) * bit   # [B, V]
            
            biased_logits = logits + bias

            # Apply masking if enabled
            if MASK_TOKENS:
                self.mask_indices(biased_logits, secret_bit)
        else:
            biased_logits = logits

        # Softmax to convert from logits to probabilities
        probs = torch.softmax(biased_logits / temperature, dim=-1)

        return biased_logits, logits, probs, out.past_key_values

    def soft_embeddings(self, token_probs):
        embed_matrix = self.base_lm.get_input_embeddings().weight  # [V, d]
        return token_probs @ embed_matrix


In [8]:
# Stego Decoder Model Definition

class StegoDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, output_size=16, pool_size=4):
        super().__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size

        self.output_size = output_size
        self.pool_size = pool_size

        # Token-wise binary classification head
        # Note: Sigmoid() seems to strongly decrease decoder effectiveness?
        self.classifier = nn.Sequential(
            nn.Linear(vocab_size * pool_size, 1),
            #nn.GELU(),
            #nn.Linear(128, 1),
        )

    def forward(self, token_probs=None, labels=None):
        # Perform classification to get logits for each bit
        logits = self.classifier(token_probs)

        loss = None
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits.squeeze(-1), labels.float())

        return {
            "loss": loss,
            "logits": logits,
        }


In [9]:
# Create LM models

#GENERATOR_MODEL = "Qwen/Qwen3-0.6B"
#GENERATOR_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
#GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct"

generator_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
generator = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL).to(device)

# Print model info
embedding_size = generator.config.hidden_size
vocab_size = generator.config.vocab_size
print(f'Generator embedding dimensions: {embedding_size}, vocabulary size: {vocab_size}')

Generator embedding dimensions: 576, vocabulary size: 49152


In [10]:
# Create decoder & decoder model

encoder = StegoEncoder(base_lm=generator).to(device)
encoder.train()

decoder = StegoDecoder(vocab_size=vocab_size, embedding_size=embedding_size, output_size=SECRET_LENGTH, pool_size=TOKENS_PER_BIT).to(device)
decoder.train()

StegoDecoder(
  (classifier): Sequential(
    (0): Linear(in_features=49152, out_features=1, bias=True)
  )
)

In [11]:
# Utility functions

def generate_secret(batch_size=1):
    """Generate a random binary secret."""
    return torch.randint(0, 2, (batch_size, SECRET_LENGTH))   # [ B, N ]

def bits_to_str(bits, sep=''):
    """Convert a bit tensor to a string representation."""
    return sep.join(str(b) for b in bits.tolist())

def calculate_accuracy(preds, targets):
    """Calculate bitwise accuracy between predictions and targets."""
    correct = (preds == targets).sum().item()
    total = targets.numel()
    return correct / total

def generate_carrier_prompt():
    """Generate a random carrier prompt."""
    subject = np.random.choice(INPUT_SUBJECTS)
    return INPUT_PROMPT.format(subject)

def generate_carrier_prompts():
    """Generate the full training set of carrier prompts."""
    return [ INPUT_PROMPT.format(subject) for subject in INPUT_SUBJECTS ]

def build_generator_prompt(carrier, secret_bits = None):
    """Build the prompt for the generator model, optionally injecting the secret in the system prompt."""
    messages = [ ]

    if INJECT_SECRET_IN_SYSTEM_PROMPT and secret_bits is not None:
        messages.append(
            {'role': 'system', 'content': f'You are a helpful assistant who always encodes these bits in their answer: {bits_to_str(secret_bits.flatten(), " ")}. '}
        )

    messages.append({'role': 'user', 'content': carrier })

    return generator_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def kl_semantic_loss(stego_logits, base_logits):
    """Compute KL divergence between stego and base logits as semantic loss."""
    p = F.log_softmax(stego_logits, dim=-1)
    q = F.softmax(base_logits, dim=-1)

    return F.kl_div(p, q, reduction="batchmean")

@torch.no_grad()
def encode(carrier, secret_bits):
    """End-to-end encoding of secret bits into generated stego text based on the carrier prompt."""
    prompt = build_generator_prompt(carrier, secret_bits)
    inputs = generator_tokenizer(prompt, return_tensors="pt")

    input_embeddings = encoder.generate_input_embeddings(inputs.input_ids.to(device), secret_bits)
    past_kv = None

    # Generate N tokens, store ids for each token
    selected_tokens = [ ]
    for t in range(STEGO_LENGTH):   
        stego_logits, original_logits, token_probs, past_kv = encoder(
            secret_bit = secret_bits[:, t],
            inputs_embeds=input_embeddings,
            past_key_values=past_kv
        )

        # Select the top probability token
        predictions = torch.argmax(token_probs, dim=-1)

        # Convert the token to 'hard' embeddings again for the next iteration
        embedding_layer = generator.get_input_embeddings()
        input_embeddings = embedding_layer(predictions).unsqueeze(1) # (B, 1, D)
        
        selected_tokens.append(predictions[0])

    return generator_tokenizer.decode(selected_tokens, skip_special_tokens=True)

@torch.no_grad()
def decode(stego_text):
    """End-to-end decoding of stego text to recover the secret bits."""
    encoded_token_ids = generator_tokenizer(stego_text, return_tensors="pt").input_ids
    
    # Ensure correct length, discard extra tokens and pad if too short
    encoded_token_ids = encoded_token_ids[:, :STEGO_LENGTH]
    encoded_token_ids = F.pad(encoded_token_ids, (0, STEGO_LENGTH - encoded_token_ids.shape[1]), value=generator_tokenizer.pad_token_id)

    encoded_token_probs = F.one_hot(encoded_token_ids, vocab_size).float().to(device)

    out = decoder(
        token_probs = encoded_token_probs,
    )

    logits = out["logits"]      # (B, N, 1)
    preds = (logits.squeeze(-1) > 0).long()
    
    return preds[0]


In [12]:
# Test prompt generation and generation inference

# Prepare carrier and secret into a prompt
carrier = generate_carrier_prompt()
secret = generate_secret().to(device)
print(build_generator_prompt(carrier, secret))

print('----------------')

# Encode the secret into stego text
stego = encode(carrier, secret)
print(stego)

print('----------------')

# Decode the stego text
decoded = decode(stego)
print(f'Original: {bits_to_str(secret[0])}')
print(f'Decoded: {bits_to_str(decoded)}')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
Tell me a fact about art and culture.<|im_end|>
<|im_start|>assistant

----------------
Art and culture are two of the most fascinating aspects of human experience. In my
----------------
Original: 0110110101111111
Decoded: 1111111111101011


In [13]:
# Training helpers

def loss_schedule(epoch, max_epoch):
    p = min(epoch / RECOVERY_LOSS_STEPS, 1.0)
    alpha = RECOVERY_LOSS_ALPHA_START + (RECOVERY_LOSS_ALPHA_END - RECOVERY_LOSS_ALPHA_START) * p
    return alpha, 1.0 - alpha

In [None]:
# Training loop

# Collect all model parameters
collect_params = lambda model: [ p for p in model.parameters() if p.requires_grad ]

parameters = [ ]
parameters += collect_params(encoder)
parameters += collect_params(decoder)

# Initialize optimizer to guide training
optimizer = torch.optim.AdamW(parameters, lr=LR_INITIAL)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=LR_GAMMA)

# Prepare prompts
print('Preparing training prompts...')

carrier_prompts = generate_carrier_prompts()
generator_prompts = [ build_generator_prompt(carrier) for carrier in carrier_prompts ]
generator_inputs = generator_tokenizer(generator_prompts, return_tensors="pt", padding=True).input_ids.to(device)

print('Generator inputs shape:', generator_inputs.shape)
exit()

# Training loop for fixed number of steps
print('Starting training...')
for step in range(MAX_EPOCH):
    secret_bits = generate_secret(generator_inputs.shape[0]).to(device)

    # Prepare embeddings for first step by adding secret prefix to token embeddings
    # NOTE: when not using prefix embeddings, these are constant and could be pre-generated
    soft_embeds = encoder.generate_input_embeddings(generator_inputs, secret_bits)
    past_kv = None

    # Generate N tokens, store probabilities for each token
    generated_token_probs = []

    # Initialize semantic loss, which we accumulate per generated token
    semantic_loss = 0.0

    for t in range(STEGO_LENGTH): 
        stego_logits, original_logits, token_probs, past_kv = encoder(
            secret_bit=secret_bits[:, t],
            inputs_embeds=soft_embeds,
            past_key_values=past_kv
        )

        # Calculate semantic loss compared to base logits
        semantic_loss += kl_semantic_loss(stego_logits, original_logits)

        # Prepare soft embeddings for next step
        soft_embeds = encoder.soft_embeddings(token_probs)  # [ B, D ]

        # Store all generated token probabilities
        generated_token_probs.append(token_probs)

        # Take soft embeds for next step
        soft_embeds = soft_embeds.unsqueeze(1)

    # Stack probabilities for each token to [ B, N, D ]
    encoder_outputs_probs = torch.stack(generated_token_probs, dim=1).to(device)

    # Calculate recovery loss based on decoding success
    recovery_loss = decoder(
        token_probs=encoder_outputs_probs, 
        labels=secret_bits.to(device)
    )['loss']

    # Combine loss based on loss weighting schedule
    w_recovery, w_semantic = loss_schedule(step, MAX_EPOCH)
    loss = w_recovery * recovery_loss + w_semantic * semantic_loss

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()

    if step % 1 == 0:
        prompt = generate_carrier_prompt()
        secret_bits = generate_secret().to(device)
        stego_text = encode(prompt, secret_bits)
        decoded_secret_bits = decode(stego_text)

        display.clear_output(wait=True)

        print(f"Step {step}")
        print(f"L_sem: {semantic_loss.item():.3f} | L_rec: {recovery_loss.item():.3f} | Acc: {calculate_accuracy(decoded_secret_bits, secret_bits):.3f}")
        print(f'Lr: {scheduler.get_last_lr()[0]:.6f} | w_rec: {w_recovery:.3f} | w_sem: {w_semantic:.3f}')
        print()
        print(f'Prompt: {prompt}')
        print(f'Stego: {stego_text}')
        print()


Step 27
L_sem: 0.281 | L_rec: 0.650 | Acc: 0.625
Lr: 0.056798 | w_rec: 0.855 | w_sem: 0.145

Prompt: Tell me a fact about sports.
Stego: A fact about sports is that sports are a sport that involves physical activity, competition



KeyboardInterrupt: 

: 

In [None]:
# Final end-to-end test

carrier = generate_carrier_prompt()
secret = generate_secret().to(device)

stego_text = encode(carrier, secret)
recovered_secret = decode(stego_text)

print("Stego text:")
print(stego_text)
print()
print("Original secret:", bits_to_str(secret[0]))
print("Recovered secret:", bits_to_str(recovered_secret))
print(f"Accuracy: {100 * calculate_accuracy(recovered_secret, secret):.2f}%")
