<a href="https://colab.research.google.com/github/MichaelPaulukonis/notebooks/blob/main/%5BPublic%5D_Prompt_Parrot_v1_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Welcome to Prompt Parrot! 

This notebook is designed to train GPT-2 on a list of your prompts. And then parrot your prompts back at you. They could be brilliant. They could be hilarious. Either way, it'll be fun!

This notebook created and maintained by [Stephen Young](https://twitter.com/KyrickYoung) or SteveTheNinja#0616 on discord

# Instructions

1. Create a text file with all of your prompts!

    **Prompt Parrot requires a minimum of 15-20 prompts to work.** However more than 50 prompts is ideal! The more the merrier.

    Example prompt file:
    ```
    a beautiful painting of a mountain by Tyler Edlin, trending on arstation
    a stunning painting of Chicago, oil and canvas
    cyberpunk hoover dam, cgsociety, artstation, highly detailed matte painting
    ```

2. Next upload your file to the session like so

    ![upload_icon.png](https://drive.google.com/uc?id=14Is1KFTqpRwqeCWuIPC3StQsAbmG6als)

3. Finally copy the path to your uploaded file and paste it below!

    ![upload_icon.png](https://drive.google.com/uc?id=1sjvjMc3mUkmbNgh3SkRd616A_Gvol5zP)

#0. Setup

In [None]:
#@title 0.1 GPU Check

!nvidia-smi -L

In [None]:
#@title 0.2 Install Dependencies
!pip install transformers datasets

In [None]:
#@title 0.3 Imports

import random
import os
import random
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TextDataset, DataCollatorForLanguageModeling,
    Trainer, TrainingArguments
)
import datasets
from google.colab import files

# no need for caching with tiny datasets. More trouble than it's worth
datasets.disable_caching()

#1. Do the run



In [None]:
#@title 1.1 Prompts.txt

prompts_file_path = "/content/prompts.txt" #@param{type:"string"}

all_prompts = []
if prompts_file_path:
    with open(prompts_file_path) as infile:
        all_prompts = infile.read().strip().split("\n")
else:
    uploaded = files.upload()
    _, all_prompts = list(uploaded.items())[0]
    all_prompts = all_prompts.decode("UTF-8").split("\n")

if not all_prompts:
    raise UserWarning(f"Read 0 prompts from {prompts_file_path}")

prompt_starts = list(set([" ".join(p.split()[0:2]).replace(",", "") for p in all_prompts if len(p.split()) > 1]))

In [None]:
#@title 1.2 Train GPT-2
#@markdown Number of iterations to train for. Try 50-500? If prompts generated are identical to your input prompts, try turning down num_train_epochs.
num_train_epochs=50 #@param{type:"integer"}

end_token = "<|endoftext|>"
prompts_txt = "scrambled_prompts.txt"

# scramble the prompts so the model doesn't learn association between lines
with open(prompts_txt, "w+") as fp:
    for _ in range(4):
        random.shuffle(all_prompts)
        fp.write(end_token.join(all_prompts) + end_token)

# quick and dirty workaround to blow away the cache for now
# TODO: upgrade to huggingface datasets lib. TextDataset is deprecated
!rm /content/cached_lm_GPT2TokenizerFast*
!rm -rf /content/output/

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
train_dataset = TextDataset(tokenizer=tokenizer, file_path=prompts_txt, block_size=tokenizer.model_max_length)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir='./output',
    overwrite_output_dir=True,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=1,
    prediction_loss_only=True,
    logging_steps=100,
    save_steps=0,
    seed=random.randint(0, 2**32-1),
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

trainer.train()
# model.save_pretrained("./model")

In [None]:
#@title 1.3 Generate prompts!

#@markdown _What?! A prompt for my prompts?!?_ **PREPOSTEROUS!!** 
#@markdown
#@markdown GPT-2 requires a prompt string to start a sentence. Best results will resemble part of your prompts. It can be a single letter or a full sentence.
#@markdown `If you leave this blank, Prompt Parrot will automatically choose starting phrases from your prompts`
prompt_override="bugs" #@param{type:"string"}
num_prompts=15 #@param{type:"integer"}


#@markdown Maximum and min length of the generated prompts. Will cut off mid word. This is expected behavior
max_length=60 #@param{type:"integer"}
min_length=10 #@param{type:"integer"}
#@markdown `temperature`: If you find your prompts are too similar to inputs, try turning up the temperature. If your prompts are too insane, turn down the temperature. A good deafult is 1.6
temperature=2.0 #@param{type:"number"}
#@markdown `top_k`: If you find your prompts are identical to your inputs, try turning down top_k. A good range is 70-100. A good default is 100.
top_k=70 #@param{type:"integer"}
top_p=0.9  #@param{type:"number"}

prompt = random.choice(prompt_starts)
if prompt_override:
    prompt = prompt_override

encoded_prompt = tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
encoded_prompt = encoded_prompt.to(model.device)

output_sequences = model.generate(
    input_ids=encoded_prompt,
    max_length=max_length,
    min_length=min_length,
    temperature=temperature,
    top_k=top_k,
    top_p=top_p,
    do_sample=True,
    num_return_sequences=num_prompts,
    pad_token_id=tokenizer.eos_token_id # gets rid of warning
    )

for generated_sequence in output_sequences:
    generated_sequence = generated_sequence.tolist()
    text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
    print(text.strip().replace("\n", " ").replace("/", ","))