In [1]:
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import T5Tokenizer, T5ForConditionalGeneration

from preprocess_utils import get_highlighted_subtable, linearize_subtable

In [2]:
device=torch.device('cuda:0')
batch_size=8

In [3]:
# Pre-Trained T5 Tokenizer
tokenizer=T5Tokenizer.from_pretrained('t5-large')
# Add Special Tokens: Table Tags
tokenizer.add_special_tokens({
    'additional_special_tokens': [
        '<page_title>',
        '</page_title>',
        '<section_title>',
        '</section_title>',
        '<table>',
        '</table>',
        '<cell>',
        '</cell>',
        '<col_header>',
        '</col_header>',
        '<row_header>',
        '</row_header>'
    ]
})
# Pre-Trained T5 Model
pretrained=T5ForConditionalGeneration.from_pretrained('t5-large').to(device)
# Resize PLM's Embedding Layer
pretrained.resize_token_embeddings(len(tokenizer))

Embedding(32112, 1024)

In [4]:
class ToTToDataset(Dataset):
    """
    For Evaluation (Dev Set)
    """
    def __init__(self, path_data, tokenizer):
        #
        self.data=[]
        
        # Load Dataset
        with open(path_data, 'r') as f:
            dataset=f.read().splitlines()
            f.close()
            
        for _data in dataset:
            data=json.loads(_data)
            
            # Preprocess
            subtable=get_highlighted_subtable(table=data['table'], cell_indices=data['highlighted_cells'], with_heuristic_headers=True)
            cells_linearized=linearize_subtable(
                subtable=subtable,
                table_page_title=data['table_page_title'],
                table_section_title=data['table_section_title']
            )
            
            # Encode
            encoded=tokenizer.encode(cells_linearized)
            if len(encoded)>512:
                # Truncate
                encoded=encoded[:511]+[tokenizer.eos_token_id]
            self.data.append(encoded)
            
        print(len(self.data), 'datas')
        
    def __getitem__(self, idx):
        return self.data[idx]
        
    def __len__(self):
        return len(self.data)

In [5]:
def collate_fn(batch):
    """
    Same Sequence Length on Same Batch
    """
    max_len_data=0
    for data in batch:
        if len(data)>max_len_data: max_len_data=len(data)
            
    datas=[]
    attn_masks=[]
    for data in batch:
        data.extend([tokenizer.pad_token_id]*(max_len_data-len(data)))
        datas.append(data)
        
        attn_mask=[int(e!=tokenizer.pad_token_id) for e in data]
        attn_masks.append(attn_mask)
        
    return torch.tensor(datas), torch.tensor(attn_masks)

In [6]:
dataset_dev=ToTToDataset(path_data='../totto_data/totto_dev_data.jsonl', tokenizer=tokenizer)
dataloader_dev=DataLoader(dataset_dev, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

Token indices sequence length is longer than the specified maximum sequence length for this model (1344 > 512). Running this sequence through the model will result in indexing errors


7700 datas


In [7]:
class PromptTuning(nn.Module):
    """
    """
    def __init__(self, pretrained_config, prompt_len=20, hidden_dim=256):
        super().__init__()
        
        # Config of Pre-Trained LM
        self.pretrained_config=pretrained_config
        
        # torch.tensor([0, 1, 2, .. , prompt_len-1])
        self.pre_prompt=torch.arange(prompt_len)
        # Embedding
        self.embd=nn.Embedding(num_embeddings=prompt_len, embedding_dim=pretrained_config.d_model)
        # Reparameterization
        self.reparam=nn.Sequential(
            nn.Linear(pretrained_config.d_model, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, pretrained_config.d_model)
        )
        
    def forward(self, batch_size, device):
        # Shape: batch_size, prompt_len
        prompt=self.pre_prompt.unsqueeze(0).expand(batch_size, -1).to(device)
        # Shape: batch_size, prompt_len, d_model
        prompt=self.embd(prompt)
        # Shape: batch_size, prompt_len, d_model
        prompt=self.reparam(prompt)
        
        return prompt

In [8]:
# Trained Model
model=torch.load('../model/T5_Prompt-Tuning_on_ToTTo(Subtable)_promptlen20_batch96_epoch3of3_lr0.3.pt')
model=model.to(device)
model.eval()

# Generation
f=open('../totto_data/generation_dev_prompt_tuning.txt', 'a')

with torch.no_grad():
    for idx, (data, attn_mask) in enumerate(dataloader_dev):
        if (idx+1)%100==0: print(batch_size*(idx+1), 'generated')
            
        data=data.to(device)
        attn_mask=attn_mask.to(device)
        
        # Get Prompt
        prompt=model(batch_size=data.shape[0], device=device)
        
        # Beam Search
        outputs=pretrained.generate(
            data,
            max_length=300,
            num_beams=5,
            early_stopping=True,
            prompt=prompt
        )
        
        for generation in tokenizer.batch_decode(outputs, skip_special_tokens=True):
            f.write(generation+'\n')
            
f.close()

800 generated
1600 generated
2400 generated
3200 generated
4000 generated
4800 generated
5600 generated
6400 generated
7200 generated


In [9]:
# Evaluation
!cd ../language_repo/ && bash language/totto/totto_eval.sh --prediction_path ../totto_data/generation_dev_prompt_tuning.txt --target_path ../totto_data/totto_dev_data.jsonl

Running with the following variables:
PREDICTION_PATH   : ../totto_data/generation_dev_prompt_tuning.txt
TARGET_PATH       : ../totto_data/totto_dev_data.jsonl 
BLEURT_CKPT       : unset 
OUTPUT_DIR        : temp
MODE              : test
Writing references.
Writing tables in PARENT format.
Preparing predictions.
Writing predictions.
Running detokenizers.
Computing BLEU (overall)
BLEU+case.mixed+numrefs.3+smooth.exp+tok.13a+version.1.5.1 = 35.5 71.2/44.3/30.6/21.8 (BP = 0.932 ratio = 0.934 hyp_len = 120337 ref_len = 128780)
Computing PARENT (overall)
Evaluated 7700 examples.
Precision = 79.36 Recall = 41.00 F-score = 49.33
Computing BLEU (overlap subset)
BLEU+case.mixed+numrefs.3+smooth.exp+tok.13a+version.1.5.1 = 40.0 74.2/48.7/35.4/26.3 (BP = 0.933 ratio = 0.935 hyp_len = 58729 ref_len = 62783)
Computing PARENT (overlap subset)
Evaluated 3784 examples.
Precision = 80.96 Recall = 43.72 F-score = 52.18
Computing BLEU (non-overlap subset)
BLEU+case.mixed+numrefs.3+smooth.exp+tok.13a+vers