# Refine txt2img Prompts with Human Feedback


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_simulacra.ipynb)


#### Optimize a gpt2-based txt2img prompt generator to produce aesthetic prompts using https://github.com/JD-P/simulacra-aesthetic-captions

Notebook by [@smellslikeml](https://github.com/smellslikeml)

---

Execute the cells below to install [TRLX](https://github.com/CarperAI/trlx) for a colab environment.

In [None]:
!git clone https://github.com/CarperAI/trlx.git
!git config --global --add safe.directory /content/trlx && cd /content/trlx && pip install -e .

Cloning into 'trlx'...
remote: Enumerating objects: 4652, done.[K
remote: Counting objects: 100% (457/457), done.[K
remote: Compressing objects: 100% (202/202), done.[K
remote: Total 4652 (delta 299), reused 380 (delta 252), pack-reused 4195[K
Receiving objects: 100% (4652/4652), 46.06 MiB | 27.26 MiB/s, done.
Resolving deltas: 100% (2863/2863), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/trlx
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting transformers>=4.21.2
  Downloading transformers-4.26.0-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m53.2 MB/s[0m eta [36m0:00:00[0m
[

In [None]:
# uninstall scikit_learn + jax to avoid numpy issues
!pip uninstall -y scikit_learn jax

Found existing installation: scikit-learn 1.0.2
Uninstalling scikit-learn-1.0.2:
  Successfully uninstalled scikit-learn-1.0.2
Found existing installation: jax 0.3.25
Uninstalling jax-0.3.25:
  Successfully uninstalled jax-0.3.25


In [None]:
import os

# run within repo
os.chdir('/content/trlx')
print(os.getcwd())

/content/trlx


In [None]:
import sqlite3
from urllib.request import urlretrieve

import trlx

url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"

if not os.path.exists(dbpath):
  print(f"fetching {dbpath}")
  urlretrieve(url, dbpath)

conn = sqlite3.connect(dbpath)
c = conn.cursor()
c.execute(
    "SELECT prompt, rating FROM ratings "
    "JOIN images ON images.id=ratings.iid "
    "JOIN generations ON images.gid=generations.id "
    "WHERE rating IS NOT NULL;"
)

prompts, ratings = tuple(map(list, zip(*c.fetchall())))

fetching sac_public_2022_06_29.sqlite


Trlx uses [wandb](https://wandb.ai/) to log results. Make sure to set up an account and use your token to authenticate when prompted after executing the cell below.

In [None]:
trlx.train(
    "gpt2",
    samples=prompts,
    rewards=ratings,
    eval_prompts=["Hatsune Miku, Red Dress"] * 64,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[Sample example]
Prompt:  <|endoftext|>
Response:  An artwork of a broken wine bottle in the medium of dry pigments<|endoftext|>
Reward:  7
[Mean prompt length] 1.00 ∈ [1, 1]
[Mean output length] 25.64 ∈ [1, 63]
[Mean sample length] 26.64 ∈ [2, 64]
[Mean return] 5.53 ∈ [1.0, 10.0]


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


losses/loss: 5.24, losses/loss_q: 0.27, losses/loss_v: 0.06, losses/loss_cql: 18.15, losses/loss_awac: 3.08:  10%|▉         | 99/1000 [01:43<15:45,  1.05s/it]

losses/loss: 4.92, losses/loss_q: 0.28, losses/loss_v: 0.04, losses/loss_cql: 16.81, losses/loss_awac: 2.92:  20%|█▉        | 199/1000 [03:29<14:03,  1.05s/it]

losses/loss: 4.75, losses/loss_q: 0.37, losses/loss_v: 0.04, losses/loss_cql: 15.73, losses/loss_awac: 2.76:  30%|██▉       | 299/1000 [05:15<12:19,  1.05s/it]

losses/loss: 4.40, losses/loss_q: 0.35, losses/loss_v: 0.04, losses/loss_cql: 14.54, losses/loss_awac: 2.55:  40%|███▉      | 399/1000 [07:01<10:22,  1.04s/it]

losses/loss: 4.15, losses/loss_q: 0.25, losses/loss_v: 0.04, losses/loss_cql: 14.34, losses/loss_awac: 2.43:  50%|████▉     | 499/1000 [08:47<08:47,  1.05s/it]

losses/loss: 4.01, losses/loss_q: 0.26, losses/loss_v: 0.04, losses/loss_cql: 13.39, losses/loss_awac: 2.37:  60%|█████▉    | 599/1000 [10:32<06:59,  1.05s/it]

losses/loss: 3.94, losses/loss_q: 0.23, losses/loss_v: 0.04, losses/loss_cql: 13.20, losses/loss_awac: 2.35:  70%|██████▉   | 699/1000 [12:18<05:16,  1.05s/it]

losses/loss: 3.59, losses/loss_q: 0.31, losses/loss_v: 0.04, losses/loss_cql: 12.04, losses/loss_awac: 2.05:  80%|███████▉  | 799/1000 [14:04<03:29,  1.04s/it]

losses/loss: 3.81, losses/loss_q: 0.18, losses/loss_v: 0.04, losses/loss_cql: 12.44, losses/loss_awac: 2.34:  90%|████████▉ | 899/1000 [15:50<01:46,  1.05s/it]

losses/loss: 3.34, losses/loss_q: 0.27, losses/loss_v: 0.05, losses/loss_cql: 11.13, losses/loss_awac: 1.92: 100%|█████████▉| 999/1000 [17:36<00:01,  1.04s/it]

losses/loss: 3.47, losses/loss_q: 0.27, losses/loss_v: 0.04, losses/loss_cql: 11.21, losses/loss_awac: 2.04: 100%|██████████| 1000/1000 [17:48<00:00,  4.36s/it]

losses/loss: 3.47, losses/loss_q: 0.27, losses/loss_v: 0.04, losses/loss_cql: 11.21, losses/loss_awac: 2.04: 100%|██████████| 1000/1000 [17:58<00:00,  1.08s/it]


<trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer at 0x7fb1b0530370>