In [None]:
# Imports
import os
import re
import sys
import json
import torch
import base_utils
import random
import string
import numpy as np
import pandas as pd
from types import SimpleNamespace
from tqdm.notebook import tqdm
from collections import Counter, defaultdict

# Importing DL libraries
import torch
from torch import nn
import torch.nn.functional as F
from torch import cuda
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, Adafactor
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util
from sentence_transformers import evaluation as st_evaluation
from sentence_transformers import SentenceTransformer, InputExample, losses

from retrieval_utils import convert_datapoint_to_sent_to_text, sent_text_as_counter

In [None]:
params = SimpleNamespace(
    # options: "arc_entail" (entailment bank), "eqasc"
    dataset_name = 'arc_entail',    
    # task_1, task_2, task_3
    task_name = 'task_3',
    # use test instead of dev data to evaluate model
    use_test_data = True,    
    encoder_model_path = '../data/arc_entail/models/%s_fine_tuned_v6/',
    encoder_checkpoint_path = '../data/arc_entail/models/%s_checkpoint_v6',
    device = 'cuda' if cuda.is_available() else 'cpu',
    # full list of sentence transformers: https://www.sbert.net/docs/pretrained_models.html
    sent_trans_name = 'all-mpnet-base-v2',
    wt_corpus_file_path = '../data/arc_entail/supporting_data/worldtree_corpus_sentences_extended.json',    
    max_retrieved_sentences = 25
)

config = SimpleNamespace(
    TRAIN_EPOCHS = 5,       # number of epochs to train
    LEARNING_RATE = 4e-5,   # learning rate
    SEED = 39,              # random seed
)

In [None]:
def set_random_seed():
    # Set random seeds and deterministic pytorch for reproducibility
    torch.manual_seed(config.SEED) # pytorch random seed
    np.random.seed(config.SEED) # numpy random seed
    torch.backends.cudnn.deterministic = True     

set_random_seed()

In [None]:
class CustomDataset(Dataset):

    def __init__(self, source_text, target_text, tokenizer, source_len, summ_len):
        self.tokenizer = tokenizer
        self.source_len = source_len
        self.summ_len = summ_len
        self.source_text = source_text
        self.target_text = target_text

    def __len__(self):
        return len(self.source_text)

    def __getitem__(self, index):
        source_text = str(self.source_text[index])
        source_text = ' '.join(source_text.split())

        target_text = str(self.target_text[index])
        target_text = ' '.join(target_text.split())

        source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, 
                                                  padding='max_length',return_tensors='pt', truncation=True)
        target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, 
                                                  padding='max_length',return_tensors='pt', truncation=True)

        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target['attention_mask'].squeeze()

        return {
            'source_ids': source_ids.to(dtype=torch.long),
            'source_mask': source_mask.to(dtype=torch.long), 
            'target_ids': target_ids.to(dtype=torch.long),
            'target_ids_y': target_ids.to(dtype=torch.long)
        }

