In [44]:
from langchain.prompts import Prompt, BaseChatPromptTemplate
from langchain.schema import BaseOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from tqdm import tqdm
import re
import pandas as pd
import dotenv
import random
import networkx as nx
from langchain.schema import (
    BaseMessage, 
    HumanMessage, 
    SystemMessage,
    AIMessage
)
dotenv.load_dotenv()

True

In [45]:
pairs_df = pd.read_csv('data/acm_ccs_clean.csv')

In [46]:
def get_nodes_edges(pairs_df, group):
    pairs = pairs_df[pairs_df['group'] == group]
    edges = []
    nodes = set()
    for i, row in pairs.iterrows():
        parent, child = row['parent'], row['child']
        parent = parent.replace('_', ' ')
        child = child.replace('_', ' ')
        edges.append({
            'parent': parent,
            'child': child,
        })
        nodes.add(parent)
        nodes.add(child)

    nodes = list(nodes)    
    # randomly shffle the nodes and relations to avoid learning pattern
    random.shuffle(edges)
    random.shuffle(nodes)

    return nodes, edges

def get_groups(pairs_df, split='train'):
    groups = pairs_df[pairs_df['type'] == split]['group'].unique()
    return groups

In [47]:


class TaxomomyPrompt(BaseChatPromptTemplate):
    def format_messages(self, **kwargs) -> list[BaseMessage]:
        group_examples = kwargs['group_examples']
        concepts = kwargs['concepts']

        prefix_prompt = (
            "You are an expert constructing a taxonomy from a list of concepts. Given a list of concepts,"
            "construct a taxonomy by creating a list of their parent-child relationships.\n\n"
        )

        prefix_message = SystemMessage(content=prefix_prompt)

        example_messages = []
        for nodes, edges in group_examples:
            node_prompt = ', '.join(nodes)
            edge_prompt = ', '.join([f"{edge['child']} is a {edge['parent']}" for edge in edges])

            example_messages.append(
                HumanMessage(content=f"Concepts: {node_prompt}\nRelationships: ")
            )

            example_messages.append(
                AIMessage(content=f"{edge_prompt}\n\n")
            )
        
        concepts = ', '.join(concepts)
        question_message = HumanMessage(content=f"Concepts: {concepts}\nRelationships: ")


        return [prefix_message] + example_messages + [question_message]


In [48]:
few_shot_numbers = 5

train_groups = get_groups(pairs_df)
test_groups = get_groups(pairs_df, split='test')

def generate_train_examples(pairs_df, num_examples=5):
    random.shuffle(train_groups)
    train_group_examples = []
    for group in train_groups[:num_examples]:
        nodes, edges = get_nodes_edges(pairs_df, group)
        train_group_examples.append((nodes, edges))
    return train_group_examples
    

In [55]:
class TaxonomyParser(BaseOutputParser):
    def parse(self, output):
        print(output)
        output = output.split(',')
        output = [line for line in output if line != '']
        output = [re.split(' is a | is an ', line) for line in output]
        result = []
        for line in output:
            result.append({
                'child': line[0].strip(),
                'parent': line[1].strip()
            })
        return result
    
def call_chain(chain, concepts, group_examples, num_retries=5):
    count = 0
    while count < num_retries:
        try:
            result = chain.run(
                concepts=concepts,
                group_examples=group_examples,
            )
            return result
        except Exception as e:
            print(e)
            count += 1
    raise Exception("Failed to generate result")

prompt = TaxomomyPrompt(
    input_variables=['concepts', 'group_examples'],
)
llm = ChatOpenAI(model='gpt-3.5-turbo', max_tokens=-1)
chain = LLMChain(llm=llm, prompt=prompt, output_parser=TaxonomyParser())
results = []

In [57]:
for test_group in tqdm(test_groups):
    concepts, edges = get_nodes_edges(pairs_df, test_group)
    group_examples = generate_train_examples(pairs_df, few_shot_numbers)

    
    result = call_chain(chain, concepts, group_examples)
    results.append({
        'group': test_group,
        'nodes': concepts,
        'result': result,
    })

  7%|▋         | 1/15 [00:14<03:18, 14.18s/it]

