In [1]:
import nltk
import evaluate
import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers import T5Tokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BertModel, BertConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type == 'cuda':
    device_name = torch.cuda.get_device_name(device)

    print(f'CUDA on {device_name}')
else:
    print(f'CPU')

CUDA on NVIDIA RTX 500 Ada Generation Laptop GPU


In [3]:
SECRET_LENGTH = 16
MAX_STEGO_LENGTH = 256

# Using BF16 seems to fix some inf/nan errors in one of the models
dtype = torch.bfloat16

In [4]:
import gc

# Delete models, clean up memory

if 'generator_tokenizer' in globals():
    del generator_tokenizer
    
if 'decoder_tokenizer' in globals():
    del decoder_tokenizer

if 'generator' in globals():
    del generator
    
if 'decoder' in globals():
    del decoder
    
if 'semantic_anchor' in globals():
    del semantic_anchor

gc.collect()

torch._C._cuda_clearCublasWorkspaces()
torch.cuda.empty_cache()


In [5]:
# Check if everything's cleaned
allocated = torch.cuda.memory_allocated() / (1024 * 1024)
reserved = torch.cuda.memory_reserved() / (1024 * 1024)

print(f'Allocated: {allocated:.1f} MB; Reserved: {reserved:.1f} MB')

Allocated: 0.0 MB; Reserved: 0.0 MB


In [6]:
class BertEmbeddingDecoder(nn.Module):
    def __init__(self, hidden_size=768, sequence_length=16):
        super().__init__()

        config = BertConfig(
            hidden_size=hidden_size,
            num_hidden_layers=4,
            num_attention_heads=8,
            intermediate_size=4 * hidden_size,
            max_position_embeddings=sequence_length,
            dtype=dtype
        )

        self.bert = BertModel(config).to(device)

        # Token-wise binary classification head
        self.classifier = nn.Linear(hidden_size, 1, device=device, dtype=torch.float32)

    def forward(self, inputs_embeds, attention_mask=None, labels=None):
        """
        inputs_embeds: (B, N, D)
        attention_mask: (B, N) optional
        labels: (B, N) optional
        """

        outputs = self.bert(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
        )

        # (B, N, D)
        hidden_states = outputs.last_hidden_state

        # (B, N, 1)
        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits.squeeze(-1), labels.float())

        return {
            "loss": loss,
            "logits": logits,
        }


In [7]:
# Decoding head that uses hidden states of an LM network to perform binary classification
# instead of outputting text
class BitDecoder(nn.Module):
    def __init__(self, decoder_model, decoder_pad_token, bit_len):
        super().__init__()
        self.decoder = decoder_model
        self.decoder_pad_token = decoder_pad_token
        self.bit_len = bit_len

        self.bit_head = nn.Linear(
            decoder_model.config.hidden_size,
            2,  # Binary classification, so one logit per value
            dtype=dtype
        ).to(device)
        
        self.decoder_input_ids = torch.full(
            (1, self.bit_len),
            self.decoder_pad_token,
            dtype=torch.long
        ).to(device)

    def forward(self, **kwargs):
        outputs = self.decoder(
            **kwargs,
            decoder_input_ids=self.decoder_input_ids,
            output_hidden_states=True,
            return_dict=True
        )

        # Use last hidden state from LM
        h = outputs.decoder_hidden_states[-1]  # [B, T, d]

        # Pool or select fixed positions
        h_bits = h[:, :self.bit_len, :]  # [B, bit_len, d]

        logits = self.bit_head(h_bits)   # [B, bit_len, 2]
        return logits
    
    def decode(self, inputs):
        # Decode
        bit_logits = self(input_ids=inputs)
        bits = bit_logits.argmax(dim=-1)

        return bits[0, :self.bit_len]

In [8]:
#GENERATOR_MODEL = "Qwen/Qwen3-0.6B"
#GENERATOR_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
#GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
GENERATOR_MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct"
#DECODER_MODEL = "google/flan-t5-base"

generator_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
#decoder_tokenizer = T5Tokenizer.from_pretrained(DECODER_MODEL)

