#Setup

In [None]:
# virtualenv venv -p /usr/bin/python3.7
# source venv/bin/activate
# pip3 install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip3 install transformers sentencepiece

In [None]:
!mkdir "data"
%cd "data"
!wget "https://aristo-data-public.s3.amazonaws.com/proofwriter/proofwriter-dataset-V2020.12.3.zip"
# wget "https://drive.google.com/uc?export=download&id=1kVr-YsUVFisceiIklvpWEe0kHNSIFtNh"
# mv "uc?export=download&id=1kVr-YsUVFisceiIklvpWEe0kHNSIFtNh" "entailment_trees_emnlp2021_data_v3.zip"
!unzip "proofwriter-dataset-V2020.12.3.zip"
# unzip "entailment_trees_emnlp2021_data_v3.zip"
%cd ".."

In [None]:
# pip3 install datasets numpy gsutil
!gsutil cp -r gs://ai2-oyvindt/t5-models/hf-conversions/rr_owa_d3plus_infstage_ma1_mixture_large_hf .

In [None]:
!python -m spacy download en_core_web_md
!python -m spacy download en_core_web_lg

#Imports

In [5]:
import gc
import random
import os
import re
import time
import datetime
import json
import torch
import spacy
import numpy as np
import transformers
import joblib
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.nn import Module
from sklearn import tree, metrics, model_selection
from matplotlib import pyplot as plt

In [6]:
%matplotlib inline

#Utils

In [7]:
utils_since = time.time()

def clear_cache():
    gc.collect()
    with torch.no_grad():
        torch.cuda.empty_cache()
    assert 1 / 0 == 0

def set_device(gpu):
    if gpu and torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
    return device

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(seed)
    return seed

def my_print(header=None, text=None, mirror=None, timestamp=True, reset=False):
    if reset:
        global utils_since
        utils_since = time.time()
        if mirror is not None:
            with open(mirror, 'wt') as f:
                f.truncate(0)
    result = ""
    if header is not None:
        result += f"{header} -> "
    if text is not None:
        result += text
    if timestamp:
        if text is not None:
            result += " | "
        now = datetime.timedelta(seconds=round(time.time()-utils_since))
        result += f"elapsed: {now}"
    print(result)
    if mirror is not None:
        with open(mirror, 'at') as f:
            print(result, file=f)

def save_model(model, tokenizer, optimizer, label, metadata):
    path = os.path.join(metadata['SAVE_PATH'], f"{metadata['MODEL_NAME']}_{label}")
    model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    torch.save(optimizer.state_dict(), os.path.join(path, 'optim.pt'))

def load_model(label, metadata):
    path = os.path.join(metadata['SAVE_PATH'], f"{metadata['MODEL_NAME']}_{label}")
    model = T5ForConditionalGeneration.from_pretrained(path, local_files_only=True)
    tokenizer = T5Tokenizer.from_pretrained(path, local_files_only=True)
    optimizer = transformers.Adafactor(params=model.parameters())
    optimizer.load_state_dict(torch.load(os.path.join(path, 'optim.pt')))
    return model, tokenizer, optimizer

#Datasets

In [8]:
def split_prediction(prediction):
    main_split = prediction.split(';')
    answer = main_split[0].split('=')[-1].strip()
    if answer == '' and len(main_split) >= 3:
        answer = main_split[1].split('=')[-1].strip()
    proof = main_split[-1].split('=')[-1].strip()
    return answer, proof

def check_prediction(prediction, possible_targets):
    answer, proof = split_prediction(prediction)
    targets = [t.split('+') for t in possible_targets.split('|')]
    return [answer, proof] in targets

