In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch import nn
from math import pow
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from bitstring import BitArray

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

In [None]:
secret_mes = bytes(input("Enter the message you want to hide. Only use ASCII characters"),"ascii")

## Encryption

In [None]:
key_len = 16
#keygen
key = get_random_bytes(key_len)
cipher = AES.new(key, AES.MODE_GCM)
#encrypt
ciphertext, tag = cipher.encrypt_and_digest(secret_mes)
encode_this = BitArray(cipher.nonce + ciphertext + tag)

## Helper functions to create balanced probabilities

In [None]:
def leq_set(k,l):
    result = set([frozenset()])
    for i in range(k):
        new_result = result.copy()
        #print(new_result)
        for r in result:
            for j in l:
                new_r = r.union(frozenset([j]))
                #print(new_r)
                new_result.add(new_r)
        result = new_result

    return result


def set_combinations(l):
    result = set()
    l1 = leq_set(int(len(l)/2), l)
    l1.remove(frozenset())
    for i in l1:
        ln = [j for j in l if not j in i]
        l2 = leq_set(len(ln), ln)
        l2.remove(frozenset())
        for j in l2:
            t = (i,j)
            result.add(t)
    return result

In [None]:
def balance_probs(probs):
    length = probs.values.size()[1]
    combs = set_combinations(range(length))
    indeces = None
    difference = 1
    for (comb1,comb2) in combs:
        sum1 = 0
        for i in comb1:
            sum1 += probs.values[0,i]
        sum2 = 0
        for i in comb2:
            sum2 += probs.values[0,i]

        a = torch.abs(sum1-sum2)
        if a < difference:
            difference = a
            indeces = (comb1,comb2)
    return indeces

## Encode

In [None]:
current_string = "Dear audience,"
previous_string = ""

i = 0
while i < len(encode_this):
    b = encode_this[i]
    print("Current string:", current_string)
    current_tokens = tokenizer(current_string, return_tensors="pt")
    print(current_tokens.input_ids)
    probs = torch.topk(nn.functional.softmax(model(**current_tokens).logits[:, -1, :], dim=-1),8)
    (indeces_0, indeces_1) = balance_probs(probs)

    if b:
        (indeces_right, indeces_wrong) = (indeces_1, indeces_0)
    else:
        (indeces_right, indeces_wrong) = (indeces_0, indeces_1)
    
    print("Don't use these tokens ", [tokenizer.decode(probs.indices[0,i]) for i in indeces_wrong])
    print("Prefer these tokens ", [tokenizer.decode(probs.indices[0,i]) for i in indeces_right])
    
    new_token_string = input("add another token")
    new_token = tokenizer(new_token_string).input_ids[0]

    print([probs.indices[0,i] for i in indeces_right], [probs.indices[0,i] for i in indeces_wrong])

    if new_token in [probs.indices[0,i] for i in indeces_right]:
        i += 1
        print(i)
    elif new_token in [probs.indices[0,i] for i in indeces_wrong]:
        print("token from the 'don't use set'")
        continue

    previous_string = current_string[:]
    current_string += new_token_string

## Decode

In [None]:
in_string = "Dear audience, please do your own due and take"
seed_len = 3
in_tokens = tokenizer(in_string, return_tensors="pt").input_ids

In [None]:
bs = []

for i in range(seed_len, in_tokens.size()[1]):
    current_tokens = in_tokens[:,:i]
    probs = torch.topk(nn.functional.softmax(model(current_tokens).logits[:,-1, :], dim=-1),8)
    indeces = balance_probs(probs)
    new_token = in_tokens[0,i]

    if new_token in [probs.indices[0,i] for i in indeces[0]]:
        bs.append(False)
    elif new_token in [probs.indices[0,i] for i in indeces[1]]:
        bs.append(True)

decrypt_this = BitArray(bs).tobytes()

## Decryption

In [None]:
nonce = decrypt_this[:key_len]
ciphertext = decrypt_this[key_len:-key_len]
tag = decrypt_this[-key_len:]

In [None]:
#decrypt
cipher = AES.new(key, AES.MODE_GCM, nonce)
cipher.decrypt_and_verify(ciphertext, tag)