generator = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL, dtype=dtype).to(device)
decoder = BertEmbeddingDecoder(hidden_size=generator.config.hidden_size, sequence_length=SECRET_LENGTH)
#decoder = T5ForConditionalGeneration.from_pretrained(DECODER_MODEL, dtype=dtype).to(device)

semantic_anchor = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL, dtype=dtype).to(device)

# Freeze base model
for p in semantic_anchor.parameters():
    p.requires_grad = False

generator.train()
decoder.train()
semantic_anchor.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (r

In [9]:
# Reprojection layer to map generator to decoder embeddings
#proj_layer = torch.nn.Linear(generator.config.hidden_size, decoder.config.hidden_size, device=device, dtype=dtype)

# Bit decoder will map decoder logits to binary secret
#bit_decoder = BitDecoder(decoder, decoder_tokenizer.pad_token_type_id, SECRET_LENGTH)

In [10]:
def build_generator_prompt(carrier, secret_str):
    messages = [{"role": "user", "content": f"{carrier}. Secret: {secret_str}"}]
    return generator_tokenizer.apply_chat_template(messages, tokenize=False)

In [11]:
def generator_forward(temperature=1.0, **kwargs):
    out = generator(
        **kwargs,
        return_dict=True,
        use_cache=True,
        eos_token_id=None, # prevent early stop
    )

    logits = out.logits[:, -1, :]   # [B, V]
    probs  = torch.softmax(logits / temperature, dim=-1)

    return logits, probs, out.past_key_values


In [12]:
def soft_embeddings(token_probs):
    embed_matrix = generator.get_input_embeddings().weight  # [V, d]
    return token_probs @ embed_matrix

In [13]:
@torch.no_grad()
def anchor_forward(anchor, inputs) -> torch.Tensor:
    outputs = anchor(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        return_dict=True
    )
    logits = outputs.logits[:, :SECRET_LENGTH, :]
    return logits


In [14]:
def kl_semantic_loss(stego_logits, base_logits):
    p = F.log_softmax(stego_logits, dim=-1)
    q = F.softmax(base_logits, dim=-1)

    # Masked?
    # kl = F.kl_div(log_p, q, reduction="none")  # [B, T, V]
    # kl = kl.sum(dim=-1)                        # [B, T]

    # # Mask padding tokens
    # kl = kl * attention_mask

    # return kl.sum() / attention_mask.sum()

    return F.kl_div(p, q, reduction="batchmean")

In [15]:
def bit_loss(logits, bits):
    """
    logits: [B, bit_len, 2]
    bits:   [B, bit_len]  (0/1)
    """
    return F.cross_entropy(
        logits.view(-1, 2),
        bits.view(-1)
    )

In [16]:
def decoder_forward(soft_embeds, secret_bits):
    out = decoder(inputs_embeds=soft_embeds, labels=secret_bits)

    #loss = bit_loss(logits, secret_bits)

    return out['loss']

In [17]:
@torch.no_grad()
def encode(carrier, secret):
    prompt = build_generator_prompt(carrier, secret)
    inputs = generator_tokenizer(prompt, return_tensors="pt").to(device)
    ids = generator.generate(
        **inputs,
        max_new_tokens=SECRET_LENGTH + 1, # +1 seems necessary somehow
        eos_token_id=None, # prevent early stop
        do_sample=False,
    )

    print(ids.shape)
    return generator_tokenizer.decode(ids[0, -SECRET_LENGTH:], skip_special_tokens=True)


In [18]:
@torch.no_grad()
def decode(stego_text):
    encoded_token_ids = generator_tokenizer(stego_text, return_tensors="pt").input_ids.to(device)

    # Encoder embedding matrix
    embedding_layer = generator.get_input_embeddings()

    # (B, N, D)
    encoded_embeds = embedding_layer(encoded_token_ids).to(device)

    #attention_mask = torch.ones(encoded_embeds.shape[:2])

    out = decoder(
        inputs_embeds=encoded_embeds
        #attention_mask=attention_mask,
    )

    logits = out["logits"]      # (B, N, 1)
    preds = (logits.squeeze(-1) > 0).long()
    
    return preds[0]


In [19]:
def bits_to_str(bits):
    return "".join(str(b) for b in bits.tolist())

In [20]:
optimizer = torch.optim.AdamW(
    list(generator.parameters()) + 
    list(decoder.parameters()),
    #list(proj_layer.parameters()) + 
    #list(bit_decoder.parameters()),
    lr=3e-5
)

for step in range(1000):
    carrier_prompt = "Explain how neural networks learn."
    secret_bits = torch.randint(0, 2, (1, SECRET_LENGTH))   # [ B, N ]
    secret_str = bits_to_str(secret_bits.flatten())

    prompt = build_generator_prompt(carrier_prompt, secret_str)
    generator_inputs = generator_tokenizer(prompt, return_tensors="pt").to(device)
    past_kv = None

    # Generate N tokens, store embeddings for each token
    generated_embeds = []
    for t in range(SECRET_LENGTH):
        stego_logits, token_probs, past_kv = generator_forward(
            input_ids=generator_inputs.input_ids if t == 0 else None,
            inputs_embeds=soft_embeds.unsqueeze(1) if t > 0 else None,
            past_key_values=past_kv
        )

        soft_embeds = soft_embeddings(token_probs)  # [ B, D ]
        generated_embeds.append(soft_embeds)

    # Stack embeddings for each token to [ B, N, D ]
    encoder_outputs_embeds = torch.stack(generated_embeds, dim=1).to(device)

    # Linear projection to decoder embedding space
    #decoder_inputs = proj_layer(soft_embeds)     # [B, T, d_dec]

    # Calculate recovery loss based on decoding success
    recovery_loss = decoder_forward(encoder_outputs_embeds, secret_bits.to(device))

    # Calculate semantic loss compared to base model
    #base_logits = anchor_forward(semantic_anchor, generator_inputs)
    #semantic_loss = kl_semantic_loss(stego_logits, base_logits)
    semantic_loss = np.zeros(1)

    # Combine loss
    loss = recovery_loss
    #loss = 1.0 * recovery_loss + 0.0 * semantic_loss

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    if step % 1 == 0:
        stego_text = encode(carrier_prompt, secret_str)
        decoded_secret_bits = decode(stego_text)
        decoded_secret_str = bits_to_str(decoded_secret_bits)

        print(f"Step {step} | L_sem: {semantic_loss.item():.3f} | L_rec: {recovery_loss.item():.3f}")
        print(f'Secret:  {secret_str}')
        print(f'Decoded: {decoded_secret_str}')
        print(f'Stego:   {stego_text}')
        print()


torch.Size([1, 68])
Step 0 | L_sem: 0.000 | L_rec: 0.770
Secret:  1000001100110111
Decoded: 1101110011101000
Stego:   assistant
Neural networks learn by iteratively updating the weights and biases of

torch.Size([1, 68])
Step 1 | L_sem: 0.000 | L_rec: 0.747
Secret:  0000001000101110
Decoded: 0100010011101000
Stego:   assistant
Neural networks learn by iteratively updating the weights and biases of

torch.Size([1, 68])
Step 2 | L_sem: 0.000 | L_rec: 0.589
Secret:  0000101001001101
Decoded: 0000010001100000
Stego:   assistant
Neural networks learn by iteratively updating the weights and biases of

torch.Size([1, 68])
Step 3 | L_sem: 0.000 | L_rec: 0.723
Secret:  1010111011001111
Decoded: 0000010001100000
Stego:   assistant
Neural networks learn by iteratively updating the weights and biases of

torch.Size([1, 68])
Step 4 | L_sem: 0.000 | L_rec: 0.737
Secret:  0011101111000101
Decoded: 0000010011101000
Stego:   assistant
Neural networks learn by iteratively updating the weights and biases

KeyboardInterrupt: 

In [None]:
carrier = "Write a short paragraph explaining how neural networks learn."
secret = "1011001110001111"

stego_text = encode(carrier, secret)
recovered_secret = decode(stego_text)

print("Stego text:")
print(stego_text)
print()
print("Original secret:", secret)
print("Recovered secret:", bits_to_str(recovered_secret))