class ProofStageDataset(Dataset):

    def __init__(self, tokenizer, split, metadata, limit=None, offset=None, alt_format=False, check_length=False, legacy_random=False, legacy_bias=False):
        self.max_source_length = metadata['MAX_SOURCE_LENGTH']
        self.max_target_length = metadata['MAX_TARGET_LENGTH']
        self.tokenizer = tokenizer
        with open(os.path.join(metadata['DATASET_HOME'], split), 'rt') as data:
            self.data = [json.loads(line) for line in data]
        if limit is not None:
            self.data = self.data[:limit]
        if offset is not None:
            self.data = self.data[offset:]
        self.alt_format = alt_format
        self.check_length = check_length
        self.legacy_random = legacy_random
        self.legacy_bias = legacy_bias

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        question = "What is one singlehop inference?"
        if self.legacy_random:
            pre_theory = [t['text'] for l in [self.data[index]['triples'].values(),self.data[index]['rules'].values()] for t in l]
            random.shuffle(pre_theory)
            alt_theory = [(k,v['text']) for l in [self.data[index]['triples'].items(),self.data[index]['rules'].items()] for k, v in l]
            theory = []
            mask = [False for _ in pre_theory]
            for t1 in pre_theory:
                flag = False
                for i, t2 in enumerate(alt_theory):
                    if t2[1] == t1 and mask[i] == False and flag == False:
                        theory.append(t2)
                        mask[i] = True
                        flag = True
        else:
            theory = [(k,v['text']) for l in [self.data[index]['triples'].items(),self.data[index]['rules'].items()] for k, v in l]
            random.shuffle(theory)
        context = ' '.join(f"sent{i+1}: {t[1]}" for i, t in enumerate(theory))
        targets = []
        inferences = self.data[index]['allInferences']
        for inf in inferences:
            ans = inf['text']
            proofs = inf['proofs'].split('OR')
            for prf in proofs:
                pre_tags = list(filter(None, [re.sub(r"[->()\[\]]", '', s) for s in re.split(' ', prf)]))
                tags = [f"sent{i+1}" for x in pre_tags for i, t in enumerate(theory) if x == t[0]]
                if self.legacy_bias:
                    for i in range(0, len(theory)-1):
                        for j in range(i+1, len(theory)):
                            if theory[i][1] == theory[j][1] and f"sent{j+1}" in tags:
                                tags[tags.index(f"sent{j+1}")] = f"sent{i+1}"
                assert len(tags) >= 2 and len(tags) <= 3
                if self.alt_format:
                    format = f"# {tags[0]} {tags[1]}" if len(tags) == 2 else f"# {tags[0]} & {tags[1]} {tags[2]}"
                else:
                    format = f"# {tags[1]} {tags[0]}" if len(tags) == 2 else f"# {tags[2]} & {tags[0]} {tags[1]}"
                targets.append((ans,format))
        if len(targets) == 0:
            targets.append(("Nothing.","None"))
        if self.legacy_random:
            if len(inferences) > 0:
                answer_choice = random.randint(0, len(inferences)-1)
                proof_choice = random.randint(0, len(inferences[answer_choice]['proofs'].split('OR'))-1)
                offset = sum(1 for i in range(answer_choice) for prf in inferences[i]['proofs'].split('OR'))
                choice = offset + proof_choice
            else:
                choice = 0
        else:
            choice = random.randint(0, len(targets)-1)
        answer = targets[choice][0]
        proof = targets[choice][1]
        source_text = f"$answer$ ; $proof$ ; $question$ = {question} ; $context$ = {context}"
        target_text = f"$answer$ = {answer} ; $proof$ = {proof}"
        source = self.tokenizer(
            source_text,
            padding = 'max_length',
            max_length = self.max_source_length,
            pad_to_max_length = True,
            truncation = True,
            return_tensors = 'pt',
        )
        target = self.tokenizer(
            target_text,
            padding = 'max_length',
            max_length = self.max_target_length,
            pad_to_max_length = True,
            truncation = True,
            return_tensors = 'pt',
        )
        if self.check_length:
            assert source['input_ids'].squeeze().to(dtype=torch.long)[-1] == 0
            assert target['input_ids'].squeeze().to(dtype=torch.long)[-1] == 0
        return {
            'source_ids': source['input_ids'].squeeze().to(dtype=torch.long),
            'source_mask': source['attention_mask'].squeeze().to(dtype=torch.long),
            'target_ids': target['input_ids'].squeeze().to(dtype=torch.long),
            'target_ids_y': target['attention_mask'].squeeze().to(dtype=torch.long),
            'possible_targets': '|'.join(['+'.join(x) for x in targets]),
            'source_text': source_text,
        }

class ProofIterativeDataset(Dataset):

    def __init__(self, tokenizer, split, metadata, limit=None, offset=None):
        self.max_source_length = metadata['MAX_SOURCE_LENGTH']
        self.max_target_length = metadata['MAX_TARGET_LENGTH']
        self.tokenizer = tokenizer
        with open(os.path.join(metadata['DATASET_HOME'], split), 'rt') as data:
            self.data = [json.loads(line) for line in data]
        if limit is not None:
            self.data = self.data[:limit]
        if offset is not None:
            self.data = self.data[offset:]
        self.items = list()
        for d in self.data:
            theory = [v['text'] for l in [d['triples'].items(),d['rules'].items()] for _, v in l]
            questions = [(v['question'][:-1]+'?',str(v['answer'])) for v in d['questions'].values()]
            for q in questions:
                random.shuffle(theory)
                t = ' '.join(f'sent{i+1}: {x}' for i, x in enumerate(theory))
                self.items.append((t, q[0], q[1]))
        if limit is not None:
            self.items = self.items[:limit]
        if offset is not None:
            self.items = self.items[offset:]

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        context, question, answer = self.items[index]
        return {
            'context': context,
            'question': question,
            'answer': answer,
        }

#Training

In [9]:
def train_epoch(model, tokenizer, optimizer, loader, metadata):
    model.to(metadata['DEVICE'])
    model.train()
    cumulative_loss = 0
    for i, data in enumerate(loader):
        y = data['target_ids'].to(metadata['DEVICE'], dtype=torch.long)
        y_ids = y[:,:-1].contiguous()
        lm_labels = y[:,1:].clone().detach()
        lm_labels[y[:,1:]==tokenizer.pad_token_id] = -100
        ids = data['source_ids'].to(metadata['DEVICE'], dtype=torch.long)
        mask = data['source_mask'].to(metadata['DEVICE'], dtype=torch.long)
        outputs = model(input_ids=ids, attention_mask=mask, decoder_input_ids=y_ids, labels=lm_labels)
        loss = outputs.loss
        cumulative_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return cumulative_loss / len(loader)

