# Imputation for different dataset

In [None]:
import openai
import pandas as pd
import time
import requests
import json
import jsonlines
import re

sk = "YOUR API KEY"
openai.api_key = sk

def chat(input_data, model="gpt-3.5-turbo", temperature=0.8):
    
    nmessages = [{"role": "user", "content": input_data, "temperature": temperature}]

    while (1):
        try:
            response = openai.ChatCompletion.create(          
                model=model,
                messages=nmessages
            )
            resmessage = response['choices'][0]['message']['content']
            break
        except:
            time.sleep(10)
            
    return  resmessage

## Obtain all possible answer

In [None]:
queries = {}

with open('../data/wikituples/final_data/queries.tsv', 'r') as f:
    for line in f:
        line = line.strip()
        qid, query = line[:line.find('\t')], line[line.find('\t')+1:]
        queries[qid] = query

print(len(queries))

In [None]:
def load_entity_vocab(ignore_bad_title=True, min_ent_count=1):
    entity_vocab = {}
    bad_title = 0
    few_entity = 0
    with open('../data/wikituples/entity_vocab.txt', 'r', encoding="utf-8") as f:
        for line in f:
            _, entity_id, entity_title, mid, count = line.strip().split('\t')
            if ignore_bad_title and entity_title == '':
                bad_title += 1
            elif int(count) < min_ent_count:
                few_entity += 1
            else:
                entity_vocab[len(entity_vocab)] = {
                    'wiki_id': int(entity_id),
                    'wiki_title': entity_title,
                    'count': count
                }
    print('total number of entity: %d\nremove because of empty title: %d\nremove because count<%d: %d'%(len(entity_vocab),bad_title,min_ent_count,few_entity))
    return entity_vocab

In [None]:
import numpy as np 
import jsonlines

entity_vocab = load_entity_vocab(min_ent_count=2, ignore_bad_title=True)
all_entity_set = set([item['wiki_id'] for _,item in entity_vocab.items()])

entityid_to_text = {}
for _,item in entity_vocab.items():
    entityid_to_text[item['wiki_id']] = [item['wiki_title']]
    

train_tuple_to_table = {}
train_tuple_id = 0
with jsonlines.open('../data/wikituples/train_tables.jsonl', 'r') as f:
    for table in f:
        
        table_id = table.get("_id", "")
        pgTitle = table.get("pgTitle", "").lower()
        secTitle = table.get("sectionTitle", "").lower()
        headers = table.get("processed_tableHeaders", [])
        rows = table.get("tableData", {})
        entity_cells = np.array(table.get("entityCell",[[]]))
        subject = table['subject_column']
        for i in range(len(rows)):
            for j in range(len(rows[i])):
                if len(rows[i][j]['surfaceLinks']) > 0:
                    if rows[i][j]['surfaceLinks'][0]['target']['id'] in all_entity_set:
                        entityid_to_text[rows[i][j]['surfaceLinks'][0]['target']['id']].append(rows[i][j]['text'])

In [None]:
processed_tuples = []
with open('./results/WikiTuples/GPT_wikituples_wo_evidence.jsonl', 'r') as f:
    for line in f:
        line = json.loads(line)
        tuple_id = int(line['tuple_id'])
        processed_tuples.append(tuple_id)

print(len(processed_tuples))

In [None]:
import json
import jsonlines
import random

template = '''What's the most likely value for the [TO-FILL] cell in the table below? Please respond using JSON: {answer_format}, the key is attribute name of each [TO-FILL], value is the predicted value for each [TO-FILL].\n'''
tuples = {}

count, acc = 0,0
with open('../data/wikituples/missing_tables.jsonl', 'r') as f:
    for line in f:
        line = json.loads(line)
        table_id = line['_id']
        pgTitle, sectionTitle, tableCaption = line['pgTitle'], line['sectionTitle'], line['tableCaption']
        tuple_ids = line['tuple_id']
        tableData = line['tableData']
        headers = line['processed_tableHeaders']
        ground_truth = line['ground_truth']
        caption = 'caption:' + line['pgTitle'] + ' | ' + line['sectionTitle'] + ' | ' + line['tableCaption']
        for index, t_id in enumerate(tuple_ids):
            
            if t_id in processed_tuples:
                continue

            input_data = template + caption + '\n'
            tuple_d = tableData[index]
            tuple_answer = ground_truth[index]
            for j in range(len(headers)):
                input_data += '|' + headers[j]
            input_data += '|\n'
            missing_pos = []

            try:
                new_tuple_ids = [ tt for tt in tuple_ids if tt != t_id]
                seleteced_tuple = random.choice(new_tuple_ids)
                example_tuple = ground_truth[tuple_ids.index(seleteced_tuple)]

                for j in range(len(headers)):
                    if type(example_tuple[j]) is list:
                        input_data += '|' + example_tuple[j][-1]
                    else:
                        input_data += '|' + example_tuple[j]

                input_data += '|\n'
            except:
                print("no other examples")
            
            answer_format = '{'
            for j in range(len(headers)):
                
                if tuple_d[j] == 'N/A':
                    answer_format += headers[j] + ": " + '""' + ", "
                    input_data += '|' + '[TO-FILL]'
                    missing_pos.append(j)
                else:
                    input_data += '|' + tuple_d[j]

            input_data += '|\n'

            answer_format = answer_format[:-2] + '}'
            # print(input_data)
            input_data = input_data.replace('Red, Green & White}|', 'Red, Green & White|')
            input_data = input_data.format(answer_format=answer_format)

            print("---------------------------------------------------")
            print(f"Input: \n{input_data}")

            output = chat(input_data, model="gpt-3.5-turbo", temperate=0.3)

            print(f"Output: \n{output}")
                
            fout = jsonlines.open('./results/WikiTuples/GPT_wikituples_wo_evidence.jsonl', 'a')
            fout.write({'tuple_id':t_id, 'input': input_data, 'output': output})
            fout.close()
                

