In [None]:
# Imports
import os
import re
import sys
import csv
import json
import torch
import random
import numpy as np
import pandas as pd
from types import SimpleNamespace
from tqdm.notebook import tqdm

# Importing DL libs
import torch
import torch.nn.functional as F
from torch import cuda
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, Adafactor
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
from transformers import Trainer as TransformersTrainer
from transformers import TrainingArguments as TransformersTrainingArguments
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util

# Importing entailment bank code
from entailment_baseline import CustomDataset, Trainer, EntailmentARCDataset, SemanticSearch #, PrefixConstrainedGenerator
from entailment_bank.eval import run_scorer
from retrieval_utils import convert_datapoint_sent_to_uuid
from retrieval_utils import convert_datapoint_to_sent_to_text, sent_text_as_counter
import base_utils

In [None]:
params = SimpleNamespace(
    # options: "arc_entail" (entailment bank)
    dataset_name = 'arc_entail',
    # "baseline", "iter" (iterative)
    approach_name = 'iter',
    # options: "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"
    model_name = 't5-small',
    # options: "task_1", "task_2", "task_3"
    task_name = 'task_2',
    # use test instead of dev data to evaluate model
    use_test_data = True,
    results_file_path = 'results/{dataset_name}/test_{approach_name}_{model_name}_{task_name}{suffix}.{extension}',
    logs_file_path = 'logs/{dataset_name}/{prefix}{approach_name}_{model_name}_{task_name}_logs{suffix}.txt',
    model_file_path = '../data/{dataset_name}/models/fine_tuned_{approach_name}_{model_name}_{task_name}{suffix}.pth',
    device = 'cuda' if cuda.is_available() else 'cpu',
    # Save model with lowest score, can be either: None, 'loss', 'acc'
    save_min_val_error_type = 'acc',
    wt_corpus_file_path = '../data/arc_entail/supporting_data/worldtree_corpus_sentences_extended.json',    
    max_number_steps = 8,
    use_data_step_aug = False,
    ############################
    # retrieval params #########
    use_custom_retrieval = True,
    max_retrieved_sentences = 25,
    # full list of sentence transformers: https://www.sbert.net/docs/pretrained_models.html
    sent_trans_name = 'all-mpnet-base-v2',
    # sent_trans_encoder_model_path = None, # Don't use custom model
    sent_trans_encoder_model_path = '../data/arc_entail/models/%s_fine_tuned_v6',
)

config = SimpleNamespace(
    TRAIN_EPOCHS = 15,       # number of epochs to train (default: 10)
    LEARNING_RATE = 5e-5,    # learning rate (default: 0.01)
    TRAIN_BATCH_SIZE = 4,    # input batch size for training (default: 64)
    VALID_BATCH_SIZE = 4,    # input batch size for testing (default: 1000)
    SEED = 40,               # random seed (default: 42)
    MAX_LEN = 512,
    SUMMARY_LEN = 256
)

In [None]:
# class TrainerIterative(Trainer):
class TrainerIterative(Trainer):
    
    def __init__(self, *args, **kwargs):         
        super().__init__(*args, **kwargs)
    
    def get_next_step_source_from_search(self, semantic_search, source, 
                                         pred, sents_used = []):
        # in this case the current source (context) will not be modified, but
        # method will be overwritten by Trainers that retrieve a new context
        # for each step
        return source
    
    def create_new_context_from_retrieval(self, context, sents_used_lst):
        new_context = []
        print('combining context and sents used...')
        print('len(context) =', len(context), ' len(sents_used_lst) =', len(sents_used_lst))
        context_iter = 0
        for context_item, sents_used in zip(context, sents_used_lst):
            first_sent = re.search("sent[0-9]+: ", context_item)
            first_sent_pos = None
            if first_sent:
                first_sent_pos = first_sent.span()[0]
            new_sents = ' '.join(sents_used)
            hypothesis = context_item[:first_sent_pos].strip()
            new_context_item = hypothesis
            if len(sents_used) > 0:
                new_context_item += ' ' + new_sents
            new_context.append(new_context_item)
            if context_iter < 5:
                print('first_sent', first_sent)
                print('first_sent_pos', first_sent_pos)
                print('hypothesis', hypothesis)
                print('OLD CONTEXT:', context_item)
                print('NEW CONTEXT:', new_context_item)
            context_iter += 1
        return new_context
    
    def predict_full_proof(self, loader, generation_args = None, verbose=True, 
                           prefix_constrained_generator = None, semantic_search = None):
        '''
        Predicts proof iteratively. 
        
        For each step call model to predinet proof next step, then update the context 
        for the following steps.
        '''
        
        step_source = None; step_target = None
        step_params = {
            'batch_size': self.config.VALID_BATCH_SIZE,
            'shuffle': False,
            'num_workers': 0
        }
        step_loader = loader
        dataset_len = len(loader.dataset)
        context = []; predictions = [''] * dataset_len; actuals = []
        unfinished_proof_idx = list(range(dataset_len))
        sents_used_lst = [[] for _ in range(dataset_len)]
        
        # iterate over multiple proof steps
        for step_it in range(self.params.max_number_steps):            
            
            # update source text for prefix constrained generation
            if prefix_constrained_generator is not None:
                prefix_constrained_generator.set_source_text(loader.dataset.source_text)
            
            # predict following entailment step
            step_predictions, step_actuals, step_context = self.predict(
                step_loader, generation_args)
            
            if verbose:
                print(f'RUNNING STEP: {step_it}')
                print(f'LEN unfinished_proof_idx: {len(unfinished_proof_idx)}')
                print(f'LEN step_predictions: {len(step_predictions)}')
                print(f'LEN step_context: {len(step_context)}')
                print(f'LEN predictions: {len(predictions)}')
            
            if step_it == 0:
                # save unmodified context and target
                actuals = step_actuals
                context = step_context
            step_source = []; step_target = []
            for pred_it in range(len(step_predictions)):                    
                pred = step_predictions[pred_it]
                # remove prefix of prediction
                first_pos = re.search('sent|int', pred)
                if first_pos is not None:
                    old_pred = pred
                    pred = pred[first_pos.start():]
                    if verbose:
                        print(f'pred prefix removal: "{old_pred}" => "{pred}"')
                        print('---------')
                    
                # update existing prediction
                pred_idx = unfinished_proof_idx[pred_it]                       
                predictions[pred_idx] += (' ' if step_it > 0 else '') + pred
                
                if verbose:
                    print(step_context[pred_it])
                    print(step_it, pred_it, pred_idx, pred)
                    print('FULL PRED:', predictions[pred_idx])
                    print('---------')
                
                entailment = None
                if pred.count('->') == 1:
                    _, entailment = pred.split('->')
                # test if there is no more premisses (forces stop on following step in such case)
                no_more_sent = re.search("(sent)[0-9]+(: )", step_context[pred_it]) is None
                
                # create next source for step                
                new_end_context = pred if pred[-1] != ';' else pred[:-1]
                source = '%s, %s' % (step_context[pred_it].rstrip(), new_end_context)
                if pred.count('->') == 1:
                    premises, entailment = pred.split('->')                    
                    new_context, removed_sentences = EntailmentARCDatasetIterative.remove_premises_from_context(
                        source, premises, return_removed_sents = True)
                    source = new_context
                    if verbose:
                        print(f'removed_sentences: {" || ".join(removed_sentences)}')
                    sents_used_lst[pred_idx].extend(removed_sentences)                
                
                if entailment is None or 'hypothesis' in entailment or no_more_sent:
                    # reached end of proof, remove from unfinished_proof_idx
                    unfinished_proof_idx[pred_it] = None                    
                else:
                    # did not reached end of proof, continue producing more steps
                    if semantic_search is not None:
                        old_source = source
                        source = self.get_next_step_source_from_search(
                            semantic_search, source, pred, 
                            sents_used_lst[pred_idx])
                        if verbose:
                            print('source change:')
                            print(old_source)
                            print('***********')
                            print(source)               
                
                    target = step_actuals[pred_it]
                    step_source.append(source)
                    step_target.append(target)
            
                    if verbose:
                        print()
            
            if len(step_source) == 0:
                # produced the full proof for all inputs
                break
            
            # remove indexes of finished proofs
            unfinished_proof_idx = [i for i in unfinished_proof_idx if i is not None]
            
            # create new loader for next entailment step        
            step_set = CustomDataset(step_source, step_target, self.tokenizer, 
                                     self.config.MAX_LEN, self.config.SUMMARY_LEN)
            step_loader = DataLoader(step_set, **step_params)
                    
        if semantic_search is not None:
            # if using semantic search, then needs to update the context since it 
            # changed for every step with retrived results
            context = self.create_new_context_from_retrieval(context, sents_used_lst)
        
        return predictions, actuals, context
    
