# Setup

In [1]:
!pip install termcolor
!pip install openai

import os
import re
import csv
import time
import openai
import requests
import random
import concurrent.futures


from termcolor import colored
from tqdm.notebook import tqdm_notebook



In [2]:
openai.api_type = "azure"
openai.api_version = "2023-05-15"
openai.api_base = ""
openai.api_key = ""

# Function Definition

## OpenAI Function

In [3]:
def get_completion(myPrompt):
    my_engine="gpt4"
    response = openai.Completion.create(
      engine = my_engine,
      prompt = myPrompt,
      temperature = 0.1,
      max_tokens = 200,
      frequency_penalty = 0.2,
      presence_penalty = 0.2,
      stop = None) 
    return response

## LamAPI Functions

In [4]:
s = requests.Session()

headers = {
    "accept": "application/json",
    "Content-Type": "application/json"
}

url = "https://lamapi.inside.disco.unimib.it/"

# Use LAMAPI to retrieve some entities from wikidata that could be suitable for the cell entity
def lamapi_retrieval(cell):
    cell = cell.strip()
    params = {
        'token': "insideslab-lamapi-2022",
        'name': cell,
        'kg': "wikidata",
        'limit': 3
        }
    return s.get("https://lamapi.inside.disco.unimib.it/lookup/entity-retrieval", headers=headers, params=params).json()[cell.lower()]

# Use LAMAPI to identify the column type
def lamapi_cta(column):
    params = {
        'token': "insideslab-lamapi-2022",
        }
    return s.post(url + "sti/column-analysis", headers=headers, params=params, json={"json":[column] }).json()

#Use LAMAPI to retrieve the entity label from the entity id in wikidata
def lamapi_entity_name(id):
    params = {
        'token': "insideslab-lamapi-2022",
        'kg': "wikidata",
        'lang':'en'
        }
    return s.post(url + "entity/labels", headers=headers, params=params, json={"json":[id]}).json()["wikidata"]

## Prompt Creation Functions 

### Prompt with pool

In [5]:
def create_cea_prompt(cea_table, cea_pool_list):
    cea_pool = ", ".join(map(str, cea_pool_list))
    return """
    For each row r in table T:
        For each field c in row r:
            # Assume that the field contains an entity mention
            entity_mention = r[c]
            # Initialize a variable to track the best match
            best_match = null
            max_score = 0
            # Search for matches in the entity pool P
            For each entity p in pool P:
                # Assume the function calculate_similarity_score returns a value representing the likelihood that the entity p is semantically appropriate for the entity_mention
                score = calculate_similarity_score(entity_mention, p)
                # Update the best match if the score is higher than the current maximum
                If score > max_score:
                    max_score = score
                    best_match = p
            # Assign the best match to the mention in the entity in the table
            r[c] = best_match
    
    #Example usage:
    table = """+cea_table+"""
    pool = """+cea_pool+"""
    Result = 
    """

In [6]:
# Take a table and return a iterable dictionary of its columns 
def create_columns(table):
    columns = {}
    for row in table.split("\n"):
        for index, cell in enumerate(row.split(",")):   
            if index not in columns:
                columns[index] = [cell]
            else:
                columns[index].append(cell)
    return columns

In [7]:
#randomly choice 2 entities among the groundtruth entities in order to append them to the pool 
def random_choice(list):  
    while True:
        if len(list) < 2:
            response_0 = lamapi_entity_name(list[0])
            if response_0 != {}:
                lab_entity_0 = response_0[list[0]]["labels"]["en"]
                return list[0] + " " + lab_entity_0
            else:
                return 0
        else:   
            index_1, index_2 = random.sample(range(len(list)), 2)
            response1 = lamapi_entity_name(list[index_1])
            response2 = lamapi_entity_name(list[index_2])
            if response1 != {}:
                if response2 != {}:
                    break
    if response1[list[index_1]]["labels"].get("en") is not None:
        lab_entity_1 = response1[list[index_1]]["labels"]["en"]
    else:
        lab_entity_1 = "generic"
    if response2[list[index_2]]["labels"].get("en") is not None:
        lab_entity_2 = response2[list[index_2]]["labels"]["en"]
    else:
        lab_entity_2 = "generic"
    return list[index_1]+" "+lab_entity_1,list[index_2]+" "+lab_entity_2

