In [None]:
import datasets
from openai import OpenAI
from tqdm.notebook import tqdm

In [None]:
data = datasets.load_dataset('social_bias_frames')

In [None]:
import os

os.environ['OPENAI_API_KEY'] = 'sk-xxx'

client = OpenAI()

In [None]:
to_convert = {}

for x in data['train']:
    
    if x['targetMinority'] in to_convert:
        to_convert[x['targetMinority']].add(x['targetStereotype'])
    else:
        to_convert[x['targetMinority']] = set([x['targetStereotype']])

In [None]:

i = 0
for k, v in tqdm(to_convert.items()):

    for s in tqdm(list(v)[:50]):
        print('Target minority: {0}\nStereotype: {1}\n\n'.format(k, s))
        i +=1 
        if i > 500:
            break
    if i > 500:
        break

In [None]:
system_prompt = "You are a helpful assistant tasked with converting social stereotypes into a dynamic knowledge graph." + \
    "This is for a research project analyzing common social biases.\n\n" + \
    "Please present the graph in triplets in the format `(start_node, edge, end_node)` where nodes are subjects and an edge is the relationship between them." + \
    "Please have the start_node be a target group where possible. Also, please decompose compound target groups.\n\n" + \
    "The input will be in the format\n" + \
    "Target minority: <the target group>\n" + \
    "Stereotype: <the social stereotype, that may or may not also contain target groups>\n" + \
    "----------\n" + \
    "Examples:\n" + \
    "Target minority: black folks\n" + \
    "Stereotype: black women are basic\n\n" + \
    "Graph:\n" + \
    "`(black women, are, basic)`\n" + \
    "`(black women, belong to, black people)`\n" + \
    "`(black women, belong to, women)`"

In [None]:

# This code is for v1 of the openai package: pypi.org/project/openai

stereotypes = []
for k, v in tqdm(to_convert.items()):

    for s in tqdm(list(v)[:50]):
        if len(s) > 0:
            response = client.chat.completions.create(
                  model="gpt-4",
                  messages=[
                    {
                      "role": "system",
                      "content": system_prompt
                    },
                    {
                      "role": "user",
                      "content": 'Target minority: {0}\nStereotype: {1}'.format(k, s)
                    }
                  ],
                  temperature=1,
                  max_tokens=256,
                  top_p=1,
                  frequency_penalty=0,
                  presence_penalty=0
                )
#             if 
            stereotypes.append(response)

In [None]:
import re

st_cleaned = []

for s in stereotypes:
    if type(s) is tuple:
        temp = s[0].choices[0].message.content
    else:
        temp = s.choices[0].message.content
#     print(temp)
    st_cleaned.append(re.findall('`(.*)`', temp.split(':')[-1]))
    
st_cleaned = set([tuple(x.replace('(', '').replace(')', '').split(', ')) for y in st_cleaned for x in y if len(x) > 0])

In [None]:
import pickle as pkl

with open('triplet_dump.pkl', 'wb') as f:
    pkl.dump(st_cleaned, f)

In [None]:
edge_types = [x[1] for x in st_cleaned if len(x) == 3]

In [None]:
import dgl

kg = {}

for e in edge_types:
    kg[('node', e, 'node')] = []

In [None]:
import re

nodes_all = []
prefix_all = []

to_ignore = ['people']

for x in tqdm(st_cleaned):
#     print(x[0])
    if len(x) == 3:
        if len(re.sub(r'[^a-zA-Z]', '', x[0])) == 0 or len(re.sub(r'[^a-zA-Z]', '', x[2])) == 0:
            continue
        elif x[0] in to_ignore or x[2] in to_ignore:
            continue
        else:
        
            temp = x[0].replace('folks', 'people').replace('race', 'people')
            try:
                start_node = nodes_all.index(temp)
            except ValueError as e:
                start_node = len(nodes_all)
                nodes_all.append(temp)

            temp = x[2].replace('folks', 'people').replace('race', 'people')
            try:
                end_node = nodes_all.index(temp)
            except ValueError as e:
                end_node = len(nodes_all)
                nodes_all.append(temp)

            kg[('node', x[1], 'node')].append([start_node, end_node])

In [None]:

from dgl import heterograph

sbic_kg = heterograph(kg)

In [None]:
def subgraph(event, kg, query_e=edge_types):
    if type(event) is not int:
        i = nodes_all.index(event)
    else:
        i = event

    sub = dgl.in_subgraph(kg, nodes={'node': [i]})
    r = set()
    for e in query_e:
        edges = sub.edges(etype=('node', e, 'node'))

        for i in range(len(edges[0])):
#             print(nodes_all[edges[0][i]], e, nodes_all[edges[1][i]])
            r.add((nodes_all[edges[0][i]], e, nodes_all[edges[1][i]]))

    return r

In [None]:
nodes_all

In [None]:
temp = subgraph('r word people', sbic_kg)

In [None]:
temp

In [None]:
sbic_kg