class TrainerIterativeMultiStepRet(TrainerIterative):
    '''
    Iterative trainer that retrieves a new set of context sentences (possible new leaves) for each step
    in the entailment process.
    '''
    
    def get_next_step_source_from_search(self, semantic_search, source, 
                                         pred, sents_used = []):
        '''
        semantic_search: Semantic search object
        source: source (context) that needs to be changed
        pred: the current step's prediction
        sents_used: list of sentences used in proof so far (e.g. 'sent1: sky is blue')
        '''
        
        # in this case the current source (context) will not be modified, but
        # method will be overwritten by Trainers that retrieve a new context
        # for each step
        int_text = None
        hypot_text = None  
        first_sent_pos = None
        new_source = source
        sents_used_txt = ' '.join(sents_used).lower()
        
        if pred.count('->') == 1:
            premises, entailment = pred.split('->')
            int_node = re.findall('int([0-9]+): ', entailment)
            if len(int_node) == 1:
                int_text = entailment[len(int_node[0]):].strip(' ;.,')
                
        sent_matches = list(re.finditer("(sent)[0-9]+: ", source))
        if len(sent_matches) > 0:
            first_sent_pos = sent_matches[0].span()[0]
            hypot_text = source[:first_sent_pos].strip(' ;.,')
        
        if int_text and int_text != '' and hypot_text and hypot_text != '':
            query = '; '.join([int_text, hypot_text])   
            
            # perform regular search for iteration step (not the conditioned search)
            top_results = super(semantic_search.__class__, semantic_search).search(
                [query], top_k = 2 * self.params.max_retrieved_sentences)
            # make sure search results were not already used in proof
            top_results = [top for top in top_results[0]
                           if top not in sents_used_txt]
            
            # remove all sents from context, add equivalent number of
            # sents from top returned results
            match_implication = re.search(",(sent|[0-9]+| |int|&)+ -> ", source)
            implication_text = ''            
            if match_implication is not None:
                first_implication_pos = match_implication.span()[0]
                implication_text = source[first_implication_pos:].strip()
            
            # when creating new context, avoid symbol numbers used in 
            # previous proof steps
            new_sents_symbs = [f'sent{sent_id}' for sent_id in range(1,50)
                               if f'sent{sent_id}:' not in sents_used_txt]
            new_sents_symbs = new_sents_symbs[:len(sent_matches)]
            new_sents_lst  = top_results[:len(sent_matches)]
            new_sents_text = ' '.join([f'{ns_symb}: {ns_text}'
                                       for ns_symb, ns_text in zip(new_sents_symbs, 
                                                                   new_sents_lst)])

            new_source = hypot_text + ', '
            new_source += new_sents_text
            new_source += implication_text                
        return new_source    

