-
Notifications
You must be signed in to change notification settings - Fork 0
/
add_diversity.py
80 lines (66 loc) · 2.08 KB
/
add_diversity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from openai import OpenAI
# Your API key must be saved in an env variable for this to work.
client = OpenAI()
# Get a prompt, embed it into a classification request to GPT
initial_prompt = input("Type your prompt:\n\n")
human_involved_prompt = f'''
Does the following text include a person? Answer "yes" or "no":
"{initial_prompt}"
'''
# For debugging and transparency
print(f"Querying for:\n{human_involved_prompt}")
# Send it to GPT
response_one = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": human_involved_prompt
}],
temperature=1,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
# Extract the yes/no answer
gpt_yes_no = response_one.choices[0].message.content
print("GPT on containing humans:", gpt_yes_no)
mentions_humans = "yes" in gpt_yes_no.lower()
# Prepare the next query.
modify_request_prompt = f'''
Produce three copies of the following prompt by copying it then adding demographic information about the people involved. Make each persons sex, ethnicity, and age different. Make each description 35 words or less:
{initial_prompt}
'''
# Either request the three variations, or not.
if mentions_humans:
response_two = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": modify_request_prompt
}],
temperature=1,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
# Extract the response and separate the three variations
modified_prompts = response_two.choices[0].message.content
image_prompts = list(filter(None, modified_prompts.split("\n")))
else:
image_prompts = [initial_prompt]
# Draw the images
for img_prompt in image_prompts:
print(img_prompt)
# Ask for images
response_three = client.images.generate(
model="dall-e-3",
prompt=img_prompt,
size="1024x1024",
quality="standard",
n=1
)
image_url = response_three.data[0].url
print(image_url)
print()