In [1]:
import json
import pandas as pd
import numpy as np
from transformers import BertTokenizer
from torch.utils.data import Dataset
import torch

In [2]:
import json
import pandas as pd

class Header:
    def __init__(self, names: list, types: list):
        self.names = names
        self.types = types

    def __getitem__(self, idx):
        return self.names[idx], self.types[idx]

    def __len__(self):
        return len(self.names)

    def __repr__(self):
        return ' | '.join(['{}({})'.format(n, t) for n, t in zip(self.names, self.types)])

class Table:
    def __init__(self, id, name, title, header: Header, rows, **kwargs):
        self.id = id
        self.name = name
        self.title = title
        self.header = header
        self.rows = rows
        self._df = None

    @property
    def df(self):
        if self._df is None:
            self._df = pd.DataFrame(data=self.rows,
                                    columns=self.header.names,
                                    dtype=str)
        return self._df

    def _repr_html_(self):
        return self.df._repr_html_()

class Tables:
    table_dict = None

    def __init__(self, table_list: list = None, table_dict: dict = None):
        self.table_dict = {}
        if isinstance(table_list, list):
            for table in table_list:
                self.table_dict[table.id] = table
        if isinstance(table_dict, dict):
            self.table_dict.update(table_dict)

    def push(self, table):
        self.table_dict[table.id] = table

    def __len__(self):
        return len(self.table_dict)

    def __add__(self, other):
        return Tables(
            table_list=list(self.table_dict.values()) +
            list(other.table_dict.values())
        )

    def __getitem__(self, id):
        return self.table_dict[id]

    def __iter__(self):
        for table_id, table in self.table_dict.items():
            yield table_id, table

class Question:
    def __init__(self, text):
        self.text = text

    def __repr__(self):
        return self.text

    def __getitem__(self, idx):
        return self.text[idx]

    def __len__(self):
        return len(self.text)

class SQL:
    op_sql_dict = {0: ">", 1: "<", 2: "==", 3: "!="}
    agg_sql_dict = {0: "", 1: "AVG", 2: "MAX", 3: "MIN", 4: "COUNT", 5: "SUM"}
    conn_sql_dict = {0: "NULL", 1: "AND", 2: "OR"}

    def __init__(self, cond_conn_op: int, agg: list, sel: list, conds: list, **kwargs):
        self.cond_conn_op = cond_conn_op
        self.sel = []
        self.agg = []
        sel_agg_pairs = zip(sel, agg)
        sel_agg_pairs = sorted(sel_agg_pairs, key=lambda x: x[0])
        for col_id, agg_op in sel_agg_pairs:
            self.sel.append(col_id)
            self.agg.append(agg_op)
        self.conds = sorted(conds, key=lambda x: x[0])

    @classmethod
    def from_dict(cls, data: dict):
        return cls(**data)

    def keys(self):
        return ['cond_conn_op', 'sel', 'agg', 'conds']

    def __getitem__(self, key):
        return getattr(self, key)

    def to_json(self):
        return json.dumps(dict(self), ensure_ascii=False, sort_keys=True)

    def equal_all_mode(self, other):
        return self.to_json() == other.to_json()

    def equal_agg_mode(self, other):
        self_sql = SQL(cond_conn_op=0, agg=self.agg, sel=self.sel, conds=[])
        other_sql = SQL(cond_conn_op=0, agg=other.agg, sel=other.sel, conds=[])
        return self_sql.to_json() == other_sql.to_json()

    def equal_conn_and_agg_mode(self, other):
        self_sql = SQL(cond_conn_op=self.cond_conn_op,
                       agg=self.agg,
                       sel=self.sel,
                       conds=[])
        other_sql = SQL(cond_conn_op=other.cond_conn_op,
                        agg=other.agg,
                        sel=other.sel,
                        conds=[])
        return self_sql.to_json() == other_sql.to_json()

    def equal_no_val_mode(self, other):
        self_sql = SQL(cond_conn_op=self.cond_conn_op,
                       agg=self.agg,
                       sel=self.sel,
                       conds=[cond[:2] for cond in self.conds])
        other_sql = SQL(cond_conn_op=other.cond_conn_op,
                        agg=other.agg,
                        sel=other.sel,
                        conds=[cond[:2] for cond in other.conds])
        return self_sql.to_json() == other_sql.to_json()

    def __eq__(self, other):
        raise NotImplementedError('compare mode not set')

    def __repr__(self):
        repr_str = ''
        repr_str += "sel: {}\n".format(self.sel)
        repr_str += "agg: {}\n".format([self.agg_sql_dict[a]
                                        for a in self.agg])
        repr_str += "cond_conn_op: '{}'\n".format(
            self.conn_sql_dict[self.cond_conn_op])
        repr_str += "conds: {}".format(
            [[cond[0], self.op_sql_dict[cond[1]], cond[2]] for cond in self.conds])

        return repr_str

    def _repr_html_(self):
        return self.__repr__().replace('\n', '<br>')