### Data Imputation with Retrieved Data

In [None]:
from collections import defaultdict

retrieved_tuples = {} # qid, top-k docids
all_scores = defaultdict(dict)
with open('../results/rerank/wikituples_test.tsv', 'r') as f:
    for line in f:
        qid, docid, rank, score = line.strip().split('\t')
        # qid, docid, score = line.strip().split('\t')
        score = float(score)
        all_scores[qid][docid] = score

qq = list(all_scores.keys())

# topK_pids
topK_results = {}
for qid in qq:
    score_list = sorted(list(all_scores[qid].items()), key=lambda x: x[1], reverse=True)
    for rank, (docid, score) in enumerate(score_list):
        if rank > 4:
            continue
        if qid not in topK_results:
            topK_results[qid] = []
        topK_results[qid].append(docid)

with open('/Users/yichendezaizai/Data_Imputation/data/wikituples/final_data/folds.json', 'r') as f:
    folds = json.load(f)
    test_qids = folds['test']


In [None]:
collection = {}
with open('/Users/yichendezaizai/Data_Imputation/data/wikituples/final_data/collection.tsv', 'r') as f:
    for line in f:
        line = line.strip()
        qid, query = line[:line.find('\t')], line[line.find('\t')+1:]
        collection[qid] = query

In [None]:
def convert_to_table(tuple_id, serialized_tuple):
    # print(tuple_id, serialized_tuple)
    # 分割标题和数据
    caption_split = serialized_tuple.split(' attribute ')
    title = caption_split[0].split(']: ')[1].strip()

    # 提取属性和值
    attributes = caption_split[1:]

    headers = []
    values = []
    sign = 0
    if int(tuple_id) >= 482835 and int(tuple_id) <=482849:
        sign = 1

    for attribute in attributes:
        attribute_value_split = attribute.split(' value ')
        attribute_name = attribute_value_split[0].strip()
        value = attribute_value_split[1].split(' attribute ')[0].strip()  # 分割可能的下一个属性
        
        if sign == 1 and len(attribute_name) > 10:
            attribute_name = attribute_name[:10]
        headers.append(attribute_name)
        
        values.append(value)

    # 构建表格
    table = 'caption: ' + title + '\n|' + ' | '.join(headers) + ' |\n|' + ' | '.join(values) + ' |'
    return table


In [None]:
processed_tuples = []
with open('./results/WikiTuples/GPT4_wikituples_with_retrieved_tuples_by_monoT5.jsonl', 'r') as f:
    for line in f:
        line = json.loads(line)
        tuple_id = int(line['tuple_id'])
        processed_tuples.append(tuple_id)

print(len(processed_tuples))

In [None]:
import json
import jsonlines
import random

template = '''Based on the retrieved tabular data, what's the most likely value for the [TO-FILL] cell in the table below? Please respond using JSON: {answer_format}, the key is attribute name of each [TO-FILL], value is the predicted value for each [TO-FILL].\n'''
tuples = {}

