In [None]:
!pip install scikit-learn

In [None]:
import torch
import json
import os
import pandas as pd
from utils import *
from sklearn.metrics import confusion_matrix

In [None]:
with open("./config.json", "r") as fp:
    config = json.load(fp)

In [None]:
DATA_DIR_PATH = config["data_dir_path"]
LABEL_DICT = config["label_dict"]

In [None]:
DF_NAME = "snli_test.tsv"

# Load Dataset

In [None]:
df = pd.read_csv(os.path.join(DATA_DIR_PATH, DF_NAME), delimiter='\t', index_col=0)
df.head()

In [None]:
decomposed_df = df[(df.entailment_tableau_size > 2) & (df.contradiction_tableau_size > 2)]
decomposed_df

In [None]:
undecomposed_df = df[(df.entailment_tableau_size == 2) & (df.contradiction_tableau_size == 2)]
undecomposed_df

In [None]:
print("Decomposed Sample Rate:", len(decomposed_df) / len(df))
print("Undecomposed Sample Rate:", len(undecomposed_df) / len(df))

# Define The Model

In [None]:
TARGET_DF = decomposed_df

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
batch_size = 32

In [None]:
with open("./data/word_index_map.json", "r") as worddict_file:
    worddict = json.load(worddict_file)

In [None]:
from esim.data import Preprocessor
preprocessor = Preprocessor(lowercase=False,
                            ignore_punctuation=False,
                            num_words=None,
                            stopwords={},
                            labeldict=LABEL_DICT,
                            bos=None,
                            eos=None)
preprocessor.worddict = worddict
preprocessor

In [None]:
checkpoint = torch.load("./data/checkpoints/best.pth.tar")

# Retrieving model parameters from checkpoint.
vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
num_classes = checkpoint["model"]["_classification.4.weight"].size(0)

In [None]:
from esim.model import ESIM

model = ESIM(vocab_size,
             embedding_dim,
             hidden_size,
             num_classes=num_classes,
             device=device).to(device)
model.load_state_dict(checkpoint["model"])

In [None]:
import numpy as np

def predict(premises, hypothesises):
    premises_split = []
    for premise in premises:
        if type(premise) is list:
            premises_split.append(premise)
        else:
            premises_split.append([w for w in premise.rstrip().split()])

    hypothesises_split = []
    for hypothesis in hypothesises:
        if type(hypothesis) is list:
            hypothesises_split.append(hypothesis)
        else:
            hypothesises_split.append([w for w in hypothesis.rstrip().split()])
    
    transformed_premises = [preprocessor.words_to_indices(premise_split) for premise_split in premises_split]
    transformed_hypothesises = [preprocessor.words_to_indices(hypothesis_split) for hypothesis_split in hypothesises_split]
    
    results = []
    
    model.eval()
    with torch.no_grad():
        for start_index in range(0, len(transformed_premises), batch_size):
            premises_batch = transformed_premises[start_index: start_index+batch_size]
            premises_len_batch = [len(premise) for premise in premises_batch]
            max_of_premises_len_batch = max(premises_len_batch)
            
            premises_batch_tensor = torch.ones((len(premises_batch), max_of_premises_len_batch), dtype=torch.long) * 0

            for i, premise in enumerate(premises_batch):
                end = premises_len_batch[i]
                premises_batch_tensor[i][:end] = torch.tensor(premise[:end])
            
            hypothesises_batch = transformed_hypothesises[start_index: start_index+batch_size]
            hypothesises_len_batch = [len(hypothesis) for hypothesis in hypothesises_batch]
            max_of_hypothesises_len_batch = max(hypothesises_len_batch)

            hypothesises_batch_tensor = torch.ones((len(hypothesises_batch), max_of_hypothesises_len_batch), dtype=torch.long) * 0

            for i, hypothesis in enumerate(hypothesises_batch):
                end = hypothesises_len_batch[i]
                hypothesises_batch_tensor[i][:end] = torch.tensor(hypothesis[:end])
            
            _, probs = model(
                premises_batch_tensor.to(device),
                torch.tensor(premises_len_batch).to(device),
                hypothesises_batch_tensor.to(device),
                torch.tensor(hypothesises_len_batch).to(device)
            )
            results_batch = [prob.cpu().numpy() for prob in probs]
            results.extend(results_batch)
    return np.array(results)
            
