In [2]:
import openai
import json
import numpy as np
from numpy.linalg import norm

api_key = open("api_key").read()
openai.api_key = api_key

In [16]:
def get_embedding(text, model="text-embedding-ada-002"):
   text = text.replace("\n", " ")
   return openai.Embedding.create(input = [text], model=model)['data'][0]['embedding']

def cos_sim(a, b):
   return np.dot(a, b)/(norm(a)*norm(b))

def save_json(data, filepath=r'new_data.json'):
   with open(filepath, 'w') as fp:
      json.dump(data, fp, indent=4)

In [9]:
def merge_paragraphs(sentences):
    sentence_list = [" ".join(sentence_word_list) for sentence_word_list in sentences] # merge the words into sentences
    paragraph = " ".join(sentence_list)
    return paragraph

In [10]:
event_dict = json.load(open("data/result/RAMS/gpt_biHgraph_dev/hyperedges.json"))

candidates = [
    '0-Compare-Require-Inable-Problem',
    '1-Occurred',
    '2-Involve-Investigate-Lose',
    '3-Bombing-Surrender-Celebrate',
    '4-Appointment-Question-Lack-Link',
    '5-Mention-Criticize',
    '6-Describe-Imprison-Convict-Punish',
]
# candidates = ['skirmish-Fans-Fans-French port city of Marseille-bottles , chairs and other objects', 'surrendered-Q17', 'bought-appointment-commodities trader', 'agreement-Obama administration-Q212', 'walked-Q868772-we', 'guilty-Q3702004-pay - for - play campaign finance scheme', 'convicted-6,267 inmates-Q1886845-drug - related crimes']
event_data_list = [event_dict[candidate] for candidate in candidates]

In [11]:
candidate_paragraphs = { candidate: merge_paragraphs(event_dict[candidate]['content']) for candidate in candidates }

In [17]:
# embeddings_dict = {candidate: get_embedding(candidate_paragraphs[candidate]) for candidate in candidates}
for event_id, event in event_dict.items():
    print(event_id, len(event_dict))
    event['embedding'] = get_embedding(merge_paragraphs(event['content']))
save_json(event_dict, "data/result/RAMS/gpt_biHgraph_dev/hyperedges_w_embeddings.json")

0-Compare-Require-Inable-Problem 772
1-Occurred 772
2-Involve-Investigate-Lose 772
3-Bombing-Surrender-Celebrate 772
4-Appointment-Question-Lack-Link 772
5-Mention-Criticize 772
6-Describe-Imprison-Convict-Punish 772
9-Break into-Sell-Belong-Claim 772
11-Detain-Fire 772
12-Dismissal-Replace-Oppose-Convince 772
13--[Note: I have used "Explain" and "Criticize" as triggers-Criticize-Explain-Mention 772
15-Claim 772
16-Compare-Require-Inable-Problem 772
17-Owe-Criticism 772
19-Resentment-End-Unable-Support 772
20-Criticism-Speculate-Sidelining 772
22-Return-Defect-Provide 772
23-Occur-Demonstrate-Show 772
24-Fear-Warned-Find 772
25-Lobbying activities-Charity-Conference call 772
26-Suggest-Critcize 772
27-Involve-Tensions-Make 772
28-Rising-Seen-Suffering 772
29-Invasion-Lead to-Discuss-Mention 772
30-Create-Target-Shot-Homicides-Increase 772
32-Need-Not Reason-Not Avoiding-Cause 772
33-Breakdown-Cause-Reduce 772
34-Claim-Rise-Lead-Cause 772
37-Found-Impersonation 772
39-Harass-Adopt-Discr

In [13]:
pairs = [(a, b) for idx, a in enumerate(candidates) for b in candidates[idx + 1:]]
distances = [cos_sim(embeddings_dict[a], embeddings_dict[b]) for a, b in pairs]

In [14]:
for idx, pair in enumerate(pairs):
    print(pair, distances[idx])

('0-Compare-Require-Inable-Problem', '1-Occurred') 0.7323959316776886
('0-Compare-Require-Inable-Problem', '2-Involve-Investigate-Lose') 0.7606094426085571
('0-Compare-Require-Inable-Problem', '3-Bombing-Surrender-Celebrate') 0.7294283318371837
('0-Compare-Require-Inable-Problem', '4-Appointment-Question-Lack-Link') 0.7907936443116823
('0-Compare-Require-Inable-Problem', '5-Mention-Criticize') 0.83052185844559
('0-Compare-Require-Inable-Problem', '6-Describe-Imprison-Convict-Punish') 0.744892401146626
('1-Occurred', '2-Involve-Investigate-Lose') 0.737870538774968
('1-Occurred', '3-Bombing-Surrender-Celebrate') 0.7418198663287412
('1-Occurred', '4-Appointment-Question-Lack-Link') 0.7305941830050053
('1-Occurred', '5-Mention-Criticize') 0.7329799735667253
('1-Occurred', '6-Describe-Imprison-Convict-Punish') 0.7707750971754709
('2-Involve-Investigate-Lose', '3-Bombing-Surrender-Celebrate') 0.6975508151855792
('2-Involve-Investigate-Lose', '4-Appointment-Question-Lack-Link') 0.812926269681