Probabilistic inference problems is a Probabilistic reasoning algorithms, Statistical graphics is a Exploratory data analysis, Expectation maximization is a Probabilistic algorithms, Kernel density estimators is a Density estimation, Max marginal computation is a Probabilistic algorithms, Quantile regression is a Regression analysis, Random number generation is a Probabilistic algorithms, Gibbs sampling is a Markov-chain Monte Carlo methods, Markov networks is a Markov processes, Density estimation is a Nonparametric representations, Dimensionality reduction is a Nonparametric statistics, Distribution functions is a Probability and statistics, Cluster analysis is a Multivariate statistics, Resampling methods is a Nonparametric statistics, Decision diagrams is a Equational models, Time series analysis is a Stochastic processes, Spline models is a Regression analysis, Markov-chain Monte Carlo methods is a Probabilistic algorithms, Survival analysis is a Probability and statistics, Markov

 13%|█▎        | 2/15 [00:15<01:25,  6.58s/it]

Mathematical software performance is a Mathematical software, Solvers is a Mathematical software, Statistical software is a Mathematical software.


 20%|██        | 3/15 [00:17<00:53,  4.45s/it]

Continuous functions is a Differential calculus, Integral calculus is a Differential calculus, Differential calculus is a Calculus, Algebraic topology is a Topology, Point-set topology is a Topology, Geometric topology is a Topology, Topology is a Continuous mathematics, Lambda calculus is a Calculus
Top-k retrieval in databases is a Information retrieval, Query representation is a Information retrieval query processing, Similarity measures is a Information retrieval, Query reformulation is a Query representation, Information retrieval is a Document representation, Document structure is a Document representation, Language models is a Retrieval models and ranking, Multilingual and cross-lingual retrieval is a Language models, Retrieval effectiveness is a Evaluation of retrieval results, Task models is a Information retrieval, Query intent is a Query representation, Retrieval models and ranking is a Information retrieval, Query suggestion is a Query representation, Content analysis and f

 20%|██        | 3/15 [01:00<04:00, 20.05s/it]


KeyboardInterrupt: 

In [20]:
len(results)

114

In [21]:
result_pairs = []
for result in results:
    for edge in result['result']:
        result_pairs.append({
            'group': result['group'],
            'child': edge['child'],
            'parent': edge['parent'],
        })

In [22]:
df = pd.DataFrame(result_pairs)
df.to_csv('./results/gpt-3/wordnet/results_5.csv', index=None)

## Combine Multiple Predictions

In [24]:
actual_tree='data/bansal_wordnet_true_pairs.csv'
dfs = []
num_generations = 5

for i in range(1, num_generations + 1):
    filename = f'./results/gpt-3/wordnet/results_{i}.csv'
    df = pd.read_csv(filename)
    df['child'] = df['child'].apply(lambda x: x.replace(' ', '_'))
    df['parent'] = df['parent'].apply(lambda x: x.replace(' ', '_'))
    dfs.append(df)

df = pd.concat(dfs)
df


Unnamed: 0,group,child,parent
0,647,backblast,blast
1,647,explosion,blast
2,647,backfire,blast
3,647,big_bang,explosion
4,647,blowback,blast
...,...,...,...
2090,760,docosahexaenoic_acid,omega-3_fatty_acid
2091,760,ricinoleic_acid,fatty_acid
2092,760,linolenic_acid,omega-3_fatty_acid
2093,760,alpha-linolenic_acid,linolenic_acid


In [25]:
rows = []
columns = df.columns.tolist()

for group_name, group_df in df.groupby(columns):
    count = group_df.shape[0]
    new_row = {
        columns[0]: group_name[0],
        columns[1]: group_name[1],
        columns[2]: group_name[2],
        'count': count,
    }
    rows.append(new_row)

merged_df = pd.DataFrame(rows)

In [26]:
def convert_to_ancestor_graph(G):
    '''Converts a (parent) tree to a graph with edges for all ancestor relations in the tree.'''
    G_anc = nx.DiGraph()
    for node in G.nodes():
        for anc in nx.ancestors(G, node):
            G_anc.add_edge(anc, node)
    return G_anc

import networkx as nx

