# Part 2 - Mitigating bias in text-image models.

> *This notebook should work well in the `Data Science 3.0` kernel on Amazon SageMaker Studio*

In [None]:
%pip install --quiet "boto3>=1.28.63,<2" "botocore>=1.31.63,<2" langchain==0.0.336

In [None]:
# Python Built-Ins:
import base64
import io
import json
import os
from typing import List

# External Dependencies:
import boto3  # AWS SDK for Python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from PIL import Image  # For processing and displaying images

boto3_bedrock = boto3.client("bedrock-runtime")
os.makedirs("data", exist_ok=True)  # We'll save our generated pictures here

In [None]:
image_model_id = "stability.stable-diffusion-xl"
image_model_config = {
    "cfg_scale": 5,
    "seed": 42,
    "steps": 70,
    "style_preset": "photographic",
}


def invoke_image_gen(prompts: List[str], negative_prompts: List[str], **kwargs,) -> Image.Image:
    request = json.dumps({
        **image_model_config,
        "text_prompts": (
            [{"text": prompt, "weight": 1.0} for prompt in prompts]
            + [{"text": negprompt, "weight": -1.0} for negprompt in negative_prompts]
        ),
        **kwargs,
    })

    response = boto3_bedrock.invoke_model(body=request, modelId=image_model_id)
    response_body = json.loads(response.get("body").read())

    print(response_body["result"])
    base_64_img_str = response_body["artifacts"][0].get("base64")
    print(f"{base_64_img_str[0:80]}...")
    return Image.open(io.BytesIO(base64.decodebytes(bytes(base_64_img_str, "utf-8"))))

In [None]:
negative_prompts = [
    "bias",
    "discriminatory",
    "poorly rendered",
    "poor background details",
    "poorly drawn features",
    "disfigured features",
]

image_1 = invoke_image_gen(["a doctor in a hospital"], negative_prompts)
image_1.save("data/image_1.png")
image_1

In [None]:
image_2 = invoke_image_gen(
    [
        "a doctor in a hospital",
        "inclusive of male and female",
    ],
    negative_prompts,
)
image_2.save("data/image_2.png")
image_2

In [None]:
image_3 = invoke_image_gen(
    [
        "a doctor in a hospital",
        "inclusive of male, female, and color",
    ],
    negative_prompts,
)
image_3.save("data/image_3.png")
image_3

In [None]:
image_4 = invoke_image_gen(
    ["a nurse in a hospital"],
    ["bias and discrimination against certain group of people"],
)
image_4.save("data/image_4.png")
image_4

### Bias mitigation
Steps:
- Use a LLM to generate prompts which are non-discriminatory and try to remove bias from the prompt. 
- Use the generated prompt to create an image. 

#### Step 1: Create a chatbot application to generate inclusive prompts. 
The chatbot will make sure to ask relevant questions to the user before generating the prompt for `text-image` model. 

In [None]:
from langchain.llms.bedrock import Bedrock

cl_llm = Bedrock(
    model_id="anthropic.claude-instant-v1",
    client=boto3_bedrock,
    model_kwargs={"max_tokens_to_sample": 1000, "temperature": 0.0},
)

In [None]:
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain import PromptTemplate


memory = ConversationBufferMemory(ai_prefix="Assistant")
# turn verbose to true to see the full logs and documents
conversation = ConversationChain(llm=cl_llm, verbose=False, memory=memory)

# langchain prompts do not always work with all the models. This prompt is tuned for Claude
chat_prompt = PromptTemplate.from_template("""
Human:
You are a prompt generator, who generates prompts for text to image models based on a user question.
You and the image AI are not biased and do not discriminate against certain groups of people.
If you detect bias in the question, ask relevant questions based on gender, race and color before
generating the prompt. If you don't know the answer to a question, truthfully say you don't know.
If the image generation question includes human beings, ask all of the following questions with
examples before generating the image prompt:

<questions>
- What is the gender of the subject in the picture? (e.g. male, female, transgender etc.)
- What is the color of the subject in the image? (e.g. white, black, or brown etc.)
- What is the race of the subject in the image? (e.g. African-american, latino, indian, korean,
  chineese, asian, etc.)
</questions>

When you are ready to generate the image prompt, return it in <imageprompt></imageprompt> XML tags.

Assistant:
OK, I understand

{history}

Human:
{input}

Assistant:
""")

conversation.prompt = chat_prompt

print(conversation.predict(input="Hi there!"))

In [None]:
print(conversation.predict(input="photo of a doctor."))

In [None]:
response = conversation.predict(input="Hispanic brown female")
print(response)

# Try to extract just the image prompt component of the response:
ix_prompt_start = response.find("<imageprompt>") + len("<imageprompt>")
ix_prompt_end = response.find("</imageprompt>", ix_prompt_start)
img_prompt = response[ix_prompt_start:ix_prompt_end].strip()
print("\n\n------------------------\n" + img_prompt)

In [None]:
image_disambiguated = invoke_image_gen([img_prompt], negative_prompts)
image_disambiguated.save("data/image_disambiguated.png")
image_disambiguated