In [8]:
# Create a pool of suitable entities for the table 
def create_pool(table, table_id):
    columns = create_columns(table)
    cells_list = []
    pool = []
    for key, value in columns.items():
        if lamapi_cta(value)["0"]["tag"] == "NE":
            for cell in value:
                cells_list.append(cell)
            with concurrent.futures.ThreadPoolExecutor() as executor:
                temp_entities = list(executor.map(lamapi_retrieval, cells_list))
            for newlist in temp_entities:
                for index in range(len(newlist)):
                    id_ent = newlist[index]["id"]
                    name_ent = newlist[index]["name"]
                    pool.append(f"{id_ent} {name_ent}")
                pool.extend(random_choice(list(gt_dict[table_id].values())))
    return pool

### Prompt without pool

In [9]:
def create_cea_prompt_no_pool(cea_table):
    return """
    T = …
    For each row r in table T:
        For each field c in row r:
            # Assume that the field contains an entity mention
            entity_mention = r[c]
    
            # Initialize a variable to track the best match
            best_match = null
            max_score = 0
    
            # assume the get_entities_from_wikidata return entities from Wikidata that can be suitable for the entity_mention
            candidate_entities = get_entities_from_wikidata(entity_mention)
    
            # Search for matches among candidate entities
            For each wikidata_entity in candidate_entities:
                # Calculate a similarity score between the mention in the table and the entity in Wikidata
                score = calculate_similarity_score(entity_mention, wikidata_entity)
    
                # Update the best match if the score is higher than the current maximum
                If score > max_score:
                    max_score = score
                    best_match = wikidata_entity
    
            # Assign the best match to the mention in the entity in the table
            r[c] = best_match
    
    #Example usage:
    table = """+cea_table+"""
    Result = 
    """

## Prompt Execution Function

In [10]:
# Execute each prompt from the prompt list
def execute_prompt(prompt_dict):
    list_exception = []
    generation_dict = {}
    i = 0
    max_min = 120
    start_time = time.time() 
    for key, value in tqdm_notebook(prompt_dict.items()):
        elapsed_time = time.time() - start_time
        if elapsed_time > (max_min * 60): 
            break
        generation_dict[key] = {}
        time.sleep(0.3)
        try:
            generation_dict[key] = get_completion(value)["choices"][0]["text"]
        except Exception as e:
            i += 1
            list_exception.append(e)
    print(str(i)+" Prompt Execution Error")
    print(list_exception)
    return generation_dict

## Parsing Functions

In [11]:
# Extract the row and column number with the entity associated
def entity_extractor(table_annotation, table_id):
    result = []
    for row_num, row in enumerate(table_annotation, start=1):
        column = row.split(',')
        for column_num, entity in enumerate(column):
            match = re.search(r'Q\d+', entity)
            if match:
                wiki_entity = match.group()
                result.append(f"{table_id}, {row_num}, {column_num}, {wiki_entity}")
    if len(result) == 0:
        result.append(f"{table_id}, No Wikidata Entities Found")
    return result

In [12]:
# Parse the gpt response to keep only the table annotations
def parser(raw_table_annotation):
    if raw_table_annotation != {}:
        rows = raw_table_annotation.split("\n")
        result = []
        found_empty_row = False    
        for row in rows:
            row = row.strip()
            row = row.replace("\t", ",")
            if row == "":
                found_empty_row = True
            elif row == "<|im_end|>":
                found_empty_row = True
            elif found_empty_row:
                break
            elif row.endswith("<|im_end|>"):
                result.append(row.replace("<|im_end|>", ""))
            else:
                result.append(row)
        return result

## CSV Writing Function