predict(["I like tomatos", ["I", "like", "tomatos"]],
        ["I do n't like tomatos", ["I", "do", "n't", "like", "tomatos"]])


# ANSWER WITH NORMAL ESIM

In [None]:
premises = [" ".join(tree2tokenlist(sample.udtree1)) for sample in TARGET_DF.itertuples()]
hypothesises = [" ".join(tree2tokenlist(sample.udtree2)) for sample in TARGET_DF.itertuples()]

gold_labels = np.array([LABEL_DICT[sample.gold_label] for sample in TARGET_DF.itertuples()])
simple_predicted_labels = predict(premises, hypothesises).argmax(axis=1)


In [None]:
print("acc: {:.3f}%".format(100 * (simple_predicted_labels == gold_labels).sum() / len(TARGET_DF)))

In [None]:
confusion_matrix(gold_labels, simple_predicted_labels)

# ANSWER WITH TABLEAU WITH ESIM

In [None]:
def transform_tableau(tableau, premise_list, hypothesis_list):
    entry_list = []
    child_entries_list = []
    contradictable_entries_pair_list = []
    all_branches = []
    
    def append_entry_list(node):
        entry_offset = len(entry_list)
        entry_size = 0
        for entry in node["entries"]:
            entry_list.append(entry)
            child_entries_list.append([entry_offset + entry_size + 1])
            entry_size += 1

        childtree = []
        for child_node in node["child_nodes"]:
            childtree.append(append_entry_list(child_node))
        child_entries_list[entry_offset + entry_size - 1] = childtree
        return entry_offset

    def append_contradictable_entries_pair_list(entry_index):
        subtree_entry_indices = []
        for child_entry_index in child_entries_list[entry_index]:
            subtree_entry_indices.extend(append_contradictable_entries_pair_list(child_entry_index))

        if entry_list[entry_index]["exist_eq_entries"] == False:
            for subtree_entry_index in subtree_entry_indices:
                if entry_list[entry_index]["origin"] != entry_list[subtree_entry_index]["origin"]:
                    if entry_list[entry_index]["sign"] == True and entry_list[subtree_entry_index]["sign"] == False:
                        contradictable_entries_pair_list.append((entry_index, subtree_entry_index))
                    elif entry_list[entry_index]["sign"] == False and entry_list[subtree_entry_index]["sign"] == True:
                        contradictable_entries_pair_list.append((subtree_entry_index, entry_index))
                    elif entry_list[entry_index]["sign"] == True and entry_list[subtree_entry_index]["sign"] == True:
                        contradictable_entries_pair_list.append((entry_index, subtree_entry_index))

            subtree_entry_indices.append(entry_index)
        return subtree_entry_indices

    def calculate_branch(entry_index):
        if len(child_entries_list[entry_index]) == 0:
            return [{entry_index}]

        branches = []
        for child_entry_index in child_entries_list[entry_index]:
            branches.extend(calculate_branch(child_entry_index))
        for branch in branches:
            branch.add(entry_index)
        return branches

    append_entry_list(tableau["root"])
    append_contradictable_entries_pair_list(0)
    all_branches = calculate_branch(0)
    # entry_list, child_entries_list, contradictable_enttries_pair_list, all_branchesを計算した

    all_sentence_list = [tree2tokenlist(ET.fromstring(entry["tree"])) for entry in entry_list]
    
    def findadd_sentence_pair(premise, hypothesis):
        for i, _premise in enumerate(premise_list):
            _hypothesis = hypothesis_list[i]
            if premise == _premise and hypothesis == _hypothesis:
                return i
        premise_list.append(premise)
        hypothesis_list.append(hypothesis)
        return len(premise_list) - 1
    
    contradiction_labels = []
    sentence_pair_row = []

    for i, pair in enumerate(contradictable_entries_pair_list):
        if entry_list[pair[0]]["sign"] == True and entry_list[pair[1]]["sign"] == False:
            contradiction_labels.append(LABEL_DICT["entailment"])
        elif entry_list[pair[0]]["sign"] == True and entry_list[pair[1]]["sign"] == True:
            contradiction_labels.append(LABEL_DICT["contradiction"])
        sentence_pair_row.append(findadd_sentence_pair(all_sentence_list[pair[0]],
                                                       all_sentence_list[pair[1]]))
    
    return {
        "branches": all_branches,
        "pairs": contradictable_entries_pair_list,
        "sentence_pair_row": sentence_pair_row,
        "contradiction_labels": torch.Tensor(contradiction_labels).to(device)
    }

