In [None]:
from collections import defaultdict
from pprint import pprint
import json
import requests
import random
import openai
import time

In [None]:
def save_json(data, filepath=r'new_data.json'):
    with open(filepath, 'w') as fp:
        json.dump(data, fp, indent=4)

# From Sam: 
# I believe you have a different query approach. Feel free to switch to yours.
def request_chatgpt_gpt4(messages):
    url = "http://127.0.0.1:5000/event_hgraph"
    body = {"messages": messages}
    response = requests.post(url, json=body).json()
    gpt_response = response['choices'][0]['message']['content'].strip()
    return gpt_response


### Below is the prompts to generate topics

In [56]:
def clusterLabelToHyperedge(cluster_labels, partition, hyperedge_dict):
    reverse_partition = defaultdict(list)
    for node_id, cluster_label in partition.items():
        reverse_partition[str(cluster_label)].append(node_id)
    hyperedges = []
    for cluster_label in cluster_labels:
        for hyperedge_id in reverse_partition[cluster_label.split("-")[2]]:
            hyperedges.append(hyperedge_dict[hyperedge_id])

    return hyperedges

def query_leaf_topic(nodes, node_type):
    if node_type == 'article':
        example = json.load(open(r'data/result/AllTheNews/cluster_summary/example_article.json'))
        summaries = [node['summary'] for node in nodes]
        summaries_message = ""
        for index, summary in enumerate(summaries):
            summaries_message += "Article {}: \n".format(index+1)
            summaries_message += summary + '\n\n\n'
        messages = [
            { 
                "role": "system", 
                "content": """
                    You are a news article summarization system. 
                    The user will provide you with a set of summarized news articles, your job is to further summarize them into one noun phrase.
                    Use words that are already in the articles, and try to use as few words as possible.
                """
            },
            { "role": "system", "name": "example_user", "content": example['leaf']['summaries']},
            { "role": "system", "name": "example_system", "content": example['leaf']['topic']},
            { "role": "user", "content": summaries_message}
        ]
        topic = request_chatgpt_gpt4(messages)
        return topic
    else:
        example = json.load(open(r'data/result/AllTheNews/cluster_summary/example_entity.json'))
        if len(nodes) > 20:
            messages = [
            { 
                "role": "system", 
                "content": """
                    You are an entity summarization system.
                    The user will provide you with a list of entities, they can be people, places, or things.
                    The user wants to get a gist of what entities are in the list.
                    First, split the entities into different categories.
                    Then, assign each category a human-readable name.
                    If entities in a category are all related to a specific entity, use that entity as the category.
                    Reply with the following format:
                    Category 1, Category 2, Category 3, ...
                """
            },
            # example 0
            { "role": "system", "name": "example_user", "content": 
            """
            Entities: {} \n
            """.format(example['non-leaf'][0]['entities'])
            },
            { "role": "system", "name": "example_system", "content": example['non-leaf'][0]['category']},
            # example 1
            { "role": "system", "name": "example_user", "content": 
            """
            Entities: {} \n
            """.format(example['non-leaf'][1]['entities'])
            },
            { "role": "system", "name": "example_system", "content": example['non-leaf'][1]['category']},
            { "role": "user", "content": 
            """
            Entities: {} \n
            """.format(", ".join(nodes))
            }
        ]
        else:        
            messages = [
                { 
                    "role": "system", 
                    "content": """
                        You are an entity summarization system.
                        The user will provide you with a list of entities, they can be people, places, or things.
                        The user wants to get a gist of what entities are in the list.
                        Pick out a few entities that best represents the list.
                        Avoid picking out overlapping entities.
                        Do not pick more than 3 entities.
                    """
                },
                # example 1
                { "role": "system", "name": "example_user", "content": 
                """
                Entities: {} \n
                What kinds of entities are there? \n
                """.format(example['leaf'][0]['entities'])
                },
                { "role": "system", "name": "example_system", "content": example['leaf'][0]['category']},
                # example 2
                { "role": "system", "name": "example_user", "content": 
                """
                Entities: {} \n
                What kinds of entities are there? \n
                """.format(example['leaf'][1]['entities'])
                },
                { "role": "system", "name": "example_system", "content": example['leaf'][1]['category']},
                # example 3
                # { "role": "system", "name": "example_user", "content": 
                #  """
                #  Entities: {} \n
                #  What kinds of entities are there? \n
                #  """.format(example['leaf'][2]['entities'])
                # },
                # { "role": "system", "name": "example_system", "content": example['leaf'][2]['category']},
                # user input
                { "role": "user", "content": 
                """
                Entities: {} \n
                What kinds of entities are there? \n
                """.format(", ".join(nodes))
                }
            ]
        topic = request_chatgpt_gpt4(messages)
        return topic

