In [9]:
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from torch.nn import CosineSimilarity
from copy import deepcopy
from random import choice
from functools import partial
import numpy as np
from numpy.random import choice as np_choice
from time import time
from transformers import T5Tokenizer, T5Model

In [172]:
# Define Sharpened Cosine Similarity
def scs(v, w, dim):
    return CosineSimilarity(dim=dim, eps=1e-08)(v, w) ** 3

# Function to pad tensors to the same length
def pad_tensors(tensor_list, padding_value=0):
    max_len = max(tensor.size(1) for tensor in tensor_list)
    padded_tensors = [
        torch.nn.functional.pad(tensor, (0, 0, 0, max_len - tensor.size(1)), value=padding_value)
        for tensor in tensor_list
    ]
    return torch.cat(padded_tensors, dim=0)

# Decoding function to translate T5 embeddings into text
def decode_t5_embedding(embedding, t5_model_path, tokenizer_path, num_epochs=200, batch_size=128, max_len=120):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the T5 model and tokenizer
    t5_model = T5Model.from_pretrained("t5-base").to(device).eval()
    tokenizer = T5Tokenizer.from_pretrained("t5-base")

    # Randomly generate the initial tokens
    vocabulary = list(range(tokenizer.vocab_size))  # Use full vocabulary
    random_length = np.random.randint(1, max_len // 2)  # Generate a random length for the initial text
    best_tokens = [choice(vocabulary) for _ in range(random_length)]  # Randomly select tokens from vocabulary

    with torch.no_grad():
        input_ids = torch.tensor([best_tokens]).to(device)
        encoder_outputs = t5_model.encoder(input_ids=input_ids)
        best_vec = encoder_outputs.last_hidden_state.mean(dim=1)  # Mean pooling

    embedding = embedding.mean(dim=1)  # Ensure the input embedding matches dimensions
    best_score = scs(best_vec, embedding, dim=1).item()

    # Define modification operations
    def delete_token(token_list):
        if len(token_list) > 0:
            del token_list[choice(range(len(token_list)))]
        return token_list

    def insert_token(token_list, vocabulary):
        insert_id = choice(range(len(token_list) + 1))
        new_word = choice(vocabulary)
        token_list.insert(insert_id, new_word)
        return token_list

    def replace_token(token_list, vocabulary):
        if len(token_list) > 0:
            replace_id = choice(range(len(token_list)))
            new_word = choice(vocabulary)
            token_list[replace_id] = new_word
        return token_list

    def donothing(token_list):
        return token_list

    ops = [
        delete_token,
        partial(insert_token, vocabulary=vocabulary),
        partial(replace_token, vocabulary=vocabulary),
    ]

    # Iterative optimization
    t_start = time()

    for epoch in range(num_epochs):
        if len(best_tokens) >= max_len:
            probs = np.array([0.2, 0.0, 0.8])  # Restrict length growth
        else:
            probs = np.array([0.15, 0.1, 0.75])

        ops_ids = np_choice(range(len(ops)), batch_size, p=probs)
        candidates = []

        for op_id in ops_ids:
            op = ops[op_id]
            candidate_tokens = op(deepcopy(best_tokens))
            if len(candidate_tokens) > 0:  # Ensure non-empty candidate
                candidates.append(candidate_tokens)

        if len(candidates) > 0:  # Ensure there are valid candidates
            candidate_vecs = []
            for cand_tokens in candidates:
                input_ids = torch.tensor([cand_tokens]).to(device)
                with torch.no_grad():
                    vec = t5_model.encoder(input_ids=input_ids).last_hidden_state.mean(dim=1)  # Mean pooling
                    candidate_vecs.append(vec)

            candidate_vecs = torch.cat(candidate_vecs, dim=0)
            scores = torch.tensor([scs(candidate_vec, embedding, dim=0).item() for candidate_vec in candidate_vecs])

            max_score = scores.max().item()
            if max_score > best_score:
                best_score = max_score
                best_tokens = candidates[scores.argmax().item()]
        best_text = tokenizer.decode(best_tokens, skip_special_tokens=True)
        print(f"Epoch {epoch + 1}/{num_epochs}, Best Score: {best_score:.4f}, Current Best Text: {best_text}", end="\r")

    print(f"Decoding completed in {(time() - t_start):.2f}s. Best score: {best_score:.4f}")
    return tokenizer.decode(best_tokens, skip_special_tokens=True)


In [173]:
if __name__ == "__main__":
    # Load embedding vector from a provided txt file
    with open("aux_prompt_vec.txt", "r") as f:
        embedding_values = [float(line.strip()) for line in f.readlines()]

    example_embedding = torch.tensor([embedding_values]).to("cuda" if torch.cuda.is_available() else "cpu")

    # Decode the embedding
    decoded_text = decode_t5_embedding(
        embedding=example_embedding,
        t5_model_path=None,
        tokenizer_path=None
    )

    print("Decoded Text:", decoded_text)

Decoding completed in 123.38s. Best score: 0.0044ext: urmatoareA vigorous Cruz 15%meterkayorjemandOne promisingBE dans seit Vent Than Island Spider 1- Bird Option Would Kirk Blog Make Choir Ya organismuluii
Decoded Text: urmatoareA vigorous Cruz 15%meterkayorjemandOne promisingBE dans seit Vent Than Island Spider 1- Bird Option Would Kirk Blog Make Choir Ya organismului
