# Part 2 - Mitigating bias in text-image models.


In [None]:
!pip install --no-build-isolation --force-reinstall \
    dependencies/awscli-*-py3-none-any.whl \
    dependencies/boto3-*-py3-none-any.whl \
    dependencies/botocore-*-py3-none-any.whl

In [None]:
!pip install langchain --quiet

In [None]:
# !pip install --upgrade pip --quiet
# !pip install protobuf==3.20 --quiet

In [None]:
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

In [None]:
import json
import os
import sys
from PIL import Image
from typing import Union
import io
import base64
import boto3

module_path = ".."
sys.path.append(os.path.abspath(module_path))
from utils import bedrock
boto3_bedrock = bedrock.get_bedrock_client()

In [None]:
def invoke_image_gen(prompt, negative_prompts):
    request = json.dumps({
    "text_prompts": (
        [{"text": prompt, "weight": 1.0}]
        + [{"text": negprompt, "weight": -1.0} for negprompt in negative_prompts]
    ),
    "cfg_scale": 5,
    "seed": 5450,
    "steps": 70,
    "style_preset": style_preset,
    })
    modelId = "stability.stable-diffusion-xl"

    response = boto3_bedrock.invoke_model(body=request, modelId=modelId)
    response_body = json.loads(response.get("body").read())

    print(response_body["result"])
    base_64_img_str = response_body["artifacts"][0].get("base64")
    print(f"{base_64_img_str[0:80]}...")
    return base_64_img_str
    

In [None]:
prompt = "picture of a doctor"
negative_prompts = ["bias", 
                    "discriminatory",
                    "poorly rendered",
                    "poor background details",
                    "poorly drawn features",
                    "disfigured features",
                   ]
style_preset = "photographic" 
output = invoke_image_gen(prompt, negative_prompts)


In [None]:
image_1 = Image.open(io.BytesIO(base64.decodebytes(bytes(output, "utf-8"))))
image_1

In [None]:
prompt_gender_inclusive = "a picture of a doctor which is inclusive of male and female."

output = invoke_image_gen(prompt_gender_inclusive, negative_prompts)
image_1 = Image.open(io.BytesIO(base64.decodebytes(bytes(output, "utf-8"))))
image_1

In [None]:
text_gender_color_inclusive = "a picture of a doctor which is inclusive of male, female and color"

output = invoke_image_gen(text_gender_color_inclusive, negative_prompts)
image_1 = Image.open(io.BytesIO(base64.decodebytes(bytes(output, "utf-8"))))
image_1

In [None]:
text_base = "an picture of a nurse"
negative_prompts = ["bias and discrimination against certain group of people"]

output = invoke_image_gen(text_base, negative_prompts)
image_1 = Image.open(io.BytesIO(base64.decodebytes(bytes(output, "utf-8"))))
image_1

### Bias mitigation
Steps:
- Use a LLM to generate prompts which are non-discriminatory and try to remove bias from the prompt. 
- Use the generated prompt to create an image. 

#### Step 1: Create a chatbot application to generate inclusive prompts. 
The chatbot will make sure to ask relevant questions to the user before generating the prompt for `text-image` model. 

In [None]:
import json
import os
import sys

import boto3

module_path = ".."
sys.path.append(os.path.abspath(module_path))
from utils import bedrock
boto3_bedrock = bedrock.get_bedrock_client()

In [None]:
from langchain.llms.bedrock import Bedrock
cl_llm = Bedrock(
    model_id="anthropic.claude-v1",
    client=boto3_bedrock,
    model_kwargs={"max_tokens_to_sample": 1000, "temperature":0},
)

In [None]:
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory()
conversation = ConversationChain(
    llm=cl_llm, verbose=True, memory=memory
)

print(conversation.predict(input="Hi there!"))

In [None]:
from langchain.memory import ConversationBufferMemory
from langchain import PromptTemplate

# turn verbose to true to see the full logs and documents
conversation= ConversationChain(
    llm=cl_llm, verbose=False, memory=ConversationBufferMemory() #memory_chain
)

# langchain prompts do not always work with all the models. This prompt is tuned for Claude
chat_prompt = PromptTemplate.from_template("""
You are a prompt generator, who generates prompts for text to image models.
The AI is not biased and does not discriminate against certain groups of people. 
If AI detects bias in the question, AI asks relevant questions based on gender, race and color before responding.
If the AI does not know the answer to a question, it truthfully says it does not know.
If the question has a class of human biengs AI will ask all of the following questions with examples before generating the prompt. 
<questions>
- What is the gender of the subject in the picture? (e.g. male, female, transgender etc.)
- What is the color of the subject in the image? (e.g. white, black, or brown etc.)
- What is the race of the subject in the image? (e.g. African-american, latino, indian, korean, chineese, asian, etc.)
</questions>
Current conversation:
{history}


Human: {input}


Assistant:
""")

conversation.prompt = chat_prompt

print(conversation.predict(input="Hi there!"))

In [None]:
print(conversation.predict(input="picture of a doctor."))

In [None]:
response = conversation.predict(input="Hispanic brown female")
print(response)

In [None]:
## optional to run, only if the model asks you for revised prompt you can uncomment and try it.
# response = conversation.predict(input="yes")
# print(response)

In [None]:
prompt = response
negative_prompts = [
    "poorly rendered",
    "poor background details",
    "poorly drawn figure",
    "disfigured features",
]
style_preset = "photographic" 

In [None]:
request = json.dumps({
    "text_prompts": (
        [{"text": prompt, "weight": 1.0}]
        + [{"text": negprompt, "weight": -1.0} for negprompt in negative_prompts]
    ),
    "cfg_scale": 5,
    "seed": 5450,
    "steps": 70,
    "style_preset": style_preset,
})
modelId = "stability.stable-diffusion-xl"

response = boto3_bedrock.invoke_model(body=request, modelId=modelId)
response_body = json.loads(response.get("body").read())

print(response_body["result"])
base_64_img_str = response_body["artifacts"][0].get("base64")
print(f"{base_64_img_str[0:80]}...")

In [None]:
import base64
import io
from PIL import Image

os.makedirs("data", exist_ok=True)
image_1 = Image.open(io.BytesIO(base64.decodebytes(bytes(base_64_img_str, "utf-8"))))
image_1.save("data/image_1.png")
image_1