In [1]:
# Library imports
import nltk
import evaluate
import torch
import math

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers import T5Tokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BertModel, BertConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Determine what device to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type == 'cuda':
    device_name = torch.cuda.get_device_name(device)

    print(f'CUDA on {device_name}')
else:
    print(f'CPU')

CUDA on NVIDIA GeForce RTX 3060 Laptop GPU


In [3]:
# Configuration parameters

SECRET_LENGTH = 16
TOKENS_PER_BIT = 1
SKIPPED_OUTPUT_TOKENS = 4   			# Number of initial output tokens to skip (seem to be constant for SmolLM)
PREFIX_LENGTH = 0						# Number of tokens in which the secret is encoded before the input prompt. 0 to disable
INJECT_SECRET_IN_SYSTEM_PROMPT = False 	# Whether to insert the secret as binary text in the system prompt

MASK_TOKENS = True						# Whether to mask a subset of tokens basd on the encoded secret bit. A simple form of encoding

STEGO_LENGTH = SECRET_LENGTH * TOKENS_PER_BIT

In [4]:
# Cleanup previous run
# This allows rerunning the entire notebook without memory leaks

import gc

# Delete models, clean up memory

if 'generator_tokenizer' in globals():
    del generator_tokenizer
    
if 'generator' in globals():
    del generator
    
if 'decoder' in globals():
    del decoder
    
if 'semantic_anchor' in globals():
    del semantic_anchor

gc.collect()

torch._C._cuda_clearCublasWorkspaces()
torch.cuda.empty_cache()


In [5]:
# Check if everything's cleaned
allocated = torch.cuda.memory_allocated() / (1024 * 1024)
reserved = torch.cuda.memory_reserved() / (1024 * 1024)

print(f'Allocated: {allocated:.1f} MB; Reserved: {reserved:.1f} MB')

Allocated: 0.0 MB; Reserved: 0.0 MB


In [6]:
class StegoEmbeddingDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, output_size=16, pool_size=4):
        super().__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size

        self.output_size = output_size
        self.pool_size = pool_size

        # config = BertConfig(
        #     hidden_size=embedding_size,
        #     num_hidden_layers=4,
        #     num_attention_heads=8,
        #     intermediate_size=16 * embedding_size
        # )

        # self.bert = BertModel(config)

        # Token-wise binary classification head
        # NOTE: 1 layer seems to work fine for now
        self.classifier = nn.Sequential(
            nn.Linear(vocab_size * pool_size, 1),
            # Note: Sigmoid() seems to strongly decrease decoder effectiveness?
            # nn.GELU(),
            # nn.Linear(256, 128),
            # nn.GELU(),
            # nn.Linear(128, 1),
        )

    def forward(self, token_probs=None, outputs_embeds=None, attention_mask=None, labels=None):
        """
        inputs_embeds: (B, N, D)
        attention_mask: (B, N) optional
        labels: (B, N) optional
        """

        # Run input embeddings through BERT model
        # outputs = self.bert(
        #     inputs_embeds=outputs_embeds,
        #     attention_mask=attention_mask,
        #     return_dict=True,
        # )

        # hidden_states = outputs.last_hidden_state # (B, N, D)

        classifier_input = token_probs

        # Reshape embeddings per token to stack all embeddings for a bit pool
        # assert hidden_states.shape[1] == self.output_size * self.pool_size, f'Input tokens size error: {hidden_states.shape[1]}. Expected: {self.output_size * self.pool_size}'
        # assert hidden_states.shape[2] == self.embedding_size, f' Input embedding size error: {hidden_states.shape[2]}. Expected: {self.embedding_size}'
        # classifier_input = hidden_states.reshape(-1, self.output_size, self.pool_size * self.embedding_size)  # (B, T, P * D)
        
        # Perform classification to get logits for each bit
        logits = self.classifier(classifier_input)

        loss = None
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits.squeeze(-1), labels.float())

        return {
            "loss": loss,
            "logits": logits,
        }


In [7]:

class SecretPrefixEncoder(nn.Module):
    def __init__(self, secret_length: int, hidden_size: int):
        super().__init__()
    
        self.secret_length = secret_length
        self.hidden_size = hidden_size

        self.projection = nn.Sequential(
            nn.Linear(1, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )

    def forward(self, bits: torch.Tensor):
        """
        bits: [ B, S ] in {0,1}
        returns: [ B, S, D ]
        """
        emb = self.projection(bits.reshape(-1, self.secret_length, 1).float())
        return emb


In [8]:
# Create encoder models

#GENERATOR_MODEL = "Qwen/Qwen3-0.6B"
#GENERATOR_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
#GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct"

generator_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)

generator = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL).to(device)

if PREFIX_LENGTH > 0:
	prefix_encoder = SecretPrefixEncoder(secret_length=SECRET_LENGTH, hidden_size=generator.config.hidden_size).to(device)

#semantic_anchor = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL).to(device)

# Freeze base model
#for p in semantic_anchor.parameters():
#    p.requires_grad = False

generator.train()
#semantic_anchor.eval()

embedding_size = generator.config.hidden_size
vocab_size = generator.config.vocab_size
print(f'Generator embedding dimensions: {embedding_size}, vocabulary size: {vocab_size}')

Generator embedding dimensions: 576, vocabulary size: 49152


In [9]:
# Create decoder model

decoder = StegoEmbeddingDecoder(vocab_size=vocab_size, embedding_size=embedding_size, output_size=SECRET_LENGTH, pool_size=TOKENS_PER_BIT).to(device)
decoder.train()

StegoEmbeddingDecoder(
  (classifier): Sequential(
    (0): Linear(in_features=49152, out_features=1, bias=True)
  )
)

In [10]:
def generate_secret():
    return torch.randint(0, 2, (1, SECRET_LENGTH))   # [ B, N ]

def bits_to_str(bits, sep=''):
    return sep.join(str(b) for b in bits.tolist())

def build_generator_prompt(carrier, secret_bits):

    messages = [ ]

    if INJECT_SECRET_IN_SYSTEM_PROMPT:
        messages.append(
            {'role': 'system', 'content': f'You are a helpful assistant who always encodes these bits in their answer: {bits_to_str(secret_bits.flatten(), " ")}. '}
        )

    messages.append({'role': 'user', 'content': carrier })

    return generator_tokenizer.apply_chat_template(messages, tokenize=False)

def generate_input_embeddings(input_ids, secret_bits = None):
    # Token embeddings based on input tokens
    token_embeds = generator.get_input_embeddings()(input_ids) # (B, T, D)

    if PREFIX_LENGTH == 0 or secret_bits is None:
        return token_embeds
    
    # Prefix embeddings based on secret bits
    prefix_embeds = prefix_encoder(secret_bits)

    # Concatenate prefix and token embeddings
    input_embeds = torch.cat([prefix_embeds, token_embeds], dim=1)  # (B, T+P, D)
    
    return input_embeds

def mask_indices(bit):
    # start = 0 if bit.item() == 0 else vocab_size // 2
    # end = vocab_size if bit.item() == 1 else vocab_size // 2

    #return range(start, end)
    return range(bit.item(), vocab_size, 2)

@torch.no_grad()
def generator_forward(secret_bit = None, temperature=1.0, **kwargs):
    out = generator(
        **kwargs,
        return_dict=True,
        use_cache=True,
        eos_token_id=None, # prevent early stop
    )

    logits = out.logits[:, -1, :]   # [B, V]

    if MASK_TOKENS and secret_bit is not None:
        # Create a mask to mask out subsets of tokens
        logit_mask = torch.full((vocab_size, ), True)
        
        # bit = 0: mask even tokens
        # bit = 1: mask odd tokens
        masked_indices = list(mask_indices(secret_bit))
        logit_mask[masked_indices] = False

        # Mask out tokens we don't allow
        logits[:, ~logit_mask] = -math.inf

    probs = torch.softmax(logits / temperature, dim=-1)

    return logits, probs, out.past_key_values

def soft_embeddings(token_probs):
    embed_matrix = generator.get_input_embeddings().weight  # [V, d]
    return token_probs @ embed_matrix