def test_epoch(model, tokenizer, loader, metadata, nb=4, rp=2.5, lp=0.5, es=False):
    model.to(metadata['DEVICE'])
    model.eval()
    predictions = []
    possible_targets = []
    with torch.no_grad():
        for data in loader:
            y = data['target_ids'].to(metadata['DEVICE'], dtype=torch.long)
            ids = data['source_ids'].to(metadata['DEVICE'], dtype=torch.long)
            mask = data['source_mask'].to(metadata['DEVICE'], dtype=torch.long)
            generated_ids = model.generate(
                input_ids = ids,
                attention_mask = mask,
                max_length = metadata['MAX_TARGET_LENGTH'],
                num_beams = nb,
                repetition_penalty = rp,
                length_penalty = lp,
                early_stopping = es,
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            predictions.extend(preds)
            possible_targets.extend(data['possible_targets'])
    results = [check_prediction(p, t) for p, t in zip(predictions, possible_targets)]
    accuracy = sum([x for x in results]) / len(results)
    return accuracy

def training_loop(model, tokenizer, optimizer, training_loader, validation_loader, testing_loader, metadata, save_best=True, save_last=True, base_accuracy=-1.0):
    best_accuracy = base_accuracy
    for epoch in range(metadata['EPOCHS']):
        loss = train_epoch(model, tokenizer, optimizer, training_loader, metadata)
        my_print("TRAIN", f"epoch: {epoch+1} | loss: {loss:.5f}", mirror=metadata['MIRROR_OUTPUT'])
        accuracy = test_epoch(model, tokenizer, validation_loader, metadata)
        my_print("VAL  ", f"epoch: {epoch+1} | accuracy: {accuracy:.3f}", mirror=metadata['MIRROR_OUTPUT'])
        if save_last:
            save_model(model, tokenizer, optimizer, 'last', metadata)
            my_print("SAVE ", "model: last", mirror=metadata['MIRROR_OUTPUT'])
        if accuracy > best_accuracy:
            if save_best:
                save_model(model, tokenizer, optimizer, 'best', metadata)
                my_print("SAVE ", "model: best", mirror=metadata['MIRROR_OUTPUT'])
            best_accuracy = accuracy
    last_model = model
    accuracy = test_epoch(last_model, tokenizer, testing_loader, metadata)
    my_print("TEST ", f"model: last | accuracy: {accuracy:.3f}", mirror=metadata['MIRROR_OUTPUT'])
    if save_best:
        best_model, _, _ = load_model('best', metadata)
        accuracy = test_epoch(best_model, tokenizer, testing_loader, metadata)
        my_print("TEST ", f"model: best | accuracy: {accuracy:.3f}", mirror=metadata['MIRROR_OUTPUT'])
    else:
        best_model = None
    return last_model, best_model

def single_inference(model, tokenizer, context, metadata, nb=4, rp=2.5, lp=0.5, es=False):
    model.to(metadata['DEVICE'])
    model.eval()
    with torch.no_grad():
        source_text = f"$answer$ ; $proof$ ; $question$ = What is one singlehop inference? ; $context$ = {context}"
        source = tokenizer(
            source_text,
            padding = 'max_length',
            max_length = metadata['MAX_SOURCE_LENGTH'],
            pad_to_max_length = True,
            truncation = True,
            return_tensors = 'pt',
        )
        source_ids = source['input_ids'].to(dtype=torch.long)
        source_mask = source['attention_mask'].to(dtype=torch.long)
        ids = source_ids.to(metadata['DEVICE'], dtype=torch.long)
        mask = source_mask.to(metadata['DEVICE'], dtype=torch.long)
        generated_ids = model.generate(
            input_ids = ids,
            attention_mask = mask,
            max_length = metadata['MAX_TARGET_LENGTH'],
            num_beams = nb,
            repetition_penalty = rp,
            length_penalty = lp,
            early_stopping = es,
            num_return_sequences = nb,
            output_scores = True,
            return_dict_in_generate = True,
        )
        preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids.sequences]
        scores = torch.softmax(generated_ids.sequences_scores, dim=-1).tolist()
    return [r for r in zip(preds, scores)]

def check_ids_length(tokenizer, metadata):
    training_set = ProofStageDataset(tokenizer, metadata['TRAIN_DATASET'], metadata, check_length=True)
    training_loader = DataLoader(training_set, batch_size=metadata['BATCH_SIZE'], shuffle=False)
    validation_set = ProofStageDataset(tokenizer, metadata['VAL_DATASET'], metadata, check_length=True)
    validation_loader = DataLoader(validation_set, batch_size=metadata['BATCH_SIZE'], shuffle=False)
    testing_set = ProofStageDataset(tokenizer, metadata['TEST_DATASET'], metadata, check_length=True)
    testing_loader = DataLoader(testing_set, batch_size=metadata['BATCH_SIZE'], shuffle=False)
    for loader in [training_loader, validation_loader, testing_loader]:
        for data in loader:
            pass

#Inference

In [10]:
def tokenize_sentence(string, lemmifier, pattern='[^a-zA-Z0-9àèéìòùÀÈÉÌÒÙ ]'):
    stripped_sentence = re.sub(pattern, '', string).lower()
    tokenized_sentence = [t.lemma_ for t in lemmifier(stripped_sentence)]
    return tokenized_sentence

def search_answer(question, theory, spacy_module='en_core_web_md'):
    lemmifier = spacy.load(spacy_module)
    answer = 'Unknown'
    proof = 'None'
    tokenized_question = tokenize_sentence(question, lemmifier)
    opposite_questions = list()
    if 'not' in tokenized_question:
        not_index = tokenized_question.index('not')
        opposite_question_1 = tokenized_question.copy()
        opposite_question_1.remove('not')
        opposite_questions.append(opposite_question_1)
        if not_index > 0 and tokenized_question[not_index-1] == 'do':
            opposite_question_2 = opposite_question_1.copy()
            opposite_question_2.pop(not_index-1)
            opposite_questions.append(opposite_question_2)
    tokenized_theory = [tokenize_sentence(t[1], lemmifier) for t in theory]
    for i, t in enumerate(tokenized_theory):
        if t == tokenized_question:
            answer = 'True'
            proof = theory[i][0]
        elif t in opposite_questions:
            answer = 'False'
            proof = theory[i][0]
        elif 'not' in t:
            not_index = t.index('not')
            opposite_t = t.copy()
            opposite_t.remove('not')
            if opposite_t == tokenized_question:
                answer = 'False'
                proof = theory[i][0]
            elif not_index > 0 and opposite_t[not_index-1] == 'do':
                opposite_t.pop(not_index-1)
                if opposite_t == tokenized_question:
                    answer = 'False'
                    proof = theory[i][0]
    return answer, proof

