# Run SuperPrompt-v1 AI Model Ipywidgets UI

Make your prompts better for AI Art or in general!

Used Model: https://huggingface.co/roborovski/superprompt-v1

Blog Model: https://brianfitzgerald.xyz/prompt-augmentation/

Google Colab Notebook Made by [Nick088](https://linktr.ee/Nick088) using Ipywidgets UI which is allowed on Google Colab

In [None]:
#@title Install & Load Dependencies, Model

#@markdown If you wanna use CPU (slower, no daily limit): Set the CPU from Edit -> Notebook Settings -> CPU

#@markdown If you wanna use GPU (faster, max 12 free hours daily limit): Set the Video Card from Edit -> Notebook Settings -> T4 GPU OR Any other GPUs based on your Google Colab Subscription

#@markdown Anyways its a very small model, it doesn't matter much if you use cpu or gpu.

!pip install transformers
!pip install einops
!pip install accelerate
!pip install sentencepiece
import torch
from IPython.display import clear_output
from transformers import T5Tokenizer, T5ForConditionalGeneration
import random

if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU")

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16) # torch.float16 is basically fp16, so model precision in 16 bits which is faster and less resource consuming, you could also put torch.float32 which is fp32 that lods it in 32bits which is more precise but slower and more resource consuming

model.to(device)

# ipywidgets ui
!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension
from ipywidgets import widgets
from ipywidgets import Layout

clear_output()
print(f"Downloaded & SuperPrompt-v1 on {'GPU' if device == 'cuda' else 'CPU'}")

In [None]:
#@title Run ipywidgets UI

# Define the function to generate text
def generate_text(system_prompt, your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k):
    full_prompt = f"{system_prompt}, {your_prompt}"
    input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k)
    dirty_text = tokenizer.decode(outputs[0])
    text = dirty_text.replace("<pad>", "").replace("</s>", "")
    return text


# style to fix too long descriptions
style = {'description_width': 'initial'}


# Create the your prompt widget
your_prompt_widget = widgets.Text(
    value="A storefront with 'Text to Image' written on it.",
    placeholder='Type your prompt here',
    description='Your Prompt:',
    disabled=False,
    style=style,
    layout=Layout(width='480px', height='50px')
)

# Create the system prompt widget
system_prompt_widget = widgets.Text(
    value="Expand with as much details as possible the prompt:",
    placeholder='Type the system prompt here',
    description='System Prompt (Prompt to stylize the AI):',
    disabled=False,
    style=style,
    layout=Layout(width='600px', height='50px')
)

# Create the max_new_tokens slider
max_new_tokens_widget = widgets.IntSlider(
    value=512,
    min=250,
    max=512,
    step=1,
    description='Max New Tokens (Maximum number of the tokens to generate, controls how long is the text):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style,
    layout=Layout(width='800px', height='50px')
)

# Create the repetition_penalty slider
repetition_penalty_widget = widgets.FloatSlider(
    value=1.2,
    min=0.0,
    max=2.0,
    step=0.05,
    description='Repetition Penalty (Penalize repeated tokens, so makes the AI repeat less of itself):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style,
    layout=Layout(width='800px', height='50px')
)

# Create the temperature slider
temperature_widget = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.00,
    step=0.05,
    description='Temperature (Higher values produce more diverse outputs):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style,
    layout=Layout(width='700px', height='50px')
)

# Create the top_p slider
top_p_widget = widgets.FloatSlider(
    value=1.0,
    min=0.0,
    max=2.0,
    step=0.05,
    description='Top P (Higher values sample more low-probability tokens):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style,
    layout=Layout(width='600px', height='50px')
)

# Create the top_k slider
top_k_widget = widgets.IntSlider(
    value=1,
    min=1,
    max=100,
    step=1,
    description='Top K (Higher k means more diverse outputs by considering a range of tokens):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style,
    layout=Layout(width='700px', height='50px')
)

# Create the seed input
seed_widget = widgets.IntText(
    value=42,
    description='Seed (Starting point to initiate the generation process, put 0 for random):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style,
    layout=Layout(width='480px', height='50px')
)

# Create the output widget
output_widget = widgets.Output()

# Define the function to handle button click
def on_button_clicked(b):
    with output_widget:
        output_widget.clear_output(wait=True)
        print("Generating a better version of your prompt...")
        system_prompt = system_prompt_widget.value
        your_prompt = your_prompt_widget.value
        max_new_tokens = max_new_tokens_widget.value
        repetition_penalty = repetition_penalty_widget.value
        temperature = temperature_widget.value
        top_p = top_p_widget.value
        top_k = top_k_widget.value
        seed = seed_widget.value
        if seed == 0:
          seed = random.randint(1, 100000)
          torch.manual_seed(seed)
        else:
          torch.manual_seed(seed)
        generated_text = generate_text(system_prompt, your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k)
        output_widget.clear_output(wait=True)
        print(generated_text)


# Create the button widget
button = widgets.Button(description="Generate Better Prompt", layout=Layout(width='400px', height='50px'))
button.on_click(on_button_clicked)

# Create the UI layout
ui = widgets.VBox([
    system_prompt_widget,
    your_prompt_widget,
    max_new_tokens_widget,
    repetition_penalty_widget,
    temperature_widget,
    top_p_widget,
    top_k_widget,
    seed_widget,
    button,
    output_widget
])

# display ui
ui