In [None]:
class EntailmentARCDatasetIterative(EntailmentARCDataset):
    
    @staticmethod
    def remove_premises_from_context(context, premises, return_removed_sents = False):
        removed_sentences = []
        new_context = context
        for premise in premises.split('&'):
            premise = premise.strip()
            if re.search('sent[0-9]+', premise):
                matches = list(re.finditer("(sent)[0-9]+: ", context))
                for match_idx, match in enumerate(matches):
                    if f'{premise}: ' in match.group():
                        # find start and end positions of sentence 
                        # to be removed
                        s_s = match.span()[0]
                        if match_idx + 1 < len(matches):
                            s_e = matches[match_idx+1].span()[0]
                        else:
                            match_implication = re.search(",(sent|[0-9]+| |int|&)+ -> ", context)
                            if match_implication is not None:
                                s_e = match_implication.span()[0]
                            else:
                                s_e = len(context)
                        sentence = context[s_s:s_e]
                        new_context = new_context.replace(sentence, '')
                        removed_sentences.append(sentence.strip(' ,;.'))
        new_context = new_context.replace(',,', ',')                
        if return_removed_sents:
            return new_context, removed_sentences
        return new_context
    
    def get_source_text_step(self, data_points):
        '''
        creates source text, breaking down proof step by step,
        and adding steps to working memory (context)
        '''        
        source_text = []
        # if self.semantic_search is not None:
        if False:
            new_contexts = self.get_contexts_from_search(data_points)
        for dp_it, data_point in enumerate(data_points):
            # if self.semantic_search is not None:
            if False:
                context = new_contexts[dp_it]
                if dp_it == 0:
                    print('OLD CONTEXT = ', data_point['context'])
                    print('NEW CONTEXT = ', new_contexts[dp_it])
            else:
                context = data_point['context']
            hypothesis = data_point['hypothesis']
            proof = data_point['proof']
            
            for step in proof.split(';')[:-1]:
                step = step.strip()
                source_text.append(
                    'hypothesis: %s, %s' % (hypothesis, context))
                premises, entailment = step.split('->')
                if 'hypothesis' not in entailment:
                    
                    context = context.rstrip() + ', ' + step
                    context = self.remove_premises_from_context(
                        context, premises)
        return source_text

    def get_target_text_step(self, data_points):
        '''
        creates target text, breaking down proof step by step
        '''
        source_text = []
        for data_point in data_points:
            proof = data_point['proof']
            
            for step in proof.split(';')[:-1]:
                step = step.strip()
                source_text.append('$step$ = %s;' % (step,))
        return source_text
    
    def get_source_text_proof(self, data_points):
        '''
        creates source text, full proof input
        '''
        source_text = []
        # if self.semantic_search is not None:
        if False:
            print('creating new_contexts')
            new_contexts = self.get_contexts_from_search(data_points)
        for dp_it, data_point in enumerate(data_points):
            # if self.semantic_search is not None:
            if False:
                context = new_contexts[dp_it]
                if dp_it == 0:
                    print('OLD CONTEXT = ', data_point['context'])
                    print('NEW CONTEXT = ', new_contexts[dp_it])
            else:
                context = data_point['context'] 
            hypothesis = data_point['hypothesis']
            source_text.append(
                'hypothesis: %s, %s' % (hypothesis, context))                
        return source_text

    def get_target_text_proof(self, data_points):
        '''
        creates target text, full proof output
        '''
        source_text = []
        for data_point in data_points:
            proof = data_point['proof']            
            source_text.append('$proof$ = %s;' % (proof,))
#             source_text.append('%s;' % (proof,))
        return source_text
    
    def get_step_augmented_dataset(self, data_points):
        '''
        Modify original entailment data by automatically creating augumented data points.
        
        Will treat every intermediate step as a possible new data point.
        '''
        new_data_points = []
        for data_point in data_points:
            # keep existing datapoint in augmented dataset
            new_data_points.append(data_point)
            proof = data_point['proof']
            hypothesis = data_point['hypothesis']
            steps = [s.strip() for s in proof.split(';')]
            for step_it, step in enumerate(steps[:-2]):
                new_data_point = dict(data_point)                
                antecedent, consequent = step.split(' -> ')
                new_hypothesis = consequent.split(': ')[1]                
                new_data_point['hypothesis'] = new_hypothesis                
                new_steps = steps[:step_it] + [f'{antecedent} -> hypothesis']
                new_proof = ' '.join([s + ';' for s in new_steps])
                # print('old_proof:', proof)
                # print('new_proof:', new_proof)
                # print('new_hypothesis:', new_hypothesis)
                # print('----')
                new_data_point['proof'] = new_proof
                new_data_points.append(new_data_point)
        return new_data_points
    
    def get_torch_dataloaders(self, task_name, tokenizer):
        '''
        Creation of Dataset and Dataloader for a certain entailment task.
        '''
        # Creating the Training and Validation dataset for further creation of Dataloader        
        
        self.train_data = self.data[task_name]['train']
        self.dev_data = self.data[task_name]['dev']
        self.test_data = self.data[task_name]['test']
        
        if task_name == 'task_3':
            # task_3 has no proof data, so use the training data from task_1
            self.train_data = self.data['task_1']['train']
            if self.semantic_search is not None:                
                # updates context and proofs using custom search results
                self.dev_data = self.data['task_1']['dev']
                self.train_data = self.update_dataset_with_search(
                    self.train_data, include_existing_context = True
                )
                self.dev_data = self.update_dataset_with_search(
                    self.dev_data, include_existing_context = True
                )
                self.test_data = self.update_dataset_with_search(
                    self.test_data, include_existing_context = False
                )
            
        if self.params.use_data_step_aug:
            self.train_data = self.get_step_augmented_dataset(self.train_data)
        
        train_source_text = self.get_source_text_step(self.train_data)
        train_target_text = self.get_target_text_step(self.train_data)
        training_set = CustomDataset(train_source_text, train_target_text, 
                                     tokenizer, self.config.MAX_LEN, self.config.SUMMARY_LEN)        
        
        dev_source_text = self.get_source_text_step(self.dev_data)
        dev_target_text = self.get_target_text_step(self.dev_data)
        val_set = CustomDataset(dev_source_text, dev_target_text, 
                                tokenizer, self.config.MAX_LEN, self.config.SUMMARY_LEN)
        
        dev_proof_source_text = self.get_source_text_proof(self.dev_data)
        dev_proof_target_text = self.get_target_text_proof(self.dev_data)        
        val_proof_set = CustomDataset(dev_proof_source_text, dev_proof_target_text, 
                                tokenizer, self.config.MAX_LEN, self.config.SUMMARY_LEN)
        
        test_proof_source_text = self.get_source_text_proof(self.test_data)
        test_proof_target_text = self.get_target_text_proof(self.test_data)        
        test_proof_set = CustomDataset(test_proof_source_text, test_proof_target_text, 
                                tokenizer, self.config.MAX_LEN, self.config.SUMMARY_LEN)
        
        # Defining the parameters for creation of dataloaders
        train_params = {
            'batch_size': self.config.TRAIN_BATCH_SIZE,            
            'shuffle': True,
            'num_workers': 0
            }

        val_params = {
            'batch_size': self.config.VALID_BATCH_SIZE,
            'shuffle': False,
            'num_workers': 0
            }

        # Creation of Dataloaders for testing and validation. 
        # This will be used down for training and validation stage for the model.
        training_loader = DataLoader(training_set, **train_params)
        val_loader = DataLoader(val_set, **val_params)
        val_proof_loader = DataLoader(val_proof_set, **val_params)
        test_proof_loader = DataLoader(test_proof_set, **val_params)
        
        return training_loader, val_loader, val_proof_loader, test_proof_loader

