## 用 ChatGPT 生成的 Candidate Set 找相關 relevant node 並存成 Triplet format

In [1]:
import torch
import networkx as nx
import itertools
import json
from tqdm import tqdm
from utils.conceptnet import merged_relations
import numpy as np
from scipy import sparse
import pickle
from scipy.sparse import csr_matrix, coo_matrix
from multiprocessing import Pool
from collections import OrderedDict


# from .maths import *

In [2]:
candidate_path = './data/sciq/candidate_set/generative_lm/train.json'
grounded_path = './data/sciq/grounded/train.grounded.json'
cpnet_graph_path = './data/cpnet/conceptnet.en.pruned.graph'
cpnet_vocab_path = './data/cpnet/concept.txt'

In [3]:
__all__ = ['generate_graph']

concept2id = None
id2concept = None
relation2id = None
id2relation = None

cpnet = None
cpnet_all = None
cpnet_simple = None

In [4]:
any(x is None for x in [concept2id, id2concept, relation2id, id2relation])

True

### 載入 relation 與 concept token 的字典， relation2id 為 relation 的 id ; concept2id 為 concept 的 id

In [5]:
def load_resources(cpnet_vocab_path):
    global concept2id, id2concept, relation2id, id2relation

    with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
        id2concept = [w.strip() for w in fin]
    concept2id = {w: i for i, w in enumerate(id2concept)}

    id2relation = merged_relations
    relation2id = {r: i for i, r in enumerate(id2relation)}

In [6]:
# 載入 concept 的 graph
def load_cpnet(cpnet_graph_path):
    global cpnet, cpnet_simple
    cpnet = nx.read_gpickle(cpnet_graph_path)
    cpnet_simple = nx.Graph()
    for u, v, data in cpnet.edges(data=True):
        w = data['weight'] if 'weight' in data else 1.0
        if cpnet_simple.has_edge(u, v):
            cpnet_simple[u][v]['weight'] += w
        else:
            cpnet_simple.add_edge(u, v, weight=w)

In [7]:
print(f'generating adj data for {grounded_path}...')

global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet
# 載入 concept2id, id2relation, relation2id
if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]):
    load_resources(cpnet_vocab_path)
if cpnet is None or cpnet_simple is None:
    load_cpnet(cpnet_graph_path)

generating adj data for ./data/sciq/grounded/train.grounded.json...


### 查看 concept2id 內的東西

In [8]:
# dict of conceptnet token
# key = token name, value = id
concept2id