In [None]:
class SemanticSearch():
    
    def __init__(self, corpus = None, encoder_model = None, params = None):
        self.params = params
        self.encoder_model = encoder_model
        if encoder_model is None:
            self.encoder_model = SentenceTransformer(self.params.sent_trans_name)
        if corpus is not None:
            self.update_corpus_embeddings(corpus)

    def load_wt_corpus_file(self):
        wt_corpus = {}
        with open(self.params.wt_corpus_file_path, 'r', encoding='utf8') as f:
            wt_corpus = json.loads(f.readline())
        return wt_corpus
            
    def load_wt_corpus(self, extra_facts = None):
        wt_corpus = self.load_wt_corpus_file()
        corpus = list(wt_corpus.values())
        if extra_facts is not None:
            corpus.extend(extra_facts)
            corpus = list(set(corpus))
        self.update_corpus_embeddings(corpus)
    
    def update_corpus_embeddings(self, corpus):
        self.corpus = corpus
        # Encode all sentences in corpus
        self.corpus_embeddings = self.encoder_model.encode(
            corpus, convert_to_tensor=True, show_progress_bar = False)
        self.corpus_embeddings = self.corpus_embeddings.to(self.params.device)
        self.corpus_embeddings = st_util.normalize_embeddings(self.corpus_embeddings)    
    
    def search_with_id_and_scores(self, queries, top_k = 1):
        '''
        Search for best semantically similar sentences in corpus.
        
        returns corpus ids (index in input corpus) and scoress
        '''
        if type(queries) != list:
            queries = [queries]

        #Encode all queries
        query_embeddings = self.encoder_model.encode(queries, convert_to_tensor=True)
        query_embeddings = query_embeddings.to(self.params.device)
        query_embeddings = st_util.normalize_embeddings(query_embeddings)
        hits = st_util.semantic_search(query_embeddings, self.corpus_embeddings, 
                                    top_k=top_k, score_function=st_util.dot_score)
        return hits
    
    def search(self, *args, **kwargs):
        '''
        Search for best semantically similar sentences in corpus.
        
        Only returns elements from corpus (no score or id)
        '''
        hits = self.search_with_id_and_scores(*args, **kwargs)
        elements = [[self.corpus[ret['corpus_id']]  for ret in hit] for hit in hits]
        return elements
    
    def run_test(self):
        corpus = [
            'A man is eating food.', 'A man is eating a piece of bread.',
            'The girl is carrying a baby.', 'A man is riding a horse.', 'A woman is playing violin.',
            'Two men pushed carts through the woods.', 'A man is riding a white horse on an enclosed ground.',
            'A monkey is playing drums.', 'Someone in a gorilla costume is playing a set of drums.',
            'matter in the gas phase has variable shape',
        ]
        self.update_corpus_embeddings(corpus)
        queries = ['A woman enjoys her meal', 'A primate is performing at a concert', 'matter in gas phase has no definite volume and no definite shape']
        results = self.search(queries, top_k = 2)
        for i in range(len(queries)):
            print('Query:', queries[i])
            print('Best results:', results[i])
            print()


In [None]:
class EntailmentARCDataset():
    
    ROOT_PATH = "../data/arc_entail"
    DATASET_PATH = os.path.join(ROOT_PATH, "dataset")
    TASK_PATH = os.path.join(DATASET_PATH, "task_{task_num}")
    PARTITION_DATA_PATH = os.path.join(TASK_PATH, "{partition}.jsonl")
    
    def __init__(self, semantic_search = None, params = None, config = None):
        self.params = params
        self.config = config
        self.data = {self.get_task_name(task_num): 
                     {partition: [] for partition in ['train', 'dev', 'test']}  
                     for task_num in range(1, 4)}
        self.load_dataset()
        self.semantic_search = semantic_search
    
    def get_task_name(self, task_num):
        return "task_" + str(task_num)
    
    def get_task_number(self, task_name):
        return int(task_name[-1])
    
    def get_dataset_path(self, task_num = 1, partition = 'train'):
        path = self.PARTITION_DATA_PATH.format(task_num = task_num, partition = partition)
        return path
    
    def load_dataset(self):
        for task_name in self.data:
            for partition in self.data[task_name]:
                path = self.get_dataset_path(self.get_task_number(task_name), partition)
                with open(path, 'r', encoding='utf8') as f:
                    for line in f:
                        datapoint = json.loads(line)
                        self.data[task_name][partition].append(datapoint)

# Loading Data

### **NOTE**: if not fine-tunning encoder, set ``load_model = True`` to load an existing trained encoder model

In [None]:
load_model = False
encoder_model = None

if load_model:
    encoder_model_path = params.encoder_model_path % params.sent_trans_name
    print(f'loading model from: {encoder_model_path}')
    encoder_model = SentenceTransformer(encoder_model_path)

semantic_search = SemanticSearch(encoder_model = encoder_model, params = params)

