In [None]:
import json
import numpy as np
import gzip
import pickle
import random

## 1. Read data and CTC predicted results.

In [None]:
## Read original table and question data, adjust the file path accordingly
train_tables = json.load(open("path_to_IM-TQA-v1.0/train_tables.json"))
test_tables = json.load(open("path_to_IM-TQA-v1.0/test_tables.json"))
train_questions = json.load(open("path_to_IM-TQA-v1.0/train_questions.json"))
test_questions = json.load(open("path_to_IM-TQA-v1.0/test_questions.json"))
print("train table num:", len(train_tables))
print("test table num:", len(test_tables))
print("train question num:", len(train_questions))  
print("test question num:", len(test_questions))

In [None]:
# table_id --> table_item
table_id_to_table = {}
for table_item in train_tables + test_tables:
    table_id = table_item['table_id']
    table_id_to_table[table_id] = table_item

In [None]:
## Read RGCN CTC model predicted result which is generated by 'train_ctc_task.py', which is a dict with table_id as key and pred_results as values.
## The data format looks like:
"""{
    "table_id": {   
                    'logits': array([[ -0.80221117,   7.304468  ,  -0.7408706 ,  -4.7407093 , -1.7058028 ], 
                                    ...... ,
                                        [ -0.36303076,  -1.330893  ,  14.02634   ,  -8.761856  , -3.537701  ]], dtype=float32), # output logits of RGCN CTC model, shape = [Cell Num, 5]
                    'preds': 'preds': [1, 2, 0, 2, ...... , 0], # predicted cell type label list, shape = [Cell Num]
                    'labels': [1, 2, 0, 2, ...... , 0], # ground truth cell type label list, shape = [Cell Num]
                }
    }
"""

train_pred_results = pickle.load(open('path_to_CTC_pred_results/ctc_train_pred_results.pkl','rb'))
test_pred_results = pickle.load(open('path_to_CTC_pred_results/ctc_test_pred_results.pkl','rb'))
print("CTC pred results num of train tables:", len(train_pred_results))
print("CTC pred results num of test tables:", len(test_pred_results))

In [None]:
# table_id --> CTC predicted results
all_pred_results={}
all_pred_results.update(train_pred_results)
all_pred_results.update(test_pred_results)
print(len(all_pred_results))

In [None]:
# table_id --> predicted header cell ids
table_id_to_pred_header_cell_ids = {}
for table_id in all_pred_results:
    preds = all_pred_results[table_id]['preds']
    item = {}
    item['row_attribute'] = []
    item['column_attribute'] = []
    item['row_index'] = []
    item['column_index'] = []
    # label type should be consists with the create_label_list() in the train_ctc_task.py
    for cell_id,pred in enumerate(preds):
        if pred == 1:
            item['column_attribute'].append(cell_id)
        elif pred == 2:
            item['row_attribute'].append(cell_id)
        elif pred == 3:
            item['column_index'].append(cell_id)
        elif pred == 4:
            item['row_index'].append(cell_id)
        else:
            continue
    table_id_to_pred_header_cell_ids[table_id] = item
print(len(table_id_to_pred_header_cell_ids))

## 2. Define Class for building RCI row and column representation

In [None]:
# Modified from the RCI code: https://github.com/IBM/row-column-intersection/blob/main/datasets/tables2seq_pair.py
class RCIInst:
    __slots__= 'id', 'text_a', 'text_b', 'label','question_type'
    def __init__(self, inst_id, text_a, text_b, label, question_type):
        self.id = inst_id
        self.text_a = text_a
        self.text_b = text_b
        self.label = label
        self.question_type = question_type
    def to_dict(self):
        return {'id': self.id, 'text_a': self.text_a, 'text_b': self.text_b, 'label': self.label,'question_type':self.question_type}