def query_cluster_topic(cluster_subtopics, cluster_samples, node_type):
    if node_type == 'article':
        example = json.load(open(r'data/result/AllTheNews/cluster_summary/example.json'))
        query = "Sub-topics: "
        sample_summaries = ""
        query += ", ".join(cluster_subtopics) + '\n\n\n'
        for index, cluster_sample in enumerate(cluster_samples):
            sample_summaries += "Article {}: \n".format(index+1)
            sample_summaries += cluster_sample['summary'] + '\n\n\n'

        messages = [
            { 
                "role": "system", 
                "content": """
                    You are a news article categorization system. 
                    The user will provide you with a list of sub-topics of news articles and a few examples from the sub-topics.
                    Your job is to further categorize the sub-topics into a single noun-phrase that best summarizes all the sub-topics.
                    Try to reuse the words in the examples.
                """
            },
            { "role": "system", "name": "example_user", "content": example['non-leaf']['summaries']},
            { "role": "system", "name": "example_system", "content": example['non-leaf']['topic']},
            { "role": "user", "content": query}
        ]
        topic = request_chatgpt_gpt4(messages)
        return topic
    else:
        return

def add_hierarchical_topic(hierarchy, partitions, hyperedge_dict, topic_dict, filepath, sampleFlag=True):
    dfs(hierarchy, partitions, hyperedge_dict, topic_dict, filepath, sampleFlag)
    return topic_dict

def dfs(hierarchy, partitions, hyperedge_dict, topic_dict, filepath, sampleFlag=True):
    level = int(hierarchy['key'].split('-')[1])
    if level == 1: # at level 1, use the children (leaf nodes) to generate a topic
        # collect the leaf node summaries
        children_labels = list(map(lambda x: x['key'], hierarchy['children']))
        hyperedges = clusterLabelToHyperedge(children_labels, partitions[0], hyperedge_dict)
        # entity
        entity_titles = [entity['title'] for entity in hyperedges]
        if hierarchy['key'] in topic_dict: return # if already have a topic, skip. This happens when continuing from a break point

        # generate the topic
        gpt_topic = query_leaf_topic(entity_titles, node_type='entity')
        # record the result
        topic_dict[hierarchy['key']] = gpt_topic
        save_json(topic_dict, filepath)
        print(hierarchy['key'], gpt_topic)
        return
    else:
        sub_topic_samples = [] # samples from the sub-topics
        all_hyperedges = []
        if hierarchy['key'] in topic_dict: return # if already have a topic, skip. This happens when continuing from a break point
        # standard dfs
        for child in hierarchy['children']:
            dfs(child, partitions, hyperedge_dict, topic_dict, filepath, sampleFlag)
            # sample from the sub-topics
            level = int(child['key'].split('-')[1])
            hyperedges = clusterLabelToHyperedge([child['key']], partitions[level], hyperedge_dict)
            if sampleFlag:
                sample = hyperedges[0]
                sub_topic_samples.append(sample)
            all_hyperedges += hyperedges
        # use the sub-topics and samples to generate a topic for the current node
        # article
        # cluster_subtopics = [topic_dict[child['key']] for child in hierarchy['children']]

        # entity
        cluster_subtopics = [topic_dict[child['key']].split(",") for child in hierarchy['children']]
        cluster_subtopics = [item.strip() for sublist in cluster_subtopics for item in sublist] # flatten
        if sampleFlag:
            sample_hyperedges = random.sample(all_hyperedges, min(20, len(all_hyperedges)))
        # generate the topic
        # article
        # gpt_topic = query_cluster_topic(cluster_subtopics, sample_hyperedges)
        # entity
        print("non-leaf: ", cluster_subtopics)
        gpt_topic = query_leaf_topic(cluster_subtopics, node_type='entity')
        # record the result
        topic_dict[hierarchy['key']] = gpt_topic
        save_json(topic_dict, filepath)
        print(hierarchy['key'], gpt_topic)
        return

