In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ"
# To use a different branch, change revision
# For example: revision="gptq-4bit-64g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

In [None]:
from transformers import AutoModel

encoder_model_name = "princeton-nlp/unsup-simcse-roberta-large"
encoder = AutoModel.from_pretrained(encoder_model_name).to(model.device)
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder_model_name, use_fast=True)

In [None]:
target_sentence = "Today my pet is sick and i am very sad. I plan to take her to the hospital tomorrow and see what the doctor can do. I hope it doesnt cost too much."
with torch.no_grad():
    target_embedding = encoder(**encoder_tokenizer(target_sentence, return_tensors="pt", padding=True, truncation=True).to(model.device)).pooler_output[0]

target_embedding.shape

In [None]:
import torch
import numpy as np

prompt = "Close your eyes, and tell me what do you see"
prompt_template=f'''[INST] <<SYS>>
You are very imaginative and creative, and you are picturing a scene in your mind. The scene is about """{target_sentence}"""
A ultra realistic scene begins to show up...<</SYS>>
{prompt}[/INST]

I see'''

# output = model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
# print(tokenizer.decode(output[0]))
prompt_template

In [None]:
from termcolor import colored, cprint

# cprint("Prompt:", "green")

def cosine_similarity(a, b):
    return torch.dot(a, b) / (torch.norm(a) * torch.norm(b))


top_k = 20
hint_weight = 5
hint_weight_decay = 1.01
# reptition_penalty = 0.98

past_key_values = None

input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
input_length = len(input_ids[0])
ids_so_far = input_ids
generation_mask = []

embeddings = []


with torch.no_grad():
    for i in range(400):
        output = model(input_ids=input_ids, return_dict=True, past_key_values=past_key_values)
        logits = output.logits[0]
        past_key_values = output.past_key_values

        top_tokens = torch.topk(logits[-1], top_k).indices
        raw_best_token = torch.argmax(logits[-1])

        best_score = -10000
        best_token = None

        for top_token in top_tokens:
            concat_indices = torch.cat([ids_so_far[0][input_length:], top_token.unsqueeze(0)])
            string = tokenizer.decode(concat_indices)
            encoder_tokens = encoder_tokenizer(string, return_tensors='pt').input_ids.cuda()
            embedding = encoder(input_ids=encoder_tokens).pooler_output
            similarity = cosine_similarity(target_embedding, embedding[0])

            similarity = torch.max(similarity, torch.tensor(1e-5).to(model.device))

            # seen_times = 0
            # for token in ids_so_far[0][:-10]:
            #     if token == top_token:
            #         seen_times += 1

            score = torch.log(similarity) * hint_weight + torch.log(logits[-1][top_token])
            # score = score * (reptition_penalty ** seen_times)
            # print(torch.log(similarity), torch.log(logits[-1][top_token]), score)
            if score > best_score:
                best_score = score
                best_token = top_token
                embeddings.append(embedding)

        if best_token == raw_best_token:
            # print("Raw best token is the best token")
            generation_mask.append(1)
            cprint(tokenizer.decode(best_token), end=" ", color="green")
        else:
            generation_mask.append(0)
            cprint(tokenizer.decode(best_token), end=" ", color="red")

        ids_so_far = torch.cat([ids_so_far, best_token.unsqueeze(0).unsqueeze(0)], dim=-1)
        input_ids = best_token.unsqueeze(0).unsqueeze(0)

        hint_weight *= hint_weight_decay

        if i % 50 == 0:
            # print(tokenizer.decode(ids_so_far[0]), end=" ")
            print('\n')
            print('similarity:', similarity)
            for i in range(len(generation_mask)):
                if generation_mask[i] == 1:
                    cprint(tokenizer.decode(ids_so_far[0][input_length + i]), end=" ", color="green")
                else:
                    cprint(tokenizer.decode(ids_so_far[0][input_length + i]), end=" ", color="red")
            print('\n\n')
        # break
        # best_token = torch.argmax(logits[-1])
        # ids_so_far = torch.cat([ids_so_far, best_token.unsqueeze(0).unsqueeze(0)], dim=-1)

        # input_ids = torch.cat([input_ids, best_token.unsqueeze(0).unsqueeze(0)], dim=-1)input_ids
        # input_ids = best_token.unsqueeze(0).unsqueeze(0)

In [None]:
print(tokenizer.decode(ids_so_far[0][input_length:]))

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)

embeddings_array = torch.cat(embeddings).cpu().numpy()
target_embeddings_cpu = target_embedding.cpu().numpy()

pca.fit(embeddings_array)

embeddings_pca = pca.transform(embeddings_array)
target_embedding_pca = pca.transform(target_embeddings_cpu.reshape(1, -1))

import matplotlib.pyplot as plt
plt.plot(embeddings_pca[:, 0], embeddings_pca[:, 1])
plt.scatter(target_embedding_pca[:, 0], target_embedding_pca[:, 1], color='red')
plt.show()

In [None]:
for i in range(len(generation_mask)):
    if generation_mask[i] == 1:
        cprint(tokenizer.decode(ids_so_far[0][input_length + i]), end=" ", color="green")
    else:
        cprint(tokenizer.decode(ids_so_far[0][input_length + i]), end=" ", color="red")

In [None]:
hint_weight = 5

In [None]:
import os

i = 1
while os.path.exists(f"../dream/story{i}.txt"):
    i += 1

with open(f"../dream/story{i}.txt", "w") as f:
    f.write(f"Hint: {target_sentence}\n\n")
    f.write(f"Top k: {top_k}\n")
    f.write(f"Hint weight: {hint_weight}\n\n")
    f.write(tokenizer.decode(ids_so_far[0]))