In [1]:
MAX_SOURCE_LENGTH = 1024
MAX_TARGET_LENGTH = 1024
USES_PADDING = "max_length" #False
DO_IGNORE_PAD_TOKEN_FOR_LOSS = True

In [2]:
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
model_name = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_config(config)

In [3]:
def preprocess_function(data):
        input_encodings = tokenizer.batch_encode_plus(data['text'], padding=USES_PADDING, truncation=True)
        target_encodings = tokenizer.batch_encode_plus(data['logical_form'], padding=USES_PADDING, truncation=True)

        labels = target_encodings['input_ids']
        if USES_PADDING == "max_length" and DO_IGNORE_PAD_TOKEN_FOR_LOSS:
            labels = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels
            ]

        encodings = {
            'input_ids': input_encodings['input_ids'],
            'attention_mask': input_encodings['attention_mask'],
            'labels': labels,
        }
        return encodings

In [10]:
EXCLUDE_TYPES = ['designates_at_that', 'is_a_party_of', 'is_a_member_of', 'is_a_small_business_entity_as_informed_by', 'it_is_the_end_of', 'has_paid_itself', 'end_of_the_world']
SATOH_DATAPATH = '../Datasets/pair_synset_by_lesk.txt'

import re
from time import sleep
import torch
from datasets.arrow_dataset import Dataset
from datasets import concatenate_datasets


def get_satoh_data_unseen(val_percentage=0.1):
    with open(SATOH_DATAPATH, 'r') as datastream:
        raw_data = datastream.readlines()
    
    x_data = []
    y_data = []
    for line in raw_data:
        i = 0
        indices = []
        while len(indices) < 4:
            i += line[i:].find("\"") + 1
            indices.append(i)

        x_data.append(line[indices[2]:indices[3] - 1])
        y_data.append(line[indices[0]:indices[1] - 1])
        
    unseen_x = []
    unseen_y = []
    for i, point in enumerate(y_data):
        name = point[:point.find("(")]
        if name in EXCLUDE_TYPES:
            unseen_x.append(x_data[i])
            unseen_y.append(point)
            
    for i, point in enumerate(unseen_x):
        x_data.remove(point)
        y_data.remove(unseen_y[i])
                    
    pre_dataset = {"text": x_data, "logical_form": y_data}
    
    dataset = Dataset.from_dict(pre_dataset)
    unseen_dataset = Dataset.from_dict({"text": unseen_x, "logical_form": unseen_y})
    
    val_size = int(val_percentage * len(dataset))
    test_size = 50
    train_size = len(dataset) - val_size - test_size

    dataset_dict = dataset.train_test_split(train_size=train_size, test_size=(val_size + test_size))
    test_val_dict = dataset_dict["test"].train_test_split(train_size=val_size, test_size=test_size)
    
    dataset_dict["val"] = test_val_dict["train"]
    dataset_dict["test"] = test_val_dict["test"]
    dataset_dict["test"] = unseen_dataset#concatenate_datasets([dataset_dict["test"], unseen_dataset])
    
    return dataset_dict

In [11]:
datasets = get_satoh_data_unseen()
datasets = datasets.map(preprocess_function, batched=True)
columns = ['input_ids', 'labels','attention_mask',] 
datasets.set_format(type='torch', columns=columns)

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [12]:
def batch_predict(dataset, trainer, tokenizer, limit=50):
    results = []
    original_texts = []
    original_logical_forms = []
    predicted_logical_forms = []
    split_dataset = dataset.train_test_split(train_size=1, test_size=(len(dataset) - 1))
    
    for i in range(len(dataset) - 2):
        results.append(trainer.predict(split_dataset["train"]).predictions)
        original_logical_forms.append(split_dataset["train"]["logical_form"][0])
        original_texts.append(split_dataset["train"]["text"][0])
        split_dataset = split_dataset["test"].train_test_split(train_size=1, test_size=(len(dataset) - 2 - i))
        if i == limit - 3:
            break
        
    results.append(trainer.predict(split_dataset["train"]).predictions)
    if split_dataset["test"].num_rows == 1:
        results.append(trainer.predict(split_dataset["test"]).predictions)
    
    original_logical_forms.append(split_dataset["train"]["logical_form"][0])
    original_logical_forms.append(split_dataset["test"]["logical_form"][0])
    
    
    original_texts.append(split_dataset["train"]["text"][0])
    original_texts.append(split_dataset["test"]["text"][0])
    
    for result in results:
        text = []
        for value in result[0][0][1:]:
            _lst = list(value)
            _id = _lst.index(max(_lst))
            chunk = tokenizer.convert_ids_to_tokens(_id)
            if chunk == "</s>":
                break
                
            text.append(chunk)
        
        predicted_logical_forms.append("".join(text))
        
    return original_texts, original_logical_forms, predicted_logical_forms

In [13]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./models/' + model_name + "-legal-unseen",          
    num_train_epochs=5,           
    per_device_train_batch_size=1, 
    per_device_eval_batch_size=1,   
    warmup_steps=100,               
    weight_decay=0.01,              
    logging_dir='./logs',  
    save_steps=5000,
)

In [16]:
our_model = model.from_pretrained("../trained_models/bart-satoh-unseen")
trainer = Trainer(
model=our_model,                       
args=training_args,                  
train_dataset=datasets['train'],        
eval_dataset=datasets['val']   
)

In [17]:
split = datasets["test"].train_test_split(train_size=37, test_size=37)

original_texts1, original_logical_forms1, predicted_logical_forms1 = batch_predict(split["train"], trainer, tokenizer, limit=37)

***** Running Prediction *****
  Num examples = 1
  Batch size = 1


***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size

In [19]:
original_texts2, original_logical_forms2, predicted_logical_forms2 = batch_predict(split["test"], trainer, tokenizer, limit=37)

***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size = 1
***** Running Prediction *****
  Num examples = 1
  Batch size

In [26]:
original_texts = original_texts1 + original_texts2
original_logical_forms = original_logical_forms1 + original_logical_forms2
predicted_logical_forms = predicted_logical_forms1 + predicted_logical_forms2

predicted_logical_forms = [pred.replace("Ġ", " ") for pred in predicted_logical_forms]

In [27]:
import pandas as pd

df = pd.DataFrame(list(zip(original_texts, original_logical_forms, predicted_logical_forms)), columns=['Original Texts', 'Original logical forms', 'Predicted logical forms'])

In [28]:
df.to_excel("predicted_test_dataset.xlsx")