def single_iterative_inference(model, tokenizer, metadata, file_name):
    with open(file_name, 'r') as input_file:
        lines = input_file.read().splitlines()
    question = None
    theory = []
    input_data = ''
    n = 1
    for l in lines:
        if l != '' and l[0] != '#':
            if question is None:
                question = l
            else:
                theory.append((f"sent{n}", l, "None"))
                input_data += f" sent{n}: {l}"
                n += 1
    fact = None
    while fact != 'Nothing.':
        results = single_inference(model, tokenizer, input_data, metadata)
        decoded_output = results[0][0]
        fact, proof = split_prediction(decoded_output)
        if fact != 'Nothing.':
            input_data += f" sent{n}: {fact}"
            theory.append((f"sent{n}", fact, proof))
            n += 1
    answer, proof = search_answer(question, theory)
    return theory, question, answer, proof

def iterative_test(model, tokenizer, loader, metadata, nb=4, rp=2.5, lp=0.5, es=False, verbose=False):
    model.to(metadata['DEVICE'])
    model.eval()
    predictions = []
    targets = []
    with torch.no_grad():
        for n, data in enumerate(loader):
            if verbose and (n + 1) % 10 == 0:
                my_print("TEST ", f"n: {n+1}/{len(loader)}")
            context = data['context'][0]
            question = data['question'][0]
            answer = data['answer'][0]
            theory = [(t.split(':')[0].strip(),t.split(':')[1].strip()+'.',"None") for t in context.split('.')[:-1]]
            fact = None
            while fact != 'Nothing.':
                source_text = f"$answer$ ; $proof$ ; $question$ = What is one singlehop inference? ; $context$ = {context}"
                source = tokenizer(
                    source_text,
                    padding = 'max_length',
                    max_length = metadata['MAX_SOURCE_LENGTH'],
                    pad_to_max_length = True,
                    truncation = True,
                    return_tensors = 'pt',
                )
                ids = source['input_ids'].to(metadata['DEVICE'], dtype=torch.long)
                mask = source['attention_mask'].to(metadata['DEVICE'], dtype=torch.long)
                generated_ids = model.generate(
                    input_ids = ids,
                    attention_mask = mask,
                    max_length = metadata['MAX_TARGET_LENGTH'],
                    num_beams = nb,
                    repetition_penalty = rp,
                    length_penalty = lp,
                    early_stopping = es,
                )
                preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
                fact, proof = split_prediction(preds[0])
                if fact != 'Nothing.':
                    if fact not in [theory[i][1] for i, _ in enumerate(theory)]:
                        number = len(theory)+1
                        theory.append((f'sent{number}',fact,proof))
                        context += f' sent{number}: {fact}'
                    else:
                        fact = 'Nothing.'
            prediction, proof = search_answer(question, theory)
            predictions.append(prediction)
            targets.append(answer)
            if verbose and prediction != answer:
                my_print("WRONG", f"question: {question} | prediction: {prediction} ({proof}) | answer: {answer}")
                for t in theory:
                    my_print("     ", f"{t[0]}: {t[1]} ({t[2]})", timestamp=False)
    results = [p == t for p, t in zip(predictions, targets)]
    accuracy = sum([x for x in results]) / len(results)
    return accuracy

