# Import libraries

For openai>=1.0.0, openai.ChatCompletion is not supported anymore. Install an earlier version. Install cohere and tiktoken to address the Error message that recently came up:
<blockquote>ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.<br>
llmx 0.0.15a0 requires cohere, which is not installed.<br>
llmx 0.0.15a0 requires tiktoken, which is not installed.</blockquote>

In [None]:
!pip install openai==0.27.8 cohere tiktoken
!pip install python-dotenv

In [None]:
import inspect
import os
import dotenv
import openai
import pandas as pd
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)

# Mount drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create directories to store inferences

In [None]:
os.makedirs('/content/drive/MyDrive/datasets/SBIC_gpt-3.5-turbo-0301',
            exist_ok=True)
os.makedirs('/content/drive/MyDrive/datasets/SBIC_gpt-3.5-turbo-0613',
            exist_ok=True)

# Authentication

You have to first save your OpenAI API key in a .env file in your Google Drive, in the format:
OPENAI_API_KEY = '...'.

In [None]:
# OpenAI authentication on Google Colab. via
# https://stackoverflow.com/a/77166086
dotenv.load_dotenv('/content/drive/MyDrive/.env')
openai.api_key = os.getenv('OPENAI_API_KEY')

# Import dataset

In [None]:
SBIC_v2_agg_test = pd.read_csv(
    '/content/drive/MyDrive/datasets/SBIC.v2.agg.test.csv')
SBIC_v2_agg_test = SBIC_v2_agg_test.reset_index(drop=True)

# Set prompts

In [None]:
def system_content(persona):
    return inspect.cleandoc(
        f"""From now on, act as {persona}.
        Provide an answer that {persona} would regarding the post.""")

In [None]:
def user_content(post):
    return inspect.cleandoc(f"""Post: \"\"\"
    {post}
    \"\"\"

    Could this post be considered offensive, disrespectful, or toxic to
    anyone/someone?
    (Does this sound insulting, rude, toxic, disrespectful, abusive,
    inflammatory, hateful?)
    Assume the author belongs to majority groups (white, male, straight,
    abled, etc.). Try to answer this question while considering a broad set
    of people from different backgrounds, not just your own.

    A: Yes, this could be offensive
    B: Maybe, I'm not sure
    C: No, this is harmless

    Constraint: Answer with only the option above that is most accurate and
    nothing else.""")

# Define a function to call Chat Completion API with (persona) or without (no-persona default) the *system message*

In [None]:
# Prevent rate limit errors. via
# https://github.com/openai/openai-cookbook/blob/main/examples/
# How_to_handle_rate_limits.ipynb
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def classifier_with_backoff(user_cont, model, system_cont=None):
    if system_cont is None:
        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "user", "content": user_cont}
            ],
            temperature=0,
        )
    else:
        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "system", "content": system_cont},
                {"role": "user", "content": user_cont}
            ],
            temperature=0,
        )
    return response['choices'][0]['message']['content']

# Define a function to perform model inference on **SBIC**

Create a function `predict_q1a_labels` that takes a list of personas (e.g., ['a White/Caucasian person', 'a Black/African American']), a `topic` (e.g., a sociodemographic attribute such as *Race/Ethnicity*), and a GPT model name (e.g., gpt-3.5-turbo-0301) as required input arguments. The optional argument `base` determines if annotations made by the no-persona default should be included. For each persona in the list of personas, the function iterates over each post in **SBIC**, calling the function `classifier_with_backoff` for each post.  It then adds the respective lists of labels predicted by each simulated persona as new columns to the **SBIC** dataframe. The predicted labels are remapped to conform to how the data was originally coded in **SBIC**, e.g., `0.0` if `C: No, this is harmless`. The dataframe is saved in TSV file format, and the function returns the dataframe.

In [None]:
def predict_q1a_labels(persona_list, topic, model, base=None):
    df = SBIC_v2_agg_test.copy(deep=True)
    posts = df['post'].values
    # Remap predicted class labels. via
    # https://stackoverflow.com/a/20250996
    remap = {'A: Yes, this could be offensive': '1.0',
             'B: Maybe, I\'m not sure': '0.5',
             'C: No, this is harmless': '0.0'}
    if base == 'Yes':
        df['predicted_q1a_labels_base'] = [
            classifier_with_backoff(user_content(post),
                                    model) for post in posts]
        df.replace({'predicted_q1a_labels_base': remap}, inplace=True)
    for persona in persona_list:
        if 'an Arab' not in persona:
            persona += ' in the United States'
        df[f'predicted_q1a_labels_{persona}'.replace(' ', '_')] = [
            classifier_with_backoff(user_content(post),
                                    model,
                                    system_content(persona)) for post in posts]
        df.replace({f'predicted_q1a_labels_{persona}'.replace(' ', '_'): remap},
                   inplace=True)
    df.to_csv(
        f'/content/drive/MyDrive/datasets/SBIC_{model}/'
        f'SBIC_predictions_{topic}_{model}.tsv',
        sep='\t', index=False)
    return df

# Create dictionaries, each with a `topic` as the key and a list of personas as the value

In [None]:
race_ethnicity = {
    'race_ethnicity': ['an Asian/Asian American',
                       'a Black/African American',
                       'a Hispanic/Latinx person',
                       'a White/Caucasian person',
                       'a Native Hawaiian/Pacific Islander',
                       'a Native American/First Nations person']}