In [None]:
entail_dataset = EntailmentARCDataset(
    semantic_search = semantic_search, params = params, config = config)
print(entail_dataset.data[params.task_name]['train'][0])

# Retrieval Utils

In [None]:
def counter_jaccard_similarity(c1, c2):
    inter = c1 & c2
    union = c1 | c2
    return sum(inter.values()) / float(sum(union.values()))

def construct_sent_context_mapping(matching_texts, sent_to_text, new_mapping_key, matching_uuids = None):
    alignment = []
    for match_it, matching_text in enumerate(matching_texts):
        text_counter = sent_text_as_counter(matching_text)
        matching_uuid = None
        if matching_uuids is not None:
            matching_uuid = matching_uuids[match_it]
        for sent_k, sent_v in sent_to_text.items():
            alignment.append((sent_k, sent_v['text'], matching_text, match_it, matching_uuid,
                              counter_jaccard_similarity(sent_v['text_counter'], text_counter)))
    sorted_alignment = sorted(alignment, key= lambda x: x[-1], reverse=True)
    matches_it_used = []
    for align_item in sorted_alignment:
        sent_key = align_item[0]
        matching_text = align_item[2]
        match_it = align_item[3]
        matching_uuid = align_item[4]
        if new_mapping_key not in sent_to_text[sent_key].keys():
            if match_it not in matches_it_used:                
                sent_to_text[sent_key][new_mapping_key] = matching_text
                if matching_uuid is not None:
                    sent_to_text[sent_key][new_mapping_key + '_uuid'] = matching_uuid
                matches_it_used.append(match_it)
    
    assert all([new_mapping_key in v for v in sent_to_text.values()])
    return sent_to_text


def construct_datapoint_context_mapping(datapoint):
    sent_to_text = convert_datapoint_to_sent_to_text(datapoint)
    triples = datapoint['meta']['triples'].values()
    assert len(sent_to_text.keys()) == len(triples)    
    sent_to_text = construct_sent_context_mapping(
        triples, sent_to_text, new_mapping_key = 'triple_text')
    
    wt_p_items = [wt_p_item['original_text'] 
                  for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]
    wt_p_uuids = [wt_p_item['uuid'] 
                  for wt_p_item in datapoint['meta']['worldtree_provenance'].values()]
    assert len(sent_to_text.keys()) == len(wt_p_items)
    sent_to_text = construct_sent_context_mapping(
        wt_p_items, sent_to_text, new_mapping_key = 'wt_p_text', matching_uuids = wt_p_uuids)
    
    for sent_to_text_v in sent_to_text.values():
        del sent_to_text_v['text_counter']
    return sent_to_text

def create_context_mapping(dataset = entail_dataset.data['task_1']['test'], verbose = False):
    context_mapping = []
    wt_corpus = {}
    with open(params.wt_corpus_file_path, 'r', encoding='utf8') as f:
        wt_corpus = json.loads(f.readline())
    
    for datapoint in dataset:
        datapoint_context_mapping = construct_datapoint_context_mapping(datapoint)
        context_mapping.append(datapoint_context_mapping)
        
        for k, v in datapoint_context_mapping.items():            
            if 'wt_p_text_uuid' in v and v['wt_p_text_uuid'] in wt_corpus.keys():
                datapoint_context_mapping[k]['wt_corpus_text'] = wt_corpus[v['wt_p_text_uuid']]
                
        if verbose:
            for k, v in datapoint_context_mapping.items():            
                for item_k, item_v in v.items():
                    print(item_k, '=', item_v)
                print()
            print('======')
    return context_mapping

In [None]:
def fix_wt_corpus_with_task_1_data(split = 'test'):
    dataset = entail_dataset.data['task_1'][split]    
    context_mapping = create_context_mapping(dataset)
    removal_sents = [v for cm in context_mapping for p in cm.values() 
                     for k,v in p.items() if k != 'text']
    include_sents = [v for cm in context_mapping for p in cm.values() 
                     for k,v in p.items() if k == 'text']
    print(removal_sents[:20])
    wt_corpus = semantic_search.load_wt_corpus_file()
    corpus = list(set(list(wt_corpus.values())) - set(removal_sents))
    corpus.extend(include_sents)
    semantic_search.update_corpus_embeddings(corpus)
    
    
