In [1]:
import nltk
import evaluate
import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers import T5Tokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BertModel, BertConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type == 'cuda':
    device_name = torch.cuda.get_device_name(device)

    print(f'CUDA on {device_name}')
else:
    print(f'CPU')

CUDA on NVIDIA RTX 500 Ada Generation Laptop GPU


In [3]:
SECRET_LENGTH = 16
TOKENS_PER_BIT = 1
SKIPPED_OUTPUT_TOKENS = 4   # Number of initial output tokens to skip (seem to be constant for SmolLM)
PREFIX_LENGTH = 32          # Number of tokens in which the secret is encoded before the input prompt

STEGO_LENGTH = SECRET_LENGTH * TOKENS_PER_BIT

In [4]:
import gc

# Delete models, clean up memory

if 'generator_tokenizer' in globals():
    del generator_tokenizer
    
if 'decoder_tokenizer' in globals():
    del decoder_tokenizer

if 'generator' in globals():
    del generator
    
if 'decoder' in globals():
    del decoder
    
if 'semantic_anchor' in globals():
    del semantic_anchor

gc.collect()

torch._C._cuda_clearCublasWorkspaces()
torch.cuda.empty_cache()


In [5]:
# Check if everything's cleaned
allocated = torch.cuda.memory_allocated() / (1024 * 1024)
reserved = torch.cuda.memory_reserved() / (1024 * 1024)

print(f'Allocated: {allocated:.1f} MB; Reserved: {reserved:.1f} MB')

Allocated: 0.0 MB; Reserved: 0.0 MB


In [6]:
class StegoEmbeddingDecoder(nn.Module):
    def __init__(self, hidden_size=768, output_size=16, pool_size=4):
        super().__init__()

        self.hidden_size = hidden_size
        self.output_size = output_size
        self.pool_size = pool_size

        # config = BertConfig(
        #     hidden_size=hidden_size,
        #     num_hidden_layers=4,
        #     num_attention_heads=8,
        #     intermediate_size=4 * hidden_size
        # )

        # self.bert = BertModel(config)

        # Token-wise binary classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * pool_size, 256),
            nn.Sigmoid(),
            nn.Linear(256, 256),
            nn.Sigmoid(),
            nn.Linear(256, 1),
        )

    def forward(self, inputs_embeds, attention_mask=None, labels=None):
        """
        inputs_embeds: (B, N, D)
        attention_mask: (B, N) optional
        labels: (B, N) optional
        """

        # outputs = self.bert(
        #     inputs_embeds=inputs_embeds,
        #     attention_mask=attention_mask,
        #     return_dict=True,
        # )

        # hidden_states = outputs.last_hidden_state # (B, N, D)
        hidden_states = inputs_embeds

        # Reshape embeddings per token to stack all embeddings for a bit pool
        assert hidden_states.shape[1] == self.output_size * self.pool_size, f'Input tokens size error: {hidden_states.shape[1]}. Expected: {self.output_size * self.pool_size}'
        assert hidden_states.shape[2] == self.hidden_size, f' Input embedding size error: {hidden_states.shape[2]}. Expected: {self.hidden_size}'
        classifier_input = hidden_states.reshape(-1, self.output_size, self.pool_size * self.hidden_size)  # (B, T, P * D)
        
        # Perform classification to get logits for each bit
        logits = self.classifier(classifier_input)

        loss = None
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits.squeeze(-1), labels.float())

        return {
            "loss": loss,
            "logits": logits,
        }


In [7]:

class SecretPrefixEncoder(nn.Module):
    def __init__(self, bit_len: int, hidden_size: int, prefix_length: int):
        super().__init__()
    
        self.hidden_size = hidden_size
        self.prefix_length = prefix_length

        layer_size = hidden_size * prefix_length

        self.projection = nn.Sequential(
            nn.Linear(bit_len, layer_size),
            nn.Tanh(),
            nn.Linear(layer_size, layer_size)
        )

    def forward(self, bits: torch.Tensor):
        """
        bits: [ B, N ] in {0,1}
        returns: [ B, P, D ]
        """
        emb = self.projection(bits.float())
        return emb.reshape(bits.shape[0], self.prefix_length, self.hidden_size) # [ B, P, D]


In [8]:
#GENERATOR_MODEL = "Qwen/Qwen3-0.6B"
#GENERATOR_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
#GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct"
#DECODER_MODEL = "google/flan-t5-base"

generator_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
#decoder_tokenizer = T5Tokenizer.from_pretrained(DECODER_MODEL).to(device)

generator = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL).to(device)
decoder = StegoEmbeddingDecoder(hidden_size=generator.config.hidden_size, output_size=SECRET_LENGTH, pool_size=TOKENS_PER_BIT).to(device)
#decoder = T5ForConditionalGeneration.from_pretrained(DECODER_MODEL).to(device)

prefix_encoder = SecretPrefixEncoder(bit_len=SECRET_LENGTH, hidden_size=generator.config.hidden_size, prefix_length=PREFIX_LENGTH).to(device)

#semantic_anchor = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL).to(device)

# Freeze base model
#for p in semantic_anchor.parameters():
#    p.requires_grad = False