class CTC_RowColumnConvert:
    def __init__(self):
        self.cell_sep_token = '*'          # separator token between different cells, used to use '|' but albert doesn't have it
        self.cell_value_sep_token = '：'    # separator token between header cells and cell text, in RCI, this should be ":", we change it to "：" for chinese data
        self.answer_in_header = False
        self.negative_sample_rate = 1.0
        self.per_table_negatives = -1
        self.max_cells_per_col = -1
        self.max_cell_char_length = 200    # max char length of cell text
        self.row_pos_count = 0
        self.col_pos_count = 0
        self.row_neg_count = 0
        self.col_neg_count = 0
        self.all_neg_count = 0
        self.all_seq_pair_count = 0
        self.unanswerable = 0
        self.multi_answer = 0
        self.single_answer = 0
        
        self.cell_sep = f' {self.cell_sep_token} '
        self.cell_value_sep = f' {self.cell_value_sep_token} '
        self.bad_question_id = []

    def build_rows(self, layout, cell_values, answer_list):
        target_rows = []  # row_id of answer cell
        target_columns = [] # columnd_id of answer cell
        table_rows = []  # a nested list of table rows

        for i in range(len(layout)):
            one_row = []
            for j in range(len(layout[0])):
                cell_id = layout[i][j]
                cell_text = cell_values[cell_id][:self.max_cell_char_length]
                one_row.append(cell_text)
                if cell_id in answer_list:  # 
                    target_rows.append(i)
                    target_columns.append(j)
            table_rows.append(one_row)
        assert len(table_rows) == len(layout)
        
        return table_rows, target_rows, target_columns  
    
    def get_related_row_header_and_index(self, cell_values, cell_ids_on_the_left_in_one_row, row_attribute, row_index):
        """
        Extract related row attibutes and row indexes from the left of a cell.
        """
        # extract related row attributes
        related_left_headers = []
        row_header_flag = 0
        for left_cell_id in cell_ids_on_the_left_in_one_row:   # extract the nearest and successive row attribute cells on the left 
            if (left_cell_id in row_attribute) and (left_cell_id not in related_left_headers):
                related_left_headers.append(left_cell_id)
                row_header_flag = 1
            if (left_cell_id not in row_attribute) and (row_header_flag == 1):
                break
        related_left_headers = related_left_headers[::-1]  # convert to normal order
        related_left_headers = [ cell_values[cell_id] for cell_id in related_left_headers ]
        related_row_attribute_str = ' '.join(related_left_headers)  # repr string of related and nearest row attribute cells
        # extract related row indexes
        related_left_indexs = []
        row_index_flag = 0
        for left_cell_id in cell_ids_on_the_left_in_one_row: # extract the nearest and successive row index cells on the left
            if (left_cell_id in row_index) and (left_cell_id not in related_left_indexs):
                related_left_indexs.append(left_cell_id)
                row_index_flag = 1
            if (left_cell_id not in row_index) and (row_index_flag == 1):
                break
        related_left_indexs = related_left_indexs[::-1]  # convert to normal order
        related_left_indexs = [ cell_values[cell_id] for cell_id in related_left_indexs ]
        related_row_index_str = ' '.join(related_left_indexs)     # repr string of related and nearest row index cells
        return related_row_attribute_str, related_row_index_str
    
    def get_related_col_header_and_index(self, cell_values, cell_ids_on_the_top_in_one_col, col_attribute, col_index):
        """
        Extract related column attibutes and column indexes from the top of a cell.
        """
        # extract related column attributes
        related_top_headers = []
        col_header_flag = 0
        for top_cell_id in cell_ids_on_the_top_in_one_col: # extract the nearest and successive column attribute cells on the top
            if (top_cell_id in col_attribute) and (top_cell_id not in related_top_headers):
                related_top_headers.append(top_cell_id)
                col_header_flag = 1
            if (top_cell_id not in col_attribute) and (col_header_flag == 1):
                break
        related_top_headers = related_top_headers[::-1]  # convert to normal order
        related_top_headers = [cell_values[cell_id] for cell_id in related_top_headers]
        related_col_attribute_str = ' '.join(related_top_headers)   # repr string of related and nearest column attribute cells
        # extract related column indexes
        related_top_indexs = []
        col_index_flag = 0
        for top_cell_id in cell_ids_on_the_top_in_one_col:
            if (top_cell_id in col_index) and (top_cell_id not in related_top_indexs):
                related_top_indexs.append(top_cell_id)
                col_index_flag = 1
            if (top_cell_id not in col_index) and (col_index_flag == 1):
                break
        related_top_indexs = related_top_indexs[::-1] 
        related_top_indexs = [cell_values[cell_id] for cell_id in related_top_indexs]
        related_col_index_str = ' '.join(related_top_indexs)   # repr string of related and nearest column index cells
        return related_col_attribute_str, related_col_index_str
    
    def link_information_to_ori_tables(self, table_rows, table_item, pred_header_cell_ids):
        """
        Based on the CTC predicted results, build an enhanced table rows, which store the related headers of each cell.
        """
        layout = table_item['cell_ID_matrix']
        cell_values = table_item['chinese_cell_value_list']
        row_attribute = pred_header_cell_ids['row_attribute']  
        col_attribute = pred_header_cell_ids['column_attribute']  
        row_index = pred_header_cell_ids['row_index']     
        col_index = pred_header_cell_ids['column_index']
        pure_header_list = row_attribute + col_attribute
        row_num = len(layout)
        col_num = len(layout[0])
        enhanced_table_rows = []
        table2d = np.array(layout,dtype='int64')
        for i in range(row_num):
            one_row = []
            for j in range(col_num):
                cell_id = layout[i][j]
                cell_text = table_rows[i][j]  
                cell_ids_on_the_left_row = table2d[i,:j].tolist()[::-1]         # cell ids on the left in the same row, reverse order
                cell_ids_on_the_top_col = table2d[:i,j].tolist()[::-1]         # cell ids on the top in the same column, reverse order
                if cell_id in pure_header_list or len(cell_text) == 0:   # skip the attribute header cells and empty cells
                    item = {}
                    item['ori_text'] = cell_text
                    item['row_attr'] = ''
                    item['row_index'] = ''
                    item['col_attr'] = ''
                    item['col_index'] = ''
                    one_row.append(item)
                else:
                    # for index cells and pure data cells, extract their related row attribute and row index representation on the left
                    row_attr_str, row_index_str = self.get_related_row_header_and_index(cell_values,cell_ids_on_the_left_row,row_attribute,row_index)
                    # extract related column attribute and column index representations
                    col_attr_str, col_index_str = self.get_related_col_header_and_index(cell_values,cell_ids_on_the_top_col,col_attribute,col_index)
                    item = {}
                    item['ori_text'] = cell_text
                    item['row_attr'] = row_attr_str
                    item['row_index'] = row_index_str
                    item['col_attr'] = col_attr_str
                    item['col_index'] = col_index_str
                    one_row.append(item)
            enhanced_table_rows.append(one_row)
        return enhanced_table_rows

    def convert(self, question_item, table_item, pred_header_cell_ids):
        """
        Build enhanced RCI row and column representations based on the CTC results.
        """
        sample_id = question_item['question_id']
        question = question_item['chinese_question']
        answer_cell_list = question_item['answer_cell_list']
        question_type = question_item['question_type']
        layout = table_item['cell_ID_matrix']
        cell_values = table_item['chinese_cell_value_list']
        row_attribute = pred_header_cell_ids['row_attribute']  
        col_attribute = pred_header_cell_ids['column_attribute']  
        row_index = pred_header_cell_ids['row_index']     
        col_index = pred_header_cell_ids['column_index']
        attribute_list = row_attribute+col_attribute
        row_examples = []
        column_examples = []
        table_rows,target_rows,target_columns = self.build_rows(layout, cell_values, answer_cell_list)
        # build enhanced table rows to find related header cell information
        enhanced_table_rows = self.link_information_to_ori_tables(table_rows, table_item, pred_header_cell_ids)
        
        row_pos = []
        row_neg = []
        col_pos = []
        col_neg = []
        if len(target_rows) == 0:
            print('problem sample_id:',sample_id)
            self.unanswerable += 1
            return [],[]
        
        # construct RCI row representation which is enhanced with related column attribute and column index
        for ri, enhanced_row in enumerate(enhanced_table_rows):  
            if ri in target_rows:
                is_pos = True
            else:
                is_pos = False
            
            one_row_cell_ids = layout[ri]
            
            if all([cell_id in col_attribute for cell_id in one_row_cell_ids if len(cell_values[cell_id])>0]): # all cells in one row are column attributes
                one_row_cell_texts = []  #
                for cell_item in enhanced_row:
                    cell_text = cell_item['ori_text']
                    if cell_text not in one_row_cell_texts:
                        one_row_cell_texts.append(cell_text)
                    else:
                        pass
                row_rep = '列属性：' + self.cell_sep.join(one_row_cell_texts)
                
            elif all([ cell_id in attribute_list for cell_id in one_row_cell_ids if len(cell_values[cell_id])>0]):  # cells in one row are row or column attributes
                one_row_cell_texts = []
                for cell_item in enhanced_row:
                    cell_text = cell_item['ori_text']
                    if cell_text not in one_row_cell_texts:
                        one_row_cell_texts.append(cell_text)
                    else:
                        pass
                row_rep = '混合属性构成的行：' + self.cell_sep.join(one_row_cell_texts)
            else:
                # for index cells and data cells in one row, merge their related column attribute and column index information
                one_row_cell_texts = []
                for cell_id, cell_item in zip(one_row_cell_ids, enhanced_row):
                    if (cell_id in row_attribute) and (cell_item['ori_text'] not in one_row_cell_texts):
                        one_row_cell_texts.append(cell_item['ori_text'])
                    elif cell_id in col_attribute:   # skip column attribute as it does not help understand this row
                        continue
                    else:
                        col_attr = cell_item['col_attr']  
                        col_index = cell_item['col_index']
                        ori_cell_text = cell_item['ori_text']
                        col_header_list = [col_index,col_attr]
                        col_information_str = ' '.join([header for header in col_header_list if len(header)>0 ])
                        if len(col_information_str) > 0:
                            final_cell_str = f'{col_information_str}：{ori_cell_text}'
                        else:
                            final_cell_str = ori_cell_text
                        if final_cell_str not in one_row_cell_texts:
                            one_row_cell_texts.append(final_cell_str)
                        else:
                            pass
                row_rep = self.cell_sep.join(one_row_cell_texts)
            row_example = RCIInst(inst_id=f'{sample_id}:{ri}', text_a=question, text_b=row_rep, label=is_pos, question_type=question_type)
            if is_pos:
                row_pos.append(row_example)
            else:
                row_neg.append(row_example)
        row_examples.extend(row_pos)
        row_examples.extend(row_neg)
        
        # construct RCI column representation which is enhanced with related row attribute and row index
        if self.max_cells_per_col>0:
            enhanced_table_rows = enhanced_table_rows[:self.max_cells_per_col]
        
        table2d = np.array(layout)
        col_num = len(layout[0])
        for ci in range(col_num):
            if ci in target_columns:
                is_pos = True
            else:
                is_pos = False
            one_col_cell_ids = table2d[:,ci].tolist() 
            enhanced_col = np.array(enhanced_table_rows)[:,ci].tolist() 
            
            # all cells in one column are row attributes
            if all([ cell_id in row_attribute for cell_id in one_col_cell_ids if len(cell_values[cell_id])>0]): 
                one_col_cell_texts = []
                for cell_item in enhanced_col:
                    cell_text = cell_item['ori_text']
                    if cell_text not in one_col_cell_texts:
                        one_col_cell_texts.append(cell_text)
                    else:
                        pass
                col_rep = '行属性：' + self.cell_sep.join(one_col_cell_texts)
            # cells in one column are row or column attributes  
            elif all([ cell_id in attribute_list for cell_id in one_col_cell_ids if len(cell_values[cell_id])>0 ]):
                one_col_cell_texts = []            
                for cell_item in enhanced_col:
                    cell_text = cell_item['ori_text']
                    if cell_text not in one_col_cell_texts:
                        one_col_cell_texts.append(cell_text)
                    else:
                        pass
                col_rep = '混合属性构成的列：' + self.cell_sep.join(one_col_cell_texts)
            # for index cells and data cells in one column, merge their related row attribute and row index information
            else:
                one_col_cell_texts = []
                for cell_id,cell_item in zip(one_col_cell_ids,enhanced_col):
                    if (cell_id in col_attribute) and(cell_item['ori_text'] not in one_col_cell_texts):
                        one_col_cell_texts.append(cell_item['ori_text'])
                    elif cell_id in row_attribute:    # skip row attribute as it does not help understand this column
                        continue
                    else:  
                        row_attr = cell_item['row_attr']    
                        row_index = cell_item['row_index']  
                        ori_cell_text = cell_item['ori_text']
                        row_header_list = [row_index,row_attr]
                        row_information_str = ' '.join([header for header in row_header_list if len(header)>0 ])
                        if len(row_information_str) > 0:
                            final_cell_str = f'{row_information_str}：{ori_cell_text}'
                        else:
                            final_cell_str = ori_cell_text
                        #final_cell_str=f'{row_information_str}：{ori_cell_text}'
                        if final_cell_str not in one_col_cell_texts:
                            one_col_cell_texts.append(final_cell_str)
                        else:
                            pass
                col_rep = self.cell_sep.join(one_col_cell_texts)
            col_example = RCIInst(inst_id=f'{sample_id}:{ci}', text_a=question, text_b=col_rep, label=is_pos,question_type=question_type)
            if is_pos:
                col_pos.append(col_example)
            else:
                col_neg.append(col_example)
        
        column_examples.extend(col_pos)
        column_examples.extend(col_neg)
        # update statistic
        self.row_pos_count += len(row_pos)
        self.row_neg_count += len(row_neg)
        self.col_pos_count += len(col_pos)
        self.col_neg_count += len(col_neg)
        self.all_neg_count += len(row_neg)+len(col_neg)
        self.all_seq_pair_count += len(row_examples)+len(column_examples)

        if len(row_pos) == 0 and len(col_pos)==0:
            print('unanswerable question: ',sample_id)
            self.unanswerable += 1  
        elif len(row_pos) > 1:
            self.multi_answer += 1
        else:
            self.single_answer += 1
        return row_examples, column_examples




