In [1]:
%load_ext autoreload
%autoreload 2

from langchain.prompts import Prompt, BaseChatPromptTemplate
from langchain.schema import BaseOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from tqdm import tqdm
import re
import pandas as pd
import dotenv
import random
import networkx as nx
from langchain.schema import (
    BaseMessage, 
    HumanMessage, 
    SystemMessage,
    AIMessage
)
dotenv.load_dotenv()

True

In [74]:
pairs_df = pd.read_csv('data/acm_ccs_clean.csv')

In [75]:
def get_nodes_edges(pairs_df, group):
    pairs = pairs_df[pairs_df['group'] == group]
    edges = []
    nodes = set()
    for i, row in pairs.iterrows():
        parent, child = row['parent'], row['child']
        parent = parent.replace('_', ' ')
        child = child.replace('_', ' ')
        edges.append({
            'parent': parent,
            'child': child,
        })
        nodes.add(parent)
        nodes.add(child)

    nodes = list(nodes)    
    # randomly shffle the nodes and relations to avoid learning pattern
    random.shuffle(edges)
    random.shuffle(nodes)

    return nodes, edges

def get_groups(pairs_df, split='train'):
    groups = pairs_df[pairs_df['type'] == split]['group'].unique()
    return groups

In [76]:


class TaxomomyPrompt(BaseChatPromptTemplate):
    def format_messages(self, **kwargs) -> list[BaseMessage]:
        group_examples = kwargs['group_examples']
        concepts = kwargs['concepts']

        prefix_prompt = (
            "You are an expert constructing a taxonomy from a list of concepts. Given a list of concepts,"
            "construct a taxonomy by creating a list of their parent-child relationships.\n\n"
        )

        prefix_message = SystemMessage(content=prefix_prompt)

        example_messages = []
        for nodes, edges in group_examples:
            node_prompt = '; '.join(nodes)
            edge_prompt = '; '.join([f"{edge['child']} is a subtopic of {edge['parent']}" for edge in edges])

            example_messages.append(
                HumanMessage(content=f"Concepts: {node_prompt}\nRelationships: ")
            )

            example_messages.append(
                AIMessage(content=f"{edge_prompt}\n\n")
            )
        
        concepts = '; '.join(concepts)
        question_message = HumanMessage(content=f"Concepts: {concepts}\nRelationships: ")


        return [prefix_message] + example_messages + [question_message]


In [78]:
few_shot_numbers = 5

train_groups = get_groups(pairs_df)
test_groups = get_groups(pairs_df, split='test')

def generate_train_examples(pairs_df, num_examples=5):
    random.shuffle(train_groups)
    train_group_examples = []
    for group in train_groups[:num_examples]:
        nodes, edges = get_nodes_edges(pairs_df, group)
        train_group_examples.append((nodes, edges))
    return train_group_examples
    

In [85]:
class TaxonomyParser(BaseOutputParser):
    def parse(self, output):
        output = output.split(';')
        output = [line for line in output if line != '']
        output = [re.split(' is a subtopic of ', line) for line in output]
        print(output)
        result = []
        for line in output:
            if len(line) != 2:
                continue
            result.append({
                'child': line[0].strip(),
                'parent': line[1].strip()
            })
        return result
    
def call_chain(chain, concepts, group_examples, num_retries=5):
    count = 0
    while count < num_retries:
        try:
            result = chain.run(
                concepts=concepts,
                group_examples=group_examples,
            )
            return result
        except Exception as e:
            print(e)
            count += 1
    raise Exception("Failed to generate result")

prompt = TaxomomyPrompt(
    input_variables=['concepts', 'group_examples'],
)
llm = ChatOpenAI(model='gpt-4')# ChatOpenAI(model='gpt-3.5-turbo')
chain = LLMChain(llm=llm, prompt=prompt, output_parser=TaxonomyParser())
results = []

In [86]:
for test_group in tqdm(test_groups):
    concepts, edges = get_nodes_edges(pairs_df, test_group)
    group_examples = generate_train_examples(pairs_df, few_shot_numbers)

    print(prompt.format(
        concepts=concepts,
        group_examples=group_examples,
    ))
    break

  0%|          | 0/15 [00:00<?, ?it/s]

