In [7]:
import os, sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

In [3]:
from dataclasses import dataclass, field
from io import StringIO
from typing import List, Tuple, Union

from icecream import ic

import torch
import torch.nn.functional as F
from icecream import ic
from nltk.corpus import stopwords
from torch import Tensor
from transformers import BertForMaskedLM, BertTokenizer  # type: ignore
from transformers.tokenization_utils import PreTrainedTokenizer

from Stegasus.SemanticMasking import MaskGen


KeyboardInterrupt



In [None]:
MaskGen

<function Stegasus.SemanticMasking.SemanticMask.MaskGen(text: str, tokenizer=None)>

In [None]:
def getTokenizerAndModel(model_name_or_path: str = 'bert-base-cased'):
    tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
    temp = BertForMaskedLM.from_pretrained(model_name_or_path)
    assert isinstance(temp, BertForMaskedLM)
    model: BertForMaskedLM = temp
    return tokenizer, model
TOKENIZER, MODEL = getTokenizerAndModel()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# Meta-type for "numeric" things; matches our docs
Number = Union[int, float, bool]

@dataclass
class PreprocessedText:
  input_ids: torch.Tensor
  masked_ids: torch.Tensor
  sorted_output: Tuple[torch.Tensor, torch.Tensor]
  def __iter__(self):
    yield self.input_ids
    yield self.masked_ids
    yield self.sorted_output[0]
    yield self.sorted_output[1]
    

@dataclass
class MaskedStegoResult:
  encoded_text: str
  encoded_bytes: str
  remaining_bytes: str

class MaskedStego:
  """
  Examples
  masked_stego.decode("The quick red fox jumps over the poor dog.", 3, 0.01))
  masked_stego("The quick brown fox jumps over the lazy dog. and said boom you lazy dog stay back",'010101010101', 3, 0.01))
  """
  
  def __init__(self,tokenizer,model) -> None: 
    MaskedStego._STOPWORDS: List[str] = stopwords.words('english')
    self._model = model
    self._tokenizer = tokenizer

  def __call__(self, cover_text: str, message: str, mask_interval: int = 3, score_threshold: float = 0.01) -> MaskedStegoResult:
    assert set(message) <= set('01')
    message_io = StringIO(message)
    input_ids, masked_ids, sorted_score, indices = self._preprocess_text(cover_text, mask_interval)
    for i_token, token in enumerate(masked_ids):
      if token != self._tokenizer.mask_token_id:
        continue
      ids = indices[i_token]
      scores = sorted_score[i_token]
      candidates = self._pick_candidates_threshold(ids, scores, score_threshold)
      if len(candidates) < 2:
        continue
      replace_token_id = self._block_encode_single(candidates, message_io).item()
      input_ids[i_token] = replace_token_id
    encoded_message: str = message_io.getvalue()[:message_io.tell()]
    message_io.close()
    stego_text = self._tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return MaskedStegoResult(encoded_text=stego_text,encoded_bytes=encoded_message,remaining_bytes=message[len(encoded_message):])

  def decode(self, stego_text: str, mask_interval: int = 3, score_threshold: float = 0.005) -> str:
    decoded_message: List[str] = []
    
    input_ids, masked_ids, sorted_score, indices = self._preprocess_text(stego_text, mask_interval)
    
    for i_token, token in enumerate(masked_ids):
      if token != self._tokenizer.mask_token_id:
        continue
      ids = indices[i_token]
      scores = sorted_score[i_token]
      candidates = self._pick_candidates_threshold(ids, scores, score_threshold)
      if len(candidates) < 2:
        continue
      chosen_id = int(input_ids[i_token].item())
      decoded_message.append(self._block_decode_single(candidates, chosen_id))

    return ''.join(decoded_message)

  def _preprocess_text(self, sentence: str, mask_interval: int) -> "PreprocessedText":
    encoded_ids = self._tokenizer([sentence], return_tensors='pt').input_ids[0]
    masked_ids = self._mask(encoded_ids.clone(), mask_interval)
    sorted_score, indices = self._predict(masked_ids)
    return PreprocessedText(input_ids=encoded_ids,masked_ids=masked_ids,sorted_output=(sorted_score,indices))

  def _mask(self, input_ids, mask_interval: int) -> Tensor:
    ic(input_ids,mask_interval)
    length = len(input_ids)
    tokens: List[str] = self._tokenizer.convert_ids_to_tokens(input_ids)
    ic(tokens)
    offset = mask_interval // 2 + 1
    mask_count = offset
    # TODO remove mask_interval
    #   or  handle the case if interval -1
    #   use [s[i[0],i[1]] for i in extract_pos]
    #   note that token is a string list to compare is the str
    #   is in the pos
    #   maybe use alongside the interval ... idk
    #   
    for i, token in enumerate(tokens):
      # Skip initial subword
      if i + 1 < length and self._is_subword(tokens[i + 1]): continue
      if not self._substitutable_single(token): continue
      if mask_count % mask_interval == 0:
        input_ids[i] = self._tokenizer.mask_token_id
      mask_count += 1
    ic(self._tokenizer.mask_token_id)
    ic(self._tokenizer.convert_ids_to_tokens(input_ids))
    ic(input_ids)
    return input_ids


  def _predict(self, input_ids: Tensor):
    self._model.eval()
    with torch.no_grad():
      output = self._model(input_ids.unsqueeze(0))['logits'][0]
      softmaxed_score = F.softmax(output, dim=1)  # [word_len, vocab_len]
      return softmaxed_score.sort(dim=1, descending=True)

  def _pick_candidates_threshold(self, ids: Tensor, scores: Tensor, threshold: float):
    filtered_ids = ids[scores >= threshold]
    def filter_fun(idx: Tensor) -> bool:
      return self._substitutable_single(self._tokenizer.convert_ids_to_tokens(int(idx.item())))
    return list(filter(filter_fun, filtered_ids))

  def _substitutable_single(self, token: str) -> bool:
    if self._is_subword(token): return False
    if token.lower() in MaskedStego._STOPWORDS: return False
    if not token.isalpha(): return False
    return True

  @staticmethod
  def _block_encode_single(ids: List[torch.Tensor], message: StringIO) -> torch.Tensor:
    assert len(ids) > 0
    if len(ids) == 1:
      return ids[0]
    capacity = len(ids).bit_length() - 1
    bits_str = message.read(capacity)
    if len(bits_str) < capacity:
      padding: str = '0' * (capacity - len(bits_str))
      bits_str = bits_str + padding
      message.write(padding)
    index = int(bits_str, 2)
    return ids[index]

  @staticmethod
  def _block_decode_single(ids: List[Tensor], chosen_id: int) -> str:
    if len(ids) < 2:
      return ''
    capacity = len(ids).bit_length() - 1
    index = ids.index(chosen_id) # type: ignore
    return format(index, '0' + str(capacity) +'b')

  @staticmethod
  def _is_subword(token: str) -> bool:
    return token.startswith('##')
    