def multibeam_test(model, tokenizer, loader, metadata, nb=4, rp=2.5, lp=0.5, es=False):
    model.to(metadata['DEVICE'])
    model.eval()
    results = []
    with torch.no_grad():
        for data in loader:
            y = data['target_ids'].to(metadata['DEVICE'], dtype=torch.long)
            ids = data['source_ids'].to(metadata['DEVICE'], dtype=torch.long)
            mask = data['source_mask'].to(metadata['DEVICE'], dtype=torch.long)
            generated_ids = model.generate(
                input_ids=ids,
                attention_mask=mask,
                max_length=metadata['MAX_TARGET_LENGTH'],
                num_beams=nb,
                repetition_penalty=rp,
                length_penalty=lp,
                early_stopping=es,
                num_return_sequences=nb,
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            chunks = [preds[i:i+nb] for i in range(0, len(preds), nb)]
            for c, t in zip(chunks, data['possible_targets']):
                result = check_prediction(c[0], t)
                i = 0
                while result == False and i < len(c):
                    result = check_prediction(c[i], t)
                    i += 1
                results.append(result)
    accuracy = sum([x for x in results]) / len(results)
    return accuracy

#Guidance

In [11]:
tree_nlp_sm = spacy.load('en_core_web_sm')
tree_nlp_md = spacy.load('en_core_web_md')
tree_nlp_lg = spacy.load('en_core_web_lg')

def get_tags(sentence, filter):
    if sentence is not None:
        tags = {x.lemma_ for x in list(tree_nlp_sm(sentence))+list(tree_nlp_md(sentence))+list(tree_nlp_lg(sentence)) if x.pos_ in filter}
    else:
        tags = set()
    return list(tags), len(tags)

def unpack_beams(beams, source):
    pre_theory = source.split('=')[-1].split('.')
    theory = {t.split(':')[0].strip(): t.split(':')[1].strip() for t in pre_theory[:-1]}
    pre_candidates = [b.split(';') for b in beams]
    answers = [pc[0].split('=')[-1].strip()[:-1] for pc in pre_candidates]
    for i, _ in enumerate(answers):
        if answers[i] == '' and len(pre_candidates[i]) >= 3:
            answers[i] = pre_candidates[i][1].split('=')[-1].strip()
    pre_proofs = [pc[-1].split('=')[-1].strip() for pc in pre_candidates]
    pre_proofs = [pp.split(' ') for pp in pre_proofs]
    proofs = []
    for pp in pre_proofs:
        if len(pp) == 5:
            proofs.append((pp[1],pp[3],pp[4]))
        elif len(pp) == 3:
            proofs.append((pp[1],pp[2]))
        elif len(pp) == 1 and pp[0] == 'None':
            proofs.append(pp[0])
        else:
            proofs.append('')
    candidates = [c for c in zip(answers, proofs)]
    return candidates, theory

def compute_plausibility(rule, answer, fact1, fact2=None, filter=['ADJ', 'VERB']):
    tags_r, size_r = get_tags(rule, filter)
    tags_a, size_a = get_tags(answer, filter)
    tags_f1, size_f1 = get_tags(fact1, filter)
    tags_f2, size_f2 = get_tags(fact2, filter)
    matches_a = sum([1 if t in tags_r else 0 for t in tags_a])
    matches_f1 = sum([1 if t in tags_r else 0 for t in tags_f1])
    matches_f2 = sum([1 if t in tags_r else 0 for t in tags_f2])
    if matches_a > 0 and matches_f1 > 0:
        if fact2 is None:
            score = True
        else:
            score = matches_f2 > 0
    else:
        score = False
    return score, matches_a, matches_f1, matches_f2

def compute_matches(beams, source):
    candidates, theory = unpack_beams(beams, source)
    reasonable = list()
    matches = list()
    for i, c in enumerate(candidates):
        if c[0] == 'Nothing' and c[1] == 'None':
            reasonable.append('None')
            matches.append([])
        elif len(c[1]) == 2 and c[1][0] != c[1][1] and c[1][0] in theory.keys() and c[1][1] in theory.keys() and c[0] not in theory.values():
            answer = c[0]
            rule = theory[c[1][0]]
            fact = theory[c[1][1]]
            plausible, matches_a, matches_f1, _ = compute_plausibility(rule, answer, fact, None)
            reasonable.append('Yes' if plausible else 'No')
            matches.append([matches_a,matches_f1,-1])
        elif len(c[1]) == 3 and c[1][0] != c[1][1] and c[1][0] != c[1][2] and c[1][1] != c[1][2] and c[1][0] in theory.keys() and c[1][1] in theory.keys() and c[1][2] in theory.keys() and c[0] not in theory.values():
            answer = c[0]
            rule = theory[c[1][0]]
            fact1 = theory[c[1][1]]
            fact2 = theory[c[1][2]]
            plausible, matches_a, matches_f1, matches_f2 = compute_plausibility(rule, answer, fact1, fact2)
            reasonable.append('Yes' if plausible else 'No')
            matches.append([matches_a,matches_f1,matches_f2])
        else:
            reasonable.append('Invalid')
            matches.append([])
    return matches, reasonable

class GuidanceTree:

    def __init__(self, model, tokenizer, metadata, load=None, verbose=False):
        self.model = model
        self.tokenizer = tokenizer
        self.metadata = metadata
        if load is not None:
            self.tree = joblib.load(load)
        self.verbose = verbose

    def save(self, save='tree.joblib'):
        joblib.dump(self.tree, save)

    def learn(self, loader, nb=4, cross_val=True):
        train_X = list()
        train_y = list()
        val_X = list()
        val_y = list()
        self.model.to(self.metadata['DEVICE'])
        self.model.eval()
        with torch.no_grad():
            for n, data in enumerate(loader):
                if self.verbose and (n + 1) % 10 == 0:
                    my_print("LEARN", f"n: {n+1}/{len(loader)} | samples: {len(train_y)+len(val_y)}")
                y = data['target_ids'].to(self.metadata['DEVICE'], dtype=torch.long)
                ids = data['source_ids'].to(self.metadata['DEVICE'], dtype=torch.long)
                mask = data['source_mask'].to(self.metadata['DEVICE'], dtype=torch.long)
                generated_ids = self.model.generate(
                    input_ids = ids,
                    attention_mask = mask,
                    max_length = self.metadata['MAX_TARGET_LENGTH'],
                    num_beams = nb,
                    repetition_penalty = 2.5,
                    length_penalty = 0.5,
                    early_stopping = False,
                    num_return_sequences = nb,
                    output_scores = True,
                    return_dict_in_generate = True,
                )
                preds = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids.sequences]
                beams = [preds[i:i+nb] for i in range(0, len(preds), nb)]
                scores = [generated_ids.sequences_scores[i:i+nb] for i in range(0, len(generated_ids.sequences_scores), nb)]
                confidences = [torch.softmax(s, dim=-1).tolist() for s in scores]
                for b, c, t, s in zip(beams, confidences, data['possible_targets'], data['source_text']):
                    matches, reasonable = compute_matches(b, s)
                    i = 0
                    found = -1
                    sample = False
                    while i < len(reasonable) and found == -1:
                        if reasonable[i] == 'Yes':
                            found = i
                        elif reasonable[i] == 'None':
                            found = i
                            if i + 1 < len(reasonable) and reasonable[i+1] == 'Yes':
                                sample = True
                        i += 1
                    if found == -1:
                        found = 0
                    if sample:
                        data = list()
                        data.append(c[found])
                        data.append(c[found+1])
                        data.append(c[found] - c[found+1])
                        data += matches[found+1]
                        answer = 1 if check_prediction(b[found+1], t) else 0
                        if cross_val or len(train_y) % 2 == 0:
                            train_X.append(data)
                            train_y.append(answer)
                        else:
                            val_X.append(data)
                            val_y.append(answer)
        avg_scores = list()
        for md in range(1, 15+1):
            self.tree = tree.DecisionTreeClassifier(criterion='gini', max_depth=md)
            if cross_val:
                scores = model_selection.cross_val_score(self.tree, train_X, train_y, scoring='accuracy', cv=5)
                avg_scores.append(np.mean(scores))
            else:
                self.tree.fit(train_X, train_y)
                pred_y = self.tree.predict(val_X)
                score = metrics.accuracy_score(val_y, pred_y)
                avg_scores.append(score)
        md = np.argmax(avg_scores)+1
        if self.verbose:
            my_print("TUNE ", f"md: {md} | avg_scores: {avg_scores}")
        self.tree = tree.DecisionTreeClassifier(criterion='gini', max_depth=md)
        self.tree.fit(train_X, train_y)

    def infer(self, confidences, matches, index):
        X = list()
        data = [confidences[index],confidences[index+1],confidences[index]-confidences[index+1]]
        data += matches[index+1]
        X.append(data)
        y = self.tree.predict(X)
        return index if y[0] == 0 else index + 1

    def draw(self):
        plt.figure(figsize=(20,20))
        tree.plot_tree(
            self.tree,
            fontsize = 10,
            filled = True,
            rounded = True,
            proportion = True,
            feature_names = ['conf_stop','conf_goon', 'conf_diff', 'match_answ','match_fac1','match_fac2'],
            class_names = ['STOP','GOON'],
        )
        plt.show()