System: You are an expert constructing a taxonomy from a list of concepts. Given a list of concepts,construct a taxonomy by creating a list of their parent-child relationships.


Human: Concepts: Law; Sociology; Economics; Psychology; Law, social and behavioral sciences; Ethnography; Anthropology
Relationships: 
AI: Law is a subtopic of Law, social and behavioral sciences; Ethnography is a subtopic of Anthropology; Psychology is a subtopic of Law, social and behavioral sciences; Sociology is a subtopic of Law, social and behavioral sciences; Anthropology is a subtopic of Law, social and behavioral sciences; Economics is a subtopic of Law, social and behavioral sciences


Human: Concepts: Internet communications tools; Simple Object Access Protocol (SOAP); Service discovery and interfaces; Web crawling; Online advertising; Secure online transactions; Web data description languages; Social advertising; Social tagging; Deep web; Content match advertising; Personalization; Web Ontology Lan




In [87]:
for test_group in tqdm(test_groups):
    concepts, edges = get_nodes_edges(pairs_df, test_group)
    group_examples = generate_train_examples(pairs_df, few_shot_numbers)

    print(llm.get_num_tokens_from_messages(prompt.format_messages(
        concepts=concepts,
        group_examples=group_examples,
    )))
    result = call_chain(chain, concepts, group_examples)
    results.append({
        'group': test_group,
        'nodes': concepts,
        'result': result,
    })

  0%|          | 0/15 [00:00<?, ?it/s]

1845


  7%|▋         | 1/15 [00:24<05:36, 24.04s/it]