In [None]:
masked_stego = MaskedStego(tokenizer=TOKENIZER,model=MODEL)


dmo = "The quick red fox jumps over the poor dog, And I hate tokenizers."
dmo2 = "The quick brown fox jumps over the lazy dog. and said boom you lazy dog stay back"
# a = masked_stego.decode(dmo, 3, 0.01)
masked_stego(dmo,'010101010101', 3, 0.01)
  

ic| input_ids: tensor([  101,  1109,  3613,  1894, 17594, 15457,  1166,  1103,  2869,  3676,
                         117,  1262,   146,  4819, 22559, 17260,  1116,   119,   102])
    mask_interval: 3
ic| tokens: ['[CLS]

',
             'The',
             'quick',
             'red',
             'fox',
             'jumps',
             'over',
             'the',
             'poor',
             'dog',
             ',',
             'And',
             'I',
             'hate',
             'token',
             '##izer',
             '##s',
             '.',
             '[SEP]']
ic| self._tokenizer.mask_token_id: 103
ic| self._tokenizer.convert_ids_to_tokens(input_ids): ['[CLS]',
                                                       'The',
                                                       'quick',
                                                       '[MASK]',
                                                       'fox',
                                                       'jumps',
                                                       'over',
                                                       'the',
                                                       '[MASK]',
                   

MaskedStegoResult(encoded_text='The quick talking fox jumps over the poor dog, And I hate tokenizers.', encoded_bytes='01010', remaining_bytes='1010101')

In [None]:
ic(MaskGen(dmo,TOKENIZER))

ic| MaskGen(dmo,TOKENIZER): SemanticPositions(string='The quick red fox jumps over the poor dog, And I '
                                                     'hate tokenizers.',
                                              NVA_words=[(4, 9),
                                                         (10, 13),
                                                         (18, 23),
                                                         (33, 37),
                                                         (38, 41),
                                                         (49, 53),
                                                         (54, 64)],
                                              tokens=['The',
                                                      'quick',
                                                      'red',
                                                      'fox',
                                                      'jumps',
                                             

SemanticPositions(string='The quick red fox jumps over the poor dog, And I hate tokenizers.', NVA_words=[(4, 9), (10, 13), (18, 23), (33, 37), (38, 41), (49, 53), (54, 64)], tokens=['The', 'quick', 'red', 'fox', 'jumps', 'over', 'the', 'poor', 'dog', ',', 'And', 'I', 'hate', 'token', '##izer', '##s', '.'], mask=[False, False, True, False, True, False, False, False, True, False, False, False, False, True, False, False, False])

In [None]:
b

MaskedStegoResult(encoded_text='The quick thinking fox jumps over the sly dog. and said boom you lazy dog go back', encoded_bytes='01010101', remaining_bytes='0101')

In [4]:
from MaskedStego import MaskedStego

In [5]:
masked_stego = MaskedStego()
#@title random_bit_stream
import random

def random_bit_stream(length=None):
    """Return a random string of zeros and ones of the given length (default: random integer between 0 and 100)."""
    if length is None:
        length = random.randint(0, 100)
    return ''.join(str(random.randint(0, 1)) for _ in range(length))
def int_to_binary_string(n: int, length: int):
    binary_str = bin(n)[2:]  # convert to binary string, remove '0b' prefix
    padded_str = binary_str.rjust(length, '0')  # pad with zeros to length
    return padded_str


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
import os, sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)
from SampleData import ConversationsRepo

In [None]:

print(

masked_stego("The quick brown fox jumps over the lazy dog. and said boom you lazy dog stay back",'010101010101', 3, 0.01)
)

print(

masked_stego.decode("The quick red fox jumps over the poor dog.", 3, 0.01)
)

In [17]:
def runBenchmark():
  chat_id = random.randint(1,ConversationsRepo.ConversationsCount)
  print(f"chat_id\tc size\tbits\tratio")
  for i in range(100):
    for text in ConversationsRepo.get(chat_id):
      data = random_bit_stream(len(text))
      # data = '1' * len(text)
      # text = 'hi, how are you?'
      result = masked_stego(text,data,3,0.01)
      encoded_text,rem = (result.encoded_text,result.remaining_bytes)
      print('rem=',rem)
      deData = masked_stego.decode(encoded_text,3,0.01)
      deData += rem
      print(f'text="{text}"\n->\nencoded_text="{encoded_text}" \ndata="{data}"\ndeData="{deData}"\ndata==deData="{data==deData}"')
      print(f'ratio={len(data)-len(rem)} / {len(text)}={(len(data)-len(rem)) / len(text)}')
      assert data==deData
      print('\n')
      
      
      bits = len(text)-len(rem)
      coverSize = len(text)
      line = f"{chat_id}\t{coverSize}\t{bits}\t{(bits*100)/coverSize}"
      print(line)
      with open('MLM benchmark.tsv','a') as f:
        f.write(line+'\n')
    chat_id = random.randint(1,ConversationsRepo.ConversationsCount)


runBenchmark()

chat_id	c size	bits	ratio
rem= 010111001110110001
text=" Do you own a car?"
->
encoded_text="Do you own a car?" 
data="010111001110110001"
deData="010111001110110001"
data==deData="True"
ratio=0 / 18=0.0


3893	18	0	0.0
rem= 0001000100000100101000000111111
text=" Yes, I own an older car. Do you?"
->
encoded_text="Yes, I own an electric car. Do you?" 
data="010001000100000100101000000111111"
deData="010001000100000100101000000111111"
data==deData="True"
ratio=2 / 33=0.06060606060606061


3893	33	2	6.0606060606060606
rem= 110011110111110001010101001101100011111100000000001010011110100100010001010110111001111011100101110110111100011100
text=" Me too, it still gets me around lol. Did you know trunks are trunks because they used to have wooden trunks on the back"
->
encoded_text="Me too, it still gets me around lol. Did you know trunks are trunks because they used to have little trunks on the back" 
data="00001111001111011111000101010100110110001111110000000000101001111010010001000101011011

KeyboardInterrupt: 