## 3. Build RCI training and testing data

In [None]:
# build RCI seq pair of train and test samples
train_converter = CTC_RowColumnConvert()
test_converter = CTC_RowColumnConvert()
bad_train_question_id = []
bad_test_question_id = []
good_train_question_id = []
good_test_question_id = []

# process train samples
print("construct RCI seq pair for train samples, train sample number:%d "%(len(train_questions)))
train_row_seq_pairs = []
train_col_seq_pairs = []
for i,item in enumerate(train_questions):
    table_id = item['table_id']
    question_id = item['question_id']
    table = table_id_to_table[table_id]
    pred_header_cell_ids = table_id_to_pred_header_cell_ids[table_id]
    layout = table['cell_ID_matrix']
    row_num = len(layout)
    col_num = len(layout[0])
    row_examples,col_examples = train_converter.convert(item,table,pred_header_cell_ids)
    if len(row_examples)!=0 and len(col_examples)!=0:
        assert len(row_examples)==row_num
        assert len(col_examples)==col_num
        train_row_seq_pairs.extend(row_examples)
        train_col_seq_pairs.extend(col_examples)
        good_train_question_id.append(question_id)
    else:
        bad_train_question_id.append(question_id)
    
print("seq_pair sample num: %d"%(train_converter.all_seq_pair_count))
print("unanswerable question num: %d"%(train_converter.unanswerable))