{'ab_extra': 0,
 'ab_intra': 1,
 'abactinal': 2,
 'actinal': 3,
 'abandon': 4,
 'acquire': 5,
 'arrogate': 6,
 'embrace': 7,
 'engage': 8,
 'gain': 9,
 'join': 10,
 'maintain': 11,
 'retain': 12,
 'unite': 13,
 'abandonment': 14,
 'acquisition': 15,
 'abapical': 16,
 'apical': 17,
 'abase': 18,
 'exalt': 19,
 'extoll': 20,
 'abash': 21,
 'embolden': 22,
 'reassure': 23,
 'abate': 24,
 'augment': 25,
 'abaxial': 26,
 'adaxial': 27,
 'abbreviate': 28,
 'lengthen': 29,
 'abderian': 30,
 'agelastic': 31,
 'abducent': 32,
 'adducent': 33,
 'abduction': 34,
 'adduction': 35,
 'abductive': 36,
 'deduce': 37,
 'abductor': 38,
 'abductee': 39,
 'adductor': 40,
 'abideable': 41,
 'insupportable': 42,
 'intolerable': 43,
 'unabideable': 44,
 'unbearable': 45,
 'abience': 46,
 'adience': 47,
 'abient': 48,
 'adient': 49,
 'ability': 50,
 'inability': 51,
 'abiogenesis': 52,
 'biogenesis': 53,
 'transformism': 54,
 'abjectly': 55,
 'proudly': 56,
 'abjugate': 57,
 'adjugate': 58,
 'able': 59,
 'can

In [9]:
concept2id['ab_extra']

0

In [10]:
id2concept[0]

'ab_extra'

### 查看 relation2id 內的東西

In [11]:
# dict of relation
# key = token name, value = id
relation2id

{'antonym': 0,
 'atlocation': 1,
 'capableof': 2,
 'causes': 3,
 'createdby': 4,
 'isa': 5,
 'desires': 6,
 'hassubevent': 7,
 'partof': 8,
 'hascontext': 9,
 'hasproperty': 10,
 'madeof': 11,
 'notcapableof': 12,
 'notdesires': 13,
 'receivesaction': 14,
 'relatedto': 15,
 'usedfor': 16}

In [12]:
id2relation

['antonym',
 'atlocation',
 'capableof',
 'causes',
 'createdby',
 'isa',
 'desires',
 'hassubevent',
 'partof',
 'hascontext',
 'hasproperty',
 'madeof',
 'notcapableof',
 'notdesires',
 'receivesaction',
 'relatedto',
 'usedfor']

## 將預處理的資料 整理成 QA Data Format
## Data format 資料格式
qa_data 格式 
- data : <tuple>
    - q_ids : <set> : q_ids 所有在 question 句子中的 phase verb noun 的 id
    - a_ids : <set> : a_ids 所有在 answer 句子中的 phase verb noun 的 id
    - QAcontext : <str> : 綜合 Question 以及 Answer 的文字，用空白互相格開

In [13]:
qa_data = []
with open(grounded_path, 'r', encoding='utf-8') as fin_ground:
    lines_ground = fin_ground.readlines()
    for j, line in enumerate(lines_ground):
        dic = json.loads(line)
        q_ids = set(concept2id[c] for c in dic['qc'])
        a_ids = set(concept2id[c] for c in dic['ac'])
        q_ids = q_ids - a_ids
        sentence = dic['sent']
        QAcontext = "{}.[SEP] {}.".format(sentence, dic['ans'])
        qa_data.append((q_ids, a_ids, QAcontext))

In [14]:
print('總共有',len(qa_data),'筆')

總共有 11679 筆


In [15]:
qa_data[0]

({2193, 3151, 6460, 10998, 15816, 22224, 48164, 49561, 69415},
 {25657, 44853, 575771},
 'What type of organism is commonly used in preparation of foods such as cheese and yogurt? mesophilic organisms.[SEP] mesophilic organisms.')

### 載入 ChatGPT Candidate Set 

In [16]:
def read_candidate_data(path):
    with open(path) as f:
        data = json.load(f)
    return data

In [19]:
candidate_data = read_candidate_data(candidate_path)
print('總共有 {} 筆'.format(len(candidate_data)))

總共有 11679 筆


In [21]:
candidate_data[0]

{'question': 'What type of organism is commonly used in preparation of foods such as cheese and yogurt?',
 'distractor3': 'viruses',
 'distractor1': 'protozoa',
 'distractor2': 'gymnosperms',
 'correct_answer': 'mesophilic organisms',
 'support': 'Mesophiles grow best in moderate temperature, typically between 25°C and 40°C (77°F and 104°F). Mesophiles are often found living in or on the bodies of humans or other animals. The optimal growth temperature of many pathogenic mesophiles is 37°C (98°F), the normal human body temperature. Mesophilic organisms have important uses in food preparation, including cheese, yogurt, beer and wine.',
 'candidates': '1. Thermophilic organisms\n2. Halophilic organisms\n3. Psychrophilic organisms\n4. Acidophilic organisms\n5. Aerobic organisms\n6. Anaerobic organisms\n7. Hyperthermophilic organisms \n8. Thermotolerant organisms\n9. Xerophilic organisms\n10. Alkaliphilic organisms',
 'candidate_set': ['Thermophilic organisms',
  'Halophilic organisms',
  

In [22]:
print('ChatGPT')
print('question in first data = ',candidate_data[0]['question'])
print('answer in first data = ',candidate_data[0]['correct_answer'])
print('distractors in first data = ',[candidate_data[0]['distractor1'], candidate_data[0]['distractor2'], candidate_data[0]['distractor3']])
print('candidate_set in first data = ',candidate_data[0]['candidate_set'])

ChatGPT
question in first data =  What type of organism is commonly used in preparation of foods such as cheese and yogurt?
answer in first data =  mesophilic organisms
distractors in first data =  ['protozoa', 'gymnosperms', 'viruses']
candidate_set in first data =  ['Thermophilic organisms', 'Halophilic organisms', 'Psychrophilic organisms', 'Acidophilic organisms', 'Aerobic organisms', 'Anaerobic organisms', 'Hyperthermophilic organisms ', 'Thermotolerant organisms', 'Xerophilic organisms', 'Alkaliphilic organisms']


### 合併 ChatGPT Candidate Set 並用來 retrieve 其他的 relevant node 

In [23]:
len(qa_data), len(candidate_data)

(11679, 11679)

In [24]:
for i in range(len(qa_data)):
    candidate = candidate_data[i]['candidate_set']
    pred = []
    for words in candidate:
        words = str(words)
        for token in words.split(' '):
            if token != '' or token != ' ' and len(token) != 0:
                pred.append(token.lower())
    qa_data[i] = qa_data[i] + tuple(pred)

In [25]:
qa_data[0]

({2193, 3151, 6460, 10998, 15816, 22224, 48164, 49561, 69415},
 {25657, 44853, 575771},
 'What type of organism is commonly used in preparation of foods such as cheese and yogurt? mesophilic organisms.[SEP] mesophilic organisms.',
 'thermophilic',
 'organisms',
 'halophilic',
 'organisms',
 'psychrophilic',
 'organisms',
 'acidophilic',
 'organisms',
 'aerobic',
 'organisms',
 'anaerobic',
 'organisms',
 'hyperthermophilic',
 'organisms',
 'thermotolerant',
 'organisms',
 'xerophilic',
 'organisms',
 'alkaliphilic',
 'organisms')

In [26]:
len(qa_data)

11679

In [27]:
def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part1(data):
    qc_ids = data[0]
    ac_ids = data[1]
    question = data[2]
    distractors_set = data[3:]
    extra_nodes = []
    for distractor in distractors_set:
        if distractor in concept2id:
            extra_nodes_ids = concept2id[distractor] 
            extra_nodes.append(extra_nodes_ids)
    return (sorted(qc_ids), sorted(ac_ids), question, extra_nodes)

In [28]:
res1 = list(tqdm(map(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part1, qa_data), total=len(qa_data)))

100%|██████████| 11679/11679 [00:00<00:00, 103591.42it/s]


In [29]:
res1[0]

([2193, 3151, 6460, 10998, 15816, 22224, 48164, 49561, 69415],
 [25657, 44853, 575771],
 'What type of organism is commonly used in preparation of foods such as cheese and yogurt? mesophilic organisms.[SEP] mesophilic organisms.',
 [15142,
  44853,
  526492,
  44853,
  15143,
  44853,
  384236,
  44853,
  456,
  44853,
  457,
  44853,
  432513,
  44853,
  699098,
  44853,
  602953,
  44853,
  390854,
  44853])

## concepts_to_adj_matrices_2hop_all_pair__use_LM__Part3

In [30]:
# 考慮 全部的組合
def concepts2adj(node_ids):
    global id2relation
    cids = np.array(node_ids, dtype=np.int32)
    n_rel = len(id2relation)
    n_node = cids.shape[0]
    adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8)
    triplets = []
    for s in range(n_node):
        for t in range(n_node):
            s_c, t_c = cids[s], cids[t]
            if cpnet.has_edge(s_c, t_c):
                for e_attr in cpnet[s_c][t_c].values():
                    if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel:
                        triplets.append([int(e_attr['rel']), int(s_c), int(t_c), e_attr['weight']])
                        adj[e_attr['rel']][s][t] = 1
    # cids += 1  # note!!! index 0 is reserved for padding
    adj = coo_matrix(adj.reshape(-1, n_node))
    return adj, cids, triplets

In [31]:
def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part3(data):
    qc_ids, ac_ids, question, extra_nodes = data
    schema_graph = qc_ids + ac_ids + extra_nodes # <== 考慮全部的組合
    # schema_graph = qc_ids + extra_nodes # <== 考慮 qc_ids 跟 extra_nodes 的組合 (without_ans)
    adj, concepts, triplets = concepts2adj(schema_graph)
    return {'triplets': triplets}

In [32]:
res3 = list(tqdm(map(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part3, res1), total=len(res1)))

100%|██████████| 11679/11679 [00:20<00:00, 575.13it/s]


In [33]:
res3[0]['triplets']

[[15, 2193, 69415, 1.0],
 [5, 3151, 2193, 6.325],
 [15, 3151, 2193, 0.926],
 [15, 3151, 22224, 1.0],
 [15, 15816, 10998, 2.0],
 [5, 22224, 2193, 0.5],
 [15, 48164, 2193, 0.151],
 [15, 69415, 2193, 1.0],
 [15, 44853, 25657, 1.0],
 [0, 15142, 15143, 1.0],
 [15, 44853, 25657, 1.0],
 [15, 44853, 25657, 1.0],
 [15, 44853, 25657, 1.0],
 [15, 44853, 25657, 1.0],
 [0, 456, 457, 2.0],
 [15, 44853, 25657, 1.0],
 [15, 457, 25657, 1.0],
 [0, 457, 456, 2.0],
 [15, 44853, 25657, 1.0],
 [15, 44853, 25657, 1.0],
 [15, 44853, 25657, 1.0],
 [15, 44853, 25657, 1.0],
 [15, 44853, 25657, 1.0]]

In [34]:
len(res3)

11679

In [35]:
res4 = []

for item in tqdm(res3):
    temp_list = []
    for triplets in item['triplets']:
        rel, source_node, target_node, weight = triplets
        relation = id2relation[rel]
        source = id2concept[source_node]
        target = id2concept[target_node]
        temp_list.append([relation, source, target, weight])
    res4.append(temp_list)

  0%|          | 0/11679 [00:00<?, ?it/s]

100%|██████████| 11679/11679 [00:00<00:00, 42533.87it/s]


In [36]:
len(res4)

11679

In [37]:
res4[0]

[['relatedto', 'food', 'foods', 1.0],
 ['isa', 'cheese', 'food', 6.325],
 ['relatedto', 'cheese', 'food', 0.926],
 ['relatedto', 'cheese', 'yogurt', 1.0],
 ['relatedto', 'used', 'use', 2.0],
 ['isa', 'yogurt', 'food', 0.5],
 ['relatedto', 'preparation', 'food', 0.151],
 ['relatedto', 'foods', 'food', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['antonym', 'thermophilic', 'psychrophilic', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['antonym', 'aerobic', 'anaerobic', 2.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'anaerobic', 'organism', 1.0],
 ['antonym', 'anaerobic', 'aerobic', 2.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'organisms', 'organism', 1.0]]

Remove Duplicate Triplet

In [38]:
for i in range(len(res4)):
    res4[i] = [list(t) for t in set(tuple(element) for element in res4[i])]

In [39]:
res4[0]

[['relatedto', 'preparation', 'food', 0.151],
 ['relatedto', 'food', 'foods', 1.0],
 ['antonym', 'anaerobic', 'aerobic', 2.0],
 ['antonym', 'thermophilic', 'psychrophilic', 1.0],
 ['relatedto', 'foods', 'food', 1.0],
 ['relatedto', 'used', 'use', 2.0],
 ['relatedto', 'cheese', 'food', 0.926],
 ['isa', 'cheese', 'food', 6.325],
 ['isa', 'yogurt', 'food', 0.5],
 ['relatedto', 'cheese', 'yogurt', 1.0],
 ['antonym', 'aerobic', 'anaerobic', 2.0],
 ['relatedto', 'organisms', 'organism', 1.0],
 ['relatedto', 'anaerobic', 'organism', 1.0]]

In [40]:
output_path = './data/sciq/triplets/generative_lm/train.triplet.json'

In [41]:
with open(output_path, 'w') as fout:
    json.dump(res4, fout)
print(f'data saved to {output_path}')

data saved to ./data/sciq/triplets/generative_lm/train.triplet.json


In [42]:
rel, source_node, target_node, weight = res3[0]['triplets'][0]
print('relation = ',id2relation[rel])
print('source_node = ',id2concept[source_node])
print('target_node = ',id2concept[target_node])
print('weight = ',weight)

relation =  relatedto
source_node =  food
target_node =  foods
weight =  1.0