In [None]:
gender_genderidentity_sexuality_maj = {
    'gender_genderidentity_sexuality_maj': ['a man/male',
                                            'a woman/female',
                                            'a heterosexual man',
                                            'a heterosexual woman']}

In [None]:
gender_genderidentity_sexuality_minA = {
    'gender_genderidentity_sexuality_minA': ['a non-binary person',
                                             'a LGBTQIA+ person',
                                             'a lesbian woman',
                                             'a gay man',
                                             'a bisexual man',
                                             'a bisexual woman']}

In [None]:
gender_genderidentity_sexuality_minB = {
    'gender_genderidentity_sexuality_minB': ['an asexual man',
                                             'an asexual woman',
                                             'a transgender man',
                                             'a transgender woman',
                                             'a two-spirited person',
                                             'a queer person',
                                             'an intersex person']}

In [None]:
culture_origin_religion = {
    'culture_origin_religion': ['a Muslim',
                                'an atheist',
                                'a Buddhist',
                                'a Christian',
                                'a Hindu',
                                'a Jew',
                                'an Arab',
                                'an Arab in the United States',
                                'an American citizen',
                                'an immigrant',
                                'a refugee']}

In [None]:
age_body = {
    'age_body': ['an old person',
                 'a young adult',
                 'an overweight/fat person',
                 'an average weight person',
                 'a short person',
                 'a tall person',
                 'an ugly person',
                 'an attractive person',
                 'a pregnant person']}

In [None]:
mental_or_physical_disability = {
    'mental_physical_disability': ['a physically disabled person',
                                   'a person with physical illness/disorder',
                                   'a person with mental illness/disorder',
                                   'a mentally disabled person',
                                   'a non-disabled person']}

In [None]:
socio_economic_political_lifestyle = {
    'socio_economic_political': ['a poor person',
                                 'a middle-class person',
                                 'a rich person',
                                 'a conservative',
                                 'a moderate conservative',
                                 'a moderate liberal',
                                 'a liberal/progressive',
                                 'a libertarian',
                                 'a political independent',
                                 'a veteran']}

In [None]:
crime_violence_tragedy_victim = {
    'crime_violence_tragedy_victim': ['a climate disaster victim',
                                      'a mass shooting victim',
                                      'a terrorism victim',
                                      'an assault victim',
                                      'a harassment victim']}

# Call the `predict_q1a_labels` function

Running each of the following cells could take several hours. Colab Pro+ includes continuous code execution (capped at 24 hours) and background execution capabilities, enabling you to close your browser/device while your code runs. CPU is sufficient. Alternatively, you can download the notebooks and edit the code accordingly to point to the directories on your local machine you want to use.

In [None]:
predict_q1a_labels(
    list(race_ethnicity.values())[0],
    list(race_ethnicity.keys())[0],
    'gpt-3.5-turbo-0301',
    'Yes')

In [None]:
predict_q1a_labels(
    list(race_ethnicity.values())[0],
    list(race_ethnicity.keys())[0],
    'gpt-3.5-turbo-0613',
    'Yes')

In [None]:
predict_q1a_labels(
    list(gender_genderidentity_sexuality_maj.values())[0],
    list(gender_genderidentity_sexuality_maj.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(gender_genderidentity_sexuality_maj.values())[0],
    list(gender_genderidentity_sexuality_maj.keys())[0],
    'gpt-3.5-turbo-0613')

In [None]:
predict_q1a_labels(
    list(gender_genderidentity_sexuality_minA.values())[0],
    list(gender_genderidentity_sexuality_minA.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(gender_genderidentity_sexuality_minA.values())[0],
    list(gender_genderidentity_sexuality_minA.keys())[0],
    'gpt-3.5-turbo-0613')

In [None]:
predict_q1a_labels(
    list(gender_genderidentity_sexuality_minB.values())[0],
    list(gender_genderidentity_sexuality_minB.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(gender_genderidentity_sexuality_minB.values())[0],
    list(gender_genderidentity_sexuality_minB.keys())[0],
    'gpt-3.5-turbo-0613')

In [None]:
predict_q1a_labels(
    list(culture_origin_religion.values())[0],
    list(culture_origin_religion.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(culture_origin_religion.values())[0],
    list(culture_origin_religion.keys())[0],
    'gpt-3.5-turbo-0613')

In [None]:
predict_q1a_labels(
    list(age_body.values())[0],
    list(age_body.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(age_body.values())[0],
    list(age_body.keys())[0],
    'gpt-3.5-turbo-0613')

In [None]:
predict_q1a_labels(
    list(mental_or_physical_disability.values())[0],
    list(mental_or_physical_disability.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(mental_or_physical_disability.values())[0],
    list(mental_or_physical_disability.keys())[0],
    'gpt-3.5-turbo-0613')

In [None]:
predict_q1a_labels(
    list(socio_economic_political_lifestyle.values())[0],
    list(socio_economic_political_lifestyle.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(socio_economic_political_lifestyle.values())[0],
    list(socio_economic_political_lifestyle.keys())[0],
    'gpt-3.5-turbo-0613')

In [None]:
predict_q1a_labels(
    list(crime_violence_tragedy_victim.values())[0],
    list(crime_violence_tragedy_victim.keys())[0],
    'gpt-3.5-turbo-0301')

In [None]:
predict_q1a_labels(
    list(crime_violence_tragedy_victim.values())[0],
    list(crime_violence_tragedy_victim.keys())[0],
    'gpt-3.5-turbo-0613')