In [1]:
import json
import os

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import T5Tokenizer

from preprocess_utils import get_highlighted_subtable, linearize_subtable

In [2]:
device=torch.device('cuda:2')
batch_size=24 # 10 for 't5-large'

In [3]:
# Pre-Trained T5 Tokenizer
tokenizer=T5Tokenizer.from_pretrained('t5-base')
# Add Special Tokens: Table Tags
tokenizer.add_special_tokens({
    'additional_special_tokens': [
        '<page_title>',
        '</page_title>',
        '<section_title>',
        '</section_title>',
        '<table>',
        '</table>',
        '<cell>',
        '</cell>',
        '<col_header>',
        '</col_header>',
        '<row_header>',
        '</row_header>'
    ]
})

12

In [4]:
class ToTToDataset(Dataset):
    """
    For Evaluation (Dev Set)
    """
    def __init__(self, path_data, tokenizer):
        #
        self.data=[]
        
        # Load Dataset
        with open(path_data, 'r') as f:
            dataset=f.read().splitlines()
            f.close()
            
        for _data in dataset:
            data=json.loads(_data)
            
            # Preprocess
            subtable=get_highlighted_subtable(table=data['table'], cell_indices=data['highlighted_cells'], with_heuristic_headers=True)
            cells_linearized=linearize_subtable(
                subtable=subtable,
                table_page_title=data['table_page_title'],
                table_section_title=data['table_section_title']
            )
            
            # Encode
            encoded=tokenizer.encode(cells_linearized)
            if len(encoded)>512:
                # Truncate
                encoded=encoded[:511]+[tokenizer.eos_token_id]
            self.data.append(encoded)
            
        print(len(self.data), 'datas')
        
    def __getitem__(self, idx):
        return self.data[idx]
        
    def __len__(self):
        return len(self.data)

In [5]:
def collate_fn(batch):
    """
    Same Sequence Length on Same Batch
    """
    max_len_data=0
    for data in batch:
        if len(data)>max_len_data: max_len_data=len(data)
            
    datas=[]
    attn_masks=[]
    for data in batch:
        data.extend([tokenizer.pad_token_id]*(max_len_data-len(data)))
        datas.append(data)
        
        attn_mask=[int(e!=tokenizer.pad_token_id) for e in data]
        attn_masks.append(attn_mask)
        
    return torch.tensor(datas), torch.tensor(attn_masks)

In [6]:
dataset_dev=ToTToDataset(path_data='../totto_data/totto_dev_data.jsonl', tokenizer=tokenizer)
dataloader_dev=DataLoader(dataset_dev, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

Token indices sequence length is longer than the specified maximum sequence length for this model (1344 > 512). Running this sequence through the model will result in indexing errors


7700 datas


In [7]:
# Trained Model
model=torch.load('../model/T5-base_Fine-Tuning_lr0.0001_batch24_epoch9of10.pt')
model=model.to(device)
model.eval()

# Generation
if os.path.exists('../totto_data/generation_dev.txt'):
    os.remove('../totto_data/generation_dev.txt')
f=open('../totto_data/generation_dev.txt', 'a')

with torch.no_grad():
    for idx, (data, attn_mask) in enumerate(dataloader_dev):
        if (idx+1)%100==0: print(batch_size*(idx+1), 'generated')
            
        data=data.to(device)
        attn_mask=attn_mask.to(device)
        
        # Beam Search
        outputs=model.generate(
            data,
            max_length=300,
            num_beams=5,
            early_stopping=True
        )
        
        for generation in tokenizer.batch_decode(outputs, skip_special_tokens=True):
            f.write(generation+'\n')
            
f.close()

2400 generated
4800 generated
7200 generated


In [8]:
# Evaluation
!cd ../language_repo/ && bash language/totto/totto_eval.sh --prediction_path ../totto_data/generation_dev.txt --target_path ../totto_data/totto_dev_data.jsonl

Running with the following variables:
PREDICTION_PATH   : ../totto_data/generation_dev.txt
TARGET_PATH       : ../totto_data/totto_dev_data.jsonl 
BLEURT_CKPT       : unset 
OUTPUT_DIR        : temp
MODE              : test
Writing references.
Writing tables in PARENT format.
Preparing predictions.
Writing predictions.
Running detokenizers.
Computing BLEU (overall)
BLEU+case.mixed+numrefs.3+smooth.exp+tok.13a+version.1.5.1 = 48.8 78.4/55.9/43.0/34.1 (BP = 0.970 ratio = 0.971 hyp_len = 125969 ref_len = 129793)
Computing PARENT (overall)
Evaluated 7700 examples.
Precision = 81.09 Recall = 50.54 F-score = 58.50
Computing BLEU (overlap subset)
BLEU+case.mixed+numrefs.3+smooth.exp+tok.13a+version.1.5.1 = 56.5 82.3/62.9/51.3/42.7 (BP = 0.973 ratio = 0.974 hyp_len = 61204 ref_len = 62867)
Computing PARENT (overlap subset)
Evaluated 3784 examples.
Precision = 83.20 Recall = 54.42 F-score = 62.43
Computing BLEU (non-overlap subset)
BLEU+case.mixed+numrefs.3+smooth.exp+tok.13a+version.1.5.1 = 41