In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


In [2]:
tokenizer = AutoTokenizer.from_pretrained("cerebras/btlm-3b-8k-base")
model = AutoModelForCausalLM.from_pretrained("cerebras/btlm-3b-8k-base", trust_remote_code=True, torch_dtype="auto").to("cuda")

In [277]:
prompt = """SYSTEM: You are an agent tasked with creating better, more descriptive image prompts that result in higher quality outputs. You are given a user prompt, and determine how, if it all, to change the users image. Do you understand?
AGENT: Yes, I understand.
SYSTEM: Great. Below there will be a set of examples. For each example, you will be given the users prompt and the expected output. You will be then given a new user prompt and have to generate a new output. Do not output anything beyond the prompt. Do you understand?
AGENT: Yes.
SYSTEM: Another rule, do not list numbers in your response.
AGENT: Yes, I understand.

PROMPT: A picture of a cat.
AGENT: masterpiece:2, Profession Photograph, A cat sitting on a chair, Orange, White whiskers

PROMPT: The beach
AGENT: masterpiece:3, A beautiful beach, Palm trees, blue sky, white clouds, clear sky

PROMPT: A picture of a dog.
AGENT:"""
# Change the last prompt to whatever you want to test
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

In [278]:
# Generate text using the model
outputs = model.generate(
    **inputs,
    num_beams=5,
    max_new_tokens=50,
    early_stopping=True,
    no_repeat_ngram_size=2,
    # use random seed
    top_k=10,
    top_p=0.9,
    temperature=0.9,
    do_sample=True,
)
# Convert the generated token IDs back to text
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(generated_text)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SYSTEM: You are an agent tasked with creating better, more descriptive image prompts that result in higher quality outputs. You are given a user prompt, and determine how, if it all, to change the users image. Do you understand?\nAGENT: Yes, I understand.\nSYSTEM: Great. Below there will be a set of examples. For each example, you will be given the users prompt and the expected output. You will be then given a new user prompt and have to generate a new output. Do not output anything beyond the prompt. Do you understand?\nAGENT: Yes.\nSYSTEM: Another rule, do not list numbers in your response.\nAGENT: Yes, I understand.\n\nPROMPT: A picture of a cat.\nAGENT: masterpiece:2, Profession Photograph, A cat sitting on a chair, Orange, White whiskers\n\nPROMPT: The beach\nAGENT: masterpiece:3, A beautiful beach, Palm trees, blue sky, white clouds, clear sky\n\nPROMPT: A picture of a dog.\nAGENT: masterpice:4, Dog, Blue eyes, brown fur, black nose, red collar, sitting in a grassy field, looki

In [279]:
def filter_result(result):
    result = result.replace("\n", " ")
    result = result[result.rfind("AGENT:") + 7:]
    result = result.split("  ")[0]
    result = result.split("```")[0]
    result = result.strip()
    return result    

print(filter_result(generated_text[0]))


masterpice:4, Dog, Blue eyes, brown fur, black nose, red collar, sitting in a grassy field, looking at the camera, the sun is shining in the background, it's a beautiful day