fix_wt_corpus_with_task_1_data()
print('corpus size = ', len(semantic_search.corpus))

In [None]:
def get_sents_height_on_tree(sent_text_lst, data_point):
    proof = data_point['proof']
    context = data_point['context']
    conclusion_to_antecedent_int_map = {}
    sents_to_int_map = {}
    int_to_height_map = {}
    
    steps = proof.split(';')[:-1]
    for step in steps:
        antecedent, conclusion = step.split(' -> ')
        if conclusion.strip() == 'hypothesis':
            conclusion_int = 'hypothesis'
        else:
            conclusion_int = re.findall(r'int[0-9]+', conclusion)[0]
        antecedent_ints = re.findall(r'int[0-9]+', antecedent)
        antecedent_sents = re.findall(r'sent[0-9]+', antecedent)
        for ant_sent in antecedent_sents:
            sents_to_int_map[ant_sent] = conclusion_int
        conclusion_to_antecedent_int_map[conclusion_int] = antecedent_ints
        
    cur_height = 0
    current_ints = ['hypothesis']    
    while len(current_ints) > 0:
        next_ints = []
        for cur_int in current_ints:
            int_to_height_map[cur_int] = cur_height
            if cur_int in conclusion_to_antecedent_int_map.keys():
                for next_int in conclusion_to_antecedent_int_map[cur_int]:
                    next_ints.append(next_int)
        current_ints = next_ints
        cur_height += 1
    
    heights = []
    for sent_text in sent_text_lst:
        sent_match = re.findall(
            '(sent[0-9]+): (%s)' % re.escape(sent_text.strip()), context)
        if len(sent_match) == 0:
            print('MISSING!!!')
            print('sent_text =', sent_text)
            print('context =', context)
            continue
        sent_symb = sent_match[0][0]
        int_symb = sents_to_int_map[sent_symb]
        if not int_symb in int_to_height_map:
            # this might happen when proof has antecedent missing "int"
            continue
        heights.append({
            'sent':sent_symb, 'text': sent_text,
            'height': int_to_height_map[int_symb] + 1,
        })
    
    return heights, conclusion_to_antecedent_int_map, sents_to_int_map, int_to_height_map