In [None]:
class PrefixConstrainedGenerator:
    '''
    Constraints the beam search to allowed tokens only at each step. 
    Enforces entailmnet dataset expected format (important for evaluation code)
    '''
    
    def __init__(self, tokenizer, source_text, batch_size):
        # tokenzier for encoding the text
        self.tokenizer = tokenizer
        self.source_text = source_text
        self.batch_size = batch_size
        self.batch_num = 0
    
    def get_first_token_id(self, text):
        toks = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
        return toks[0]

    def get_last_token_id(self, text):
        toks = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
        return toks[-1]

    def set_batch_number(self, batch_num):
        self.batch_num = batch_num
        
    def set_source_text(self, source_text):
        self.source_text = source_text
    
    def prefix_allowed_tokens_fn(self, batch_id, inputs_ids):
        '''
        Constrain the next token for beam search depending on currently generated prefix (input_ids)
        The output is loosely formated according to dataset specification.
        '''
        # print(inputs_ids, batch_id)
        prefix = self.tokenizer.decode(inputs_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        # print(prefix)
        if prefix.strip() == '':
            return [self.get_first_token_id('sent')]
        if prefix.endswith(' & ') or prefix.endswith(' ; '):
            return [self.get_first_token_id('sent'), self.get_first_token_id('int')]
        if prefix.endswith('sent') or prefix.endswith('int'):
            return [self.get_last_token_id('sent' + str(num)) for num in range(10)]
        if prefix.endswith(' -> '):
            return [self.get_first_token_id('hypothesis'), self.get_first_token_id('int')]
        return list(range(self.tokenizer.vocab_size))
    
    def iterative_prefix_allowed_tokens_fn(self, batch_id, inputs_ids):
        '''
        Constrain the next token for beam search depending on currently generated prefix (input_ids)
        The output is loosely formated according to dataset specification.
        '''
        prefix = self.tokenizer.decode(inputs_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        source_idx = self.batch_size * self.batch_num + batch_id        
        source_text = self.source_text[source_idx]
        available_sent_nums = [source_text[match.span()[0] + len('sent'): match.span()[1]-1]
                               for match in re.finditer("(sent)[0-9]+:", source_text)]
        avaliable_int_nums = [source_text[match.span()[0] + len('int'): match.span()[1]-1]
                               for match in re.finditer("(int)[0-9]+:", source_text)]
        
        # print('PCG = ', batch_id, prefix)
#         if len(prefix.strip()) == 0:
#             print('batch_num =', self.batch_num)
#             print('source_idx =', source_idx)
#             print('source_text =', self.source_text[source_idx])
#             print('available_sent_nums', available_sent_nums)
#             print('avaliable_int_nums', avaliable_int_nums)
        
        if prefix.strip() == 'in':
            return [self.get_last_token_id('int')]
        if prefix.strip() == '' or prefix.endswith(' & ') or prefix.endswith(' ; '):
            return [self.get_first_token_id('sent'), self.get_first_token_id('int')]
        if prefix.endswith('sent'):
            return list(set([self.get_last_token_id('sent' + num) for num in available_sent_nums]))
            # return list(set([self.get_last_token_id('sent' + str(num)) for num in range(10)]))
        if prefix.endswith('int') and not prefix.endswith('-> int'):
            return list(set([self.get_last_token_id('int' + num) for num in avaliable_int_nums]))
            # return list(set([self.get_last_token_id('int' + str(num)) for num in range(10)]))
        if prefix.endswith(' -> '):
            return [self.get_first_token_id('hypothesis'), self.get_first_token_id('int')]
        if not ' -> ' in prefix:
            all_toks = list(range(self.tokenizer.vocab_size))
            all_toks.remove(self.get_last_token_id('int1:'))
            return all_toks
        return list(range(self.tokenizer.vocab_size))

In [None]:
# %%writefile -a entailment_baseline.py

class IterativeSemanticSearch(SemanticSearch):
    '''
    Search that not only uses query but also partial results as probes.  
    '''
    
    ################################################
    
    def counter_jaccard_similarity(self, c1, c2):
        inter = c1 & c2
        union = c1 | c2
        return sum(inter.values()) / float(sum(union.values()))

    def construct_sent_context_mapping(self, matching_texts, sent_to_text, 
                                       new_mapping_key, matching_uuids = None):
        alignment = []
        for match_it, matching_text in enumerate(matching_texts):
            text_counter = sent_text_as_counter(matching_text)
            matching_uuid = None
            if matching_uuids is not None:
                matching_uuid = matching_uuids[match_it]
            for sent_k, sent_v in sent_to_text.items():
                alignment.append((sent_k, sent_v['text'], matching_text, match_it, matching_uuid,
                                  self.counter_jaccard_similarity(sent_v['text_counter'], text_counter)))
        sorted_alignment = sorted(alignment, key= lambda x: x[-1], reverse=True)
        matches_it_used = []
        for align_item in sorted_alignment:
            sent_key = align_item[0]
            matching_text = align_item[2]
            match_it = align_item[3]
            matching_uuid = align_item[4]
            if new_mapping_key not in sent_to_text[sent_key].keys():
                if match_it not in matches_it_used:                
                    sent_to_text[sent_key][new_mapping_key] = matching_text
                    if matching_uuid is not None:
                        sent_to_text[sent_key][new_mapping_key + '_uuid'] = matching_uuid
                    matches_it_used.append(match_it)

        assert all([new_mapping_key in v for v in sent_to_text.values()])
        return sent_to_text

    def construct_datapoint_context_mapping(self, datapoint):
        sent_to_text = convert_datapoint_to_sent_to_text(datapoint)
        triples = datapoint['meta']['triples'].values()
        assert len(sent_to_text.keys()) == len(triples)    
        sent_to_text = self.construct_sent_context_mapping(
            triples, sent_to_text, new_mapping_key = 'triple_text')

        wt_p_items = [wt_p_item['original_text'] 
                      for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]
        wt_p_uuids = [wt_p_item['uuid'] 
                      for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]
        assert len(sent_to_text.keys()) == len(wt_p_items)
        sent_to_text = self.construct_sent_context_mapping(
            wt_p_items, sent_to_text, new_mapping_key = 'wt_p_text', matching_uuids = wt_p_uuids)

        for sent_to_text_v in sent_to_text.values():
            del sent_to_text_v['text_counter']
        return sent_to_text

    def create_context_mapping(self, dataset, verbose = False):
        context_mapping = []
        wt_corpus = self.load_wt_corpus_file()

        for datapoint in dataset:
            datapoint_context_mapping = self.construct_datapoint_context_mapping(datapoint)
            context_mapping.append(datapoint_context_mapping)

            for k, v in datapoint_context_mapping.items():            
                if 'wt_p_text_uuid' in v and v['wt_p_text_uuid'] in wt_corpus.keys():
                    datapoint_context_mapping[k]['wt_corpus_text'] = wt_corpus[v['wt_p_text_uuid']]

            if verbose:
                for k, v in datapoint_context_mapping.items():            
                    for item_k, item_v in v.items():
                        print(item_k, '=', item_v)
                    print()
                print('======')
        return context_mapping
    
    def load_wt_corpus_with_dataset_context(self, dataset):
        context_mapping = self.create_context_mapping(dataset)
        removal_sents = [v for cm in context_mapping for p in cm.values() 
                         for k,v in p.items() if k != 'text']
        include_sents = [v for cm in context_mapping for p in cm.values() 
                         for k,v in p.items() if k == 'text']
        print(removal_sents[:20])
        wt_corpus = self.load_wt_corpus_file()
        corpus = list(set(list(wt_corpus.values())) - set(removal_sents))
        corpus.extend(include_sents)
        self.update_corpus_embeddings(corpus)
        print('corpus size = ', len(self.corpus))
    
    ################################################
    
    def load_wt_corpus(self, extra_facts = None):
        wt_corpus = self.load_wt_corpus_file()
        corpus = list(wt_corpus.values())
        if extra_facts is not None:
            corpus.extend(extra_facts)
            corpus = list(set(corpus))
        self.update_corpus_embeddings(corpus)
    
    def search(self, dataset, top_k = 25, keep_top_from_hyp = 15):
        elements = []
        for _ in dataset:
            elements.append([])
        
        probes = [d['question'] + ' ' + d['answer'] for d in dataset]
        for k_step in range(1, top_k - keep_top_from_hyp + 1):
            temp_elements = super().search(probes, top_k = top_k * 2)
            for ret_it, rets in enumerate(temp_elements):                
                for ret in rets:
                    if ret not in elements[ret_it]:
                        elements[ret_it].append(ret)                           
                        probes[ret_it] += ' ' + ret
                        break

        # now gather "keep_top_from_hyp" by using only hypothesis as probe
        probes = [d['hypothesis'] for d in dataset]
        temp_elements = super().search(
            probes, top_k = top_k * 3)
        for ret_it, rets in enumerate(temp_elements):
            for ret in rets:
                if ret not in elements[ret_it]:
                    elements[ret_it].append(ret)
                    if len(elements[ret_it]) == top_k:
                        break

        assert all([len(x) == top_k for x in elements])     
        return elements   

In [None]:
class EntailmentEvaluator():
    
    def run_prediction(self, params, proof_loader, use_prefix_const_generator = False, 
                       verbose = False, temp_results = False, semantic_search = None):
        
        generation_args = {
            'max_length': config.SUMMARY_LEN,
            'num_beams': 5,
            'repetition_penalty': 2.5,
            'length_penalty': 1.0,
            'early_stopping': True,
        }
        """
        if params.branching_factor_beta and params.branching_factor_beta > 1:
            generation_args['num_beams'] = params.branching_factor_beta
            generation_args['num_return_sequences'] = params.branching_factor_beta
        """

        prefix_constrained_generator = None
        if use_prefix_const_generator:
            prefix_constrained_generator = PrefixConstrainedGenerator(
                trainer.tokenizer, val_loader.dataset.source_text,
                config.VALID_BATCH_SIZE)        
            generation_args['prefix_allowed_tokens_fn'] = prefix_constrained_generator.prefix_allowed_tokens_fn        

        predictions, actuals, context = trainer.predict_full_proof(
            proof_loader, generation_args, verbose = verbose, 
            prefix_constrained_generator = prefix_constrained_generator,
            semantic_search = semantic_search)
        # predictions, actuals, context = trainer.predict(val_loader, generation_args)

        if temp_results:
            predictions_proof = [p if '$proof$ = ' in p else '$proof$ = ' + p for p in predictions]
            results_only_df = pd.DataFrame(predictions_proof)
            
            prediction_file = base_utils.get_results_file_path(params, result_only=True, temp=True)
            print(f'writing results to file {prediction_file}')
            results_only_df.to_csv(prediction_file,  sep='\t', header=False, index=False)
        else:        
            # Write file including input and expected output
            final_df = pd.DataFrame({'Input': context,'Generated Text': predictions, 'Actual Text': actuals})
            
            output_path = base_utils.get_results_file_path(params, test_split=params.use_test_data)
            print(f'writing full results to file {output_path}')
            final_df.to_csv(output_path)

            # Write file using entailment_bank evaluation format
            predictions_proof = [p if '$proof$ = ' in p else '$proof$ = ' + p for p in predictions]
            results_only_df = pd.DataFrame(predictions_proof)
    
            output_path = base_utils.get_results_file_path(
                params, test_split=params.use_test_data, result_only=True)
            print(f'writing results only to file {output_path}')
            results_only_df.to_csv(output_path,  sep='\t', header=False, index=False)

        print('Output Files generated for review')

        return predictions, actuals, context

### Update Task-3 results

In [None]:
def update_task_3_results(use_test_data = None, results_path = None, 
                          uuid_results_path = None, verbose = True,
                          retrieved_context = None):
    print('updating task_3 results')
    if use_test_data is None:
        use_test_data = params.use_test_data
    
    split = 'test' if use_test_data else 'dev'
    if use_test_data:
        datapoints = entail_dataset.test_data
    else:
        datapoints = entail_dataset.dev_data

    for i in range(5):
        print('new_contexts =', datapoints[i]['context'])
        print()
    
    '''
    datapoints = [dict(data_point) for data_point in entail_dataset.data['task_3'][split]]    
    if entail_dataset.semantic_search is not None:
        # using search results as context
        datapoints = entail_dataset.update_dataset_with_search(datapoints)
        for i in range(5):
            print('new_contexts =', datapoints[i]['context'])
            print()
    '''

    if results_path is None:
        results_path = base_utils.get_results_file_path(
                params, test_split=params.use_test_data, result_only=True)

    print('modifying file:', results_path, '\n\n')

    existing_results = []
    with open(results_path) as csvfile:
        existing_results = [row[0] for row in csv.reader(csvfile, delimiter='\t')]

    print('len datapoints', len(datapoints))
    print('len existing_results', len(existing_results))
    print('len existing_results', len(existing_results))
    if retrieved_context:
        assert len(retrieved_context) == len(datapoints)
        print('len retrieved_context', len(retrieved_context))
    
    print(existing_results[:10])

    new_proofs = []
    for dp_it, (datapoint, existing_proof) in enumerate(zip(datapoints, existing_results)):
        datapoint['proof'] = existing_proof
        if retrieved_context:
            old_context = datapoint['context']
            datapoint['context'] = retrieved_context[dp_it]
            if verbose and dp_it < 5:
                print('OLD CONTEXT:', old_context)
                print('NEW CONTEXT:', datapoint['context'])
                print('~~~=======~~~')
        
        new_datapoint = convert_datapoint_sent_to_uuid(datapoint)
        new_proof = new_datapoint['proof']
        if verbose and dp_it < 20:
            print('hypothesis = ', datapoint['hypothesis'])
            print('context = ', datapoint['context'])
            print('old_proof = ', existing_proof)
            print('new_proof = ', new_proof)
            print()
        new_proofs.append(new_proof)

    if uuid_results_path is None:
        uuid_results_path = base_utils.get_results_file_path(
            params, test_split=params.use_test_data, result_only=True, uuid=True)

    print('writing to file:', uuid_results_path, '\n\n')

    results_only_df = pd.DataFrame(new_proofs)
    results_only_df.to_csv(uuid_results_path,  sep='\t', header=False, index=False)

## Loading Data and Initializing Model

In [None]:
if params.approach_name == 'baseline':
    trainer = Trainer(params = params, config = config)
if params.approach_name == 'iter':
    if params.task_name == 'task_3':
        trainer = TrainerIterativeMultiStepRet(params = params, config = config)
    else:
        trainer = TrainerIterative(params = params, config = config)
trainer.model.parallelize()

In [None]:
semantic_search = None
if params.task_name == 'task_3' and params.use_custom_retrieval:
    # use retrieval if "task_3"
    encoder_model = None
    if params.sent_trans_encoder_model_path is not None:        
        encoder_model_path = params.sent_trans_encoder_model_path % params.sent_trans_name
        print('loading sentence transformer model from:', encoder_model_path)
        encoder_model = SentenceTransformer(encoder_model_path)
    if params.dataset_name == 'arc_entail':
        semantic_search = IterativeSemanticSearch(
            encoder_model = encoder_model, params = params, config = config)

In [None]:
# approach names can be: 'baseline', 'iterative'
if params.dataset_name == 'arc_entail':
    if params.approach_name == 'baseline':
        print('Creating EntailmentARCDataset dataset')
        entail_dataset = EntailmentARCDataset(
            semantic_search = semantic_search, params = params, config = config)
    if params.approach_name == 'iter':
        print('Creating EntailmentARCDatasetIterative dataset')
        entail_dataset = EntailmentARCDatasetIterative(
            semantic_search = semantic_search, params = params, config = config)
    evaluator = EntailmentEvaluator()
    
print(entail_dataset.data[params.task_name]['train'][0])

In [None]:
if entail_dataset.semantic_search is not None:
    # Load corpus for semantic search
    split = 'test' if params.use_test_data else 'dev'
    if params.dataset_name == 'arc_entail':
        # There are some facts (leaf nodes) not in worldtree corpus but
        # existing in the task context. Adds these facts to the final corpus.
        dataset = entail_dataset.data['task_1'][split]
        semantic_search.load_wt_corpus_with_dataset_context(dataset)
    
    print('corpus size = ', len(semantic_search.corpus))

In [None]:
# print data samples and statistics
print('train dataset len =', len(entail_dataset.data[params.task_name]['train']))
print('dev dataset len =', len(entail_dataset.data[params.task_name]['dev']))
print('test dataset len =', len(entail_dataset.data[params.task_name]['test']))
print('----')
training_loader, val_loader, val_proof_loader, test_proof_loader = entail_dataset.get_torch_dataloaders(
    params.task_name, trainer.tokenizer)

In [None]:
def print_in_out_samples(num, loader):
    source_text = loader.dataset.source_text
    target_text = loader.dataset.target_text
    source_target = zip(source_text[:num], target_text[:num])
    print('\n\n'.join(['SOURCE: %s\nTARGET: %s' % (a, b) for a, b in source_target]))

print('\n======== TRAINING')
print_in_out_samples(num = 15, loader = training_loader)
print('\n======== VALIDATION')
print_in_out_samples(num = 10, loader = val_loader)
print('\n======== VALIDATION FULL PROOF')
print_in_out_samples(num = 10, loader = val_proof_loader)
print('\n======== TEST FULL PROOF')
print_in_out_samples(num = 10, loader = test_proof_loader)

print('----')
print('len(training_loader)', len(training_loader))
print('len(val_loader)', len(val_loader))
print('len(val_proof_loader)', len(val_proof_loader))
print('len(test_proof_loader)', len(test_proof_loader))
for loader in [training_loader, val_loader, val_proof_loader, test_proof_loader]:
    print('----')
    for d in loader:
        print('source_ids', d['source_ids'].size(), 'source_mask', d['source_mask'].size())
        print('target_ids', d['target_ids'].size(), 'target_ids_y', d['target_ids_y'].size())
        break

## Training Model

In [None]:
# Training loop
print(f'Initiating Fine-Tuning for the model: {params.model_name} on dataset: {params.dataset_name}')
print(f'Run training for {config.TRAIN_EPOCHS} epochs')

min_val_loss = 1e10
max_val_acc = -1

for epoch in range(config.TRAIN_EPOCHS):
    print(f'\nRUNNING EPOCH #{epoch}...')
    _, val_avg_loss = trainer.train(epoch, training_loader, val_loader)
    
    if params.save_min_val_error_type == 'acc':
        predictions, actuals, context = evaluator.run_prediction(params, val_proof_loader, temp_results = True)    
            
        if params.dataset_name == 'arc_entail':
            # computes the all correct metric and save if is lowest
            prediction_file = base_utils.get_results_file_path(params, result_only=True, temp=True)
            
            # update file for task-3
            if params.task_name == 'task_3':
                uuid_prediction_file = prediction_file.replace('_result_only_temp', '_result_only_uuid_temp')            
                update_task_3_results(use_test_data = False,
                                      results_path = prediction_file, 
                                      uuid_results_path = uuid_prediction_file,
                                      retrieved_context = context)
                prediction_file = uuid_prediction_file

            eval_args = SimpleNamespace(
                task = params.task_name, # Task name: task_1, task_2, task_3
                output_dir = 'logs/arc_entail/', # Directory to store scores.
                split = 'dev', # Which split (train/dev/test) to evaluate.
                prediction_file = prediction_file,
                bleurt_checkpoint = 'entailment_bank/bleurt-large-512'
            )
            logs_file_path = base_utils.get_logs_file_path(params, temp=True, epoch_num = epoch)
            
            aggr_metrics = base_utils.run_funtion_redirect_stdout(
                run_scorer.main, args=[eval_args], filename=logs_file_path)
            metric_value = aggr_metrics['logs/arc_entail/scores-dev']['QAHC->P']['proof-overall']['acc']
                
        print('metric_value', metric_value)      
        if metric_value > max_val_acc:
            suffix = '_max_val_acc'
            max_val_acc = metric_value
            print('max_val_acc (metric) =', max_val_acc)
            # Saving trained model with highest validation all correct error
            trainer.save_model(file_path_suffix = suffix)    
        
    if params.save_min_val_error_type == 'loss' and val_avg_loss < min_val_loss:
        suffix = '_min_val_loss'
        min_val_loss = val_avg_loss
        # Saving trained model with lowest validation loss
        trainer.save_model(file_path_suffix = suffix)

    # Saving trained model
    trainer.save_model()

print('min_val_loss', min_val_loss)
print('max_val_acc', max_val_acc)

## Validation and Testing

In [None]:
# Update model from saved weights for model with min validation loss
if params.save_min_val_error_type == 'loss':
    suffix = '_min_val_loss'
    trainer.load_model(suffix)
if params.save_min_val_error_type == 'acc':    
    suffix = '_max_val_acc'
    trainer.load_model(suffix)

In [None]:
def generate_predictions(params):
    
    if params.task_name != 'task_3':
        print('Computing validation loss..')
        trainer.validate(val_loader, verbose=True)
    
    proof_loader = None
    if params.use_test_data:
        print('Now generating output for TEST dataset and saving it in a dataframe') 
        proof_loader = test_proof_loader
    else:
        print('Now generating output for VALIDATION dataset and saving it in a dataframe') 
        proof_loader = val_proof_loader
    return evaluator.run_prediction(params = params, proof_loader = proof_loader, 
                                    semantic_search = semantic_search, verbose=True)
        
# Validation loop and saving the resulting file with predictions and acutals in a dataframe.
predictions, actuals, context = generate_predictions(params)

In [None]:
def is_correct_step(target, pred):
    if '$step$ = ' in target:
        target = target[len('$step$ = '):]
    t_match = re.search('int[0-9]+:', target)
    if t_match:
        target = target[:t_match.span()[1]]
    p_match = re.search('int[0-9]+:', pred)
    if p_match:
        pred = pred[:p_match.span()[1]]
    return set(target.split()) == set(pred.split())

# tests beam search, creates multiple output senquences.
# tests how many are correct compared to gold
# can also save output so it can be used to train proof ranking models
def run_predictions_test(params, config, split='dev'):        
    prefix_constrained_generator = None
    if split == 'train':
        print('using training_loader')
        data_loader = training_loader
    if split == 'dev':
        print('using val_loader')
        data_loader = val_loader
    '''
    prefix_constrained_generator = PrefixConstrainedGenerator(
        trainer.tokenizer, data_loader.dataset.source_text,
        config.VALID_BATCH_SIZE)    
    '''
    
    trainer.model.eval()
    context = []
    predictions = []
    actuals = []
    generation_args = {
        'max_length': config.SUMMARY_LEN,        
        'num_beams': 10,
        'repetition_penalty': 2.5,
        'length_penalty': 1.0,
        'early_stopping': True,
        # 'prefix_allowed_tokens_fn': prefix_constrained_generator.iterative_prefix_allowed_tokens_fn,
        'num_return_sequences': 10
    }    
    dataset_idx = 0
    with torch.no_grad():
        tot_all = 0
        tot_correct = 0
        for batch_idx, data in enumerate(tqdm(data_loader), 0):            
            if prefix_constrained_generator is not None:
                prefix_constrained_generator.set_batch_number(batch_idx)
            y = data['target_ids'].to(params.device, dtype = torch.long)
            ids = data['source_ids'].to(params.device, dtype = torch.long)
            mask = data['source_mask'].to(params.device, dtype = torch.long)
            generation_args.update({
                'input_ids': ids,
                'attention_mask': mask,
            })
            generated_ids = trainer.model.generate(**generation_args)
            inputs = [trainer.tokenizer.decode(i, skip_special_tokens=True, 
                                               clean_up_tokenization_spaces=True) for i in ids]
            preds = [trainer.tokenizer.decode(g, skip_special_tokens=True, 
                                              clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [trainer.tokenizer.decode(t, skip_special_tokens=True, 
                                               clean_up_tokenization_spaces=True) for t in y]

            correct_set = set()
            for it, pred in enumerate(preds):
                if 'num_return_sequences' in generation_args:
                    if (batch_idx < 3 and 
                        it % generation_args['num_return_sequences'] == 0):
                        print('-' * 10)
                    it_real = int(it / generation_args['num_return_sequences'])
                data_point = entail_dataset.data[params.task_name][split][dataset_idx]
                is_correct = is_correct_step(target[it_real], pred)
                if batch_idx < 3:
                    print('--> id:', data_point['id'])
                    print('--> inputs:', inputs[it_real])
                    print('--> target:', target[it_real])
                    print('--> predict:', pred)
                    print('--> is correct:', is_correct)
                    print()
                # this is a bit of a hack
                if ('hypothesis' in target[it_real] and 
                    'num_return_sequences' in generation_args and
                    it % generation_args['num_return_sequences'] == generation_args['num_return_sequences'] -1):
                    dataset_idx += 1         
                if is_correct:
                    correct_set.add(it_real)
            tot_correct += len(correct_set)
            tot_all += config.VALID_BATCH_SIZE
        
        print('Best Possible Accuracy: %.4f' % (float(tot_correct) / tot_all,))

run_predictions_test(params, config, split='dev')

In [None]:
# uses entailment_bank evaluation (original evaluation reported on paper)

def run_ent_bank_eval(params):
    '''
    suffix = '_test_result_only' if params.use_test_data else '_result_only'
    if params.task_name == 'task_3':
        suffix += '_uuid'        
    prediction_file = params.results_file_path.format(
        model_name = params.model_name,
        task_name = params.task_name,
        dataset_name = params.dataset_name,
        suffix = suffix,
        extension = 'tsv')
    '''
    prediction_file = base_utils.get_results_file_path(
        params, test_split=params.use_test_data, result_only=True, 
        uuid=params.task_name == 'task_3')
    
    split = 'test' if params.use_test_data else 'dev'
    eval_args = SimpleNamespace(
        task = params.task_name, # Task name: task_1, task_2, task_3
        output_dir = 'logs/arc_entail/', # Directory to store scores.
        split = split, # Which split (train/dev/test) to evaluate.
        prediction_file = prediction_file,
        bleurt_checkpoint = 'entailment_bank/bleurt-large-512'
        # bleurt_checkpoint = 'entailment_bank/bleurt-large-512'
    )
    print()
    print('prediction_file', eval_args.prediction_file)
    
    '''
    suffix = '_test' if params.use_test_data else ''
    logs_file_path = params.logs_file_path.format(
        model_name = params.model_name,
        task_name = params.task_name,
        suffix = suffix
    )
    '''
    logs_file_path = base_utils.get_logs_file_path(params, test_split = params.use_test_data)
    print()
    print('logs_file_path', logs_file_path)
    aggr_metrics = base_utils.run_funtion_redirect_stdout(
        run_scorer.main, args=[eval_args], filename=logs_file_path)
    # aggr_metrics = run_scorer.main(eval_args)
    return aggr_metrics

########################################################
# Execute evaluation and compute metrics

if params.dataset_name == 'arc_entail':
    if params.task_name == 'task_3':
        update_task_3_results(retrieved_context = context)
    aggr_metrics = run_ent_bank_eval(params)
    logs_scores = 'logs/arc_entail/scores-test' if params.use_test_data else 'logs/arc_entail/scores-dev'
    print(aggr_metrics[logs_scores]['QAHC->P']['proof-overall']['acc'])

---
# *NOTE*: Following code blocks for debug only
## Optinal Scripts (update instance states)

In [None]:
'''
Saving trained model
'''
# trainer.save_model()

In [None]:
'''
Update trainer code (Only uncomment if Trainer class code changed)
'''
# trainer_old = trainer
# trainer = Trainer(tokenizer=trainer_old.tokenizer, model=trainer_old.model, optimizer=trainer_old.optimizer)

In [None]:
'''
Update model from saved weights
'''
# trainer.load_model()

In [None]:
'''
Update model from saved weights for model with min validation loss
'''
# trainer.load_model('_min_val_loss')