# Refine txt2img Prompts with Human Feedback


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_simulacra.ipynb)


#### Optimize a gpt2-based txt2img prompt generator to produce aesthetic prompts using a dataset of (prompt, rating) pairs https://github.com/JD-P/simulacra-aesthetic-captions

Notebook by [@smellslikeml](https://github.com/smellslikeml)

---

Execute the cells below to install [TRLX](https://github.com/CarperAI/trlx) for a colab environment.

In [2]:
!pip install git+https://github.com/CarperAI/trlx

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/CarperAI/trlx
  Cloning https://github.com/CarperAI/trlx to /tmp/pip-req-build-vj6wb17d
  Running command git clone --filter=blob:none --quiet https://github.com/CarperAI/trlx /tmp/pip-req-build-vj6wb17d
  Resolved https://github.com/CarperAI/trlx to commit 093e89a455b3b999506cd2c6b1f8e5c9fd828613
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.17.1 (from trlx==0.6.0)
  Using cached accelerate-0.20.3-py3-none-any.whl (227 kB)
Installing collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.16.0
    Uninstalling accelerate-0.16.0:
      Successfully uninstalled accelerate-0.16.0
Successfully installe

In [3]:
!!pip install accelerate==0.16.0

['Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/',
 'Collecting accelerate==0.16.0',
 '  Using cached accelerate-0.16.0-py3-none-any.whl (199 kB)',
 'Installing collected packages: accelerate',
 '  Attempting uninstall: accelerate',
 '    Found existing installation: accelerate 0.20.3',
 '    Uninstalling accelerate-0.20.3:',
 '      Successfully uninstalled accelerate-0.20.3',
 "\x1b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.",
 'trlx 0.6.0 requires accelerate>=0.17.1, but you have accelerate 0.16.0 which is incompatible.\x1b[0m\x1b[31m',
 '\x1b[0mSuccessfully installed accelerate-0.16.0']

In [None]:
!pip install auto-gptq
!pip install einops

In [None]:
from transformers import AutoTokenizer, pipeline, logging
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import argparse

model_name_or_path = "TheBloke/falcon-7b-instruct-GPTQ"
# You could also download the model locally, and access it there
# model_name_or_path = "/path/to/TheBloke_falcon-7b-instruct-GPTQ"

model_basename = "gptq_model-4bit-64g"

use_triton = False

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
        model_basename=model_basename,
        use_safetensors=True,
        trust_remote_code=True,
        device="cuda:0",
        use_triton=use_triton,
        quantize_config=None)


In [4]:
import os
import sqlite3
from urllib.request import urlretrieve

import trlx

url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"

if not os.path.exists(dbpath):
  print(f"fetching {dbpath}")
  urlretrieve(url, dbpath)

conn = sqlite3.connect(dbpath)
c = conn.cursor()
c.execute(
    "SELECT prompt, rating FROM ratings "
    "JOIN images ON images.id=ratings.iid "
    "JOIN generations ON images.gid=generations.id "
    "WHERE rating IS NOT NULL;"
)

prompts, ratings = tuple(map(list, zip(*c.fetchall())))

[2023-06-18 17:20:40,513] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [12]:
prompts[:100], ratings[:100]

(['An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a broken wine bottle in the medium of dry pigments',
  'An artwork of a b

In [5]:
from trlx.data.default_configs import default_ilql_config
config = default_ilql_config().evolve(train=dict(batch_size=32, total_steps=300))

Trlx uses [wandb](https://wandb.ai/) to log results. Make sure to set up an account and use your token to authenticate when prompted after executing the cell below.

In [3]:
config

TRLConfig(method=ILQLConfig(name='ilqlconfig', tau=0.7, gamma=0.99, cql_scale=0.1, awac_scale=1, alpha=0.001, beta=0, steps_for_target_q_sync=5, two_qs=True, gen_kwargs={'max_new_tokens': 56, 'top_k': 20, 'beta': 1, 'temperature': 1.0}), model=ModelConfig(model_path='gpt2', model_arch_type='causal', num_layers_unfrozen=-1, delta_kwargs=None), optimizer=OptimizerConfig(name='adamw', kwargs={'lr': 5e-05, 'betas': (0.9, 0.95), 'eps': 1e-08, 'weight_decay': 1e-06}), scheduler=SchedulerConfig(name='cosine_annealing', kwargs={'T_max': 1000000000000.0, 'eta_min': 5e-05}), tokenizer=TokenizerConfig(tokenizer_path='gpt2', padding_side='left', truncation_side='right'), train=TrainConfig(total_steps=300, seq_length=64, epochs=100, batch_size=32, checkpoint_interval=1000, eval_interval=100, pipeline='PromptPipeline', trainer='AccelerateILQLTrainer', trainer_kwargs={}, project_name='trlx', entity_name=None, group_name=None, checkpoint_dir='ckpts', rollout_logging_dir=None, save_best=True, save_opti

In [1]:
model = trlx.train(
    #"tiiuae/falcon-7b",
    "TheBloke/falcon-7b-instruct-GPTQ",
    config=config,
    samples=prompts,
    rewards=ratings,
    eval_prompts=["<|endoftext|>"] * 64
).model

NameError: ignored

In [6]:
# Infer the trained model
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b")
output = model.generate(**tokenizer(["An astronaut riding a horse"] * 16, return_tensors="pt").to(0))
tokenizer.batch_decode(output, skip_special_tokens=True)

Downloading (…)okenizer_config.json:   0%|          | 0.00/220 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.73M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/281 [00:00<?, ?B/s]

In [None]:
# Save the model locally
model.save_pretrained("gpt2-simulacra")

In [None]:
# To upload the model to Hugging Face, login first
from huggingface_hub import notebook_login
notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
# Upload the model to <your_name>/gpt2-simulacra
from huggingface_hub import create_repo, HfApi

repo_id = create_repo("gpt2-simulacra", private=False, exist_ok=True).repo_id
HfApi().upload_folder(folder_path="gpt2-simulacra", repo_id=repo_id)

pytorch_model.bin:   0%|          | 0.00/1.77G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

'https://huggingface.co/reciprocate/gpt2-simulacra/tree/main/'

In [None]:
# Load the same model now stored on Hugging Face
from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads
hf_model = AutoModelForCausalLMWithILQLHeads.from_pretrained(repo_id)

Downloading (…)lve/main/config.json:   0%|          | 0.00/907 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.77G [00:00<?, ?B/s]

Some weights of the model checkpoint at reciprocate/gpt2-simulacra were not used when initializing GPT2LMHeadModel: ['ilql_heads.target_q_heads.1.2.bias', 'ilql_heads.q_heads.0.2.bias', 'ilql_heads.target_q_heads.0.2.bias', 'ilql_heads.q_heads.0.0.bias', 'ilql_heads.v_head.2.weight', 'ilql_heads.q_heads.1.0.bias', 'ilql_heads.v_head.0.bias', 'ilql_heads.target_q_heads.0.0.bias', 'ilql_heads.target_q_heads.1.0.bias', 'ilql_heads.v_head.0.weight', 'ilql_heads.v_head.2.bias', 'ilql_heads.q_heads.0.2.weight', 'ilql_heads.target_q_heads.1.2.weight', 'ilql_heads.q_heads.1.2.weight', 'ilql_heads.target_q_heads.0.0.weight', 'ilql_heads.target_q_heads.1.0.weight', 'ilql_heads.q_heads.0.0.weight', 'ilql_heads.q_heads.1.2.bias', 'ilql_heads.target_q_heads.0.2.weight', 'ilql_heads.q_heads.1.0.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model

Downloading (…)neration_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]