def compute_retrieval_metrics(retrieved_sentences_lst, split = 'test', verbose = False):
    dataset = entail_dataset.data['task_1'][split]
    context_mapping = create_context_mapping(dataset)
    
    assert len(retrieved_sentences_lst) == len(dataset)
    
    tot_sent = 0
    tot_sent_correct = 0
    tot_sent_missing = 0
    tot_no_missing = 0
    tot_sent_not_in_wt = 0
    tot_missing_sent_height = 0
    tot_correct_sent_height = 0
     
    correct_retrieved_lst = []
    errors_lst = [] # in retreived but not in gold
    missing_lst = [] # in gold but not in retrieved
        
    for ret_sentences, dp_context_mapping, datapoint in zip(retrieved_sentences_lst, context_mapping, dataset):
        correct_retrieved = []
        errors = []
        for ret_sentence in ret_sentences:
            is_correct = False
            for mapping_texts in dp_context_mapping.values():
                if ret_sentence in mapping_texts.values():
                    is_correct = True
                    if mapping_texts['text'] not in correct_retrieved:
                        correct_retrieved.append(mapping_texts['text'])
                    if len(mapping_texts['wt_p_text_uuid']) < 2:
                        tot_sent_not_in_wt += 1
                    break
            if not is_correct:
                errors.append(ret_sentence)
        all_sents = [v['text'] for v in dp_context_mapping.values()]
        missing = list(set(all_sents) - set(correct_retrieved))
        
        correct_retrieved_lst.append(correct_retrieved)
        errors_lst.append(errors)
        missing_lst.append(missing)
        
        tot_sent += len(dp_context_mapping.keys())
        tot_sent_correct += len(correct_retrieved)
        tot_sent_missing += len(missing)
        tot_no_missing += 0 if len(missing) > 0 else 1
        
        missing_heights, _, _, _ = get_sents_height_on_tree(missing, datapoint)
        tot_missing_sent_height += sum([mh['height'] for mh in missing_heights])
        correct_heights, _, _, _ = get_sents_height_on_tree(correct_retrieved, datapoint)
        tot_correct_sent_height += sum([ch['height'] for ch in correct_heights])
        
        # if verbose and len(missing) > 0:
        if verbose and len(missing) > 0:
            hypothesis = datapoint['hypothesis']
            question = datapoint['question']
            answer = datapoint['answer']
            
            print('hypothesis', hypothesis)
            print('Q + A', question + ' -> ' + answer)
            print('=====')
            print('retrieved:', correct_retrieved)
            print('missing:', missing)
            print()
        
    
    recall = tot_sent_correct / float(tot_sent)
    all_correct = tot_no_missing / float(len(dataset))
    avg_correct_sent_height = tot_correct_sent_height / (float(tot_sent_correct) + 1e-9)
    avg_missing_sent_height = tot_missing_sent_height / (float(tot_sent_missing) + 1e-9)
    print('recall:', recall)
    print('all correct:', all_correct)
    print('number of retrieved not in corpus:', tot_sent_not_in_wt)
    print('avg height of correct sentences:', avg_correct_sent_height)
    print('avg height of missing sentences:', avg_missing_sent_height)
    
    return recall, correct_retrieved_lst, errors_lst, missing_lst

In [None]:
def test_task_3_paper_recall():
    split = 'test'
    t3_data = entail_dataset.data['task_3'][split]
    retrieved_sentences_lst = [t3['meta']['triples'].values() for t3 in t3_data]
    compute_retrieval_metrics(retrieved_sentences_lst, split = split, verbose=False)

In [None]:
def test_sent_transformer_recall(split = 'test', verbose = True, top_k = 25):
    print('top_k', top_k)
    
    t1_data = entail_dataset.data['task_1'][split]
    ret_data = semantic_search.search([t1['hypothesis'] for t1 in t1_data], top_k = top_k)
    assert len(t1_data) == len(ret_data)
    
    return compute_retrieval_metrics(ret_data, split = split)

# Retrieval Training

In [None]:
def get_dataset_examples(split = 'train', use_hard_negative = True, hn_top_k = 25):
    # Define training / dev examples for sentence transformer. 
    examples = []
    if use_hard_negative:
        # retrieved by current model but that is not part of the dataset
        _, _, errors, _ = test_sent_transformer_recall(
            split = split, verbose = False, top_k = hn_top_k)
        
    for t1_it, t1 in enumerate(tqdm(entail_dataset.data['task_1'][split])):
        hyp = t1['hypothesis']
        sents = t1['meta']['triples'].values()
        hard_negs = []
        for sent in sents:
            examples.append(InputExample(texts=[hyp, sent], label=1.0))
        
        if use_hard_negative:
            for error in errors[t1_it]:
                if error not in sents:
                    hard_negs.append(error)
                    examples.append(InputExample(texts=[hyp, error], label=0.75))
        
        neg_sents = random.sample(semantic_search.corpus, len(sents) * 4)
        for neg_sent in neg_sents:
            if neg_sent not in sents and neg_sent not in hard_negs:
                examples.append(InputExample(texts=[hyp, neg_sent], label=0.0))                        
    return examples