@torch.no_grad()
def anchor_forward(anchor, inputs) -> torch.Tensor:
    outputs = anchor(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        return_dict=True
    )
    logits = outputs.logits[:, :STEGO_LENGTH, :]
    return logits

def kl_semantic_loss(stego_logits, base_logits):
    p = F.log_softmax(stego_logits, dim=-1)
    q = F.softmax(base_logits, dim=-1)

    # Masked?
    # kl = F.kl_div(log_p, q, reduction="none")  # [B, T, V]
    # kl = kl.sum(dim=-1)                        # [B, T]

    # # Mask padding tokens
    # kl = kl * attention_mask

    # return kl.sum() / attention_mask.sum()

    return F.kl_div(p, q, reduction="batchmean")

def decoder_forward(soft_tokens, soft_embeds, secret_bits):
    out = decoder(token_probs=soft_tokens, outputs_embeds=soft_embeds, labels=secret_bits)

    return out['loss']

@torch.no_grad()
def encode(carrier, secret_bits):
    prompt = build_generator_prompt(carrier, secret_bits)
    inputs = generator_tokenizer(prompt, return_tensors="pt")

    input_embeddings = generate_input_embeddings(inputs.input_ids.to(device), secret_bits)
    past_kv = None

    # Generate N tokens, store ids for each token
    selected_tokens = [ ]
    for t in range(STEGO_LENGTH + SKIPPED_OUTPUT_TOKENS):   
        stego_logits, token_probs, past_kv = generator_forward(
            secret_bit = secret_bits[0, t - SKIPPED_OUTPUT_TOKENS] if t >= SKIPPED_OUTPUT_TOKENS else None,
            inputs_embeds=input_embeddings,
            past_key_values=past_kv
        )

        # Select the top probability token
        predictions = torch.argmax(token_probs, dim=-1)

        # Convert the token to 'hard' embeddings again for the next iteration
        embedding_layer = generator.get_input_embeddings()
        input_embeddings = embedding_layer(predictions).unsqueeze(1) # (B, 1, D)
        
        if t >= SKIPPED_OUTPUT_TOKENS:
            selected_tokens.append(predictions[0])

    return generator_tokenizer.decode(selected_tokens, skip_special_tokens=True)

@torch.no_grad()
def decode(stego_text):
    encoded_token_ids = generator_tokenizer(stego_text, return_tensors="pt").input_ids
    
    # Ensure correct length, discard extra tokens and pad if too short
    encoded_token_ids = encoded_token_ids[:, :STEGO_LENGTH]
    encoded_token_ids = F.pad(encoded_token_ids, (0, STEGO_LENGTH - encoded_token_ids.shape[1]), value=generator_tokenizer.pad_token_id)

    # Encoder embedding matrix
    embedding_layer = generator.get_input_embeddings()

    # # (B, N, D)
    encoded_embeds = embedding_layer(encoded_token_ids.to(device))

    #attention_mask = torch.ones(encoded_embeds.shape[:2])

    encoded_token_probs = F.one_hot(encoded_token_ids, vocab_size).float().to(device)

    out = decoder(
        token_probs = encoded_token_probs,
        outputs_embeds = encoded_embeds
        #attention_mask=attention_mask,
    )

    logits = out["logits"]      # (B, N, 1)
    preds = (logits.squeeze(-1) > 0).long()
    
    return preds[0]


In [None]:
# Test prompt generation and generation inference

carrier = "Write a short paragraph explaining how neural networks learn."
secret = generate_secret().to(device)
print(build_generator_prompt(carrier, secret))
print('----------------')

stego = encode(carrier, secret)
print(stego)

print('----------------')
decoded = decode(stego)
print(f'Original: {bits_to_str(secret[0])}')
print(f'Decoded: {bits_to_str(decoded)}')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
Write a short paragraph explaining how neural networks learn.<|im_end|>

----------------
A neural networks, commonly referred as CNNs or deep learning models in various domains
----------------
Original: [0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1]
Decoded: 1000010010010100


In [None]:
# Training loop


# Collect all model parameters
parameters = [ ]
#parameters += list(generator.parameters())
parameters += list(decoder.parameters())

if PREFIX_LENGTH > 0:
    parameters += list(prefix_encoder.parameters())