# process test samples
print("-"*15)
print("construct RCI seq pair for test samples, test sample number:%d "%(len(test_questions)))
test_row_seq_pairs = []
test_col_seq_pairs = []
for i,item in enumerate(test_questions):
    table_id = item['table_id']
    question_id = item['question_id']
    table = table_id_to_table[table_id]
    pred_header_cell_ids = table_id_to_pred_header_cell_ids[table_id]
    layout = table['cell_ID_matrix']
    row_num = len(layout)
    col_num = len(layout[0])
    row_examples,col_examples = test_converter.convert(item,table,pred_header_cell_ids)
    if len(row_examples)!=0 and len(col_examples)!=0:
        test_row_seq_pairs.extend(row_examples)
        test_col_seq_pairs.extend(col_examples)
        good_test_question_id.append(question_id)
    else:
        bad_test_question_id.append(question_id)
print("seq_pair sample num: %d"%(test_converter.all_seq_pair_count))
print("unanswerable question num: %d"%(test_converter.unanswerable))

In [None]:
print("Train-row seq pair num:",len(train_row_seq_pairs))
print("Train-column seq pair num:",len(train_col_seq_pairs))
print("Test-row seq pair num:",len(test_row_seq_pairs))
print("Test-column seq pair num:",len(test_col_seq_pairs))