def train_sent_trans(get_examples_fn = get_dataset_examples):

    #Define the model. Either from scratch of by loading a pre-trained model
    model = semantic_search.encoder_model
    
    train_examples = get_examples_fn(split = 'train')
    dev_examples = get_examples_fn(split = 'dev')    
    
    print('train_examples sz =', len(train_examples))
    print('dev_examples sz =', len(dev_examples))
    
    for example in train_examples[:10]:
        print(example)
        print()
    
    #Define your train dataset, the dataloader and the train loss
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
    dev_dataloader = DataLoader(dev_examples, shuffle=False, batch_size=32)
    train_loss = losses.CosineSimilarityLoss(model)
    
    eval_sent_1 = [e.texts[0] for e in dev_examples]
    eval_sent_2 = [e.texts[1] for e in dev_examples]
    eval_scores = [e.label for e in dev_examples]
    
    evaluator = st_evaluation.EmbeddingSimilarityEvaluator(eval_sent_1, eval_sent_2, eval_scores)
    
    print('evaluating pre-trained model') 
    results = model.evaluate(evaluator)
    print('results = ', results)
    
    callback = lambda score, epoch, steps: print('callback =', score, epoch, steps)
    
    print('fine-tunning model')
    #Tune the model

    model_trained_path = params.encoder_model_path % params.sent_trans_name
    checkpoint_path = params.encoder_checkpoint_path % params.sent_trans_name
    
    model.fit(train_objectives=[(train_dataloader, train_loss)], 
              epochs=config.TRAIN_EPOCHS, warmup_steps=1000, 
              save_best_model=True, output_path=model_trained_path, 
              optimizer_params = {'lr': config.LEARNING_RATE},
              evaluator=evaluator, evaluation_steps=500,
              checkpoint_path=checkpoint_path,
              checkpoint_save_total_limit = 2,
              #callback = callback
              )
    model_trained_path_final = '../data/arc_entail/models/%s_fine_tuned_all_steps_v6' % params.sent_trans_name
    model.save(model_trained_path_final)
    
    print('evaluating fine-tuned model')
    results = model.evaluate(evaluator)
    print('results = ', results)    

In [None]:
train_sent_trans(get_examples_fn = get_dataset_examples)

# Retrieval Evaluation

## EntailmentWriter evaluation

In [None]:
test_task_3_paper_recall()

## Retrieval evaluation (single)

In [None]:
_, _, _, _ = test_sent_transformer_recall(split = 'test', verbose = False)

## Multi-step retrieval evaluation (conditional)

In [None]:
def test_sent_transformer_recall(split = 'test', verbose = True, top_k = 25):
    t1_data = entail_dataset.data['task_1'][split]
    
    ret_data = []
    for _ in t1_data:
        ret_data.append([])
#     probes = [t1['hypothesis'] for t1 in t1_data]
    probes = [t1['question'] + ' ' + t1['answer'] for t1 in t1_data]
    keep_top_from_hyp = 15
    
    for k_step in range(1, top_k - keep_top_from_hyp + 1):
        temp_ret_data = semantic_search.search(probes, top_k = k_step)
        for ret_it, rets in enumerate(temp_ret_data):
            for ret in rets:
                if ret not in ret_data[ret_it]:
                    ret_data[ret_it].append(ret)
                    probes[ret_it] += ' ' + ret
                    break
            
    # now gather "keep_top_from_hyp" by using only hypothesis as probe
    probes = [t1['hypothesis'] for t1 in t1_data]
#     probes = [t1['question'] + ' ' + t1['answer'] for t1 in t1_data]
    temp_ret_data = semantic_search.search(probes, top_k = top_k * 3)
    for ret_it, rets in enumerate(temp_ret_data):
        for ret in rets:
            if ret not in ret_data[ret_it]:
                ret_data[ret_it].append(ret)
                if len(ret_data[ret_it]) == top_k:
                    break
                    
    assert all([len(x) == top_k for x in ret_data])
    
    compute_retrieval_metrics(ret_data, split = split, verbose = verbose)

# _, _, _, _ = test_sent_transformer_recall(split = 'test', verbose = False)
test_sent_transformer_recall(split = 'test', verbose = False, top_k = 25)