class Query:
    def __init__(self, question: Question, table: Table, sql: SQL = None):
        self.question = question
        self.table = table
        self.sql = sql

    def _repr_html_(self):
        repr_str = '{}<br>{}<br>{}'.format(
            self.table._repr_html_(),
            self.question.__repr__(),
            self.sql._repr_html_() if self.sql is not None else ''
        )
        return repr_str

def read_tables(table_file):
    tables = Tables()
    with open(table_file, encoding='utf-8') as f:
        for line in f:
            tb = json.loads(line)
            header = Header(tb.pop('header'), tb.pop('types'))
            table = Table(header=header, **tb)
            tables.push(table)
    return tables

def read_data(data_file, tables: Tables):
    queries = []
    with open(data_file, encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            question = Question(text=data['question'])
            table = tables[data['table_id']]
            if 'sql' in data:
                sql = SQL.from_dict(data['sql'])
            else:
                sql = None
            query = Query(question=question, table=table, sql=sql)
            queries.append(query)
    return queries


In [3]:
train_table_file = '../../TableQA-master/train/train.tables.json'
train_data_file = '../../TableQA-master/train/train.json'

In [4]:
train_tables = read_tables(train_table_file)
train_data = read_data(train_data_file, train_tables)

In [5]:
print(len(train_tables))

5013


In [6]:
train_data[0]

Unnamed: 0,影片名称,周票房（万）,票房占比（%）,场均人次
0,死侍2：我爱我家,10637.3,25.8,5.0
1,白蛇：缘起,10503.8,25.4,7.0
2,大黄蜂,6426.6,15.6,6.0
3,密室逃生,5841.4,14.2,6.0
4,“大”人物,3322.9,8.1,5.0
5,家和万事惊,635.2,1.5,25.0
6,钢铁飞龙之奥特曼崛起,595.5,1.4,3.0
7,海王,500.3,1.2,5.0
8,一条狗的回家路,360.0,0.9,4.0
9,掠食城市,356.6,0.9,3.0


In [7]:
train_data[0].table.header[0]

('影片名称', 'text')

In [8]:
str(train_data[0].question)

'二零一九年第四周大黄蜂和密室逃生这两部影片的票房总占比是多少呀'

In [9]:
idx = 333
train_data[idx].sql


In [10]:
train_data[0].table.header

影片名称(text) | 周票房（万）(real) | 票房占比（%）(real) | 场均人次(real)

In [11]:
train_data[0].table.header[0], train_data[0].table.header[1]

(('影片名称', 'text'), ('周票房（万）', 'real'))

In [12]:
len(train_data)

41522

In [13]:
class SqlLabelEncoder:
    """
    Convert SQL object into training labels.
    """
    def encode(self, sql: SQL, num_cols):
        cond_conn_op_label = sql.cond_conn_op
        
        sel_agg_label = np.ones(num_cols, dtype='int32') * len(SQL.agg_sql_dict)
        for col_id, agg_op in zip(sql.sel, sql.agg):
            if col_id < num_cols:
                sel_agg_label[col_id] = agg_op
            
        cond_op_label = np.ones(num_cols, dtype='int32') * len(SQL.op_sql_dict)
        for col_id, cond_op, _ in sql.conds:
            if col_id < num_cols:
                cond_op_label[col_id] = cond_op
            
        return cond_conn_op_label, sel_agg_label, cond_op_label
    
    def decode(self, cond_conn_op_label, sel_agg_label, cond_op_label):
        cond_conn_op = int(cond_conn_op_label)
        sel, agg, conds = [], [], []

        for col_id, (agg_op, cond_op) in enumerate(zip(sel_agg_label, cond_op_label)):
            if agg_op < len(SQL.agg_sql_dict):
                sel.append(col_id)
                agg.append(int(agg_op))
            if cond_op < len(SQL.op_sql_dict):
                conds.append([col_id, int(cond_op)])
        return {
            'sel': sel,
            'agg': agg,
            'cond_conn_op': cond_conn_op,
            'conds': conds
        }

sql_le = SqlLabelEncoder()



In [18]:
idx= 3
cond_conn_op, sel_agg, cond_op = sql_le.encode(train_data[idx].sql, num_cols=len(train_data[idx].table.header))
torch.tensor(cond_conn_op), torch.tensor(sel_agg), torch.tensor(cond_op)

(tensor(1),
 tensor([6, 4, 6, 6, 6, 6, 6, 6], dtype=torch.int32),
 tensor([4, 4, 4, 4, 4, 4, 2, 2], dtype=torch.int32))

In [15]:
import re
import numpy as np
import torch
from transformers import BertTokenizer
from torch.utils.data import Dataset
from utils import SQL

import config

class SqlLabelEncoder:
    """
    Convert SQL object into training labels.
    """
    def encode(self, sql: SQL, num_cols):
        cond_conn_op_label = sql.cond_conn_op
        
        sel_agg_label = np.ones(num_cols, dtype='int32') * len(SQL.agg_sql_dict)
        for col_id, agg_op in zip(sql.sel, sql.agg):
            if col_id < num_cols:
                sel_agg_label[col_id] = agg_op
            
        cond_op_label = np.ones(num_cols, dtype='int32') * len(SQL.op_sql_dict)
        for col_id, cond_op, _ in sql.conds:
            if col_id < num_cols:
                cond_op_label[col_id] = cond_op
            
        return cond_conn_op_label, sel_agg_label, cond_op_label
    
    def decode(self, cond_conn_op_label, sel_agg_label, cond_op_label):
        cond_conn_op = int(cond_conn_op_label)
        sel, agg, conds = [], [], []

        for col_id, (agg_op, cond_op) in enumerate(zip(sel_agg_label, cond_op_label)):
            if agg_op < len(SQL.agg_sql_dict):
                sel.append(col_id)
                agg.append(int(agg_op))
            if cond_op < len(SQL.op_sql_dict):
                conds.append([col_id, int(cond_op)])
        return {
            'sel': sel,
            'agg': agg,
            'cond_conn_op': cond_conn_op,
            'conds': conds
        }

class CustomDataset(Dataset):

    # customized dataset
    # implement ___len___ & __getitem__ function

    def __init__(
        self, 
        data, 
        sql_label_encoder=SqlLabelEncoder(), 
        max_len=config.MAX_LEN, 
        model_name=config.BASE_MODEL_PATH, 
        SEP_temp='|', 
        REAL_temp = '?',
        TEXT_temp = '!',
        cls_token='[CLS]', 
        REAL_token='[unused20]', 
        TEXT_token='[unused21]'
    ):
        self.data = data                                                                     # loaded data
        self.max_len = max_len                                                               # max length of sequence
        self.tokenizer = BertTokenizer.from_pretrained(model_name, cls_token=cls_token)      # cls_token will be replaced by a unused token in bert vocab 
        self.indexes = np.arange(len(self.data))                                             # set a list of indexes according to data length
        self.sql_label_encoder = sql_label_encoder                                           # label encoder for SQL objects

        self.CLS_id = self.tokenizer.encode([cls_token])[1]                                  # CLS's token id 
        self.REAL_id = self.tokenizer.encode([REAL_token])[1]                                # REAL's token id 
        self.TEXT_id = self.tokenizer.encode([TEXT_token])[1]                                # TEXT's token id 
        self.SEP_id = self.tokenizer.encode('[SEP]')[1]                                      # SEP's token id 

        self.SEP_temp = SEP_temp                                                             # temporarily symbol for SEP token
        self.REAL_temp = REAL_temp                                                           # temporarily symbol for REAL token
        self.TEXT_temp = TEXT_temp                                                           # temporarily symbol for TEXT token      

        self.SEP_temp_id = self.tokenizer.encode(SEP_temp)[1]                                # SEP_temp's token id 
        self.REAL_temp_id = self.tokenizer.encode(REAL_temp)[1]                              # REAL_temp's token id 
        self.TEXT_temp_id = self.tokenizer.encode(TEXT_temp)[1]                              # TEXT_temp's token id 


    def __len__(self):  
        return len(self.data) 

    def __getitem__(self, idx): 
        
        question = str(self.data[idx].question)

        # construct header in str
        header_sent = ''
        for header, label in self.data[idx].table.header:
            if label == 'text':
                header_sent += self.TEXT_temp
            if label == 'real':
                header_sent += self.REAL_temp
            header_sent += re.sub(r'[\(\（].*[\)\）]', '', header)   # remove brackets for headers
            header_sent += self.SEP_temp
        header_sent = header_sent[:-1]

        # print(question)
        # print(header_sent)
        # print(self.XLS_id)
        # print(self.SEP_id)
        # print(self.REAL_id)
        # print(self.TEXT_id)

        embeddings = self.tokenizer(
            question, header_sent,                    # sentence 1, sentence 2
            padding='max_length',                     # Pad to max_length
            truncation=True,                          # Truncate to max_length
            max_length=self.max_len,                  # Set max_length
            return_tensors='pt'                       # Return torch.Tensor objects
        )

        token_ids = torch.squeeze(embeddings['input_ids'])                                               # tensor of token ids
        token_ids = torch.where(token_ids==self.SEP_temp_id, torch.tensor(self.SEP_id), token_ids)       # replace SEP_temp_id by SEP_id
        token_ids = torch.where(token_ids==self.REAL_temp_id, torch.tensor(self.REAL_id), token_ids)     # replace REAL_temp_id by REAL_id
        token_ids = torch.where(token_ids==self.TEXT_temp_id, torch.tensor(self.TEXT_id), token_ids)     # replace TEXT_temp_id by TEXT_id

        attention_masks = torch.squeeze(embeddings['attention_mask'])                                    # binary tensor with "0" for padded values and "1" for the other values
        token_type_ids = torch.squeeze(embeddings['token_type_ids'])                                     # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens

        header_ids = [i for i, value in enumerate(token_ids) if value == self.REAL_id or value == self.TEXT_id]                       # list of SEP positions, length: nb of cols + 1

        # True if the dataset has labels (when training or validating or testing)
        if self.sql_label_encoder is not None:       
            COND_CONN_OP, SEL_AGG, COND_OP = self.sql_label_encoder.encode(self.data[idx].sql, num_cols=len(self.data[idx].table.header))             
            # COND_CONN_OP = self.data[idx].sql['cond_conn_op']
            return {
                ### X
                'token_ids': token_ids.to(torch.int32),
                'token_type_ids': token_type_ids.to(torch.int32),
                'attention_masks': attention_masks.to(torch.int32),
                'header_ids': torch.tensor(header_ids, dtype=torch.int32),
                'header_masks': torch.ones(len(header_ids), dtype=torch.int32),
                ### y
                'COND_CONN_OP': torch.tensor(COND_CONN_OP, dtype=torch.int32),
                'SEL_AGG': torch.tensor(SEL_AGG, dtype=torch.int32),
                'COND_OP': torch.tensor(COND_OP, dtype=torch.int32)         
            } 
        # False if the dataset do not have labels (when inferencing)
        else:                                                           
            return {
                'token_ids': token_ids.to(torch.int32),
                'token_type_ids': token_type_ids.to(torch.int32),
                'attention_masks': attention_masks.to(torch.int32),
                'header_ids_ids': torch.tensor(header_ids, dtype=torch.int32),
                'header_ids_masks': torch.ones(len(header_ids), dtype=torch.int32),
            }
            

def collate_fn(batch_data):
    '''
    SEP_ids lengths are different in each batch, need padding
    '''
    batch_data.sort(key=lambda xi: len(xi['header_ids']), reverse=True)

    header_ids_seq = [xi['header_ids'] for xi in batch_data]
    padded_header_ids_seq = torch.nn.utils.rnn.pad_sequence(header_ids_seq, batch_first=True, padding_value=0)

    header_masks_seq = [xi['header_masks'] for xi in batch_data]
    padded_header_masks_seq = torch.nn.utils.rnn.pad_sequence(header_masks_seq, batch_first=True, padding_value=0)

    COND_CONN_OP_seq = [xi['COND_CONN_OP'] for xi in batch_data]

    SEL_AGG_seq = [xi['SEL_AGG'] for xi in batch_data]
    padded_SEL_AGG_seq = torch.nn.utils.rnn.pad_sequence(SEL_AGG_seq, batch_first=True, padding_value=0)

    COND_OP_seq = [xi['COND_OP'] for xi in batch_data]
    padded_COND_OP_seq = torch.nn.utils.rnn.pad_sequence(COND_OP_seq, batch_first=True, padding_value=0)

    token_ids_seq = [xi['token_ids'] for xi in batch_data]
    token_type_ids_seq = [xi['token_type_ids'] for xi in batch_data]
    attention_masks_seq = [xi['attention_masks'] for xi in batch_data]


    return {
        ### X
        'token_ids': torch.stack(token_ids_seq) , 
        'token_type_ids': torch.stack(token_type_ids_seq),
        'attention_masks': torch.stack(attention_masks_seq),
        'header_ids': padded_header_ids_seq,
        'header_masks': padded_header_masks_seq,
        ### y
        'COND_CONN_OP': torch.stack(COND_CONN_OP_seq),
        'COND_OP': padded_SEL_AGG_seq,
        'SEL_AGG': padded_COND_OP_seq
    }


def collate_fn_labelless(batch_data):
    '''
    SEP_ids lengths are different in each batch, need padding
    For no label case
    '''
    batch_data.sort(key=lambda xi: len(xi['header_ids']), reverse=True)
    header_ids_seq = [xi['header_ids'] for xi in batch_data]
    padded_header_ids_seq = torch.nn.utils.rnn.pad_sequence(header_ids_seq, batch_first=True, padding_value=0)

    header_masks_seq = [xi['header_masks'] for xi in batch_data]
    padded_header_masks_seq = torch.nn.utils.rnn.pad_sequence(header_masks_seq, batch_first=True, padding_value=0)

    token_ids_seq = [xi['token_ids'] for xi in batch_data]
    token_type_ids_seq = [xi['token_type_ids'] for xi in batch_data]
    attention_masks_seq = [xi['attention_masks'] for xi in batch_data]

    return {
        ### X
        'token_ids': torch.stack(token_ids_seq) , 
        'token_type_ids': torch.stack(token_type_ids_seq),
        'attention_masks': torch.stack(attention_masks_seq),
        'header_ids': padded_header_ids_seq,
        'header_masks': padded_header_masks_seq,
    }

In [16]:
train_set = CustomDataset(train_data)

In [17]:
train_set[2332]

{'token_ids': tensor([ 101, 2791, 1765,  772, 2458, 1355, 2832, 6598, 1398, 3683, 1872, 6862,
         1469, 6589, 4500, 1398, 3683, 1872, 6862, 6963, 1920,  754,  124,  110,
         4638, 3221, 1525, 1126, 2399,  102,   21, 3198, 7313,  102,   20, 2791,
         1765,  772, 2458, 1355, 2832, 6598, 2130, 2768, 7583,  102,   20, 2832,
         6598, 1398, 3683, 1872, 6862,  102,   20, 1071,  800, 6589, 4500,  102,
           20, 1071, 2124, 6589, 4500, 1872, 6862,  102,   20, 2456, 2128, 2832,
         6598,  102,   20, 6589, 4500, 1398, 3683, 1872, 6862,  102,   21, 3177,
         2339, 7481, 2832, 6598, 2130, 2768, 7583,  102,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0, 

In [23]:
from torch.utils.data import DataLoader

train_set = CustomDataset(train_data)
train_dataloader = DataLoader(
    dataset=train_set, 
    batch_size=1, 
    shuffle=True, 
    num_workers=1,
    pin_memory=True,
    collate_fn=collate_fn
)


In [24]:
for data in train_dataloader:
    print(data)
    break

{'token_ids': tensor([[ 101, 1525,  763, 7770, 6862, 8126, 2399, 4638, 5500, 2622, 4372, 2207,
          754,  127,  110, 8024, 1398, 3198, 1071, 2094, 6121,  689, 1348, 3221,
          784,  720,  102,   21, 3403, 4638,  102,   21, 2094, 6121,  689,  102,
           20, 8109, 5500, 2622, 4372,  102,   20, 8112, 5500, 2622, 4372,  102,
           20, 8119, 5500, 2622, 4372,  102,   20, 8109, 1146, 5273, 3683,  891,
          102,   20, 8112, 1146, 5273, 3683,  891,  102,   20, 8119, 1146, 5273,
         3683,  891,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,