In [12]:
def guided_test(model, tokenizer, pilot, loader, metadata, nb=4, rp=2.5, lp=0.5, es=False, verbose=False):
    model.to(metadata['DEVICE'])
    model.eval()
    results = []
    with torch.no_grad():
        for n, data in enumerate(loader):
            if verbose and (n + 1) % 10 == 0:
                my_print("TEST ", f"n: {n+1}/{len(loader)}")
            y = data['target_ids'].to(metadata['DEVICE'], dtype=torch.long)
            ids = data['source_ids'].to(metadata['DEVICE'], dtype=torch.long)
            mask = data['source_mask'].to(metadata['DEVICE'], dtype=torch.long)
            generated_ids = model.generate(
                input_ids = ids,
                attention_mask = mask,
                max_length = metadata['MAX_TARGET_LENGTH'],
                num_beams = nb,
                repetition_penalty = rp,
                length_penalty = lp,
                early_stopping = es,
                num_return_sequences = nb,
                output_scores = True,
                return_dict_in_generate = True,
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids.sequences]
            beams = [preds[i:i+nb] for i in range(0, len(preds), nb)]
            scores = [generated_ids.sequences_scores[i:i+nb] for i in range(0, len(generated_ids.sequences_scores), nb)]
            confidences = [torch.softmax(s, dim=-1).tolist() for s in scores]
            for b, c, t, s in zip(beams, confidences, data['possible_targets'], data['source_text']):
                matches, reasonable = compute_matches(b, s)
                i = 0
                prediction = None
                while i < len(b) and prediction is None:
                    if reasonable[i] == 'Yes':
                        prediction = b[i]
                    elif reasonable[i] == 'None':
                        if i + 1 < len(reasonable) and reasonable[i+1] == 'Yes':
                            prediction = b[pilot.infer(c, matches, i)]
                        else:
                            prediction = b[i]
                    i += 1
                if prediction is None:
                    prediction = b[0]
                result = check_prediction(prediction, t)
                results.append(result)
                if verbose and result == False:
                    my_print("WRONG", f"beam: {b}")
                    my_print("     ", f"conf: {c}", timestamp=False)
                    my_print("     ", f"mtch: {matches} | reas: {reasonable}", timestamp=False)
                    my_print("     ", f"pred: {prediction}", timestamp=False)
                    my_print("     ", f"targ: {t}", timestamp=False)
    accuracy = sum([x for x in results]) / len(results)
    return accuracy