generator.train()
decoder.train()
#semantic_anchor.eval()

print(f'Generator embedding dimensions: {generator.config.hidden_size}')

Generator embedding dimensions: 576


In [9]:
def bits_to_str(bits, sep=''):
    return sep.join(str(b) for b in bits.tolist())

def build_generator_prompt(carrier, secret_bits):
    # secret_prompt = f' Encode these secret bits in your answer: {bits_to_str(secret_bits.flatten(), " ")}'

    messages = [
        {"role": "system", "content": f"You are a helpful assistant."}, 
        {"role": "user", "content": carrier }
    ]

    return generator_tokenizer.apply_chat_template(messages, tokenize=False)

def generate_input_embeddings(input_ids, secret_bits):
    # Token embeddings
    token_embeds = generator.get_input_embeddings()(input_ids) # (B, T, D)

    # Prefix embeddings
    prefix_embeds = prefix_encoder(secret_bits)

    # Concatenate prefix and token embeddings
    input_embeds = torch.cat([prefix_embeds, token_embeds], dim=1)  # (B, T+1, D)
    
    return input_embeds


In [10]:
def generator_forward(temperature=1.0, **kwargs):
    out = generator(
        **kwargs,
        return_dict=True,
        use_cache=True,
        eos_token_id=None, # prevent early stop
    )

    logits = out.logits[:, -1, :]   # [B, V]
    probs  = torch.softmax(logits / temperature, dim=-1)

    return logits, probs, out.past_key_values

def soft_embeddings(token_probs):
    embed_matrix = generator.get_input_embeddings().weight  # [V, d]
    return token_probs @ embed_matrix

@torch.no_grad()
def anchor_forward(anchor, inputs) -> torch.Tensor:
    outputs = anchor(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        return_dict=True
    )
    logits = outputs.logits[:, :STEGO_LENGTH, :]
    return logits

def kl_semantic_loss(stego_logits, base_logits):
    p = F.log_softmax(stego_logits, dim=-1)
    q = F.softmax(base_logits, dim=-1)

    # Masked?
    # kl = F.kl_div(log_p, q, reduction="none")  # [B, T, V]
    # kl = kl.sum(dim=-1)                        # [B, T]

    # # Mask padding tokens
    # kl = kl * attention_mask

    # return kl.sum() / attention_mask.sum()

    return F.kl_div(p, q, reduction="batchmean")

def decoder_forward(soft_embeds, secret_bits):
    out = decoder(inputs_embeds=soft_embeds, labels=secret_bits)

    return out['loss']

In [11]:
@torch.no_grad()
def encode(carrier, secret_bits):
    prompt = build_generator_prompt(carrier, secret_bits)
    inputs = generator_tokenizer(prompt, return_tensors="pt")
    input_embeddings = generate_input_embeddings(inputs.input_ids.to(device), secret_bits)

    ids = generator.generate(
        inputs_embeds=input_embeddings,
        max_new_tokens=STEGO_LENGTH + SKIPPED_OUTPUT_TOKENS, # First X tokens seem to be constant part of the LMs output
        eos_token_id=None, # prevent early stop
        do_sample=False,
    )

    selected_ids = ids[0, SKIPPED_OUTPUT_TOKENS:]

    assert selected_ids.shape[0] == STEGO_LENGTH, f'Generated stego length error: {selected_ids.shape[0]}. Expected: {STEGO_LENGTH}'

    return generator_tokenizer.decode(selected_ids, skip_special_tokens=True)

@torch.no_grad()
def decode(stego_text):
    encoded_token_ids = generator_tokenizer(stego_text, return_tensors="pt").input_ids
    
    # Ensure correct length, discard extra tokens and pad if too short
    encoded_token_ids = encoded_token_ids[:, :STEGO_LENGTH]
    encoded_token_ids = F.pad(encoded_token_ids, (0, STEGO_LENGTH - encoded_token_ids.shape[1]), value=generator_tokenizer.pad_token_id)

    # Encoder embedding matrix
    embedding_layer = generator.get_input_embeddings()

    # (B, N, D)
    encoded_embeds = embedding_layer(encoded_token_ids.to(device))

    #attention_mask = torch.ones(encoded_embeds.shape[:2])

    out = decoder(
        inputs_embeds=encoded_embeds
        #attention_mask=attention_mask,
    )

    logits = out["logits"]      # (B, N, 1)
    preds = (logits.squeeze(-1) > 0).long()
    
    return preds[0]


In [12]:
def generate_secret():
    return torch.randint(0, 2, (1, SECRET_LENGTH))   # [ B, N ]

In [13]:
carrier = "Write a short paragraph explaining how neural networks learn."
secret = generate_secret().to(device)

print(build_generator_prompt(carrier, secret))
print('----------------')
print(encode(carrier, secret))

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Write a short paragraph explaining how neural networks learn.<|im_end|>

----------------
Neural networks learn by training on a dataset of labeled data, where each instance


In [None]:
optimizer = torch.optim.AdamW(
    list(prefix_encoder.parameters()) +
    list(generator.parameters()) + 
    list(decoder.parameters()),
    lr=3e-5
)