In [None]:
# save data
import gzip
import json
train_row_seq_pairs = [item.to_dict() for item in train_row_seq_pairs ]
train_col_seq_pairs = [item.to_dict() for item in train_col_seq_pairs ]
test_row_seq_pairs = [item.to_dict() for item in test_row_seq_pairs ]
test_col_seq_pairs = [item.to_dict() for item in test_col_seq_pairs ]

train_row_zip_file = gzip.open('./RCI_data/train_rows.jsonl.gz', 'wt', encoding='utf-8')
train_col_zip_file = gzip.open('./RCI_data/train_cols.jsonl.gz','wt', encoding='utf-8')
test_row_zip_file = gzip.open('./RCI_data/test_rows.jsonl.gz','wt',encoding='utf-8')
test_col_zip_file = gzip.open('./RCI_data/test_cols.jsonl.gz','wt',encoding='utf-8')

for line in train_row_seq_pairs:
    train_row_zip_file.write(json.dumps(line)+'\n')
for line in train_col_seq_pairs:
    train_col_zip_file.write(json.dumps(line)+'\n')
for line in test_row_seq_pairs:
    test_row_zip_file.write(json.dumps(line)+'\n')
for line in test_col_seq_pairs:
    test_col_zip_file.write(json.dumps(line)+'\n')

train_row_zip_file.close()
train_col_zip_file.close()
test_row_zip_file.close()
test_col_zip_file.close()

In [None]:
# randomly select one sample to show the resulting seq pair
random.sample(train_row_seq_pairs,1)[0].to_dict()