In [13]:
# Save the annotation into a csv file
def save_to_csv(file_name, list):
    with open(file_name, mode='a', newline='') as file_csv:
        writer = csv.writer(file_csv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        for element in list:
            writer.writerow([element])

# GT_Table Dictionary Creation 

In [14]:
#Create the groundtruth dictionary 
cea_gt_path= [
    "./datasets/HardTablesR1/DataSets/HardTablesR1/Valid/gt/cea_gt.csv",
    "./datasets/HardTablesR2/DataSets/HardTablesR2/Valid/gt/cea_gt.csv",
    "./datasets/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round1_gt.csv",
    "./datasets/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round2_gt.csv",
    "./datasets/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round3_gt.csv",
    "./datasets/SemTab2020_Table_GT_Target/GT/CEA/CEA_Round4_gt.csv",
    "./datasets/WikidataTables2023R1/DataSets/Valid/gt/cea_gt.csv"
]

gt_dict = {}

for cea_gt in tqdm_notebook(cea_gt_path):
    with open(cea_gt) as csvfile:
        spamreader = csv.reader(csvfile, delimiter=',')
        next(spamreader)
        for row in spamreader:
            table_id = row[0]
            row_id = row[1]
            column_id = row[2]
            entity = row[3]
            if table_id not in gt_dict:
                gt_dict[table_id] = {}
                gt_dict[table_id][f"{row_id}_{column_id}"] = entity.replace("http://www.wikidata.org/entity/", "")
            else:
                gt_dict[table_id][f"{row_id}_{column_id}"] = entity.replace("http://www.wikidata.org/entity/", "")

  0%|          | 0/7 [00:00<?, ?it/s]

# Table Prompts Creation

## Prompt with pool

In [15]:
#create the prompt with pool for each table
tables_path= [
    "./datasets/HardTablesR1/DataSets/HardTablesR1/Valid/tables",
    "./datasets/HardTablesR2/DataSets/HardTablesR2/Valid/tables",
    "./datasets/WikidataTables2023R1/DataSets/Valid/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round1/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round2/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round3/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round4/tables"
]

prompt_dict = {}
max_min = 20

# for each path
for table_path in tqdm_notebook(tables_path):
    # for each table in path
    start_time = time.time() 
    for table in tqdm_notebook(os.listdir(table_path)):
        # open CSV table at table_path/table
        elapsed_time = time.time() - start_time
        if elapsed_time > (max_min * 60):  # Converti i minuti in secondi
            print(f"L'iterazione del percorso {table_path} ha superato {max_min} minuti.")
            break
        else:
            table_id = table.replace(".csv", "")
            if table.endswith('.csv') and gt_dict.get(table_id):
                prompt_dict[table_id]={}
                with open(table_path + '/' + table) as csvfile:
                    # create a csv.reader object to read the csv file
                    spamreader = csv.reader(csvfile, delimiter=',')
                    # skip the first header row and point to the second one
                    next(spamreader)
                    table_format = ""
                    # for each row in csv.reader object 
                    for row in spamreader:   
                       # transforms the current row into a string by joining its elements with commas
                       current_row = ",".join(row)
                       # create the table by joining each string row
                       table_format += f"{current_row}\n"
                    table_pool = list(set(create_pool(table_format, table_id)))
                    prompt_dict[table_id] = create_cea_prompt(table_format, table_pool)
    print(table_path + "\n")

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/201 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Prompt without pool 

In [None]:
#create the prompt without pool for each table
tables_path= [
    "./datasets/HardTablesR1/DataSets/HardTablesR1/Valid/tables",
    "./datasets/HardTablesR2/DataSets/HardTablesR2/Valid/tables",
    "./datasets/WikidataTables2023R1/DataSets/Valid/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round1/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round2/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round3/tables",
    "./datasets/SemTab2020_Table_GT_Target/Round4/tables"
]

prompt_dict_no_pool = {}
max_min = 20

# for each path
for table_path in tqdm_notebook(tables_path):
    # for each table in path
    start_time = time.time() 
    for table in tqdm_notebook(os.listdir(table_path)):
        # open CSV table at table_path/table
        elapsed_time = time.time() - start_time
        if elapsed_time > (max_min * 60):  # Converti i minuti in secondi
            print(f"L'iterazione del percorso {table_path} ha superato {max_min} minuti.")
            break
        else:
            table_id = table.replace(".csv", "")
            if table.endswith('.csv') and gt_dict.get(table_id):
                prompt_dict_no_pool[table_id]={}
                with open(table_path + '/' + table) as csvfile:
                    # create a csv.reader object to read the csv file
                    spamreader = csv.reader(csvfile, delimiter=',')
                    # skip the first header row and point to the second one
                    next(spamreader)
                    table_format = ""
                    # for each row in csv.reader object 
                    for row in spamreader:   
                       # transforms the current row into a string by joining its elements with commas
                       current_row = ",".join(row)
                       # create the table by joining each string row
                       table_format += f"{current_row}\n"
                    prompt_dict_no_pool[table_id] = create_cea_prompt_no_pool(table_format)
    print(table_path + "\n")

# Prompt with Pool Execution

In [None]:
#Annotate each table using GPT
raw_annotations = execute_prompt(prompt_dict) 

## Parsing and Writing Annotations

In [None]:
for key, value in tqdm_notebook(raw_annotations.items()):
    if value != {}:
        parsed_result = parser(value)
        annotated_entities = entity_extractor(parsed_result, key)
        save_to_csv("output.csv", annotated_entities)      

# Prompt without Pool Execution

In [None]:
#Annotate each table using GPT
raw_annotations_no_pool = execute_prompt(prompt_dict_no_pool) 

In [None]:
for key, value in tqdm_notebook(raw_annotations_no_pool.items()):
    if value != {}:
        parsed_result = parser(value)
        annotated_entities = entity_extractor(parsed_result, key)
        save_to_csv("output_no_pool.csv", annotated_entities)

# Prompt with Pool Output Analysis

## Stats Analysys

In [None]:
#CEA groundtruth path
ann_out= "./output.csv"

ann_dict = {}

with open(ann_out) as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    next(spamreader)
    for row in spamreader:
        for element in row:
            my_list = element.split(",")
            table_id = my_list[0]
            if(my_list[1] != " No Wikidata Entities Found"):
                row_id = my_list[1].strip()
                column_id = my_list[2].strip()
                entity = my_list[3].strip()
                if table_id not in ann_dict:
                    ann_dict[table_id] = {}
                    ann_dict[table_id][f"{row_id}_{column_id}"] = entity
                else:
                    ann_dict[table_id][f"{row_id}_{column_id}"] = entity

In [None]:
right_cell = 0
wrong_cell = 0
missing_cell = 0
missing_table = 0
entire_table = 0 
n_cell = 0

for key, value in gt_dict.items():
    if key in ann_dict:
        n_cell += len(gt_dict[key])
        if gt_dict[key] == ann_dict[key]:  
            entire_table += 1
            right_cell += len(ann_dict[key])
        else:
            for gt_key, gt_value in gt_dict[key].items():
                check = 0
                try:
                    num_cel_tab = len(gt_dict[key])
                    if ann_dict[key][gt_key] == gt_value:
                        right_cell += 1
                        check += 1
                    else:
                        wrong_cell += 1                        
                except:
                    missing_cell += 1
                if len(gt_dict[key]) == len(ann_dict[key]):
                    if check == num_cel_tab:
                        entire_table += 1


print(colored("Output Analysis Prompt with Pool",'red', attrs=['bold']))
print("")
print(colored("Number of tables: ", 'red', attrs=['bold'])+ str(len(ann_dict)))  
print(colored("Tabled correctly annotated: ", 'blue')+str(entire_table) + " ("+str(round(entire_table/len(ann_dict)*100, 2))+"%)")
print("")
print(colored("Number of annotable cells: ", 'red', attrs=['bold'])+ str(n_cell))
print(colored("Cells correctly annotated: ", 'blue')+str(right_cell)+" ("+str(round(right_cell/n_cell*100, 2))+"%)")
print(colored("Cells misannotated: ", 'blue')+str(wrong_cell)+" ("+str(round(wrong_cell/n_cell*100, 2))+"%)")       
print(colored("Missing cells' annotation: ", 'blue')+str(missing_cell)+" ("+str(round(missing_cell/n_cell*100, 2))+"%)")

## Table Analysis

In [None]:
key = "QI8FPOEX"
tab_right_cell = 0
tab_wrong_cell = 0
tab_missing_cell = 0
print(colored("\nGROUNDTRUTH ANNOTATIONS: " + key , 'red', attrs=['bold']))
if key in gt_dict:
    value = gt_dict[key]
    i = 0
    temp_list = []
    for col_row, value in gt_dict[key].items():
        temp_list.append(col_row + ":" + value)
        if(len(temp_list) == 3):
            print(temp_list)
            temp_list = []
        if temp_list:
            print(temp_list)
else:
    print(colored("Do no exist gt table", 'red'))
print(colored("\nGPT ANNOTATIONS: ", 'red', attrs=['bold']))
if key in ann_dict:
    for col_row, value in ann_dict[key].items():
        temp_list.append(col_row + ":" + value)
        if(len(temp_list) == 3):
            print(temp_list)
            temp_list = []
        try:
            if gt_dict[key][col_row] == value:
                tab_right_cell += 1
            else:
                tab_wrong_cell += 1                        
        except:
                tab_missing_cell += 1
    if temp_list:
        print(temp_list)
    print(colored("\nCorrect cells annotation: ", 'blue') +str(tab_right_cell))
    print(colored("Incorrect cells annotation: ", 'blue') +str(tab_wrong_cell))
    print(colored("Missing cells annotation: ", 'blue') +str(tab_missing_cell))
else:
    print(colored("Do no exist annotations for this table", 'red'))
       



# Prompt without Pool Output Analysis

## Stats Analysys

In [None]:
#CEA groundtruth path
ann_out= "./output_no_pool.csv"

ann_dict_no_pool = {}

with open(ann_out) as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    next(spamreader)
    for row in spamreader:
        for element in row:
            my_list = element.split(",")
            table_id = my_list[0]
            if(my_list[1] != " No Wikidata Entities Found"):
                row_id = my_list[1].strip()
                column_id = my_list[2].strip()
                entity = my_list[3].strip()
                if table_id not in ann_dict_no_pool:
                    ann_dict_no_pool[table_id] = {}
                    ann_dict_no_pool[table_id][f"{row_id}_{column_id}"] = entity
                else:
                    ann_dict_no_pool[table_id][f"{row_id}_{column_id}"] = entity

In [None]:
right_cell = 0
wrong_cell = 0
missing_cell = 0
missing_table = 0
entire_table = 0 
n_cell = 0

for key, value in gt_dict.items():
    if key in ann_dict_no_pool:
        n_cell += len(gt_dict[key])
        if gt_dict[key] == ann_dict_no_pool[key]:  
            entire_table += 1
            right_cell += len(ann_dict_no_pool[key])
        else:
            for gt_key, gt_value in gt_dict[key].items():
                check = 0
                try:
                    num_cel_tab = len(gt_dict[key])
                    if ann_dict_no_pool[key][gt_key] == gt_value:
                        right_cell += 1
                        check += 1
                    else:
                        wrong_cell += 1                        
                except:
                    missing_cell += 1
                if len(gt_dict[key]) == len(ann_dict_no_pool[key]):
                    if check == num_cel_tab:
                        entire_table += 1


print(colored("Output Analysis Prompt without Pool",'red', attrs=['bold']))
print("")

print(colored("Number of tables: ", 'red', attrs=['bold'])+ str(len(ann_dict_no_pool)))  
print(colored("Tabled correctly annotated: ", 'blue')+str(entire_table) + " ("+str(round(entire_table/len(ann_dict)*100, 2))+"%)")
print("")
print(colored("Number of annotable cells: ", 'red', attrs=['bold'])+ str(n_cell))
print(colored("Cells correctly annotated: ", 'blue')+str(right_cell)+" ("+str(round(right_cell/n_cell*100, 2))+"%)")
print(colored("Cells misannotated: ", 'blue')+str(wrong_cell)+" ("+str(round(wrong_cell/n_cell*100, 2))+"%)")       
print(colored("Missing cells' annotation: ", 'blue')+str(missing_cell)+" ("+str(round(missing_cell/n_cell*100, 2))+"%)")

## Table Analysis

In [None]:
key = "ZK2S0Y91"
value = gt_dict[key]

i = 0
temp_list = []
print(colored("\nGROUNDTRUTH ANNOTATIONS: " + key , 'red', attrs=['bold']))
for col_row, value in gt_dict[key].items():
    temp_list.append(col_row + ":" + value)
    if(len(temp_list) == 3):
        print(temp_list)
        temp_list = []
if temp_list:
    print(temp_list)
print(colored("\nGPT ANNOTATIONS: ", 'red', attrs=['bold']))
if key in ann_dict_no_pool:
    for col_row, value in ann_dict_no_pool[key].items():
        temp_list.append(col_row + ":" + value)
        if(len(temp_list) == 3):
            print(temp_list)
            temp_list = []
        try:
            if gt_dict[key][col_row] == value:
                tab_right_cell += 1
            else:
                tab_wrong_cell += 1                        
        except:
                tab_missing_cell += 1
    if temp_list:
        print(temp_list)
    print(colored("\nCorrect cells annotation: ", 'blue') +str(tab_right_cell))
    print(colored("Incorrect cells annotation: ", 'blue') +str(tab_wrong_cell))
    print(colored("Missing cells annotation: ", 'blue') +str(tab_missing_cell))
else:
    print(colored("Do no exist annotations for this table", 'red'))
       