optimizer = torch.optim.AdamW(parameters, lr=3e-1)

for step in range(1000):
    carrier_prompt = "Explain how neural networks learn."
    secret_bits = generate_secret().to(device)

    prompt = build_generator_prompt(carrier_prompt, secret_bits)
    generator_inputs = generator_tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    # Prepare embeddings for first step by adding secret prefix to token embeddings
    soft_embeds = generate_input_embeddings(generator_inputs, secret_bits)
    past_kv = None

    # Generate N tokens, store embeddings for each token
    generated_embeds = []
    generated_token_probs = []

    for t in range(STEGO_LENGTH + SKIPPED_OUTPUT_TOKENS): 
        stego_logits, token_probs, past_kv = generator_forward(
            secret_bit = secret_bits[0, t - SKIPPED_OUTPUT_TOKENS] if t >= SKIPPED_OUTPUT_TOKENS else None,
            inputs_embeds=soft_embeds,
            past_key_values=past_kv
        )

        soft_embeds = soft_embeddings(token_probs)  # [ B, D ]

        # Discretize token probs by selecting the largest
        # This breaks differentiability
        #predictions = torch.argmax(token_probs, dim=1)
        #token_probs = F.one_hot(predictions, vocab_size).float().to(device)

        # Only save embeds after the first X tokens
        if t >= SKIPPED_OUTPUT_TOKENS:
            generated_token_probs.append(token_probs)
            generated_embeds.append(soft_embeds)

        # Take soft embeds for next step
        soft_embeds = soft_embeds.unsqueeze(1)

    #print(f'Number of generated embeddings: {len(generated_embeds)}, skipping first {SKIPPED_OUTPUT_TOKENS}...')

    # Stack embeddings for each token to [ B, N, D ]
    encoder_outputs_embeds = torch.stack(generated_embeds, dim=1).to(device)
    encoder_outputs_probs = torch.stack(generated_token_probs, dim=1).to(device)

    #print(f'Encoder outputs embeddings shape: {encoder_outputs_embeds.shape}')

    # Calculate recovery loss based on decoding success
    recovery_loss = decoder_forward(soft_tokens=encoder_outputs_probs, soft_embeds=encoder_outputs_embeds, secret_bits=secret_bits.to(device))

    # Calculate semantic loss compared to base model
    #base_logits = anchor_forward(semantic_anchor, generator_inputs)
    #semantic_loss = kl_semantic_loss(stego_logits, base_logits)
    semantic_loss = np.zeros(1)

    # Combine loss
    loss = recovery_loss
    #loss = 1.0 * recovery_loss + 0.0 * semantic_loss

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    if step % 1 == 0:
        stego_text = encode(carrier_prompt, secret_bits)
        secret_str = bits_to_str(secret_bits.flatten())
        decoded_secret_bits = decode(stego_text)
        decoded_secret_str = bits_to_str(decoded_secret_bits)

        print(f"Step {step} | L_sem: {semantic_loss.item():.3f} | L_rec: {recovery_loss.item():.3f}")
        print(f'Secret:  {secret_str}')
        print(f'Decoded: {decoded_secret_str}')
        print(f'Stego:   {stego_text}')
        print()


tensor(0.9175, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.9998, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.3145, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.4154, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.7837, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.1017, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.6683, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.1529, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.8830, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.3605, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.5022, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.2852, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.1540, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.4079, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.3242, device='cuda:0')
tensor(0., device='cuda:0')

tensor(0.0451, device='cuda:0')
tensor(0., device='cuda:0')

Step 0 | L_sem: 0.000 | 

KeyboardInterrupt: 

In [13]:
# Final end-to-end test

carrier = "Give me a recipe for pasta bolognese"
secret = generate_secret().to(device)

stego_text = encode(carrier, secret)
recovered_secret = decode(stego_text)

print("Stego text:")
print(stego_text)
print()
print("Original secret:", bits_to_str(secret[0]))
print("Recovered secret:", bits_to_str(recovered_secret))


Stego text:
Here are some delicious pasta recipes for your Italian craving:

**Pasta

Original secret: 1010100000111110
Recovered secret: 1010100000111110