count, acc = 0,0
with open('../data/wikituples/missing_tables.jsonl', 'r') as f:
    for line in f:
        line = json.loads(line)
        table_id = line['_id']
        pgTitle, sectionTitle, tableCaption = line['pgTitle'], line['sectionTitle'], line['tableCaption']
        tuple_ids = line['tuple_id']
        tableData = line['tableData']
        headers = line['processed_tableHeaders']
        ground_truth = line['ground_truth']
        caption = 'caption:' + line['pgTitle'] + ' | ' + line['sectionTitle'] + ' | ' + line['tableCaption']
        for index, t_id in enumerate(tuple_ids):
            
            if t_id in processed_tuples or t_id not in test_qids:
                continue

            input_data = template + caption + '\n'
            tuple_d = tableData[index]
            tuple_answer = ground_truth[index]
            for j in range(len(headers)):
                input_data += '|' + headers[j]
            input_data += '|\n'
            missing_pos = []
            
            answer_format = '{'
            for j in range(len(headers)):
                
                if tuple_d[j] == 'N/A':
                    answer_format += headers[j] + ": " + '""' + ", "
                    input_data += '|' + '[TO-FILL]'
                    missing_pos.append(j)
                else:
                    input_data += '|' + tuple_d[j]

            input_data += '|\n'
            # print(input_data)
            answer_format = answer_format[:-2] + '}'
            input_data = input_data.replace('Red, Green & White}|', 'Red, Green & White|')
            input_data = input_data.replace('Green & Gold}|', 'Green & Gold|')
            input_data = input_data.format(answer_format=answer_format)
            
            
            # Adding retrieved tables
            input_data += 'Retrieved Tables:\n'
            retrieved_tables = topK_results[str(t_id)]
            for rank, docid in enumerate(retrieved_tables):
                input_data += 'Table ' + str(rank+1) + ': ' + convert_to_table(docid, collection[docid]) + '\n\n'

            print("---------------------------------------------------")
            print(f"Input: \n{input_data}")

            output = chat(input_data, model="gpt-4", temperate=0.3)

            print(f"Output: \n{output}")
                
            fout = jsonlines.open('./results/WikiTuples/GPT4_wikituples_with_retrieved_tuples_by_monoT5.jsonl', 'a')
            fout.write({'tuple_id':t_id, 'input': input_data, 'output': output})
            fout.close()
                

In [None]:
import json
import jsonlines
import time
import ast

processed_tuples = {}
with open('./results/WikiTuples/GPT4_wikituples_with_retrieved_tuples_by_monoT5.jsonl', 'r') as f:
    for line in f:
        line = json.loads(line)
        _id = line['tuple_id']
        processed_tuples[_id] = line
        

with open('/Users/yichendezaizai/Data_Imputation/data/wikituples/final_data/folds.json', 'r') as f:
    folds = json.load(f)
    test_qids = folds['test']

def process_answer(answer_set):
    return [aa.lower().replace('_', ' ') for aa in answer_set]

count, acc = 0,0
with open('../data/wikituples/missing_tables.jsonl', 'r') as f:
    for line in f:
        
        line = json.loads(line)
        table_id = line['_id']
        pgTitle, sectionTitle, tableCaption = line['pgTitle'], line['sectionTitle'], line['tableCaption']
        tuple_ids = line['tuple_id']
        tableData = line['tableData']
        headers = line['processed_tableHeaders']
        ground_truth = line['ground_truth']
        caption = 'caption:' + line['pgTitle'] + ' | ' + line['sectionTitle'] + ' | ' + line['tableCaption']
        
        for index, t_id in enumerate(tuple_ids):
            if t_id not in processed_tuples:
                continue
            
            print(t_id)
            if t_id not in test_qids:
                continue

            tuple_d = tableData[index]
            tuple_answer = ground_truth[index]
            
            missing_pos = []
            for j in range(len(headers)):
                if tuple_d[j] == 'N/A':
                    missing_pos.append(j)
            count += len(missing_pos)

            line = processed_tuples[t_id]

            if 'imputed_data' not in line:
                output = line['output']
                
                imputed_data = ast.literal_eval(output)

                imputed_values = list(imputed_data.values())
                # print(imputed_values)

                for mm, pos in enumerate(missing_pos):
                    try:
                        hint = headers[pos]
                        cell_value = imputed_data[hint]
                    except:
                        try:
                            cell_value = imputed_values[mm]
                        except:
                            continue

                    if type(tuple_answer[pos]) == list:
                            entity_id = tuple_answer[pos][0]
                            answer_set = set(entityid_to_text[entity_id])
                    else:
                        answer_set = tuple_answer[pos]
                    
                    answer_set = process_answer(answer_set)

                    if type(cell_value) is list:
                        for cc in cell_value:
                            cell_value = cc.lower().replace('_', ' ')
                            if any(cell_value in answer for answer in answer_set):
                                acc += 1
                    else:
                        cell_value = cell_value.lower().replace('_', ' ')
                        # 如果cell_value在任意一个answer_set中，就算正确
                        if any(cell_value in answer for answer in answer_set):
                            acc += 1

            else:

                imputed_data = line['imputed_data']

                # print(f"imputed_data: {imputed_data}")
        
                for mm, pos in enumerate(missing_pos):

                    try:
                        cell_value = imputed_data[pos]
                    except:
                        continue
                        
                    if type(tuple_answer[pos]) == list:
                        entity_id = tuple_answer[pos][0]
                        answer_set = set(entityid_to_text[entity_id])

                    else:
                        answer_set = tuple_answer[pos]
                    
                    answer_set = process_answer(answer_set)
                    if type(cell_value) is list:
                        for cc in cell_value:
                            cell_value = cc.lower().replace('_', ' ')
                            if any(cell_value in answer for answer in answer_set):
                                acc += 1
                    else:
                        cell_value = cell_value.lower().replace('_', ' ')
                        # 如果cell_value在任意一个answer_set中，就算正确
                        if any(cell_value in answer for answer in answer_set):
                            acc += 1


accuaracy = round(acc/count, 3)
print(f"Imputed Accuracy: {accuaracy} on {len(test_qids)} tuples")