### Below is where the main function begins
### I've helped you change the variable names (from hyperedges to entities) and kept the original line. 
### You should know that when I wrote the code, hyperedges == articles.
### In your context, you will operate on entities instead.
### This is just for your reference when you want to understand the testing/debugging codes
### Be aware that in the dfs functions, variables of 'hyperedges' is not renamed to 'entities'. 
### Rename them if you feel confused.

In [None]:
# 1. Read in hierarchy and partition
# entity
hierarchy = json.load(open('data/result/AllTheNews/network/server/ravasz_hierarchies_entity.json'))
partitions = json.load(open('data/result/AllTheNews/network/server/ravasz_partitions_entity.json'))
# article
# hierarchy = json.load(open('data/result/AllTheNews/network/server/ravasz_hierarchies_article.json'))
# partitions = json.load(open('data/result/AllTheNews/network/server/ravasz_partitions_article.json'))

In [59]:
# 2. Read in entities. 
entities_dict = json.load(open('data/result/AllTheNews/network/entities.json'))
# hyperedges_dict = json.load(open('data/result/AllTheNews/network/hyperedges.json')) # the original line

# 3. generate topic. hierarchical_topics.json should be empty at first
# entity
topic_dict = json.load(open('data/result/AllTheNews/hierarchical_topics_entities_raw.json'))
breakpoint_filepath = 'data/result/AllTheNews/hierarchical_topics_entities_raw.json'

# article
# topic_dict = json.load(open('data/result/AllTheNews/hierarchical_topics_articles.json'))
# topic_dict = add_hierarchical_topic(hierarchy, partitions, hyperedges_dict, topic_dict)

In [60]:
topic_dict = add_hierarchical_topic(hierarchy, partitions, entities_dict, topic_dict, breakpoint_filepath, sampleFlag=False)