for step in range(1000):
    carrier_prompt = "Explain how neural networks learn."
    secret_bits = generate_secret().to(device)

    prompt = build_generator_prompt(carrier_prompt, secret_bits)
    generator_inputs = generator_tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    # Prepare embeddings for first step by adding secret prefix to token embeddings
    soft_embeds = generate_input_embeddings(generator_inputs, secret_bits)
    past_kv = None

    # Generate N tokens, store embeddings for each token
    generated_embeds = []
    for t in range(STEGO_LENGTH + SKIPPED_OUTPUT_TOKENS):   
        stego_logits, token_probs, past_kv = generator_forward(
            inputs_embeds=soft_embeds,
            past_key_values=past_kv
        )

        soft_embeds = soft_embeddings(token_probs)  # [ B, D ]

        # HACK: Modulate embeddings based on secret bit to test decoder
        if (t >= SKIPPED_OUTPUT_TOKENS):
            bit = secret_bits[0,t - SKIPPED_OUTPUT_TOKENS]
            soft_embeds = torch.mul(soft_embeds, bit)

        generated_embeds.append(soft_embeds)

        # Take soft embeds for next step
        soft_embeds = soft_embeds.unsqueeze(1)

    #print(f'Number of generated embeddings: {len(generated_embeds)}, skipping first {SKIPPED_OUTPUT_TOKENS}...')

    # Remove fixed first X tokens from output
    generated_embeds = generated_embeds[SKIPPED_OUTPUT_TOKENS:]

    # Stack embeddings for each token to [ B, N, D ]
    encoder_outputs_embeds = torch.stack(generated_embeds, dim=1).to(device)

    #print(f'Encoder outputs embeddings shape: {encoder_outputs_embeds.shape}')

    # Calculate recovery loss based on decoding success
    recovery_loss = decoder_forward(encoder_outputs_embeds, secret_bits.to(device))

    # Calculate semantic loss compared to base model
    #base_logits = anchor_forward(semantic_anchor, generator_inputs)
    #semantic_loss = kl_semantic_loss(stego_logits, base_logits)
    semantic_loss = np.zeros(1)

    # Combine loss
    loss = recovery_loss
    #loss = 1.0 * recovery_loss + 0.0 * semantic_loss

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    if step % 1 == 0:
        stego_text = encode(carrier_prompt, secret_bits)
        secret_str = bits_to_str(secret_bits.flatten())
        decoded_secret_bits = decode(stego_text)
        decoded_secret_str = bits_to_str(decoded_secret_bits)

        print(f"Step {step} | L_sem: {semantic_loss.item():.3f} | L_rec: {recovery_loss.item():.3f}")
        print(f'Secret:  {secret_str}')
        print(f'Decoded: {decoded_secret_str}')
        print(f'Stego:   {stego_text}')
        print()


Step 0 | L_sem: 0.000 | L_rec: 0.707
Secret:  1011110111110000
Decoded: 0000000000000000
Stego:   

Neural networks learn by processing data through the combination of inputs and outputs

Step 1 | L_sem: 0.000 | L_rec: 0.683
Secret:  0100000001111001
Decoded: 0000000000000000
Stego:   Neural networks learn by processing and transforming data into patterns, which they then use

Step 2 | L_sem: 0.000 | L_rec: 0.710
Secret:  0110011110111101
Decoded: 0000000000000000
Stego:   system
system

user
You I

Step 3 | L_sem: 0.000 | L_rec: 0.694
Secret:  1010101001010110
Decoded: 0000000000000000
Stego:   


Neural networks learn by processing data from the input data and transforming

Step 4 | L_sem: 0.000 | L_rec: 0.685
Secret:  1010100110010000
Decoded: 0000000000000000
Stego:   



Neural networks learn by processing data and using algorithms to identify

Step 5 | L_sem: 0.000 | L_rec: 0.703
Secret:  0111001011001111
Decoded: 0000000000000000
Stego:   

















Step 6 | L_sem: 0.000 

KeyboardInterrupt: 

In [None]:
stego_text
input_ids = generator_tokenizer(stego_text, return_tensors="pt").input_ids

print(input_ids)
print(input_ids.shape)

tensor([[ 9042,   867,  5115,   359,   253,  7132,  1743,   338,   338,   338,
           338,   338,   338,   338,  9218,   338,  1848,  1848,   338,   338,
            30,  1069,   314,   314,   314,  1431, 17327,    30,   198,    49,
          7132,  1743,    30,   198,  4093,   198,    49,   932,  3108,   282,
           314,   253, 47126,   253,    30,   198,    49,   932,  3108,   282,
           314,  1589,   314, 32195,    30,   198,    49,  3951,    30]])
torch.Size([1, 59])


In [None]:
carrier = "Write a short paragraph explaining how neural networks learn."
secret = generate_secret().to(device)

stego_text = encode(carrier, secret)
recovered_secret = decode(stego_text)

print("Stego text:")
print(stego_text)
print()
print("Original secret:", bits_to_str(secret))
print("Recovered secret:", bits_to_str(recovered_secret))