def iterative_guided_test(model, tokenizer, pilot, loader, metadata, nb=4, rp=2.5, lp=0.5, es=False, verbose=False):
    model.to(metadata['DEVICE'])
    model.eval()
    predictions = []
    targets = []
    with torch.no_grad():
        for n, data in enumerate(loader):
            if verbose and (n + 1) % 10 == 0:
                my_print("TEST ", f"n: {n+1}/{len(loader)}")
            context = data['context'][0]
            question = data['question'][0]
            answer = data['answer'][0]
            theory = [(t.split(':')[0].strip(),t.split(':')[1].strip()+'.',"None") for t in context.split('.')[:-1]]
            fact = None
            while fact != 'Nothing.':
                source_text = f"$answer$ ; $proof$ ; $question$ = What is one singlehop inference? ; $context$ = {context}"
                source = tokenizer(
                    source_text,
                    padding = 'max_length',
                    max_length = metadata['MAX_SOURCE_LENGTH'],
                    pad_to_max_length = True,
                    truncation = True,
                    return_tensors = 'pt',
                )
                ids = source['input_ids'].to(metadata['DEVICE'], dtype=torch.long)
                mask = source['attention_mask'].to(metadata['DEVICE'], dtype=torch.long)
                generated_ids = model.generate(
                    input_ids = ids,
                    attention_mask = mask,
                    max_length = metadata['MAX_TARGET_LENGTH'],
                    num_beams = nb,
                    repetition_penalty = rp,
                    length_penalty = lp,
                    early_stopping = es,
                num_return_sequences = nb,
                output_scores = True,
                return_dict_in_generate = True,
                )
                preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids.sequences]
                beams = [preds[i:i+nb] for i in range(0, len(preds), nb)]
                scores = [generated_ids.sequences_scores[i:i+nb] for i in range(0, len(generated_ids.sequences_scores), nb)]
                confidences = [torch.softmax(s, dim=-1).tolist() for s in scores]
                s = source_text
                for b, c in zip(beams, confidences):
                    matches, reasonable = compute_matches(b, s)
                    i = 0
                    prediction = None
                    while i < len(b) and prediction is None:
                        if reasonable[i] == 'Yes':
                            prediction = b[i]
                        elif reasonable[i] == 'None':
                            if i + 1 < len(reasonable) and reasonable[i+1] == 'Yes':
                                prediction = b[pilot.infer(c, matches, i)]
                            else:
                                prediction = b[i]
                        i += 1
                    if prediction is None:
                        prediction = b[0]
                fact, proof = split_prediction(prediction)
                if fact != 'Nothing.':
                    if fact not in [theory[i][1] for i, _ in enumerate(theory)]:
                        number = len(theory)+1
                        theory.append((f'sent{number}',fact,proof))
                        context += f' sent{number}: {fact}'
                    else:
                        fact = 'Nothing.'
            prediction, proof = search_answer(question, theory)
            predictions.append(prediction)
            targets.append(answer)
            if verbose and prediction != answer:
                my_print("WRONG", f"question: {question} | prediction: {prediction} ({proof}) | answer: {answer}")
                for t in theory:
                    my_print("     ", f"{t[0]}: {t[1]} ({t[2]})", timestamp=False)
                my_print("     ", f"beam: {b} | conf: {c} | mtch: {matches} | reas: {reasonable}", timestamp=False)
    results = [p == t for p, t in zip(predictions, targets)]
    accuracy = sum([x for x in results]) / len(results)
    return accuracy

def single_guided_inference(model, tokenizer, pilot, context, metadata, nb=4, rp=2.5, lp=0.5, es=False):
    model.to(metadata['DEVICE'])
    model.eval()
    with torch.no_grad():
        source_text = f"$answer$ ; $proof$ ; $question$ = What is one singlehop inference? ; $context$ = {context}"
        source = tokenizer(
            source_text,
            padding = 'max_length',
            max_length = metadata['MAX_SOURCE_LENGTH'],
            pad_to_max_length = True,
            truncation = True,
            return_tensors = 'pt',
        )
        source_ids = source['input_ids'].to(dtype=torch.long)
        source_mask = source['attention_mask'].to(dtype=torch.long)
        ids = source_ids.to(metadata['DEVICE'], dtype=torch.long)
        mask = source_mask.to(metadata['DEVICE'], dtype=torch.long)
        generated_ids = model.generate(
            input_ids = ids,
            attention_mask = mask,
            max_length = metadata['MAX_TARGET_LENGTH'],
            num_beams = nb,
            repetition_penalty = rp,
            length_penalty = lp,
            early_stopping = es,
            num_return_sequences = nb,
            output_scores = True,
            return_dict_in_generate = True,
        )
        preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids.sequences]
        scores = torch.softmax(generated_ids.sequences_scores, dim=-1).tolist()
        matches, reasonable = compute_matches(preds, source_text)
        i = 0
        prediction = None
        while i < len(preds) and prediction is None:
            if reasonable[i] == 'Yes':
                prediction = preds[i]
            elif reasonable[i] == 'None':
                if i + 1 < len(reasonable) and reasonable[i+1] == 'Yes':
                    prediction = preds[pilot.infer(scores, matches, i)]
                else:
                    prediction = preds[i]
            i += 1
        if prediction is None:
            prediction = preds[0]
    return prediction

def single_guided_iterative_inference(model, tokenizer, pilot, metadata, file_name):
    with open(file_name, 'r') as input_file:
        lines = input_file.read().splitlines()
    question = None
    theory = []
    input_data = ''
    n = 1
    for l in lines:
        if l != '' and l[0] != '#':
            if question is None:
                question = l
            else:
                theory.append((f"sent{n}", l, "None"))
                input_data += f" sent{n}: {l}"
                n += 1
    fact = None
    while fact != 'Nothing.':
        decoded_output = single_guided_inference(model, tokenizer, pilot, input_data, metadata)
        fact, proof = split_prediction(decoded_output)
        if fact != 'Nothing.':
            if fact not in [theory[i][1] for i, _ in enumerate(theory)]:
                input_data += f" sent{n}: {fact}"
                theory.append((f"sent{n}", fact, proof))
                n += 1
            else:
                fact = 'Nothing.'
    answer, proof = search_answer(question, theory)
    return theory, question, answer, proof

#Main