non-leaf:  ['Missouri', 'Mississippi', 'Kansas', 'Gloucester County Public Schools', "Prince George's County Public Schools", 'Joint Base Andrews', 'New Orleans', 'Louisiana', 'Washington (state)', 'D.C', 'Atlanta', 'Michigan', 'Baltimore', 'Maryland', 'Pennsylvania', 'Pittsburgh', 'Philadelphia', 'Pennsylvania woman', 'Andrea Constand v. William H. Cosby', 'Jr.', 'Clinton', 'Maryland', 'United States presidential elections in Missouri', 'United States presidential elections in Alabama', 'Cora Faith Walker', 'Springfield', 'Massachusetts', 'Boston', 'Massachusetts', 'Missouri State University', '2015–2016 University of Missouri protests', 'University of Missouri']
L-2-108 Places (Cities and States), Educational Institutions
non-leaf:  ['Illinois', 'Indiana', 'Iowa', 'Standing Rock Indian Reservation', 'Wyoming', 'Milwaukee', 'University of Wisconsin–Madison', 'Janesville', 'Wisconsin']
L-2-490 Illinois, Indiana, Iowa
non-leaf:  ['California', 'Colorado', 'South Dakota', 'South Carolina

### Below are testing/debugging functions. 

In [None]:
clusters = defaultdict(list)
for node_id, cluster_id in partitions[1].items():
    clusters[cluster_id].append(node_id)
hyperedges_2344 = [hyperedges_dict[hyperedge_id] for hyperedge_id in clusters[2344]]
query_leaf_topic(hyperedges_2344)

In [None]:
# target_cluster = hierarchy['children'][0]['children'][0]['children'][1]['children'][1]
target_cluster = hierarchy['children'][2]['children'][1]['children'][1]['children'][0]
len(target_cluster), [cluster['key'] for cluster in target_cluster], target_cluster 

In [57]:
target_cluster = hierarchy['children'][2]['children'][1]['children'][1]['children'][0]
topics = json.load(open('data/result/AllTheNews/hierarchical_topics_entities_mod.json'))
cluster_children = [child['key'] for child in target_cluster['children']]
sub_topic_samples = []
# topic_dict = {}
# for child in target_cluster['children']:
    # # if len(child['children']) > 10 or len(child['children']) < 3: 
    # #     continue
    # level = int(child['key'].split('-')[1])
    # entities = clusterLabelToHyperedge([child['key']], partitions[level], entities_dict)
    # entity_titles = [entity['title'] for entity in entities]
    # print(target_cluster['key'], topics[target_cluster['key']], child['key'], len(entities), len(child['children']))
    # print(", ".join(entity_titles))
    # sub_topic = topics[child['key']]
    # print(sub_topic, new_sub_topic)
    # print("----------------------")
    # topic_dict[child['key']] = sub_topic
    # sample = hyperedges[0]
    # sub_topic_samples.append(sample)
cluster_subtopics = [topic_dict[child['key']].split(",") for child in target_cluster['children']]
cluster_subtopics = [item.strip() for sublist in cluster_subtopics for item in sublist] # flatten
new_sub_topic = query_leaf_topic(cluster_subtopics, 'entity')
print(target_cluster['key'], topic_dict[target_cluster['key']])
print(cluster_subtopics)
print(new_sub_topic)
print("---------------------")
# pprint(topic_dict)

L-2-4 The entities you provided can be categorized as US airline companies.
['New York (state)', 'Texas', 'Lackland Air Force Base', 'United States', 'Mexico', 'Canada', 'Japan', 'North Korea', 'South Korea', 'Myanmar', 'Madaya', 'Vietnam', 'Jordan', 'Aung San Suu Kyi', 'China', 'Taiwan', 'United States trademark law', 'America', 'Americans', 'US', 'Radio in the United States', 'Society of the United States', 'Britain', 'United States customary units', 'Common Core State Standards Initiative', 'Pentagon', 'Pakistan', 'India', 'US-Cuba relations', 'China-United States relations', 'Japanese cuisine', 'Israeli cuisine', 'New Guinea', 'Kaka Point', 'National Aquarium of New Zealand', 'Anthem', 'Ohio', 'Portland', 'Oregon', 'Vietnam War', 'receiving countries', 'Four Asian Tigers', 'other industrialized countries', 'Same-sex marriage in Nevada', 'Same-sex marriage in Connecticut', 'Same-sex marriage in Belgium.', 'Turkish nationality law', 'Philippine nationality law', 'British nationality 

In [None]:
children = [752, 1069, 1070, 1478]
# cluster_subtopics = [
#     'Increasing Gun Violence in Chicago',
#     'Crime Rates and Policing Tactics',
#     'Misconceptions about Crime in the United States',
#     'Global Events and Optimism',
# ]
cluster_subtopics = [topic_dict["L-1-{}".format(cluster_label)] for cluster_label in children]
cluster_samples = []
clusters = defaultdict(list)
for node_id, cluster_label in partitions[1].items():
    clusters[cluster_label].append(node_id)
for cluster_label in children:
    cluster_samples += clusters[cluster_label]
cluster_samples = random.sample(cluster_samples, 10)
cluster_samples = [hyperedges_dict[sample] for sample in cluster_samples]
topic = query_cluster_topic(cluster_subtopics, cluster_samples)
sample_summaries = [sample['summary'] for sample in cluster_samples]
print(topic)
pprint(cluster_subtopics)
for summary in sample_summaries:
    print(summary)


In [None]:
# functions for generating few-show examples for the prompt
sample_summaries = "Sub-topics: "
sample_summaries += ", ".join(cluster_subtopics) + '\n\n\n'
for index, cluster_sample in enumerate(cluster_samples):
    sample_summaries += "Article {}: \n".format(index+1)
    sample_summaries += cluster_sample['summary'] + '\n\n\n'
example = json.load(open(r'data/result/AllTheNews/cluster_summary/example.json'))
example['non-leaf']['summaries'] = sample_summaries
example['non-leaf']['topic'] = 'Crimes in the United States'
save_json(example, r'data/result/AllTheNews/cluster_summary/example.json')

In [58]:
topics = json.load(open('data/result/AllTheNews/hierarchical_topics_entities_raw.json'))
level_1_topics = {}
for cluster, topic in topics.items():
    level = int(cluster.split("-")[1])
    if level < 2:
        level_1_topics[cluster] = topic
save_json(level_1_topics, 'data/result/AllTheNews/hierarchical_topics_entities_raw.json')