forest=[]
for g in tqdm(list(set(merged_df.group))):
    df_tree=merged_df[merged_df.group==g]
    graph = nx.DiGraph()
    parents=df_tree['parent'].tolist()
    children=df_tree['child'].tolist()
    weights=df_tree['count'].tolist()
    nodes = set(parents + children)
    for node in nodes:
        graph.add_node(node)
    for i in range(len(parents)):
        graph.add_edge(parents[i], children[i], weight=weights[i])

    T = nx.maximum_branching(graph)

    T = convert_to_ancestor_graph(T)
    df=nx.to_pandas_edgelist(T)
    df['group']=g
    forest.append(df)


100%|██████████| 114/114 [00:00<00:00, 514.63it/s]


In [27]:
res_v2=pd.concat(forest, ignore_index=True)

In [28]:
res_v2.columns=['parent','child','group']
res_v2['compare']=res_v2['parent']+res_v2['child']+res_v2['group'].astype(str)
res_v2['group'].value_counts()

666    130
697    121
740    118
677    111
729    103
      ... 
655     13
647     12
682     11
744     10
737     10
Name: group, Length: 114, dtype: int64

In [29]:
df_actual=pd.read_csv(actual_tree)
df_actual=df_actual[df_actual['group']>= 647]
df_actual['compare']=df_actual['parent']+df_actual['child']+df_actual['group'].astype(str)

In [30]:
forest=[]
for g in tqdm(list(set(df_actual.group))):
    df_tree=df_actual[df_actual.group==g]
    graph = nx.DiGraph()
    parents=df_tree['parent'].tolist()
    children=df_tree['child'].tolist()
    nodes = set(parents + children)
    for node in nodes:
        graph.add_node(node)
    for i in range(len(parents)):
        graph.add_edge(parents[i], children[i])
    # print(nx.is_connected(graph))
    T = convert_to_ancestor_graph(graph)
    df=nx.to_pandas_edgelist(T)
    df['group']=g
    forest.append(df)

df_actual =pd.concat(forest, ignore_index=True)
df_actual.columns=['parent','child','group']
df_actual['compare']=df_actual['parent']+df_actual['child']+df_actual['group'].astype(str)
df_actual['group'].value_counts()

100%|██████████| 114/114 [00:00<00:00, 1285.84it/s]


677    87
666    87
758    78
741    76
691    76
       ..
746    15
682    14
647    14
737    13
661    13
Name: group, Length: 114, dtype: int64

In [31]:
df_actual[df_actual.group==647]

Unnamed: 0,parent,child,group,compare
0,explosion,airburst,647,explosionairburst647
1,explosion,blast,647,explosionblast647
2,explosion,bomb_blast,647,explosionbomb_blast647
3,explosion,backblast,647,explosionbackblast647
4,explosion,backfire,647,explosionbackfire647
5,explosion,fragmentation,647,explosionfragmentation647
6,explosion,big_bang,647,explosionbig_bang647
7,explosion,inflation,647,explosioninflation647
8,explosion,nuclear_explosion,647,explosionnuclear_explosion647
9,explosion,blowback,647,explosionblowback647


In [32]:
res_v2[res_v2.group==647]

Unnamed: 0,parent,child,group,compare
0,blast,airburst,647,blastairburst647
1,blast,explosion,647,blastexplosion647
2,blast,bomb_blast,647,blastbomb_blast647
3,blast,backblast,647,blastbackblast647
4,blast,fragmentation,647,blastfragmentation647
5,blast,backfire,647,blastbackfire647
6,blast,big_bang,647,blastbig_bang647
7,blast,inflation,647,blastinflation647
8,blast,nuclear_explosion,647,blastnuclear_explosion647
9,blast,blowback,647,blastblowback647


In [33]:
recall = []
precision = []
f1 = []
for group in tqdm(list(set(df_actual.group))):
    group_actual = df_actual[df_actual.group == group]
    group_pred = res_v2[res_v2.group == group]
    recall.append(len(group_actual.merge(group_pred, on='compare')) / len(group_actual))
    precision.append(len(group_actual.merge(group_pred, on='compare')) / len(group_pred))
    if precision[-1] + recall[-1] == 0:
        f1.append(0)
    else:
        f1.append(2 * (precision[-1] * recall[-1]) / (precision[-1] + recall[-1]))

100%|██████████| 114/114 [00:00<00:00, 440.23it/s]


In [34]:
import numpy as np
np.mean(recall)

0.6679041284897339

In [35]:
np.mean(precision)

0.6326333624569183

In [36]:
np.mean(f1)

0.6337440037577927