In [13]:
def main():
    metadata = {
        'MODEL_NAME': 't5-small',
        'SAVE_PATH': '/content/drive/MyDrive/ai_cloud_playground/',
        'DATASET_HOME': './data/proofwriter-dataset-V2020.12.3/OWA/depth-3/',
        'TRAIN_DATASET': 'meta-stage-train.jsonl',
        'VAL_DATASET': 'meta-stage-dev.jsonl',
        'TEST_DATASET': 'meta-stage-test.jsonl',
        'BATCH_SIZE': 8,
        'EPOCHS': 10,
        'MAX_SOURCE_LENGTH': 640,
        'MAX_TARGET_LENGTH': 64,
        'DEVICE': set_device(gpu=True),
        'SEED': set_seed(2147),
        'MIRROR_OUTPUT': None,
    }

    # model = T5ForConditionalGeneration.from_pretrained(metadata['MODEL_NAME'])
    # tokenizer = T5Tokenizer.from_pretrained(metadata['MODEL_NAME'], model_max_length=metadata['MAX_SOURCE_LENGTH'])
    # model = T5ForConditionalGeneration.from_pretrained('rr_owa_d3plus_infstage_ma1_mixture_large_hf', local_files_only=True)
    # tokenizer = T5Tokenizer.from_pretrained('t5-large', model_max_length=512)
    # optimizer = transformers.Adafactor(params=model.parameters())
    model, tokenizer, optimizer = load_model('best', metadata)
    pilot = GuidanceTree(model, tokenizer, metadata, load=os.path.join(metadata['SAVE_PATH'], 't5-small_tree_validation.joblib'), verbose=True)

    training_set = ProofStageDataset(tokenizer, metadata['TRAIN_DATASET'], metadata, limit=None)
    training_loader = DataLoader(training_set, batch_size=metadata['BATCH_SIZE'], shuffle=True)
    validation_set = ProofStageDataset(tokenizer, metadata['VAL_DATASET'], metadata, limit=None)
    validation_loader = DataLoader(validation_set, batch_size=metadata['BATCH_SIZE'], shuffle=False)
    testing_set = ProofStageDataset(tokenizer, metadata['TEST_DATASET'], metadata, limit=None)
    testing_loader = DataLoader(testing_set, batch_size=metadata['BATCH_SIZE'], shuffle=False)
    iter_test_set = ProofIterativeDataset(tokenizer, 'meta-test.jsonl', metadata, limit=3600)
    iter_test_loader = DataLoader(iter_test_set, shuffle=True)

    my_print("START", mirror=metadata['MIRROR_OUTPUT'], reset=True)

    # _, model = training_loop(model, tokenizer, optimizer, training_loader, validation_loader, testing_loader, metadata)
    
    # pilot.learn(dev_loader, cross_val=False)
    # pilot.save(os.path.join(metadata['SAVE_PATH'], 'new_tree.joblib'))
    # pilot.draw()

    # accuracy = test_epoch(model, tokenizer, testing_loader, metadata)
    # accuracy = multibeam_test(model, tokenizer, testing_loader, metadata, nb=4)
    # accuracy = iterative_test(model, tokenizer, iter_test_loader, metadata)
    # accuracy = guided_test(model, tokenizer, pilot, testing_loader, metadata, verbose=True)
    accuracy = iterative_guided_test(model, tokenizer, pilot, iter_test_loader, metadata, verbose=True)
    my_print("TEST ", f"accuracy: {accuracy:.3f}", mirror=metadata['MIRROR_OUTPUT'])

    # input = 'sent1: Joe is green. sent2: Purple things are evil. sent3: Green people are funny.'
    # result = single_inference(model, tokenizer, input, metadata)
    # result = single_guided_inference(model, tokenizer, pilot, input, metadata)
    # for p, s in result: print(f"{p} {s*100:.1f}%")

    # file = os.path.join(metadata['SAVE_PATH'], 'theories', 'th_fantasy.txt')
    # theory, question, answer, proof = single_iterative_inference(model, tokenizer, metadata, file)
    # theory, question, answer, proof = single_guided_iterative_inference(model, tokenizer, pilot, metadata, file)
    # for t in theory: print(f"{t[0]}: {t[1]} ({t[2]})")
    # print(f"{question} {answer} ({proof})")

    my_print("DONE ", mirror=metadata['MIRROR_OUTPUT'])

In [14]:
# clear_cache()

In [15]:
context = "sent1: The cat is black. sent2: The dog is not big. sent3: Black cats are scary. sent4: Big dogs are scary."
nb = 4
model = T5ForConditionalGeneration.from_pretrained('rr_owa_d3plus_infstage_ma1_mixture_large_hf')
model.to('cuda')
tokenizer = T5Tokenizer.from_pretrained('t5-large', model_max_length=640)
source = tokenizer(
    f"$answer$ ; $proof$ ; $question$ = What is one singlehop inference? ; $context$ = {context}",
    padding = 'max_length',
    max_length = 640,
    pad_to_max_length = True,
    truncation = True,
    return_tensors = 'pt' )
source_ids = source['input_ids'].to(dtype=torch.long)
source_mask = source['attention_mask'].to(dtype=torch.long)
ids = source_ids.to('cuda', dtype=torch.long)
mask = source_mask.to('cuda', dtype=torch.long)
generated_ids = model.generate(
    input_ids = ids,
    attention_mask = mask,
    max_length = 64,
    num_beams = nb,
    repetition_penalty = 2.5,
    length_penalty = 0.5,
    early_stopping = False,
    num_return_sequences = nb,
    output_scores = True,
    return_dict_in_generate = True )
predictions = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids.sequences]
scores = torch.softmax(generated_ids.sequences_scores, dim=-1 ).tolist()
for p, s in zip(predictions, scores):
    print(f"{p} ; $confidence$ = {s*100:.1f}%")

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

$answer$ = The cat is scary. ; $proof$ = # sent3 sent1 ; $confidence$ = 50.5%
$answer$ = The cat is scared. ; $proof$ = # sent3 sent1 ; $confidence$ = 20.3%
$answer$ = The cat is quiet. ; $proof$ = # sent1 sent1 ; $confidence$ = 16.8%
$answer$ = The cat is furry. ; $proof$ = # sent1 sent1 ; $confidence$ = 12.4%