[['Probability and statistics is a parent topic for all the concepts'], [' Bayesian computation, Bayesian networks, Gibbs sampling, Markov networks, Markov processes, Markov-chain Monte Carlo convergence measures, Markov-chain Monte Carlo methods, Metropolis-Hastings algorithm are subtopics of Probabilistic reasoning algorithms'], [' Hypothesis testing and confidence interval computation, Maximum likelihood estimation, Regression analysis, Robust regression are subtopics of Statistical paradigms'], [' Resampling methods, Jackknifing, Bootstrapping are subtopics of Nonparametric statistics'], [' Density estimation, Kernel density estimators are subtopics of Nonparametric representations'], [' Multivariate statistics, Cluster analysis, Dimensionality reduction are subtopics of Exploratory data analysis'], [' Factor graphs, Decision diagrams are subtopics of Probabilistic representations'], [' Causal networks, Bayesian networks, Markov networks are subtopics of Equational models'], [' Con

 13%|█▎        | 2/15 [00:26<02:27, 11.32s/it]

[['Statistical software', 'Mathematical software'], [' Mathematical software performance', 'Mathematical software'], [' Solvers', 'Mathematical software.']]
3003


 20%|██        | 3/15 [00:33<01:50,  9.21s/it]

[['Differential calculus', 'Calculus'], [' Integral calculus', 'Calculus'], [' Point-set topology', 'Topology'], [' Geometric topology', 'Topology'], [' Algebraic topology', 'Topology'], [' Calculus', 'Continuous mathematics'], [' Topology', 'Continuous mathematics'], [' Lambda calculus', 'Calculus'], [' Continuous functions', 'Continuous mathematics.']]
1557


 27%|██▋       | 4/15 [00:56<02:44, 14.98s/it]

[['Information retrieval is a parent topic of Document filtering, Presentation of retrieval results, Query representation, Video search, Novelty in information retrieval, Search interfaces, Distributed retrieval, Search engine architectures and scalability, Language models, Document collection models, Multilingual and cross-lingual retrieval, Music retrieval, Environment-specific retrieval, Business intelligence, Combination, fusion and federated search, Information extraction, Content analysis and feature selection, Probabilistic retrieval models, Enterprise search, Top-k retrieval in databases, Information retrieval query processing, Retrieval effectiveness, Near-duplicate and plagiarism detection, Summarization, Link and co-citation analysis, Search engine indexing, Similarity measures, Query reformulation, Question answering, Multimedia and multimodal retrieval, Evaluation of retrieval results, Searching with auxiliary databases, Structured text search, Ontologies, Chemical and bio

 33%|███▎      | 5/15 [01:05<02:08, 12.83s/it]

[['Graphical / visual passwords', 'Authentication'], [' Access control', 'Security services'], [' Digital rights management', 'Security services'], [' Pseudonymity, anonymity and untraceability', 'Privacy-preserving protocols'], [' Multi-factor authentication', 'Authentication'], [' Privacy-preserving protocols', 'Security services'], [' Biometrics', 'Authentication'], [' Authorization', 'Access control'], [' Authentication', 'Security services\n']]
1647


 40%|████      | 6/15 [01:10<01:30, 10.08s/it]

[['Domain-specific security and privacy architectures', 'Software security engineering'], [' Software reverse engineering', 'Software security engineering'], [' Web application security', 'Software and application security'], [' Social network security and privacy', 'Software and application security'], [' Software and application security', 'Software security engineering']]
2123


 47%|████▋     | 7/15 [01:34<01:56, 14.62s/it]

[['Social recommendation', 'Social networks'], [' Collaborative and social computing design and evaluation methods', 'Collaborative and social computing'], [' Computer supported cooperative work', 'Collaborative and social computing'], [' Collaborative and social computing devices', 'Collaborative and social computing'], [' Social networks', 'Collaborative and social computing'], [' Wikis', 'Collaborative content creation'], [' Collaborative and social computing theory, concepts and paradigms', 'Collaborative and social computing'], [' Asynchronous editors', 'Collaborative content creation'], [' Collaborative content creation', 'Collaborative and social computing'], [' Open source software', 'Collaborative and social computing systems and tools'], [' Social content sharing', 'Social networks'], [' Empirical studies in collaborative and social computing', 'Collaborative and social computing'], [' Social tagging', 'Social networks'], [' Social networking sites', 'Social networks'], [' Et

 53%|█████▎    | 8/15 [02:08<02:24, 20.58s/it]

[['Technology and censorship', 'Computing / technology policy'], [' Broadband access', 'Internet governance / domain names'], [' Patents', 'Intellectual property'], [' Hardware reverse engineering', 'Antitrust and competition'], [' Universal access', 'Internet governance / domain names'], [' Acceptable use policy restrictions', 'Internet governance / domain names'], [' Import / export controls', 'Governmental regulations'], [' Identity theft', 'Computer crime'], [' Medical technologies', 'Medical information policy'], [' Financial crime', 'Computer crime'], [' Network access restrictions', 'Internet governance / domain names'], [' Network access control', 'Internet governance / domain names'], [' Censorship', 'Technology and censorship'], [' Digital rights management', 'Intellectual property'], [' Social engineering attacks', 'Computer crime'], [' Copyrights', 'Intellectual property'], [' Pornography', 'Age-based restrictions'], [' Licensing', 'Intellectual property'], [' Database prot

 60%|██████    | 9/15 [03:42<04:21, 43.62s/it]

[['Phonology / morphology', 'Natural language processing'], [' Game tree search', 'Search methodologies'], [' Scene understanding', 'Computer vision'], [' Video segmentation', 'Computer vision tasks'], [' Visual inspection', 'Computer vision tasks'], [' Vision for robotics', 'Robotic planning'], [' Motion capture', 'Computer vision tasks'], [' Philosophical/theoretical foundations of artificial intelligence', 'Artificial intelligence'], [' Theory of mind', 'Cognitive science'], [' Planning for deterministic actions', 'Planning and scheduling'], [' Planning under uncertainty', 'Planning and scheduling'], [' Planning and scheduling', 'Artificial intelligence'], [' Interest point and salient region detections', 'Computer vision tasks'], [' Object recognition', 'Computer vision tasks'], [' Cognitive robotics', 'Robotics'], [' Tracking', 'Computer vision tasks'], [' Mobile agents', 'Intelligent agents'], [' Image representations', 'Computer vision representations'], [' Randomized search', '

 67%|██████▋   | 10/15 [04:03<03:04, 36.84s/it]

[['Modeling methodologies', 'Modeling and simulation'], [' Artificial life', 'Modeling and simulation'], [' Interactive simulation', 'Simulation types and techniques'], [' Simulation by animation', 'Simulation types and techniques'], [' Systems theory', 'Modeling and simulation'], [' Molecular simulation', 'Simulation types and techniques'], [' Discrete-event simulation', 'Simulation types and techniques'], [' Simulation languages', 'Simulation tools'], [' Real-time simulation', 'Simulation types and techniques'], [' Model development and analysis', 'Modeling methodologies'], [' Simulation support systems', 'Simulation tools'], [' Rare-event simulation', 'Simulation types and techniques'], [' Model verification and validation', 'Model development and analysis'], [' Scientific visualization', 'Visual analytics'], [' Agent / discrete models', 'Continuous models'], [' Quantum mechanic simulation', 'Simulation types and techniques'], [' Uncertainty quantification', 'Simulation theory'], ['

 73%|███████▎  | 11/15 [04:22<02:05, 31.28s/it]

[['Genomics', 'Life and medical sciences'], [' Health care information systems', 'Health informatics'], [' Computational transcriptomics', 'Transcriptomics'], [' Computational biology', 'Life and medical sciences'], [' Molecular sequence analysis', 'Genomics'], [' Population genetics', 'Genetics'], [' Imaging', 'Health informatics'], [' Molecular evolution', 'Molecular sequence analysis'], [' Computational genomics', 'Genomics'], [' Proteomics', 'Life and medical sciences'], [' Sequencing and genotyping technologies', 'Genomics'], [' Biological networks', 'Systems biology'], [' Bioinformatics', 'Life and medical sciences'], [' Molecular structural biology', 'Life and medical sciences'], [' Metabolomics / metabonomics', 'Life and medical sciences'], [' Recognition of genes and regulatory elements', 'Genomics'], [' Transcriptomics', 'Life and medical sciences'], [' Computational proteomics', 'Proteomics'], [' Health informatics', 'Life and medical sciences'], [' Consumer health', 'Health

 80%|████████  | 12/15 [04:39<01:20, 26.85s/it]

[['Sensor devices and platforms', 'Sensors and actuators'], [' Electro-mechanical devices', 'Sensors and actuators'], [' Haptic devices', 'Sensors and actuators'], [' Sensor applications and deployments', 'Sensors and actuators'], [' Wireless integrated network sensors', 'Sensors and actuators'], [' Tactile and hand-based interfaces', 'Signal processing systems'], [' Touch screens', 'Signal processing systems'], [' External storage', 'Communication hardware, interfaces and storage'], [' Printers', 'Communication hardware, interfaces and storage'], [' Scanners', 'Communication hardware, interfaces and storage'], [' Buses and high-speed links', 'Networking hardware'], [' Wireless devices', 'Networking hardware'], [' Beamforming', 'Digital signal processing'], [' Noise reduction', 'Digital signal processing'], [' Digital signal processing', 'Signal processing systems'], [' Sound-based input / output', 'Signal processing systems'], [' Displays and imagers', 'Signal processing systems\n']]


 87%|████████▋ | 13/15 [04:52<00:45, 22.75s/it]

[['Design for debug', 'Hardware validation'], [' Post-manufacture validation and debug', 'Hardware validation'], [' Coverage metrics', 'Functional verification'], [' Bug fixing (hardware)', 'Bug detection, localization and diagnosis'], [' Assertion checking', 'Functional verification'], [' Bug detection, localization and diagnosis', 'Hardware validation'], [' Power and thermal analysis', 'Physical verification'], [' Timing analysis and sign-off', 'Physical verification'], [' Semi-formal verification', 'Functional verification'], [' Layout-versus-schematics', 'Physical verification'], [' Transaction-level verification', 'Functional verification'], [' Equivalence checking', 'Physical verification'], [' Theorem proving and SAT solving', 'Model checking'], [' Design rule checking', 'Physical verification'], [' Simulation and emulation', 'Hardware validation'], [' Model checking', 'Functional verification'], [' Functional verification', 'Hardware validation'], [' Physical verification', 'Ha

 93%|█████████▎| 14/15 [05:03<00:19, 19.19s/it]

[['Hardware test', 'Online test and diagnostics'], [' Test data compression', 'Online test and diagnostics'], [' Testing with distributed and parallel systems', 'Online test and diagnostics'], [' Test-pattern generation and fault simulation', 'Online test and diagnostics'], [' Fault models and test metrics', 'Online test and diagnostics'], [' Hardware reliability screening', 'Online test and diagnostics'], [' Design for testability', 'Online test and diagnostics'], [' Analog, mixed-signal and radio frequency test', 'Hardware test'], [' Built-in self-test', 'Hardware test'], [' Board- and system-level test', 'Hardware test'], [' Memory test and repair', 'Hardware test'], [' Defect-based test', 'Hardware test\n']]
3130


100%|██████████| 15/15 [05:26<00:00, 21.75s/it]

[['Carbon based electronics', 'Emerging technologies'], [' Circuit substrates', 'Emerging technologies'], [' Quantum dots and cellular automata', 'Emerging technologies'], [' Quantum communication and cryptography', 'Quantum technologies'], [' Quantum error correction and fault tolerance', 'Quantum technologies'], [' Single electron devices', 'Emerging technologies'], [' Flexible and printable circuits', 'Emerging technologies'], [' III-V compounds', 'Emerging technologies'], [' Memory and dense storage', 'Emerging technologies'], [' Quantum computation', 'Quantum technologies'], [' Biology-related information processing', 'Bio-embedded electronics'], [' Microelectromechanical systems', 'Electromechanical systems'], [' Nanoelectromechanical systems', 'Electromechanical systems'], [' Spintronics and magnetic technologies', 'Emerging technologies'], [' Tunneling devices', 'Emerging technologies'], [' Reversible logic', 'Emerging technologies'], [' Emerging architectures', 'Emerging techn




In [88]:
len(results)

15

In [103]:
test_groups

array(['10002950_2', '10002950_3', '10002950_5', '10002951_5',
       '10002978_3', '10002978_9', '10003120_3', '10003456_2',
       '10010147_3', '10010147_5', '10010405_4', '10010583_2',
       '10010583_7', '10010583_8', '10010583_10'], dtype=object)

In [89]:
result_pairs = []
for result in results:
    for edge in result['result']:
        result_pairs.append({
            'group': result['group'],
            'child': edge['child'],
            'parent': edge['parent'],
        })

In [110]:
df = pd.DataFrame(result_pairs)
df.to_csv('./results/gpt-3/ccs/gpt4.csv', index=None)

# Evluation

In [19]:
from utils import (
    convert_to_ancestor_graph, 
    maximum_absorbance, 
    dataframe_to_ancestor_graph, 
    evaluate_groups,
    maximum_likelihood,
    maximum_branching,
    majority_voting
)
import numpy as np

In [3]:
actual_tree='data/bansal_wordnet_true_pairs.csv' # bansal_wordnet_true_pairs acm_ccs_clean
data_dir = f'./results/gpt-3/wordnet/'

## Get individual best result

In [9]:
dfs = []
num_generations = 5

for i in range(1, num_generations + 1):
    filename = f'{data_dir}/results_{i}.csv'
    df = pd.read_csv(filename)
    df['child'] = df['child'].apply(lambda x: x.replace(' ', '_'))
    df['parent'] = df['parent'].apply(lambda x: x.replace(' ', '_'))
    dfs.append(df)

# get ground truth
df_actual=pd.read_csv(actual_tree)
df_actual=df_actual[df_actual['type'] == 'test']
df_actual = dataframe_to_ancestor_graph(df_actual)

100%|██████████| 114/114 [00:00<00:00, 1070.28it/s]


In [11]:
recalls = []
precisions = []
f1s = []

for df in dfs:
    df = dataframe_to_ancestor_graph(df)
    recall, precision, f1 = evaluate_groups(df_actual, df)

    recalls.append(recall)
    precisions.append(precision)
    f1s.append(f1)

100%|██████████| 114/114 [00:00<00:00, 932.42it/s]
100%|██████████| 114/114 [00:00<00:00, 897.05it/s]
100%|██████████| 114/114 [00:00<00:00, 801.07it/s]
100%|██████████| 114/114 [00:00<00:00, 916.89it/s]
100%|██████████| 114/114 [00:00<00:00, 1090.72it/s]
100%|██████████| 114/114 [00:00<00:00, 786.25it/s]
100%|██████████| 114/114 [00:00<00:00, 1028.07it/s]
100%|██████████| 114/114 [00:00<00:00, 796.43it/s]
100%|██████████| 114/114 [00:00<00:00, 968.85it/s]
100%|██████████| 114/114 [00:00<00:00, 881.62it/s]


In [14]:
print("Best Recall: ", np.max(recalls))
print("Best Precision: ", np.max(precisions))
print("Best F1: ", np.max(f1s))

Best Recall:  0.5364916307655639
Best Precision:  0.646368967793948
Best F1:  0.572489046346162


## Combine Multiple Predictions (Maximum branching)

In [42]:
dfs = []
num_generations = 5

for i in range(1, num_generations + 1):
    filename = f'{data_dir}/results_{i}.csv'
    df = pd.read_csv(filename)
    df['child'] = df['child'].apply(lambda x: x.replace(' ', '_'))
    df['parent'] = df['parent'].apply(lambda x: x.replace(' ', '_'))
    dfs.append(df)

df = pd.concat(dfs)


In [43]:
rows = []
columns = df.columns.tolist()

for group_name, group_df in df.groupby(columns):
    count = group_df.shape[0]
    new_row = {
        columns[0]: group_name[0],
        columns[1]: group_name[1],
        columns[2]: group_name[2],
        'predict': count,
    }
    rows.append(new_row)

merged_df = pd.DataFrame(rows)

In [44]:
def convert_to_ancestor_graph(G):
    '''Converts a (parent) tree to a graph with edges for all ancestor relations in the tree.'''
    G_anc = nx.DiGraph()
    for node in G.nodes():
        for anc in nx.ancestors(G, node):
            G_anc.add_edge(anc, node)
    return G_anc

import networkx as nx

forest=[]
for g in tqdm(list(set(merged_df.group))):
    df = maximum_branching(merged_df, g)
    forest.append(df)
res_v2=pd.concat(forest, ignore_index=True)



100%|██████████| 114/114 [00:00<00:00, 431.74it/s]


In [45]:
df_actual=pd.read_csv(actual_tree)
df_actual=df_actual[df_actual['type'] == 'test']
df_actual['child'] = df_actual['child'].apply(lambda x: x.replace(' ', '_'))
df_actual['parent'] = df_actual['parent'].apply(lambda x: x.replace(' ', '_'))

In [46]:
df_actual = dataframe_to_ancestor_graph(df_actual)

100%|██████████| 114/114 [00:00<00:00, 1101.16it/s]


In [47]:
recall, precision, f1 = evaluate_groups(df_actual, res_v2)
print("Recall: ", recall, "Precision: ", precision, "F1: ", f1)

100%|██████████| 114/114 [00:00<00:00, 822.05it/s]

Recall:  0.6645223913488745 Precision:  0.6276596550517497 F1:  0.630155470051838





In [36]:
df_actual[df_actual.group=='10010147_3']

Unnamed: 0,parent,child,group,compare


In [26]:
res_v2[res_v2.group=='10010147_3']

Unnamed: 0,parent,child,group,compare


In [37]:
recall = []
precision = []
f1 = []
for group in tqdm(list(set(df_actual.group))):
    group_actual = df_actual[df_actual.group == group]
    group_pred = res_v2[res_v2.group == group]
    recall.append(len(group_actual.merge(group_pred, on='compare')) / len(group_actual))
    precision.append(len(group_actual.merge(group_pred, on='compare')) / len(group_pred))
    if precision[-1] + recall[-1] == 0:
        f1.append(0)
    else:
        f1.append(2 * (precision[-1] * recall[-1]) / (precision[-1] + recall[-1]))

100%|██████████| 114/114 [00:00<00:00, 494.88it/s]


In [38]:
import numpy as np
np.mean(recall)

0.48965134192568305

In [39]:
np.mean(precision)

0.7506532177798615

In [40]:
np.mean(f1)

0.5704412106495687