In [None]:
import os
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig

### Helper Functions

In [None]:
def save_result(path_name, file_name, result):
    """
    Save the result to a file.
    
    Args:
        path_name (str): Path to save the file
        file_name (str): Name of the file
        result (str): Content to save
    """
    file_path = os.path.join(path_name, f"{file_name}.md")
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    
    with open(file_path, 'w') as file:
        file.write(result)

def load_questions_from_file(file_path):
    """
    Load questions from a JSON file.
    
    Args:
        file_path (str): Path to the JSON file

    Returns:
        dict: Loaded questions
    """
    with open(file_path, 'r') as file:
        questions = json.load(file)
    return questions

def extract_yes_no(response):
    """
    Extract YES or NO from the model's response.
    
    Args:
        response (str): Model's response

    Returns:
        str: 'YES' or 'NO'

    Raises:
        ValueError: If response doesn't contain 'YES' or 'NO'
    """
    response = response.strip().upper()
    if "YES" in response:
        return "YES"
    elif "NO" in response:
        return "NO"
    else:
        raise ValueError("Unexpected response from LLM")

def generate_response(model, tokenizer, messages, config):
    """
    Generate a response from the model.
    
    Args:
        model: Loaded LLM model
        tokenizer: Loaded tokenizer
        messages (list): List of message dictionaries
        config: Model configuration

    Returns:
        str: Generated response
    """
    model_inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to('cuda')

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    generated_ids = model.generate(
        model_inputs,
        do_sample=True,
        temperature=0.1,
        eos_token_id=terminators,
        pad_token_id=tokenizer.eos_token_id,
        max_length=4500
    )

    response = generated_ids[0][model_inputs.shape[-1]:]
    result = tokenizer.decode(response, skip_special_tokens=True)
    
    print(f"Tokens used: {len(model_inputs[0])} out of {config.max_position_embeddings}")
    
    return result

def setup_environment():
    """Set up the environment variables."""
    try:
        hf_key = os.environ["LLAMA3_KEY"]
    except KeyError:
        print("Please set the environment variable LLAMA3_KEY")
        hf_key = input("Enter your HuggingFace API key: ")
    return hf_key

def load_model(hf_key):
    """Load the LLM model and tokenizer."""
    quant_config = BitsAndBytesConfig(load_in_8bit=True)
    model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_key)
    model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_key, quantization_config=quant_config, device_map="auto", max_length=8192)
    config = AutoConfig.from_pretrained(model_name, token=hf_key)
    
    return model, tokenizer, config


## 1. Set up the model

This step includes downloading the LLM from Hugging Face and setting up the API key.

You can temporarily add your API key by using Python with os.environ["LLAMA3_KEY"]="YOUR_API_KEY"

WARNING: NEVER SHARE/PUSH YOUR API KEY.

In [None]:
hf_key = setup_environment()
model, tokenizer, config = load_model(hf_key)

## 2. Set file directories

In [None]:
# Set the CWD to the directory of your repository
os.chdir('/home/gunes/fair/responsible-ai-model-cards/Steps-v3-thesis')

### Individual case to analyze

In [None]:
case_name = "credit"

file_path = f"files/cases/{case_name}/description.md"
case_description = open(file_path, "r").read()

decision_tree = load_questions_from_file("files/fairness-tools/decision_tree.json")
mitigation_guide = open("files/fairness-tools/mitigation.md", "r").read()

user_mitigation_method  = ""  # Leave empty if you want the LLM to generate a mitigation method

## 3. Start Analysis

In [None]:
messages = [
    {
        "role": "system", 
        "content": """You are an expert in fairness and responsible AI. Your task is to respond to a question given the description. For operations involving multiple values, perform the operations one by one. Justify your answer and at the end present your response as YES or NO."""
    },
    {
        "role": "user", 
        "content": "Description: {case_description}\n\nQuestion: {question}"
    },
]

# Store the original template
original_template = "Description: {case_description}\n\nQuestion: {question}"

# First question
question = decision_tree["Q1"]["question"]
messages[1]["content"] = original_template.format(case_description=case_description, question=question)


In [None]:
result1 = generate_response(model, tokenizer, messages, config)
print(result1)

In [None]:
# Process First Response and Generate Second Question
response = extract_yes_no(result1)
question_number = decision_tree["Q1"]["choices"][response]["result"]
question = decision_tree[question_number]["question"]

print(f"Next question ({question_number}): {question}")
    
messages[1]["content"] = original_template.format(case_description=case_description, question=question)

In [None]:
result2 = generate_response(model, tokenizer, messages, config)
print(result2)

In [None]:
# Process Second Response
response = extract_yes_no(result2)

fairness_metric = decision_tree[question_number]["choices"][response]["result"]
print(fairness_metric)

In [None]:
%%time

mitigation_prompt = f""" You are an AI expert in fairness and responsible AI. Your task is to suggest an appropriate mitigation method from the provided mitigation guide and explain your reasoning.

If a mitigation method is specified by the user, use that method and add the appropriate constraint to the method. 
If there is no mitigation method is specified by the user, based on the following case description and fairness analysis results.

At the end of your response, present the mitigation method as Suggested Mitigation Method: mitigation_method."""

messages = [
    {"role": "system", "content": mitigation_prompt},
    {"role": "user", "content": f"""Case Description: {case_description}

    Fairness Analysis Results:
    Question 1: {decision_tree["Q1"]["question"]}
    Answer 1: {result1}

    Question 2: {question}
    Answer 2: {result2}

    Suggested Fairness Metric: {fairness_metric}

    Requested Mitigation Method: {user_mitigation_method}

    Mitigation Guide:
    {mitigation_guide}
    """}
]

mitigation_result = generate_response(messages)
print(mitigation_result)

In [None]:
overall_analysis = f"""
## Fairness Analysis Results

### Question 1: {decision_tree["Q1"]["question"]}
{result1}

### Question 2: {question}
{result2}

## Suggested Fairness Metric
{fairness_metric}

## Suggested Mitigation Method
{mitigation_result}
"""

print(overall_analysis)

In [None]:
case_path = f"files/cases/{case_name}/"

# Save overall analysis
save_result(case_path, "fairness-analysis", overall_analysis)