def transform_sample(df):
    transformed_sample_list = []
    
    for sample in df.itertuples():
        premise_list = []
        hypothesis_list = []

        transformed_sample = {}
        transformed_sample["gold_label"] = LABEL_DICT[sample.gold_label]
        transformed_sample["entailment_tableau"] = transform_tableau(json.loads(sample.entailment_tableau),
                                                                     premise_list,
                                                                     hypothesis_list)
        transformed_sample["contradiction_tableau"] = transform_tableau(json.loads(sample.contradiction_tableau),
                                                                        premise_list,
                                                                        hypothesis_list)
        transformed_sample["premises"] = premise_list
        transformed_sample["hypothesises"] = hypothesis_list
        transformed_sample["premise"] = sample.sentence1
        transformed_sample["hypothesis"] = sample.sentence2
        
        sentence_pair_size = len(premise_list)
        transformed_sample["sentence_pair_size"] = sentence_pair_size

        transformed_sample["entailment_tableau"]["sentence_pair_row"] = (torch.eye(sentence_pair_size)[transformed_sample["entailment_tableau"]["sentence_pair_row"]]).to(device)
        transformed_sample["contradiction_tableau"]["sentence_pair_row"] = (torch.eye(sentence_pair_size)[transformed_sample["contradiction_tableau"]["sentence_pair_row"]]).to(device)
        
        
        transformed_sample_list.append(transformed_sample)
    return transformed_sample_list
        
target_dataset = transform_sample(TARGET_DF)
print(view_tableau(TARGET_DF.iloc[0].entailment_tableau))
target_dataset[0]

In [None]:
import copy

def is_close_tableau(tableau, r):
    if len(tableau["sentence_pair_row"]) == 0:
        return False

    is_contradiction_pairs = torch.mv(tableau["sentence_pair_row"], r) == tableau["contradiction_labels"]
    
    branches = copy.copy(tableau["branches"])
    
    for i, pair in enumerate(tableau["pairs"]):
        if is_contradiction_pairs[i] == True:
            for branch in branches:
                if pair[0] in branch and pair[1] in branch:
                    branches.remove(branch)
    return len(branches) == 0

def predict_label(sample, r):
    is_close_entailment_tableau = is_close_tableau(sample["entailment_tableau"], r)
    is_close_contradiction_tableau = is_close_tableau(sample["contradiction_tableau"], r)
    if is_close_entailment_tableau == True and is_close_contradiction_tableau == False:
        return 0
    elif is_close_entailment_tableau == False and is_close_contradiction_tableau == False:
        return 1
    elif is_close_entailment_tableau == False and is_close_contradiction_tableau == True:
        return 2
    else:
        return -1


In [None]:
from tqdm import tqdm

model.eval()

predicted_labels = []
with torch.no_grad():
    for sample in tqdm(target_dataset):
        if sample["sentence_pair_size"] > 0:
            pairs_probs = torch.from_numpy(predict(sample["premises"], sample["hypothesises"])).to(device)
            r = torch.argmax(pairs_probs, dim=1).float()
            predicted_label = predict_label(sample, r)
        else:
            predicted_label = 1
        
        predicted_labels.append(predicted_label)

tableau_predicted_labels = np.array(predicted_labels)
tableau_predicted_labels

In [None]:
print("acc: {:.3f}%".format(100 * (tableau_predicted_labels == gold_labels).sum() / len(TARGET_DF)))

In [None]:
print("err: {:.3f}%".format(100 * (tableau_predicted_labels == -1).sum() / len(TARGET_DF)))

In [None]:
confusion_matrix(gold_labels, tableau_predicted_labels)