In [1]:
from datasets import load_dataset
import logging
import sys
from tqdm import tqdm

In [2]:
dataset_test = load_dataset("wikitablequestions", trust_remote_code=True)

In [3]:
print(dataset_test)

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answers', 'table'],
        num_rows: 11321
    })
    test: Dataset({
        features: ['id', 'question', 'answers', 'table'],
        num_rows: 4344
    })
    validation: Dataset({
        features: ['id', 'question', 'answers', 'table'],
        num_rows: 2831
    })
})


In [4]:
train, test, eval = dataset_test['train'], dataset_test['test'], dataset_test['validation']

In [5]:
train['table'][0]

{'header': ['Year',
  'Division',
  'League',
  'Regular Season',
  'Playoffs',
  'Open Cup',
  'Avg. Attendance'],
 'rows': [['2001',
   '2',
   'USL A-League',
   '4th, Western',
   'Quarterfinals',
   'Did not qualify',
   '7,169'],
  ['2002',
   '2',
   'USL A-League',
   '2nd, Pacific',
   '1st Round',
   'Did not qualify',
   '6,260'],
  ['2003',
   '2',
   'USL A-League',
   '3rd, Pacific',
   'Did not qualify',
   'Did not qualify',
   '5,871'],
  ['2004',
   '2',
   'USL A-League',
   '1st, Western',
   'Quarterfinals',
   '4th Round',
   '5,628'],
  ['2005',
   '2',
   'USL First Division',
   '5th',
   'Quarterfinals',
   '4th Round',
   '6,028'],
  ['2006',
   '2',
   'USL First Division',
   '11th',
   'Did not qualify',
   '3rd Round',
   '5,575'],
  ['2007',
   '2',
   'USL First Division',
   '2nd',
   'Semifinals',
   '2nd Round',
   '6,851'],
  ['2008',
   '2',
   'USL First Division',
   '11th',
   'Did not qualify',
   '1st Round',
   '8,567'],
  ['2009',
   '2',
  

In [6]:
print(str(train['table'][0]))

{'header': ['Year', 'Division', 'League', 'Regular Season', 'Playoffs', 'Open Cup', 'Avg. Attendance'], 'rows': [['2001', '2', 'USL A-League', '4th, Western', 'Quarterfinals', 'Did not qualify', '7,169'], ['2002', '2', 'USL A-League', '2nd, Pacific', '1st Round', 'Did not qualify', '6,260'], ['2003', '2', 'USL A-League', '3rd, Pacific', 'Did not qualify', 'Did not qualify', '5,871'], ['2004', '2', 'USL A-League', '1st, Western', 'Quarterfinals', '4th Round', '5,628'], ['2005', '2', 'USL First Division', '5th', 'Quarterfinals', '4th Round', '6,028'], ['2006', '2', 'USL First Division', '11th', 'Did not qualify', '3rd Round', '5,575'], ['2007', '2', 'USL First Division', '2nd', 'Semifinals', '2nd Round', '6,851'], ['2008', '2', 'USL First Division', '11th', 'Did not qualify', '1st Round', '8,567'], ['2009', '2', 'USL First Division', '1st', 'Semifinals', '3rd Round', '9,734'], ['2010', '2', 'USSF D-2 Pro League', '3rd, USL (3rd)', 'Quarterfinals', '3rd Round', '10,727']], 'name': 'csv/20

In [79]:
train['question'][0]

'what was the last year where this team was a part of the usl a-league?'

In [80]:
train['answers'][0]

['2004']

In [91]:
def concat_str(table: dict, question: str, answers: list, mode='train') -> str:
    target_str = ''
    target_str += 'Table: | '
    target_str += ' '.join(table['header'])
    target_str += ' | '
    
    target_str += ' | '.join([' '.join(row) for row in table['rows']])
    target_str += ' |'
    target_str += '\tQuestion:'
    target_str += f' {question}\t'
    if mode == 'train':
        target_str += 'Answers: '
        target_str += ' | '
        target_str += ' | '.join(answers)
        target_str += ' |'
    
    return target_str

In [92]:
print(concat_str(train['table'][0], train['question'][0], train['answers'][0]))

Table: | Year Division League Regular Season Playoffs Open Cup Avg. Attendance | 2001 2 USL A-League 4th, Western Quarterfinals Did not qualify 7,169 | 2002 2 USL A-League 2nd, Pacific 1st Round Did not qualify 6,260 | 2003 2 USL A-League 3rd, Pacific Did not qualify Did not qualify 5,871 | 2004 2 USL A-League 1st, Western Quarterfinals 4th Round 5,628 | 2005 2 USL First Division 5th Quarterfinals 4th Round 6,028 | 2006 2 USL First Division 11th Did not qualify 3rd Round 5,575 | 2007 2 USL First Division 2nd Semifinals 2nd Round 6,851 | 2008 2 USL First Division 11th Did not qualify 1st Round 8,567 | 2009 2 USL First Division 1st Semifinals 3rd Round 9,734 | 2010 2 USSF D-2 Pro League 3rd, USL (3rd) Quarterfinals 3rd Round 10,727 |	Question: what was the last year where this team was a part of the usl a-league?	Answers:  | 2004 |


In [103]:
SAVE_EPOCH = 10
PATH_TO_train = 'train_str.txt'
PATH_TO_test = 'test_str.txt'
PATH_TO_eval = 'eval_str.txt'

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger('parse data')

def dataset_parser(dataset, n) -> None:
    train, test, eval = dataset['train'], dataset['test'], dataset['validation']
    train_string = ''
    test_string = ''
    eval_string = ''
    
    logger.info('Begin to parse train data')
    for i in tqdm(range(train.num_rows)):
        train_string += concat_str(train['table'][i], train['question'][i], train['answers'][i])
        train_string += '\n'

        
        if i % SAVE_EPOCH == 0:
            with open(f'{n}_{PATH_TO_train}', 'a', encoding='utf-8') as train_file:
                train_file.write(train_string)
            
            train_string = ''
    
    with open(f'{n}_{PATH_TO_train}', 'a', encoding='utf-8') as train_file:
        train_file.write(train_string)
    logger.info('End to parse train data')
    
    logger.info('Begin to parse test data')
    for i in tqdm(range(test.num_rows)):
        test_string += concat_str(test['table'][i], test['question'][i], test['answers'][i], 'test')
        test_string += '\n'
        
        if i % SAVE_EPOCH == 0:
            with open(f'{n}_{PATH_TO_test}', 'a', encoding='utf-8') as test_file:
                test_file.write(test_string)
                
            test_string = ''

    with open(f'{n}_{PATH_TO_test}', 'a', encoding='utf-8') as test_file:
        test_file.write(test_string)
    logger.info('End to parse test data')
            
    logger.info('Begin to parse eval data')
    for i in tqdm(range(eval.num_rows)):
        eval_string += concat_str(eval['table'][i], eval['question'][i], eval['answers'][i], 'eval')
        eval_string += '\n'
        
        if i % SAVE_EPOCH == 0:
            with open(f'{n}_{PATH_TO_eval}', 'a', encoding='utf-8') as eval_file:
                eval_file.write(eval_string)
                
            eval_string = ''
            
    with open(f'{n}_{PATH_TO_eval}', 'a', encoding='utf-8') as eval_file:
        eval_file.write(eval_string)
    logger.info('End to parse eval data')

In [104]:
dataset = load_dataset("wikitablequestions", "random-split-1", trust_remote_code=True)

In [107]:
dataset_parser(dataset, 1)

INFO:parse data:Begin to parse train data


  1%|▏         | 164/11321 [06:10<7:00:02,  2.26s/it] 


KeyboardInterrupt: 