<a href="https://colab.research.google.com/github/Mert-Keskin/Text-Steganography-Methods-Based-on-BERT-Character-Based-Fixed-Bit-and-Huffman-Coding-Approaches/blob/main/Text_Steganography_Methods_Based_on_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This project is part of my undergraduate thesis completed at Istanbul Medeniyet University, Department of Computer Engineering, and provides a detailed explanation of the methodology and results.

In [None]:
import torch
import random
from transformers import BertTokenizer, BertForMaskedLM
import hashlib
import heapq
import torch.nn.functional as F

# Charachter_Based_class

In [None]:
class Character_Based_Stego:
    def __init__(self):
        # === Load BERT model ===
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
        self.model = BertForMaskedLM.from_pretrained("bert-base-cased")
        self.model.eval()

        # === Configurations ===
        self.HEADER_BITS = 12 + 12 + 6
        self.BLOCK_SIZE = 8
        self.PREDICTION_SIZE = 257
        self.HALF_WINDOW = 5
        self.SKIP_INDEX = 0
        self.MAX_LOOP_INDEX = 3
        self.skip_counter = 0
        self.preds_index_counter = 0

    def random_hash(self, i, seed, mod):
        input_str = f"{seed}_{i}".encode()
        hash_bytes = hashlib.sha256(input_str).digest()
        hash_int = int.from_bytes(hash_bytes, 'big')
        return hash_int % mod

    def char_to_bin(self, c):
        return format(ord(c), f"0{self.BLOCK_SIZE}b")

    def bin_to_char(self, b):
        return chr(int(b, 2))

    def int_to_bin(self, n, bits):
        return format(n, f"0{bits}b")

    def bin_to_int(self, b):
        return int(b, 2)

    def get_prediction_list(self, masked_text):
        tokens = self.tokenizer.tokenize(masked_text)
        indexed = self.tokenizer.convert_tokens_to_ids(tokens)
        mask_index = tokens.index('[MASK]')

        with torch.no_grad():
            input_ids = torch.tensor([indexed])
            outputs = self.model(input_ids)
            predictions = outputs.logits[0, mask_index]
            top_preds = torch.topk(predictions, self.PREDICTION_SIZE).indices.tolist()
            return self.tokenizer.convert_ids_to_tokens(top_preds)

    def create_window(self, words, index):
        start = max(0, index - self.HALF_WINDOW)
        end = min(len(words), index + self.HALF_WINDOW + 1)
        return words[start:end], index - start

    def encode_header(self, seed, char_count, loop_index):
        return self.int_to_bin(seed, 12) + self.int_to_bin(char_count, 12) + self.int_to_bin(loop_index, 6)

    def decode_header(self, bin_data):
        seed = self.bin_to_int(bin_data[0:12])
        count = self.bin_to_int(bin_data[12:24])
        loop_index = self.bin_to_int(bin_data[24:30])
        return seed, count, loop_index

    def hide_bits_in_text(self, words, binary_data, start_idx):
        stego_words = words[:]
        for i in range(0, len(binary_data), 8):
            block = binary_data[i:i+8].ljust(8, '0')
            index = start_idx + (i // 8) * (self.HALF_WINDOW + 1)
            if index >= len(words):
                break
            window, mask_pos = self.create_window(stego_words, index)
            window[mask_pos] = '[MASK]'
            masked_text = ' '.join(window)
            preds = self.get_prediction_list(masked_text)
            pred_index = int(block, 2)
            chosen_word = preds[pred_index] if pred_index < len(preds) else preds[self.SKIP_INDEX]
            print(index)
            stego_words[index] = chosen_word
            print("header words: {}".format(chosen_word))
        print(stego_words)
        return stego_words

    def hide_message(self, cover_text, secret_msg, seed, loop_index):
        words = cover_text.split()
        binary_message = ''.join([self.char_to_bin(c) for c in secret_msg])
        header_bin = self.encode_header(seed, len(secret_msg), loop_index)

        header_block_count = (len(header_bin) + 7) // 8
        reserved_range_end = self.HALF_WINDOW + header_block_count * (self.HALF_WINDOW + 1)
        print("reversed: {}".format(reserved_range_end))

        print("[ENCODE] Header binary:", header_bin)
        print(f"[ENCODE] seed={seed}, char_count={len(secret_msg)}, loop_index={loop_index}")

        stego = self.hide_bits_in_text(words, header_bin, start_idx=self.HALF_WINDOW + 1)

        current_loop_index = loop_index
        used_indices = set()
        hidden = 0
        i = 0
        self.preds_index_counter = 0
        self.skip_counter = 0

        while hidden < len(secret_msg):
            idx = self.random_hash(i, seed, len(words))
            if idx <= reserved_range_end or any(abs(idx - u) <= self.HALF_WINDOW for u in used_indices):
                i += 1
                continue

            if idx < len(words):
                window, mask_pos = self.create_window(stego, idx)
                window[mask_pos] = '[MASK]'
                masked_text = ' '.join(window)
                preds = self.get_prediction_list(masked_text)
                letter = secret_msg[hidden]
                print(f"\n🔐 Embedding character '{letter}' at index {idx}")
                print(f"Top predictions: {preds[:5]}")
                found = False
                for j in range(1, len(preds)):
                    if len(preds[j]) > current_loop_index and preds[j][current_loop_index] == letter:
                        self.preds_index_counter += j
                        print(f"✅ Match at position {j}: '{preds[j]}'")
                        stego[idx] = preds[j]
                        found = True
                        break

                if found:
                    hidden += 1
                else:
                    self.skip_counter += 1
                    print("skip index chosen: {}".format(preds[self.SKIP_INDEX]))
                    stego[idx] = preds[self.SKIP_INDEX]
            i += 1
            used_indices.add(idx)
            current_loop_index = (current_loop_index + 1) % self.MAX_LOOP_INDEX
        print("total travel words: {}".format(len(used_indices)))
        print("total travel words with header: {}".format(len(used_indices) + 4))
        print("skipped: {}".format(self.skip_counter))
        print("avarage {}".format(self.preds_index_counter / (len(used_indices) - self.skip_counter)))
        return ' '.join(stego), len(used_indices) + 4 - self.skip_counter, self.skip_counter

    def extract_bits_from_text(self, words, bit_count, start_idx):
        bits = ''
        for i in range(0, bit_count, 8):
            index = start_idx + (i // 8) * (self.HALF_WINDOW + 1)
            if index >= len(words):
                break
            window, mask_pos = self.create_window(words, index)
            original_word = words[index]
            window[mask_pos] = '[MASK]'
            masked_text = ' '.join(window)
            preds = self.get_prediction_list(masked_text)
            try:
                if original_word not in preds:
                    print(f"  ⚠️ '{original_word}' NOT found in predictions.")
                b = self.int_to_bin(preds.index(original_word), 8)
            except ValueError:
                b = self.int_to_bin(0, 8)
            bits += b
        return bits

    def extract_message(self, stego_text):
        words = stego_text.split()
        header_bits = self.extract_bits_from_text(words, self.HEADER_BITS, start_idx=self.HALF_WINDOW + 1)
        print("[DECODE] Extracted header binary:", header_bits)

        seed, count, loop_index = self.decode_header(header_bits)
        print(f"[DECODE] Decoded header: seed={seed}, char_count={count}, loop_index={loop_index}")

        extracted = ''
        current_loop_index = loop_index
        used_indices = set()
        i = 0
        found = 0

        while found < count:
            idx = self.random_hash(i, seed, len(words))
            header_block_count = (self.HEADER_BITS + 7) // 8
            reserved_range_end = self.HALF_WINDOW + header_block_count * (self.HALF_WINDOW + 1)

            if idx <= reserved_range_end or any(abs(idx - u) <= self.HALF_WINDOW for u in used_indices) or idx >= len(words):
                i += 1
                continue

            window, mask_pos = self.create_window(words, idx)
            original_word = words[idx]
            window[mask_pos] = '[MASK]'
            masked_text = ' '.join(window)
            preds = self.get_prediction_list(masked_text)
            print("-------------------")
            print(masked_text)
            print(f"Top predictions: {preds[:5]}")
            print("-------------------")

            if original_word == preds[self.SKIP_INDEX]:
                print("came across a skip word: {}".format(preds[self.SKIP_INDEX]))
                i += 1
                current_loop_index = (current_loop_index + 1) % self.MAX_LOOP_INDEX
                used_indices.add(idx)
                continue

            if len(original_word) > current_loop_index:
                extracted += original_word[current_loop_index]
                print("✅ original word: {}  letter: {}".format(original_word, original_word[current_loop_index]))
                found += 1

            i += 1
            used_indices.add(idx)
            current_loop_index = (current_loop_index + 1) % self.MAX_LOOP_INDEX

        print("total traveled words: {}".format(len(used_indices)))
        print("total traveled words with header: {}".format(len(used_indices) + 4))
        print("skipped: {}".format(self.skip_counter))
        print("avarage {}".format(self.preds_index_counter / (len(used_indices) - self.skip_counter)))
        return extracted


# Flexible Fixed Bit

In [None]:
class Fixed_Bit_Stego:
    def __init__(self, bit_width=2):
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
        self.model = BertForMaskedLM.from_pretrained("bert-base-cased")
        self.model.eval()

        self.BLOCK_SIZE = 8
        self.HEADER_BITS = 24
        self.HALF_WINDOW = 5
        self.bit_width = bit_width
        self.message_changes = 0
        self.total_changes = 0
        self.header_changes = 0

    def squarehash(self, i, seed, mod):
        input_str = f"{seed}_{i}".encode()
        hash_bytes = hashlib.sha256(input_str).digest()
        hash_int = int.from_bytes(hash_bytes, 'big')
        return hash_int % mod

    def int_to_bin(self, n, bits):
        return format(n, f"0{bits}b")

    def bin_to_int(self, b):
        return int(b, 2)

    def text_to_bits(self, text):
        return ''.join(format(ord(c), '08b') for c in text)

    def bits_to_text(self, bits):
        chars = [chr(int(bits[i:i+8], 2)) for i in range(0, len(bits), 8)]
        return ''.join(chars)

    def get_predictions(self, text):
        tokens = self.tokenizer.tokenize(text)
        indexed = self.tokenizer.convert_tokens_to_ids(tokens)
        mask_index = tokens.index('[MASK]')
        with torch.no_grad():
            input_ids = torch.tensor([indexed])
            outputs = self.model(input_ids)
            logits = outputs.logits[0, mask_index]
            top_preds = torch.topk(logits, 2**self.bit_width).indices.tolist()
            return self.tokenizer.convert_ids_to_tokens(top_preds)

    def create_window(self, words, idx):
        start = max(0, idx - self.HALF_WINDOW)
        end = min(len(words), idx + self.HALF_WINDOW + 1)
        return words[start:end], idx - start

    def encode_header(self, seed, bit_len):
        return self.int_to_bin(seed, 12) + self.int_to_bin(bit_len, 12)

    def decode_header(self, bits):
        return self.bin_to_int(bits[:12]), self.bin_to_int(bits[12:24])

    def hide_header(self, words, seed, bit_len):
        header_bin = self.encode_header(seed, bit_len)
        changed = 0
        for i in range(0, len(header_bin), self.bit_width):
            block = header_bin[i:i+self.bit_width].ljust(self.bit_width, '0')
            idx = self.HALF_WINDOW + 1 + (i // self.bit_width) * (self.HALF_WINDOW + 1)
            print(f"[HEADER] Index: {idx}, Bits: {block}")
            window, mask_pos = self.create_window(words, idx)
            window[mask_pos] = '[MASK]'
            preds = self.get_predictions(' '.join(window))
            value = int(block, 2)
            selected = preds[value] if value < len(preds) else preds[0]
            print(f"  Prediction list: {preds}")
            print(f"  Selected word: {selected}")
            words[idx] = selected
            changed += 1
        print(f"[HEADER] Total header words changed: {changed}")
        return words, changed

    def hide_message(self, cover_text, secret_msg, seed):
        words = cover_text.split()
        bits = self.text_to_bits(secret_msg)
        words, self.header_changes = self.hide_header(words, seed, len(bits))
        used = set()

        header_block_count = self.HEADER_BITS // self.bit_width
        reserved_range_end = self.HALF_WINDOW + header_block_count * (self.HALF_WINDOW + 1)
        print("[INFO] Reserved header region ends at:", reserved_range_end)

        i = 0
        hidden = 0
        while hidden < len(bits):
            idx = self.squarehash(i, seed, len(words))

            if idx <= reserved_range_end or any(abs(idx - u) <= self.HALF_WINDOW for u in used):
                i += 1
                continue

            block = bits[hidden:hidden+self.bit_width].ljust(self.bit_width, '0')
            window, mask_pos = self.create_window(words, idx)
            window[mask_pos] = '[MASK]'
            preds = self.get_predictions(' '.join(window))
            val = int(block, 2)
            selected = preds[val] if val < len(preds) else preds[0]
            print(f"[MESSAGE] Index: {idx}, Bits: {block}")
            print(f"  Prediction list: {preds}")
            print(f"  Selected word: {selected}")
            words[idx] = selected
            hidden += self.bit_width
            used.add(idx)
            self.message_changes += 1
            i += 1
        self.total_changes = self.header_changes + self.message_changes
        print(f"[INFO] Total message words changed: {self.message_changes}")
        print(f"[INFO] Total words changed: {self.total_changes}")
        return ' '.join(words), self.total_changes

    def extract_header(self, words):
        bits = ''
        for i in range(self.HEADER_BITS // self.bit_width):
            idx = self.HALF_WINDOW + 1 + i * (self.HALF_WINDOW + 1)
            window, mask_pos = self.create_window(words, idx)
            original = words[idx]
            window[mask_pos] = '[MASK]'
            preds = self.get_predictions(' '.join(window))
            if original in preds:
                bits += self.int_to_bin(preds.index(original), self.bit_width)
            else:
                bits += '0' * self.bit_width
        return self.decode_header(bits)

    def extract_message(self, stego_text):
        words = stego_text.split()
        seed, bit_len = self.extract_header(words)
        print(seed)
        print(bit_len)
        bits = ''

        header_block_count = self.HEADER_BITS // self.bit_width
        reserved_range_end = self.HALF_WINDOW + header_block_count * (self.HALF_WINDOW + 1)
        print("reserved range: {}".format(reserved_range_end))

        i = 0
        used = set()
        while len(bits) < bit_len:
            idx = self.squarehash(i, seed, len(words))

            if idx <= reserved_range_end or any(abs(idx - u) <= self.HALF_WINDOW for u in used):
                i += 1
                continue

            window, mask_pos = self.create_window(words, idx)
            original = words[idx]
            window[mask_pos] = '[MASK]'
            preds = self.get_predictions(' '.join(window))
            if original in preds:
                bits += self.int_to_bin(preds.index(original), self.bit_width)
            else:
                print("Error!")
                bits += '0' * self.bit_width
            used.add(idx)
            i += 1
        print(f"[INFO] Total message words changed: {self.message_changes}")
        print(f"[INFO] Total words changed: {self.total_changes}")
        return self.bits_to_text(bits[:bit_len])


# Huffman with class

In [None]:
class Huffman_Stego:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
        self.model = BertForMaskedLM.from_pretrained("bert-base-cased")
        self.model.eval()

        self.BLOCK_SIZE = 8
        self.HEADER_BITS = 24
        self.HALF_WINDOW = 5
        self.STRIDE = 2 * self.HALF_WINDOW + 1
        self.TOP_K = 16
        self.changes = 0
        self.header_changes = 0
        self.message_changes = 0
        self.total_changes = 0

    def random_hash(self, i, seed, mod):
        input_str = f"{seed}_{i}".encode()
        hash_bytes = hashlib.sha256(input_str).digest()
        hash_int = int.from_bytes(hash_bytes, 'big')
        return hash_int % mod

    def int_to_bin(self, n, bits):
        return format(n, f"0{bits}b")

    def bin_to_int(self, b):
        return int(b, 2)

    def text_to_bits(self, text):
        return ''.join(format(ord(c), '08b') for c in text)

    def bits_to_text(self, bits):
        chars = [chr(int(bits[i:i+8], 2)) for i in range(0, len(bits), 8)]
        return ''.join(chars)

    def get_predictions(self, text):
        tokens = self.tokenizer.tokenize(text)
        indexed = self.tokenizer.convert_tokens_to_ids(tokens)
        mask_index = tokens.index('[MASK]')
        with torch.no_grad():
            input_ids = torch.tensor([indexed])
            outputs = self.model(input_ids)
            logits = outputs.logits[0, mask_index]
            probs = F.softmax(logits, dim=0)
            top_probs, top_indices = torch.topk(probs, self.TOP_K)
            return self.tokenizer.convert_ids_to_tokens(top_indices.tolist()), top_probs.tolist()

    def create_window(self, words, idx):
        start = max(0, idx - self.HALF_WINDOW)
        end = min(len(words), idx + self.HALF_WINDOW + 1)
        return words[start:end], idx - start

    class Node:
        def __init__(self, word, prob):
            self.word = word
            self.prob = prob
            self.left = None
            self.right = None

        def __lt__(self, other):
            return self.prob < other.prob

    def build_huffman(self, words, probs):
        heap = [self.Node(w, p) for w, p in zip(words, probs)]
        heapq.heapify(heap)
        while len(heap) > 1:
            n1 = heapq.heappop(heap)
            n2 = heapq.heappop(heap)
            merged = self.Node(None, n1.prob + n2.prob)
            merged.left = n1
            merged.right = n2
            heapq.heappush(heap, merged)
        root = heap[0]
        encoding = {}

        def traverse(node, path):
            if node.word is not None:
                encoding[node.word] = path
            if node.left:
                traverse(node.left, path + '0')
            if node.right:
                traverse(node.right, path + '1')

        traverse(root, '')
        return encoding

    def encode_header(self, seed, bit_len):
        return self.int_to_bin(seed, 12) + self.int_to_bin(bit_len, 12)

    def decode_header(self, bits):
        return self.bin_to_int(bits[:12]), self.bin_to_int(bits[12:24])

    def hide_header(self, words, seed, bit_len):
        header_bin = self.encode_header(seed, bit_len) + '0' * 0
        print("Bits: {}".format(header_bin))
        hidden = 0
        i = 0
        self.changes = 0
        len_header_bin = len(header_bin)
        while hidden < len_header_bin:
            idx = self.HALF_WINDOW+1 + (i * (self.HALF_WINDOW+1))
            if idx >= len(words):
                break

            window, mask_pos = self.create_window(words, idx)
            window[mask_pos] = '[MASK]'
            preds, probs = self.get_predictions(' '.join(window))
            huff = self.build_huffman(preds, probs)

            print(f"\n🔐 Hiding Header Block at index {idx}")
            print(f"Header bits remaining: {header_bin[hidden:hidden+10]}")
            print(f"Huffman codes: {huff}")

            matched = False
            for w, code in huff.items():
                remaining = len(header_bin) - hidden
                if len(code) > remaining:
                    header_bin += '0' * 4
                    print("Came to an end and padded.")
                if header_bin.startswith(code, hidden):
                    words[idx] = w
                    print(f"✅ Selected word '{w}' for code '{code}'")
                    hidden += len(code)
                    matched = True
                    self.changes += 1
                    break
            if not matched:
                print(f"⚠️ No exact match.")
            i += 1

        print(f"[INFO] Total header words changed: {self.changes}")
        return words, self.changes

    def hide_message(self, cover_text, secret_msg, seed):
        words = cover_text.split()
        bits = self.text_to_bits(secret_msg) + '0' * 0
        words, self.header_changes = self.hide_header(words, seed, len(bits)-0)
        used = set()
        i = 0
        hidden = 0
        self.message_changes = 0
        len_bits = len(bits)
        while hidden < len_bits:
            idx = self.random_hash(i, seed, len(words))
            if idx < (self.HEADER_BITS // 2 * (self.HALF_WINDOW+1)) or any(abs(idx - u) <= self.HALF_WINDOW for u in used):
                i += 1
                continue

            window, mask_pos = self.create_window(words, idx)
            window[mask_pos] = '[MASK]'
            preds, probs = self.get_predictions(' '.join(window))
            huff = self.build_huffman(preds, probs)

            print(f"\n📦 Hiding Secret Message Block at index {idx}")
            print(f"Bits remaining: {bits[hidden:hidden+10]}")
            print(f"Huffman codes: {huff}")

            matched = False
            for w, code in huff.items():
                remaining = len(bits) - hidden
                if len(code) > remaining:
                    bits+= '0'*4
                    print("Came to an end and padded.")
                if bits.startswith(code, hidden):
                    words[idx] = w
                    print(f"✅ Selected word '{w}' for code '{code}'")
                    hidden += len(code)
                    matched = True
                    self.message_changes += 1
                    break
            if not matched:
                print(f"⚠️ No match.")
            used.add(idx)
            i += 1
        self.total_changes = self.header_changes + self.message_changes
        return ' '.join(words), self.total_changes

    def extract_header(self, words):
        bits = ''
        i = 0

        while len(bits) < self.HEADER_BITS:
            idx = self.HALF_WINDOW+1 + (i * (self.HALF_WINDOW+1))
            if idx >= len(words):
                break

            window, mask_pos = self.create_window(words, idx)
            original = words[idx]
            window[mask_pos] = '[MASK]'
            preds, probs = self.get_predictions(' '.join(window))
            huff = self.build_huffman(preds, probs)

            print(f"\n🧩 Extracting Header at index {idx}")
            print(f"Original word: {original}")
            print(f"Huffman codes: {huff}")

            if original in huff:
                bits += huff[original]
            else:
                bits += '00'
                print(f"⚠️ '{original}' not found, defaulting to '00'")

            i += 1
        return self.decode_header(bits)

    def extract_message(self, stego_text):
        words = stego_text.split()
        seed, bit_len = self.extract_header(words)
        print(f"\n[HEADER DECODED] seed={seed}, bit_len={bit_len}\n")

        bits = ''
        i = 0
        used = set()

        while len(bits) < bit_len:
            idx = self.random_hash(i, seed, len(words))

            if idx < (self.HEADER_BITS // 2 * (self.HALF_WINDOW+1)) or any(abs(idx - u) <= self.HALF_WINDOW for u in used):
                i += 1
                continue

            window, mask_pos = self.create_window(words, idx)
            original = words[idx]
            window[mask_pos] = '[MASK]'
            preds, probs = self.get_predictions(' '.join(window))
            huff = self.build_huffman(preds, probs)

            print(f"\n🧩 Extracting Message at index {idx}")
            print(f"Original word: {original}")
            print(f"Huffman codes: {huff}")

            if original in huff:
                bits += huff[original]
            else:
                bits += '00'
                print(f"⚠️ '{original}' not found, defaulting to '00'")
            used.add(idx)
            i += 1
        print(f"\n[INFO] Total message words changed: {self.message_changes}")
        print(f"[INFO] Total words changed (header + message): {self.total_changes}")
        return self.bits_to_text(bits[:bit_len])


# Testing Huffman

In [None]:
cover_text = """What is steganography?
Steganography is the technique of hiding data within an ordinary, nonsecret file or message to avoid detection; the hidden data is then extracted at its destination. Steganography use can be combined with encryption as an extra step for hiding or protecting data. The word steganography is derived from the Greek word steganos, meaning "hidden or covered," and the Greek root graph, meaning "to write."

Steganography can be used to conceal almost any type of digital content, including text, image, video or audio content. The secret data can be hidden inside almost any other type of digital content. The content to be concealed through steganography -- called hidden text -- is often encrypted before being incorporated into the innocuous-seeming cover text file or data stream. If not encrypted, the hidden text is commonly processed in some method to increase the difficulty of detecting the secret content.

What are some examples of steganography?
Steganography is practiced by those wishing to convey a secret message or code. While there are many legitimate uses for steganography, some malware developers use steganography to obscure the transmission of malicious code -- known as stegware.

Forms of steganography have been used for centuries and include almost any technique for hiding a secret message in an otherwise harmless container. For example, using invisible ink to hide secret messages in otherwise inoffensive messages; hiding documents recorded on microdot, which can be as small as 1 millimeter in diameter; hiding messages on or inside legitimate-seeming correspondence; and even using multiplayer gaming environments to share information.
"""
secret_text = "merhaba"
huff = Huffman_Stego()

In [None]:
stego, changed = huff.hide_message(cover_text, secret_text, seed=1234)


In [None]:
hidden = huff.extract_message(stego)

In [None]:
print(hidden)

# Evaluation

In [None]:
!pip install transformers torch scikit-learn --quiet
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

# === 1. Embedding Rate (ER) ===
def compute_er(num_bits, total_words):
    return num_bits / total_words

# === 2. Kullback-Leibler Divergence (KLD) ===
def compute_kld(text_p, text_q):
    vectorizer = CountVectorizer()
    X = vectorizer.fit_transform([text_p, text_q]).toarray().astype(float)
    P, Q = X[0], X[1]
    P /= P.sum()
    Q /= Q.sum()
    epsilon = 1e-10
    P = np.clip(P, epsilon, 1)
    Q = np.clip(Q, epsilon, 1)
    return np.sum(P * np.log(P / Q))

# === 3. Perplexity (PPL) using GPT-2 ===

def compute_ppl(text, max_length=1024):
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    model.eval()

    tokens = tokenizer.encode(text)
    if len(tokens) <= max_length:
        input_ids = torch.tensor([tokens])
        with torch.no_grad():
            loss = model(input_ids, labels=input_ids).loss
            return torch.exp(loss).item()

    # Split long text into chunks
    stride = 512
    losses = []
    for i in range(0, len(tokens) - 1, stride):
        input_ids = torch.tensor([tokens[i:i + max_length]])
        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)
            losses.append(outputs.loss.item())

    mean_loss = sum(losses) / len(losses)
    return math.exp(mean_loss)

# === 4. Semantic Similarity (SIM) ===
def compute_similarity(text1, text2):
    vectorizer = CountVectorizer().fit([text1, text2])
    vectors = vectorizer.transform([text1, text2]).toarray()
    return cosine_similarity([vectors[0]], [vectors[1]])[0][0]


# Run Multiple Samples

In [None]:
with open("secret_message.txt", "r", encoding="utf-8") as f:
    secret_messages = [line.strip() for line in f if line.strip()]

with open("cover_text.txt", "r", encoding="utf-8") as f:
    cover_text = f.read()

In [None]:
cover_text = cover_text

In [None]:
accumulated_results = {algo: [] for algo in ['Character', 'Fixed_Bit_2', 'Fixed_Bit_4', 'Huffman']}

for secret_text in secret_messages:
    for StegoBuilder, name in [
        (lambda: Character_Based_Stego(), "Character"),
        (lambda: Fixed_Bit_Stego(bit_width=2), "Fixed_Bit_2"),
        (lambda: Fixed_Bit_Stego(bit_width=4), "Fixed_Bit_4"),
        (lambda: Huffman_Stego(), "Huffman")
    ]:
        algo = StegoBuilder()
        if name == "Character":
            stego, changed, skipped = algo.hide_message(cover_text, secret_text, seed=1234, loop_index=0)
            header_size = 32
        else:
            stego, changed = algo.hide_message(cover_text, secret_text, seed=1234)
            header_size = 24
            skipped = 0

        extracted = algo.extract_message(stego)


        ppl_cover = compute_ppl(cover_text)
        ppl_stego = compute_ppl(stego)
        kld_result = compute_kld(cover_text, stego)
        sim_result = compute_similarity(cover_text, stego)
        metrics = {
            'ChangedWords': changed,
            'SkippedWords': skipped,
            'EmbeddingRate': compute_er(len(secret_text) * 8+header_size, changed),
            'KLD': kld_result,
            'PPL_Cover': ppl_cover,
            'PPL_Stego': ppl_stego,
            'SemanticSim': sim_result,
            'PPL_per_word': (ppl_stego - ppl_cover) / changed if changed else 0,
            'KLD_per_word': kld_result / changed if changed else 0,
            'SIM_loss_per_word': (1 - sim_result) / changed if changed else 0,
            'Success': extracted == secret_text
        }
        accumulated_results[name].append(metrics)


In [None]:
import numpy as np

final_results = []

for algo_name, runs in accumulated_results.items():
    avg_result = {
        'Algorithm': algo_name,
        'AvgChangedWords': np.mean([r['ChangedWords'] for r in runs]),
        'SkippedWords': np.mean([r['SkippedWords'] for r in runs]),
        'AvgEmbeddingRate': np.mean([r['EmbeddingRate'] for r in runs]),
        'AvgKLD': np.mean([r['KLD'] for r in runs]),
        'AvgPPL_Cover': np.mean([r['PPL_Cover'] for r in runs]),
        'AvgPPL_Stego': np.mean([r['PPL_Stego'] for r in runs]),
        'AvgSemanticSim': np.mean([r['SemanticSim'] for r in runs]),
        'AvgPPL_PerWord': np.mean([r['PPL_per_word'] for r in runs]),
        'AvgKLD_PerWord': np.mean([r['KLD_per_word'] for r in runs]),
        'AvgSIM_Loss_PerWord': np.mean([r['SIM_loss_per_word'] for r in runs]),
        'SuccessRate': np.mean([r['Success'] for r in runs]),
        'Runs': len(runs)
    }
    final_results.append(avg_result)


In [None]:
import csv

with open("evaluation_averages.csv", "w", newline='') as f:
    writer = csv.DictWriter(f, fieldnames=final_results[0].keys())
    writer.writeheader()
    writer.writerows(final_results)
