# Data Pretreatment

In [1]:
from fuzzywuzzy import fuzz, process as fuzzy_process
import re
import numpy as np
from bisect import bisect_left
import torch

def fuzzy_retrieve(entity, pool, setting, threshold = 50):
    """Fuzzily match the exact name of the entity.

    The exacted name in text might be slightly different from the name in its wiki page. 
    A simple fuzzy matching with names in links can solve this problem.
    But note that as HotpotQA paper claims, the fullwiki dataset has maintained the consistence. 
    
    Args:
        entity (string): The entity name exacted from the text.
        pool (tuple): For fullwiki setting, is a (title, sentence number) tuple. 
        setting (string): setting.
        threshold (int, optional): Matching threshold. Defaults to 50.
    
    Returns:
        string: Matched name.
    """
    if setting == 'distractor':
        pool = pool.keys()
    else:
        if not hasattr(fuzzy_retrieve, 'db'):
            from redis import StrictRedis
            fuzzy_retrieve.db = StrictRedis(host='localhost', port=6379, db=0)
        assert isinstance(pool, tuple)
        title, sen_num = pool
        pool = set()
        for i in range(sen_num + 1):
            name = 'edges:###{}###{}'.format(i, title)
            tmp = set([x.decode().split('###')[0] for x in fuzzy_retrieve.db.lrange(name, 0, -1)])
            pool |= tmp
        
    best = (0, -1)
    for item in pool:
        item_refined = re.sub(r' \(.*?\)$', '', item)
        score = fuzz.ratio(item_refined, entity)
        if best[0] < score:
            best = (score, item)
    return best[1] if best[0] > threshold else None

def get_context_fullwiki(title):
    """Fetch the sentences of the page about "title".
    
    Args:
        title (string): Entity name.
    
    Returns:
        list: List of sentences(string). 
    """
    if not hasattr(get_context_fullwiki, 'db'):
        from redis import StrictRedis
        get_context_fullwiki.db = StrictRedis(host='localhost', port=6379, db=0, decode_responses=True)
    return get_context_fullwiki.db.lrange(title, 0, -1)

def dp(a, b):
    """A basic Dynamic programming for Edit-distance based fuzzy matching.
    
    Args:
        a (string): source.
        b (string): the long text.
    """
    f, start = np.zeros((len(a), len(b))), np.zeros((len(a), len(b)), dtype = np.int)
    for j in range(len(b)):
        f[0, j] = int(a[0] != b[j])
        if j > 0 and b[j - 1].isalnum():
            f[0, j] += 10 
        start[0, j] = j
    for i in range(1, len(a)):        
        for j in range(len(b)):
            # (0, i-1) + del(i) ~ (start[j], j)
            f[i, j] = f[i - 1, j] + 1
            start[i, j] = start[i - 1, j]
            if j == 0:
                continue
            if f[i, j] > f[i - 1, j - 1] + int(a[i] != b[j]):
                f[i, j] = f[i - 1, j - 1] + int(a[i] != b[j])
                start[i, j] = start[i-1, j - 1]

            if f[i, j] > f[i, j - 1] + 0.5:
                f[i, j] = f[i, j - 1] + 0.5
                start[i, j] = start[i, j - 1]
    r = np.argmin(f[len(a) - 1])
    ret = [start[len(a) - 1, r], r + 1]
    score = f[len(a) - 1, r] / len(a)
    return (ret, score)

def fuzzy_find(entities, sentence, ratio = 80):
    """Try to find as much entities in sentence precisely.

    Args:
        entities (list): Candidates.
        sentence (string): The sentence to examine.
    
    Returns:
        List of tuples: (entity, match span, start position, end position, score)
    """
    ret = []
    for entity in entities:
        item = re.sub(r' \(.*?\)$', '', entity).strip()
        if item == '':
            item = entity
            print(item)
        r, score = dp(item, sentence)
        if score < 0.5:
            matched = sentence[r[0]: r[1]].lower()
            final_word = item.split()[-1]
            retry = False
            while fuzz.partial_ratio(final_word.lower(), matched) < ratio:
                retry = True
                end = len(item) - len(final_word)
                while end > 0 and item[end - 1].isspace():
                    end -= 1
                if end == 0:
                    retry = False
                    score = 1
                    break
                item = item[:end]
                final_word = item.split()[-1]
            if retry:
                r, score = dp(item, sentence)
                score += 0.1
            if score >= 0.5:
                continue
            del final_word
            # from start
            retry = False
            first_word = item.split()[0]
            while fuzz.partial_ratio(first_word.lower(), matched) < ratio:
                retry = True
                start = len(first_word)
                while start < len(item) and item[start].isspace():
                    start += 1
                if start == len(item):
                    retry = False
                    score = 1
                    break
                item = item[start:]
                first_word = item.split()[0]
            if retry:
                r, score = dp(item, sentence)
                score = max(score, 1 - ((r[1] - r[0]) / len(entity)))
                score += 0.1
            if score < 0.5:
                if item.isdigit() and sentence[r[0]: r[1]] != item:
                    continue
                ret.append((entity, sentence[r[0]: r[1]], int(r[0]), int(r[1]), score))
    non_intersection = []
    for i in range(len(ret)):
        ok = True
        for j in range(len(ret)):
            if j != i:
                if not (ret[i][2] >= ret[j][3] or ret[j][2] >= ret[i][3]) and ret[j][4] < ret[i][4]:
                    ok = False
                    break
                if ret[i][4] > 0.2 and ret[j][4] < 0.1 and not ret[i][1][0].isupper() and len(ret[i][1].split()) <= 3:
                    ok = False
                    break
        if ok:
            non_intersection.append(ret[i][:4])
    return non_intersection

GENERAL_WD = ['is', 'are', 'am', 'was', 'were', 'have', 'has', 'had', 'can', 'could', 
              'shall', 'will', 'should', 'would', 'do', 'does', 'did', 'may', 'might', 'must', 'ought', 'need', 'dare']
GENERAL_WD += [x.capitalize() for x in GENERAL_WD]
GENERAL_WD = re.compile(' |'.join(GENERAL_WD))

def judge_question_type(q : str, G = GENERAL_WD) -> int:
    if q.find(' or ') >= 0:
        return 2 
    elif G.match(q):
        return 1
    else:
        return 0

def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return 1.0 - x


def find_start_end_after_tokenized(tokenizer, tokenized_text, spans: ['Obama Care', '2006']):
    """Find start and end positions of untokenized spans in tokenized text.
    
    Args:
        tokenizer (Tokenizer): Word-Piece tokenizer.
        tokenized_text (list): List of word pieces(string). 
        spans (list): list of untokenized spans(string).
    
    Returns:
        list: List of (start position, end position).
    """
    end_offset, ret = [], []
    for x in tokenized_text:
        offset = len(x) + (end_offset[-1] if len(end_offset) > 0 else -1)
        end_offset.append(offset)
    text = ''.join(tokenized_text)
    for span in spans:
        t = ''.join(tokenizer.tokenize(span))
        start = text.find(t)
        if start >= 0:
            end = start + len(t) - 1 # include end
        else:
            result = fuzzy_find([t], text)
            if len(result) == 0:    
                result = fuzzy_find([re.sub('[UNK]', '',t)], text)
                if len(result) == 0:
                    raise ValueError('Cannot find an exact match.')
            _, _, start, end = result[0]
            end -= 1
        ret.append((bisect_left(end_offset, start), bisect_left(end_offset, end)))
    return ret
    
def find_start_end_before_tokenized(orig_text, spans: [['Oba', '##ma', 'Care'], ['2006']]):
    """Find start and end positions of tokenized spans in untokenized text.
    
    Args:
        orig_text (string): Original text.
        spans (list): List of list of word pieces, as showed in example.
    
    Returns:
        list: List of (start position, end position).
    """
    ret = []
    orig_text = orig_text.lower()
    for span_pieces in spans:
        if len(span_pieces) == 0:
            ret.append((0, 0))
            continue
        span = re.sub('##', '', ''.join(span_pieces))
        start = orig_text.find(span)
        if start >= 0:
            end = start + len(span) # exclude end
        else:
            result = fuzzy_find([span], orig_text)
            if len(result) == 0 and span.find('[UNK]') > 0:
                span = span.replace('[UNK]', '')
                result = fuzzy_find([span], orig_text)
            if len(result) == 0:
                ret.append((0,0))
                continue
            _, _, start, end = result[0]
        ret.append((start, end))
    return ret

def bundle_part_to_batch(all_bundle, l = None, r = None):
    """Convert all_bundle[l:r] to a batch of inputs.
    
    Args:
        all_bundle (list of Bundles): Data in ``Bundle'' format.
        l (int, optional): Left endpoint of the interval. Defaults to None.
        r (int, optional): Right endpoint of the interval. Defaults to None.
    
    Returns:
        tuple: A batch of inputs.
    """
    if l is None:
        l, r = 0, len(all_bundle.ids)
    num_samples = r - l
    max_length = max([len(x) for x in all_bundle.ids[l:r]])
    max_seps = max([len(x) for x in all_bundle.sep_positions[l:r]])    
    ids = torch.zeros((num_samples, max_length), dtype = torch.long)
    sep_positions = torch.zeros((num_samples, max_seps), dtype = torch.long)
    hop_start_weights = torch.zeros((num_samples, max_length))
    hop_end_weights = torch.zeros((num_samples, max_length))
    ans_start_weights = torch.zeros((num_samples, max_length))
    ans_end_weights = torch.zeros((num_samples, max_length))
    segment_ids = torch.zeros((num_samples, max_length), dtype = torch.long)
    input_mask = torch.zeros((num_samples, max_length), dtype = torch.long)
    for i in range(l, r):
        length = len(all_bundle.ids[i])
        sep_num = len(all_bundle.sep_positions[i])
        ids[i - l, :length] = torch.tensor(all_bundle.ids[i], dtype = torch.long)
        sep_positions[i - l, :sep_num] = torch.tensor(all_bundle.sep_positions[i])
        hop_start_weights[i - l, :length] = torch.tensor(all_bundle.hop_start_weights[i])
        hop_end_weights[i - l, :length] = torch.tensor(all_bundle.hop_end_weights[i])
        ans_start_weights[i - l, :length] = torch.tensor(all_bundle.ans_start_weights[i])
        ans_end_weights[i - l, :length] = torch.tensor(all_bundle.ans_end_weights[i])
        segment_ids[i - l, :length] = torch.tensor(all_bundle.segment_ids[i], dtype = torch.long)
        input_mask[i - l, :length] = 1
    return ids, segment_ids, input_mask, sep_positions, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights

class WindowMean:
    def __init__(self, window_size = 50):
        self.array = []
        self.sum = 0
        self.window_size = window_size
    def update(self, x):
        self.array.append(x)
        self.sum += x
        if len(self.array) > self.window_size:
            self.sum -= self.array.pop(0)
        return self.sum / len(self.array)


In [30]:

# coding: utf-8

# %pdb on
import json
import re
import numpy as np
import copy
from tqdm import tqdm 
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from redis import StrictRedis
# from utils import fuzzy_find


db = StrictRedis(host='localhost', port=6379, db=0)


with open('./hotpot_train_v1_1.1.json', 'r') as fin:
    train_set = json.load(fin)
print('Finish Reading! len = ', len(train_set))


from hotpot_evaluate_v1 import normalize_answer, f1_score
from fuzzywuzzy import fuzz, process as fuzzy_process

def fuzzy_retrive(entity, pool):
    if len(pool) > 100:
        # fullwiki, exact match
        # TODO: test ``entity (annotation)'' and find the most like one
        if pool.get(entity):
            return entity
        else:
            return None
    else:
        # distractor mode or use link in original wiki, no need to consider ``entity (annotation)''
        pool = pool if isinstance(pool, list) else pool.keys()
        f1max, ret = 0, None
        for t in pool:
            f1, precision, recall = f1_score(entity, t)
            if f1 > f1max:
                f1max, ret = f1, t
        return ret

def find_near_matches(w, sentence):
    ret = []
    max_ratio = 0
    t = 0
    for word in sentence.split():
        while sentence[t] != word[0]:
            t += 1
        score = (fuzz.ratio(w, word) + fuzz.partial_ratio(w, word)) / 2
        if score > max_ratio:
            max_ratio = score
            ret = [(t, t + len(word))]
        elif score == max_ratio:
            ret.append((t, t + len(word)))
        else:
            pass
        t += len(word)
    return ret if max_ratio > 85 else []     

print(list(fuzzy_find(['Miami Gardens, Florida', 'WSCV', 'Hard Rock Stadium'], r"Hard Rock Stadium is a multipurpose football stadium located in Miami Gardens, a city north of Miami. It is the home stadium of the Miami Dolphins of the National Football League (NFL).")))


# construct cognitive graph in training data    
from utils import judge_question_type
def find_fact_content(bundle, title, sen_num):
    for x in bundle['context']:
        if x[0] == title:
            return x[1][sen_num]
test = copy.deepcopy(train_set)
for bundle in tqdm(test,mininterval=1000):
    entities = set([title for title, sen_num in bundle['supporting_facts']])
    bundle['Q_edge'] = fuzzy_find(entities, bundle['question'])
    question_type = judge_question_type(bundle['question'])
    for fact in bundle['supporting_facts']:
        try:
            title, sen_num = fact
            pool = set()
            for i in range(sen_num + 1):
                name = 'edges:###{}###{}'.format(i, title)
                tmp = set([x.decode().split('###')[0] for x in db.lrange(name, 0, -1)])
                pool |= tmp
            pool &= entities
            stripped = [re.sub(r' \(.*?\)$', '', x) for x in pool] + ['yes', 'no']
            if bundle['answer'] not in stripped:
                if fuzz.ratio(re.sub(r'\(.*?\)$', '', title), bundle['answer']) > 80:
                    pool.add(title)
                else:
                    pool.add(bundle['answer'])
            if bundle['answer'] == 'yes' or bundle['answer'] == 'no' \
                    or (question_type > 0 and bundle['type'] == 'comparison'):
                pool.add(title)
            r = fuzzy_find(pool, find_fact_content(bundle, title, sen_num))
            fact.append(r)
        except IndexError as e: 
            print(bundle['_id'])
with open('./hotpot_train_v1.1_refined.json', 'w') as fout:
    json.dump(test, fout)



Finish Reading! len =  90447
[('Miami Gardens, Florida', 'Miami Gardens,', 64, 78), ('Hard Rock Stadium', 'Hard Rock Stadium', 0, 17)]





  0%|          | 0/90447 [00:00<?, ?it/s][A[A[A

5a7b23ca554299042af8f703
5abed6d45542990832d3a0ef
5ab6b2fb5542995eadef0060
5ae0e2df5542990adbacf6b1





 19%|█▉        | 17393/90447 [16:40<1:10:00, 17.39it/s][A[A[A


 19%|█▉        | 17393/90447 [16:52<1:10:00, 17.39it/s][A[A[A

5a8d6138554299585d9e37c7
5ab740165542992aa3b8c7fa
5ab2f812554299545a2cfaee





 39%|███▊      | 34968/90447 [33:20<52:59, 17.45it/s]  [A[A[A


 39%|███▊      | 34968/90447 [33:33<52:59, 17.45it/s][A[A[A

5ae7e8ef5542994a481bbe05
5ab273ee5542997061209606
5a84517355429933447460d5
5a7e5b2455429934daa2fc10
5a846921554299123d8c2243
5add66475542992200553af1
5a847c91554299123d8c2268





 58%|█████▊    | 52465/90447 [50:00<36:15, 17.46it/s][A[A[A


 58%|█████▊    | 52465/90447 [50:13<36:15, 17.46it/s][A[A[A

5a7b629555429927d897bfa4
5a80577a5542996402f6a4e9





 77%|███████▋  | 69701/90447 [1:06:40<19:52, 17.39it/s][A[A[A


 77%|███████▋  | 69701/90447 [1:06:54<19:52, 17.39it/s][A[A[A

5a8164fb5542995ce29dcbf6
5abe4bb855429965af743eb8
5a90abc355429933b8a2058a
5ab6460b5542995eadeeff96





 96%|█████████▋| 87259/90447 [1:23:20<03:02, 17.44it/s][A[A[A


 96%|█████████▋| 87259/90447 [1:23:34<03:02, 17.44it/s][A[A[A

5a8c7b125542995e66a47614
5a8a317455429930ff3c0cef


100%|██████████| 90447/90447 [1:26:22<00:00, 17.45it/s]


# Model

In [6]:
from pytorch_pretrained_bert.modeling import (
    BertPreTrainedModel as PreTrainedBertModel, # The name was changed in the new versions of pytorch_pretrained_bert
    BertModel,
    BertLayerNorm,
    gelu,
    BertEncoder,
    BertPooler,
)
import torch
from torch import nn
from utils import (
    fuzzy_find,
    find_start_end_after_tokenized,
    find_start_end_before_tokenized,
    bundle_part_to_batch,
)
from pytorch_pretrained_bert.tokenization import (
    whitespace_tokenize,
    BasicTokenizer,
    BertTokenizer,
)
import re
import pdb


class MLP(nn.Module):
    def __init__(self, input_sizes, dropout_prob=0.2, bias=False):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(1, len(input_sizes)):
            self.layers.append(nn.Linear(input_sizes[i - 1], input_sizes[i], bias=bias))
        self.norm_layers = nn.ModuleList()
        if len(input_sizes) > 2:
            for i in range(1, len(input_sizes) - 1):
                self.norm_layers.append(nn.LayerNorm(input_sizes[i]))
        self.drop_out = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(self.drop_out(x))
            if i < len(self.layers) - 1:
                x = gelu(x)
                if len(self.norm_layers):
                    x = self.norm_layers[i](x)
        return x


class GCN(nn.Module):
    def init_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.05)

    def __init__(self, input_size):
        super(GCN, self).__init__()
        self.diffusion = nn.Linear(input_size, input_size, bias=False)
        self.retained = nn.Linear(input_size, input_size, bias=False)
        self.predict = MLP(input_sizes=(input_size, input_size, 1))
        self.apply(self.init_weights)

    def forward(self, A, x):
        layer1_diffusion = A.t().mm(gelu(self.diffusion(x)))
        x = gelu(self.retained(x) + layer1_diffusion)
        layer2_diffusion = A.t().mm(gelu(self.diffusion(x)))
        x = gelu(self.retained(x) + layer2_diffusion)
        return self.predict(x).squeeze(-1)


class BertEmbeddingsPlus(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config, max_sentence_type=30):
        super(BertEmbeddingsPlus, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )

        self.sentence_type_embeddings = nn.Embedding(
            max_sentence_type, config.hidden_size
        )
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(
            seq_length, dtype=torch.long, device=input_ids.device
        )
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings((token_type_ids > 0).long())
        sentence_type_embeddings = self.sentence_type_embeddings(token_type_ids)

        embeddings = (
            words_embeddings
            + position_embeddings
            + token_type_embeddings
            + sentence_type_embeddings
        )
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertModelPlus(BertModel):
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddingsPlus(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(
        self, input_ids, token_type_ids=None, attention_mask=None, output_hidden=-4
    ):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype
        )  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(
            embedding_output, extended_attention_mask, output_all_encoded_layers=True
        )
        sequence_output = encoded_layers[-1]
        # pooled_output = self.pooler(sequence_output)
        encoded_layers, hidden_layers = (
            encoded_layers[-1],
            encoded_layers[output_hidden],
        )
        return encoded_layers, hidden_layers


class BertForMultiHopQuestionAnswering(PreTrainedBertModel):
    def __init__(self, config):
        super(BertForMultiHopQuestionAnswering, self).__init__(config)
        self.bert = BertModelPlus(config)
        self.qa_outputs = nn.Linear(config.hidden_size, 4)
        self.apply(self.init_bert_weights)

    def forward(
        self,
        input_ids,
        token_type_ids=None,
        attention_mask=None,
        sep_positions=None,
        hop_start_weights=None,
        hop_end_weights=None,
        ans_start_weights=None,
        ans_end_weights=None,
        B_starts=None,
        allow_limit=(0, 0),
    ):
        """ Extract spans by System 1.
        
        Args:
            input_ids (LongTensor): Token ids of word-pieces. (batch_size * max_length)
            token_type_ids (LongTensor): The A/B Segmentation in BERTs. (batch_size * max_length)
            attention_mask (LongTensor): Indicating whether the position is a token or padding. (batch_size * max_length)
            sep_positions (LongTensor): Positions of [SEP] tokens, mainly used in finding the num_sen of supporing facts. (batch_size * max_seps)
            hop_start_weights (Tensor): The ground truth of the probability of hop start positions. The weight of sample has been added on the ground truth. 
                (You can verify it by examining the gradient of binary cross entropy.)
            hop_end_weights ([Tensor]): The ground truth of the probability of hop end positions.
            ans_start_weights ([Tensor]): The ground truth of the probability of ans start positions.
            ans_end_weights ([Tensor]): The ground truth of the probability of ans end positions.
            B_starts (LongTensor): Start positions of sentence B.
            allow_limit (tuple, optional): An Offset for negative threshold. Defaults to (0, 0).
        
        Returns:
            [type]: [description]
        """
        batch_size = input_ids.size()[0]
        device = input_ids.get_device() if input_ids.is_cuda else torch.device("cpu")
        sequence_output, hidden_output = self.bert(
            input_ids, token_type_ids, attention_mask
        )
        semantics = hidden_output[:, 0]
        # Some shapes: sequence_output [batch_size, max_length, hidden_size], pooled_output [batch_size, hidden_size]
        if sep_positions is None:
            return semantics  # Only semantics, used in bundle forward
        else:
            max_sep = sep_positions.size()[-1]
        if max_sep == 0:
            empty = torch.zeros(batch_size, 0, dtype=torch.long, device=device)
            return (
                empty,
                empty,
                semantics,
                empty,
            )  # Only semantics, used in eval, the same ``empty'' variable is a mistake in general cases but simple

        # Predict spans
        logits = self.qa_outputs(sequence_output)
        hop_start_logits, hop_end_logits, ans_start_logits, ans_end_logits = logits.split(
            1, dim=-1
        )
        hop_start_logits = hop_start_logits.squeeze(-1)
        hop_end_logits = hop_end_logits.squeeze(-1)
        ans_start_logits = ans_start_logits.squeeze(-1)
        ans_end_logits = ans_end_logits.squeeze(-1)  # Shape: [batch_size, max_length]

        if hop_start_weights is not None:  # Train mode
            lgsf = torch.nn.LogSoftmax(
                dim=1
            )  # If there is no targeted span in the sentence, start_weights = end_weights = 0(vec)
            hop_start_loss = -torch.sum(
                hop_start_weights * lgsf(hop_start_logits), dim=-1
            )
            hop_end_loss = -torch.sum(hop_end_weights * lgsf(hop_end_logits), dim=-1)
            ans_start_loss = -torch.sum(
                ans_start_weights * lgsf(ans_start_logits), dim=-1
            )
            ans_end_loss = -torch.sum(ans_end_weights * lgsf(ans_end_logits), dim=-1)
            hop_loss = torch.mean((hop_start_loss + hop_end_loss)) / 2
            ans_loss = torch.mean((ans_start_loss + ans_end_loss)) / 2
        else:
            # In eval mode, find the exact top K spans.
            K_hop, K_ans = 3, 1
            hop_preds = torch.zeros(
                batch_size, K_hop, 3, dtype=torch.long, device=device
            )  # (start, end, sen_num)
            ans_preds = torch.zeros(
                batch_size, K_ans, 3, dtype=torch.long, device=device
            )
            ans_start_gap = torch.zeros(batch_size, device=device)
            for u, (start_logits, end_logits, preds, K, allow) in enumerate(
                (
                    (
                        hop_start_logits,
                        hop_end_logits,
                        hop_preds,
                        K_hop,
                        allow_limit[0],
                    ),
                    (
                        ans_start_logits,
                        ans_end_logits,
                        ans_preds,
                        K_ans,
                        allow_limit[1],
                    ),
                )
            ):
                for i in range(batch_size):
                    if sep_positions[i, 0] > 0:
                        values, indices = start_logits[i, B_starts[i] :].topk(K)
                        for k, index in enumerate(indices):
                            if values[k] <= start_logits[i, 0] - allow:  # not golden
                                if u == 1: # For ans spans
                                    ans_start_gap[i] = start_logits[i, 0] - values[k]
                                break
                            start = index + B_starts[i]
                            # find ending
                            for j, ending in enumerate(sep_positions[i]):
                                if ending > start or ending <= 0:
                                    break
                            if ending <= start:
                                break
                            ending = min(ending, start + 10)
                            end = torch.argmax(end_logits[i, start:ending]) + start
                            preds[i, k, 0] = start
                            preds[i, k, 1] = end
                            preds[i, k, 2] = j
        return (
            (hop_loss, ans_loss, semantics)
            if hop_start_weights is not None
            else (hop_preds, ans_preds, semantics, ans_start_gap)
        )


class CognitiveGNN(nn.Module):
    def __init__(self, hidden_size):
        super(CognitiveGNN, self).__init__()
        self.gcn = GCN(hidden_size)
        self.both_net = MLP((hidden_size, hidden_size, 1))
        self.select_net = MLP((hidden_size, hidden_size, 1))

    def forward(self, bundle, model, device):
        batch = bundle_part_to_batch(bundle)
        batch = tuple(t.to(device) for t in batch)
        hop_loss, ans_loss, semantics = model(
            *batch
        )  # Shape of semantics: [num_para, hidden_size]
        num_additional_nodes = len(bundle.additional_nodes)

        if num_additional_nodes > 0:
            max_length_additional = max([len(x) for x in bundle.additional_nodes])
            ids = torch.zeros(
                (num_additional_nodes, max_length_additional),
                dtype=torch.long,
                device=device,
            )
            segment_ids = torch.zeros(
                (num_additional_nodes, max_length_additional),
                dtype=torch.long,
                device=device,
            )
            input_mask = torch.zeros(
                (num_additional_nodes, max_length_additional),
                dtype=torch.long,
                device=device,
            )
            for i in range(num_additional_nodes):
                length = len(bundle.additional_nodes[i])
                ids[i, :length] = torch.tensor(
                    bundle.additional_nodes[i], dtype=torch.long
                )
                input_mask[i, :length] = 1
            additional_semantics = model(ids, segment_ids, input_mask)

            semantics = torch.cat((semantics, additional_semantics), dim=0)

        assert semantics.size()[0] == bundle.adj.size()[0]

        if bundle.question_type == 0:  # Wh-
            pred = self.gcn(bundle.adj.to(device), semantics)
            ce = torch.nn.CrossEntropyLoss()
            final_loss = ce(
                pred.unsqueeze(0),
                torch.tensor([bundle.answer_id], dtype=torch.long, device=device),
            )
        else:
            x, y, ans = bundle.answer_id
            ans = torch.tensor(ans, dtype=torch.float, device=device)
            diff_sem = semantics[x] - semantics[y]
            classifier = self.both_net if bundle.question_type == 1 else self.select_net
            final_loss = 0.2 * torch.nn.functional.binary_cross_entropy_with_logits(
                classifier(diff_sem).squeeze(-1), ans.to(device)
            )
        return hop_loss, ans_loss, final_loss


if __name__ == "__main__":
    BERT_MODEL = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)
    orig_text = "".join(
        [
            "Theatre Centre is a UK-based theatre company touring new plays for young audiences aged 4 to 18, founded in 1953 by Brian Way, the company has developed plays by writers including which British writer, dub poet and Rastafarian?",
            " It is the largest urban not-for-profit theatre company in the country and the largest in Western Canada, with productions taking place at the 650-seat Stanley Industrial Alliance Stage, the 440-seat Granville Island Stage, the 250-seat Goldcorp Stage at the BMO Theatre Centre, and on tour around the province.",
        ]
    )
    tokenized_text = tokenizer.tokenize(orig_text)
    print(len(tokenized_text))



112


In [7]:
import re
import json
from tqdm import tqdm, trange
import pdb
import random
from collections import namedtuple
import numpy as np
import copy
import traceback
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.optimization import BertAdam



class Bundle(object):
    """The structure to contain all data for training. 
    
    A flexible class. The properties are defined in FIELDS and dynamically added by capturing variables with the same names at runtime.
    """
    pass

FIELDS = ['ids', 'hop_start_weights', 'hop_end_weights', 'ans_start_weights', 'ans_end_weights', 'segment_ids', 'sep_positions',
     'additional_nodes', 'adj', 'answer_id', 'question_type', '_id']


# Judge question type with interrogative words
GENERAL_WD = ['is', 'are', 'am', 'was', 'were', 'have', 'has', 'had', 'can', 'could', 
              'shall', 'will', 'should', 'would', 'do', 'does', 'did', 'may', 'might', 'must', 'ought', 'need', 'dare']
GENERAL_WD += [x.capitalize() for x in GENERAL_WD]
GENERAL_WD = re.compile(' |'.join(GENERAL_WD))
def judge_question_type(q : str, G = GENERAL_WD) -> int:
    if q.find(' or ') >= 0:
        return 2 
    elif G.match(q):
        return 1
    else:
        return 0

def improve_question_type_and_answer(data, e2i):
    '''Improve the result of the judgement of question type in training data with other information.
    
    If the question is a special question(type 0), answer_id is the index of final answer node. Otherwise answer_ids are
    the indices of two compared nodes and the result of comparison(0 / 1).
    This part is not very important to the overall results, but avoids Runtime Errors in rare cases.
    
    Args:
        data (Json): Refined distractor-setting samples.
        e2i (dict): entity2index dict.
    
    Returns:
        (int, int or (int, int, 0 / 1), string): question_type, answer_id and answer_entity.
    '''
    question_type = judge_question_type(data['question'])
    # fix judgement by answer
    if data['answer'] == 'yes' or data['answer'] == 'no':
        question_type = 1
        answer_entity = data['answer']
    else:
        # check whether the answer can be extracted as a span
        answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80)
        if answer_entity is None:
            raise ValueError('Cannot find answer: {}'.format(data['answer']))
    
    if question_type == 0:
        answer_id = e2i[answer_entity]
    elif len(data['Q_edge']) != 2:
        if question_type == 1:
            raise ValueError('There must be 2 entities in "Q_edge" for type 1 question.')
        elif question_type == 2: # Judgement error, should be type 0
            question_type = 0
            answer_id = e2i[answer_entity]
    else:
        answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]] # compared nodes
        if question_type == 1:
            answer_id.append(int(data['answer'] == 'yes'))
        elif question_type == 2:
            if data['answer'] == data['Q_edge'][0][1]:
                answer_id.append(0)
            elif data['answer'] == data['Q_edge'][1][1]:
                answer_id.append(1)
            else: # cannot exactly match an option
                score = (fuzz.partial_ratio(data['answer'], data['Q_edge'][0][1]), fuzz.partial_ratio(data['answer'], data['Q_edge'][1][1]))
                if score[0] < 50 and score[1] < 50:
                    raise ValueError('There is no exact match in selecting question. answer: {}'.format(data['answer']))
                else:
                    answer_id.append(0 if score[0] > score[1] else 1)
    return question_type, answer_id, answer_entity

def convert_question_to_samples_bundle(tokenizer, data: 'Json refined', neg = 2):
    '''Make training samples.
    
    Convert distractor-setting samples(question + 10 paragraphs + answer + supporting facts) to bundles.
    
    Args:
        tokenizer (BertTokenizer): BERT Tokenizer to transform sentences to a list of word pieces.
        data (Json): Refined distractor-setting samples with gold-only cognitive graphs. 
        neg (int, optional): Defaults to 2. Negative answer nodes to add in every sample.
    
    Raises:
        ValueError: Invalid question type. 

    Returns:
        Bundle: A bundle containing 10 separate samples(including gold and negative samples).
    '''

    context = dict(data['context']) # all the entities in 10 paragraphs
    gold_sentences_set = dict([((para, sen), edges) for para, sen, edges in data['supporting_facts']]) 
    e2i, i2e = {}, [] # entity2index, index2entity
    for entity, sens in context.items():
        e2i[entity] = len(i2e)
        i2e.append(entity)
    clues = [[]] * len(i2e) # pre-extracted clues

    ids, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights, segment_ids, sep_positions, additional_nodes = [], [], [], [], [], [], [], []
    tokenized_question = ['[CLS]'] + tokenizer.tokenize(data['question']) + ['[SEP]']

    # Extract clues for entities in the gold-only cogntive graph
    for entity_x, sen, edges in data['supporting_facts']:
        for entity_y, _, _, _ in edges:
            if entity_y not in e2i: # entity y must be the answer
                assert data['answer'] == entity_y
                e2i[entity_y] = len(i2e)
                i2e.append(entity_y)
                clues.append([])
            if entity_x != entity_y:
                y = e2i[entity_y]
                clues[y] = clues[y] + tokenizer.tokenize(context[entity_x][sen]) + ['[SEP]']
    
    question_type, answer_id, answer_entity = improve_question_type_and_answer(data, e2i)
    
    # Construct training samples
    for entity, para in context.items():
        num_hop, num_ans = 0, 0
        tokenized_all = tokenized_question + clues[e2i[entity]]
        if len(tokenized_all) > 512: # BERT-base accepts at most 512 tokens
            tokenized_all = tokenized_all[:512]
            print('CLUES TOO LONG, id: {}'.format(data['_id']))
        # initialize a sample for ``entity''
        sep_position = [] 
        segment_id = [0] * len(tokenized_all)
        hop_start_weight = [0] * len(tokenized_all)
        hop_end_weight = [0] * len(tokenized_all)
        ans_start_weight = [0] * len(tokenized_all)
        ans_end_weight = [0] * len(tokenized_all)

        for sen_num, sen in enumerate(para):
            tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
            if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15:
                break
            tokenized_all += tokenized_sen
            segment_id += [sen_num + 1] * len(tokenized_sen)
            sep_position.append(len(tokenized_all) - 1)
            hs_weight = [0] * len(tokenized_sen)
            he_weight = [0] * len(tokenized_sen)
            as_weight = [0] * len(tokenized_sen)
            ae_weight = [0] * len(tokenized_sen)
            if (entity, sen_num) in gold_sentences_set:
                edges = gold_sentences_set[(entity, sen_num)]
                intervals = find_start_end_after_tokenized(tokenizer, tokenized_sen,
                    [matched  for _, matched, _, _ in edges])
                for j, (l, r) in enumerate(intervals):
                    if edges[j][0] == answer_entity or question_type > 0: # successive node edges[j][0] is answer node
                        as_weight[l] = ae_weight[r] = 1
                        num_ans += 1
                    else: # edges[j][0] is next-hop node
                        hs_weight[l] = he_weight[r] = 1
                        num_hop += 1
            hop_start_weight += hs_weight
            hop_end_weight += he_weight
            ans_start_weight += as_weight
            ans_end_weight += ae_weight
            
        assert len(tokenized_all) <= 512
        # if entity is a negative node, train negative threshold at [CLS] 
        if 1 not in hop_start_weight:
            hop_start_weight[0] = 0.1
        if 1 not in ans_start_weight:
            ans_start_weight[0] = 0.1

        ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
        sep_positions.append(sep_position)
        segment_ids.append(segment_id)
        hop_start_weights.append(hop_start_weight)
        hop_end_weights.append(hop_end_weight)
        ans_start_weights.append(ans_start_weight)
        ans_end_weights.append(ans_end_weight)

    # Construct negative answer nodes for task #2(answer node prediction)
    n = len(context)
    edges_in_bundle = []
    if question_type == 0:
        # find all edges and prepare forbidden set(containing answer) for negative sampling
        forbidden = set([])
        for para, sen, edges in data['supporting_facts']:
            for x, matched, l, r in edges:
                edges_in_bundle.append((e2i[para], e2i[x]))
                if x == answer_entity:
                    forbidden.add((para, sen))
        if answer_entity not in context and answer_entity in e2i:
            n += 1
            tokenized_all = tokenized_question + clues[e2i[answer_entity]]
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('ANSWER TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(tokenizer.convert_tokens_to_ids(tokenized_all))

        for i in range(neg):
            # build negative answer node n+i
            father_para = random.choice(list(context.keys()))
            father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                father_para = random.choice(list(context.keys()))
                father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                neg -= 1
                continue
            tokenized_all = tokenized_question + tokenizer.tokenize(context[father_para][father_sen]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('NEG TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(tokenizer.convert_tokens_to_ids(tokenized_all))
            edges_in_bundle.append((e2i[father_para], n))
            n += 1

    if question_type >= 1:
        for para, sen, edges in data['supporting_facts']:
            for x, matched, l, r in edges:
                if e2i[para] < n and  e2i[x] < n:
                    edges_in_bundle.append((e2i[para], e2i[x]))
                    
    assert n == len(additional_nodes) + len(context)
    adj = torch.eye(n) * 2
    for x, y in edges_in_bundle:
        adj[x, y] = 1
    adj /= torch.sum(adj, dim=0, keepdim=True)

    _id = data['_id']
    ret = Bundle()
    for field in FIELDS:
        setattr(ret, field, eval(field))
    return ret
    
def homebrew_data_loader(bundles, mode : 'bundle or tensors' = 'tensors', batch_size = 8):
    '''Return a generator like DataLoader in pytorch
    
    Different data are fed in task #1 and #2. In task #1, steps for different entities are decoupled into 10 samples
    and can be randomly shuffled. But in task #2, inputs must be whole graphs. 
    
    Args:
        bundles (list): List of bundles for questions.
        mode (string, optional): Defaults to 'tensors'. 'tensors' represents dataloader for task #1,
            'bundle' represents dataloader for task #2.
        batch_size (int, optional): Defaults to 8. 
    
    Raises:
        ValueError: Invalid mode
    
    Returns:
        (int, Generator): number of batches and a generator to generate batches.
    '''

    if mode == 'bundle':
        random.shuffle(bundles)
        def gen():
            for bundle in bundles:
                yield bundle
        return len(bundles), gen()
    elif mode == 'tensors':
        all_bundle = Bundle()
        for field in FIELDS[:7]:
            t = []
            setattr(all_bundle, field, t)
            for bundle in bundles:
                t.extend(getattr(bundle, field))
        n = len(t)
        # random shuffle
        orders = np.random.permutation(n)
        for field in FIELDS[:7]:
            t = getattr(all_bundle, field)
            setattr(all_bundle, field, [t[x] for x in orders])
        
        num_batch = (n - 1) // batch_size + 1
        def gen():
            for batch_num in range(num_batch):
                l, r = batch_num * batch_size, min((batch_num + 1) * batch_size, n)
                yield bundle_part_to_batch(all_bundle, l, r)
        return num_batch, gen()
    else:
        raise ValueError('mode must be "bundle" or "tensors"!')
        

# Train

# 1.Span Extraction

In [None]:
import re
import json
from tqdm import tqdm, trange
import pdb
import random
from collections import namedtuple
import numpy as np
import copy
import torch
import traceback
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.optimization import BertAdam


def train(bundles, model1, device, mode, model2, batch_size, num_epoch, gradient_accumulation_steps, lr1, lr2, alpha):
    '''Train Sys1 and Sys2 models.
    
    Train models by task #1(tensors) and task #2(bundle). 
    
    Args:
        bundles (list): List of bundles.
        model1 (BertForMultiHopQuestionAnswering): System 1 model.
        device (torch.device): The device which models and data are on.
        mode (str): Defaults to 'tensors'. Task identifier('tensors' or 'bundle').
        model2 (CognitiveGNN): System 2 model.
        batch_size (int): Defaults to 4.
        num_epoch (int): Defaults to 1.
        gradient_accumulation_steps (int): Defaults to 1. 
        lr1 (float): Defaults to 1e-4. Learning rate for Sys1.
        lr2 (float): Defaults to 1e-4. Learning rate for Sys2.
        alpha (float): Defaults to 0.2. Balance factor for loss of two systems.
    
    Returns:
        ([type], [type]): Trained models.
    '''

    # Prepare optimizer for Sys1
    param_optimizer = list(model1.named_parameters())
    # hack to remove pooler, which is not used.
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    num_batch, dataloader = homebrew_data_loader(bundles, mode = mode, batch_size=batch_size)
    num_steps = num_batch * num_epoch
    global_step = 0
    opt1 = BertAdam(optimizer_grouped_parameters, lr = lr1, warmup = 0.1, t_total=num_steps)
    model1.to(device)
    model1.train()

    # Prepare optimizer for Sys2
    if mode == 'bundle':
        opt2 = Adam(model2.parameters(), lr=lr2)
        model2.to(device)
        model2.train()
        warmed = False # warmup for jointly training

    for epoch in trange(num_epoch, desc = 'Epoch'):
        ans_mean, hop_mean = WindowMean(), WindowMean()
        opt1.zero_grad()
        if mode == 'bundle':
            final_mean = WindowMean()
            opt2.zero_grad()
        tqdm_obj = tqdm(dataloader, total = num_batch)

        for step, batch in enumerate(tqdm_obj):
            try:
                if mode == 'tensors':
                    batch = tuple(t.to(device) for t in batch)
                    hop_loss, ans_loss, pooled_output = model1(*batch)
                    hop_loss, ans_loss = hop_loss.mean(), ans_loss.mean()
                    pooled_output.detach()
                    loss = ans_loss + hop_loss
                elif mode == 'bundle':
                    hop_loss, ans_loss, final_loss = model2(batch, model1, device)
                    hop_loss, ans_loss = hop_loss.mean(), ans_loss.mean()
                    loss = ans_loss + hop_loss + alpha * final_loss
                loss.backward()

                if (step + 1) % gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses. From BERT pytorch examples
                    lr_this_step = lr1 * warmup_linear(global_step/num_steps, warmup = 0.1)
                    for param_group in opt1.param_groups:
                        param_group['lr'] = lr_this_step
                    global_step += 1
                    if mode == 'bundle':
                        opt2.step()
                        opt2.zero_grad()
                        final_mean_loss = final_mean.update(final_loss.item())
                        tqdm_obj.set_description('ans_loss: {:.2f}, hop_loss: {:.2f}, final_loss: {:.2f}'.format(
                            ans_mean.update(ans_loss.item()), hop_mean.update(hop_loss.item()), final_mean_loss))
                        # During warming period, model1 is frozen and model2 is trained to normal weights
                        if final_mean_loss < 0.9 and step > 100: # ugly manual hyperparam
                            warmed = True
                        if warmed:
                            opt1.step()
                        opt1.zero_grad()
                    else:
                        opt1.step()
                        opt1.zero_grad()
                        tqdm_obj.set_description('ans_loss: {:.2f}, hop_loss: {:.2f}'.format(
                            ans_mean.update(ans_loss.item()), hop_mean.update(hop_loss.item())))
                    if step % 1000 == 0:
                        output_model_file = './models/bert-base-uncased.bin.tmp'
                        saved_dict = {'params1' : model1.module.state_dict()}
                        saved_dict['params2'] = model2.state_dict()
                        torch.save(saved_dict, output_model_file)
            except Exception as err:
                traceback.print_exc()
                if mode == 'bundle':   
                    print(batch._id) 
    return (model1, model2)


def main(output_model_file = './models/bert-base-uncased.bin', load = False, mode = 'tensors', batch_size = 12, 
            num_epoch = 1, gradient_accumulation_steps = 1, lr1 = 1e-4, lr2 = 1e-4, alpha = 0.2):
    
    BERT_MODEL = 'bert-base-uncased' # bert-large is too large for ordinary GPU on task #2
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)
    with open('./hotpot_train_v1.1_refined_1.json' ,'r') as fin:
        dataset = json.load(fin)
    bundles = []
    for data in tqdm(dataset):
        try:
            bundles.append(convert_question_to_samples_bundle(tokenizer, data))
        except ValueError as err:
            pass
        # except Exception as err:
        #     traceback.print_exc()
        #     pass
    device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
    if load:
        print('Loading model from {}'.format(output_model_file))
        model_state_dict = torch.load(output_model_file)
        model1 = BertForMultiHopQuestionAnswering.from_pretrained(BERT_MODEL, state_dict=model_state_dict['params1'])
        model2 = CognitiveGNN(model1.config.hidden_size)
        model2.load_state_dict(model_state_dict['params2'])

    else:
        model1 = BertForMultiHopQuestionAnswering.from_pretrained(BERT_MODEL,
                cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1))
        model2 = CognitiveGNN(model1.config.hidden_size)

    print('Start Training... on {} GPUs'.format(torch.cuda.device_count()))
    model1 = torch.nn.DataParallel(model1, device_ids = range(torch.cuda.device_count()))
    model1, model2 = train(bundles, model1=model1, device=device, mode=mode, model2=model2, # Then pass hyperparams
        batch_size=batch_size, num_epoch=num_epoch, gradient_accumulation_steps=gradient_accumulation_steps,lr1=lr1, lr2=lr2, alpha=alpha)
    
    print('Saving model to {}'.format(output_model_file))
    saved_dict = {'params1' : model1.module.state_dict()}
    saved_dict['params2'] = model2.state_dict()
    torch.save(saved_dict, output_model_file)

import fire
if __name__ == "__main__":
    fire.Fire(main)

# 2.Answer Prediction

In [None]:
main(load = True,mode = "bundle",batch_size = 2)

# Evaluation

In [None]:
import re
import json
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from torch.optim import Adam
from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
import pdb
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_pretrained_bert.optimization import BertAdam
import random
from collections import namedtuple
import numpy as np
import copy
# from line_profiler import LineProfiler

def cognitive_graph_propagate(tokenizer, data: 'Json eval(Context as pool)', model1, model2, device, setting:'distractor / fullwiki' = 'fullwiki', max_new_nodes = 5):
    """Answer the question in ``data'' by trained CogQA model.
    
    Args:
        tokenizer (Tokenizer): Word-Piece tokenizer.
        data (Json): Unrefined.
        model1 (nn.Module): System 1 model.
        model2 (nn.Module): System 2 model.
        device (torch.device): Selected device.
        setting (string, optional): 'distractor / fullwiki'. Defaults to 'fullwiki'.
        max_new_nodes (int, optional): Maximum number of new nodes in cognitive graph. Defaults to 5.
    
    Returns:
        tuple: (gold_ret, ans_ret, graph_ret, ans_nodes_ret)
    """
    context = dict(data['context'])
    e2i = dict([(entity, id) for id, entity in enumerate(context.keys())])
    n = len(context)
    i2e = [''] * n
    for k, v in e2i.items():
        i2e[v] = k  
    prev = [[] for i in range(n)] # elements: (title, sen_num)
    queue = range(n) 
    semantics = [None] * n

    tokenized_question = ['[CLS]'] + tokenizer.tokenize(data['question']) + ['[SEP]']

    def construct_infer_batch(queue):
        """Construct next batch (frontier nodes to visit).
        
        Args:
            queue (list): A queue containing frontier nodes.
        
        Returns:
            tuple: A batch of inputs
        """
        ids, sep_positions, segment_ids, tokenized_alls, B_starts = [], [], [], [], []
        max_length, max_seps, num_samples = 0, 0, len(queue)
        for x in queue:
            tokenized_all = copy.copy(tokenized_question)
            for title, sen_num in prev[x]:
                tokenized_all += tokenizer.tokenize(context[title][sen_num]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('PREV TOO LONG, id: {}'.format(data['_id']))
            segment_id = [0] * len(tokenized_all)
            sep_position = [] 
            B_starts.append(len(tokenized_all))
            for sen_num, sen in enumerate(context[i2e[x]]):
                tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
                if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15:
                    break
                tokenized_all += tokenized_sen
                segment_id += [sen_num + 1] * len(tokenized_sen)
                sep_position.append(len(tokenized_all) - 1)
            max_length = max(max_length, len(tokenized_all))
            max_seps = max(max_seps, len(sep_position))
            tokenized_alls.append(tokenized_all)
            ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
            sep_positions.append(sep_position)
            segment_ids.append(segment_id)

        ids_tensor = torch.zeros((num_samples, max_length), dtype = torch.long, device = device)
        sep_positions_tensor = torch.zeros((num_samples, max_seps), dtype = torch.long, device = device)
        segment_ids_tensor = torch.zeros((num_samples, max_length), dtype = torch.long, device = device)
        input_mask = torch.zeros((num_samples, max_length), dtype = torch.long, device = device)
        B_starts = torch.tensor(B_starts, dtype = torch.long, device = device)
        for i in range(num_samples):
            length = len(ids[i])
            ids_tensor[i, :length] = torch.tensor(ids[i], dtype = torch.long)
            sep_positions_tensor[i, :len(sep_positions[i])] = torch.tensor(sep_positions[i], dtype = torch.long)
            segment_ids_tensor[i, :length] = torch.tensor(segment_ids[i], dtype = torch.long)
            input_mask[i, :length] = 1
        return ids_tensor, segment_ids_tensor, input_mask, sep_positions_tensor, tokenized_alls, B_starts
    
    gold_ret, ans_nodes = set([]), set([])
    allow_limit = [0, 0]
    while len(queue) > 0:
        # visit all nodes in the frontier queue
        ids, segment_ids, input_mask, sep_positions, tokenized_alls, B_starts = construct_infer_batch(queue)
        hop_preds, ans_preds, semantics_preds, no_ans_logits = model1(ids, segment_ids, input_mask, sep_positions,
            None, None, None, None, 
            B_starts, allow_limit)  
        new_queue = []
        for i, x in enumerate(queue):
            semantics[x] = semantics_preds[i]
            # for hop spans
            for k in range(hop_preds.size()[1]):
                l, r, j = hop_preds[i, k]
                j = j.item()
                if l == 0:
                    break
                gold_ret.add((i2e[x], j)) # supporting facts
                orig_text = context[i2e[x]][j]
                pred_slice = tokenized_alls[i][l : r + 1]
                l, r = find_start_end_before_tokenized(orig_text, [pred_slice])[0]
                if l == r == 0:
                    continue    
                recovered_matched = orig_text[l: r]
                pool = context if setting == 'distractor' else (i2e[x], j)
                matched = fuzzy_retrieve(recovered_matched, pool, setting)    
                if matched is not None:
                    if setting == 'fullwiki' and matched not in e2i and n < 10 + max_new_nodes:
                        context_new = get_context_fullwiki(matched)
                        if len(context_new) > 0: # cannot resovle redirection
                            # create new nodes in the cognitive graph
                            context[matched] = context_new
                            prev.append([])
                            semantics.append(None)
                            e2i[matched] = n
                            i2e.append(matched)
                            n += 1
                    if matched in e2i and e2i[matched] != x:
                        y = e2i[matched]
                        if y not in new_queue and (i2e[x], j) not in prev[y]:
                            # new edge means new clues! update the successor as frontier nodes.
                            new_queue.append(y)
                            prev[y].append(((i2e[x], j)))
            # for ans spans
            for k in range(ans_preds.size()[1]):
                l, r, j = ans_preds[i, k]
                j = j.item()
                if l == 0:
                    break
                gold_ret.add((i2e[x], j))
                orig_text = context[i2e[x]][j]
                pred_slice = tokenized_alls[i][l : r + 1]
                l, r = find_start_end_before_tokenized(orig_text, [pred_slice])[0]
                if l == r == 0:
                    continue    
                recovered_matched = orig_text[l: r]
                matched = fuzzy_retrieve(recovered_matched, context, 'distractor', threshold=70)
                if matched is not None:
                    y = e2i[matched]
                    ans_nodes.add(y)
                    if (i2e[x], j) not in prev[y]:
                        prev[y].append(((i2e[x], j)))
                elif n < 10 + max_new_nodes:
                    context[recovered_matched] = []
                    e2i[recovered_matched] = n
                    i2e.append(recovered_matched)
                    new_queue.append(n)
                    ans_nodes.add(n)
                    prev.append([(i2e[x], j)])
                    semantics.append(None)
                    n += 1
        if len(new_queue) == 0 and len(ans_nodes) == 0 and allow_limit[1] < 0.1: # must find one answer
            # ``allow'' is an offset of negative threshold. 
            # If no ans span is valid, make the minimal gap between negative threshold and probability of ans spans -0.1, and try again.
            prob, pos_in_queue = torch.min(no_ans_logits, dim = 0)
            new_queue.append(queue[pos_in_queue])
            allow_limit[1] = prob.item() + 0.1
        queue = new_queue

    question_type = judge_question_type(data['question'])

    if n == 0:
        return set([]), 'yes', [], []
    if n == 1 and question_type > 0:
        ans_ret = 'yes' if question_type == 1 else i2e[0]
        return [(i2e[0], 0)], ans_ret, [], []
    # GCN || CompareNets
    semantics = torch.stack(semantics)
    if question_type == 0:
        adj = torch.eye(n, device = device) * 2
        for x in range(n):
            for title, sen_num in prev[x]:
                adj[e2i[title], x] = 1
        adj /= torch.sum(adj, dim=0, keepdim=True)
        pred = model2.gcn(adj, semantics)
        for x in range(n):
            if x not in ans_nodes:
                pred[x] -= 10000.
        ans_ret = i2e[torch.argmax(pred).item()]
    else:
        # Take the most golden paragraphs as x,y
        gold_num = torch.zeros(n)
        for title, sen_num in gold_ret:
            gold_num[e2i[title]] += 1
        x, y = gold_num.topk(2)[1].tolist()
        diff_sem = semantics[x] - semantics[y]
        classifier = model2.both_net if question_type == 1 else model2.select_net
        pred = int(torch.sigmoid(classifier(diff_sem)).item() > 0.5)
        ans_ret = ['no', 'yes'][pred] if question_type == 1 else [i2e[x], i2e[y]][pred] 
    
    ans_ret = re.sub(r' \(.*?\)$', '', ans_ret)

    graph_ret = []
    for x in range(n):
        for title, sen_num in prev[x]:
            graph_ret.append('({}, {}) --> {}'.format(title, sen_num, i2e[x]))    

    ans_nodes_ret = [i2e[x] for x in ans_nodes]
    return gold_ret, ans_ret, graph_ret, ans_nodes_ret

def main(BERT_MODEL='bert-base-uncased', model_file='./models/bert-base-uncased.bin', data_file='./hotpot_dev_fullwiki_v1_merge.json', max_new_nodes=5):
    setting = 'distractor' if data_file.find('distractor') >= 0 else 'fullwiki'
    with open(data_file, 'r') as fin:
        dataset = json.load(fin)
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)
    device = torch.device('cpu') 
#     device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
    print('Loading model from {}'.format(model_file))
    model_state_dict = torch.load(model_file)
    model1 = BertForMultiHopQuestionAnswering.from_pretrained(BERT_MODEL, state_dict=model_state_dict['params1'])
    model2 = CognitiveGNN(model1.config.hidden_size)
    model2.load_state_dict(model_state_dict['params2'])
    sp, answer, graphs = {}, {}, {}
    print('Start Training... on {} GPUs'.format(torch.cuda.device_count()))
    model1 = torch.nn.DataParallel(model1, device_ids = range(torch.cuda.device_count()))
    model1.to(device).eval()
    model2.to(device).eval()

    with torch.no_grad():
        for data in tqdm(dataset):
            gold, ans, graph_ret, ans_nodes = cognitive_graph_propagate(tokenizer, data, model1, model2, device, setting = setting, max_new_nodes=max_new_nodes)
            sp[data['_id']] = list(gold)
            answer[data['_id']] = ans
            graphs[data['_id']] = graph_ret + ['answer_nodes: ' + ', '.join(ans_nodes)]
    pred_file = data_file.replace('.json', '_pred.json')
    with open(pred_file, 'w') as fout:
        json.dump({'answer': answer, 'sp': sp, 'graphs': graphs}, fout)
    
import fire
if __name__ == "__main__":

In [21]:
import sys
import ujson as json
import re
import string
from collections import Counter
import pickle

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def update_answer(metrics, prediction, gold):
    em = exact_match_score(prediction, gold)
    f1, prec, recall = f1_score(prediction, gold)
    metrics['em'] += float(em)
    metrics['f1'] += f1
    metrics['prec'] += prec
    metrics['recall'] += recall
    return em, prec, recall

def update_sp(metrics, prediction, gold):
    cur_sp_pred = set(map(tuple, prediction))
    gold_sp_pred = set(map(tuple, gold))
    tp, fp, fn = 0, 0, 0
    for e in cur_sp_pred:
        if e in gold_sp_pred:
            tp += 1
        else:
            fp += 1
    for e in gold_sp_pred:
        if e not in cur_sp_pred:
            fn += 1
    prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
    recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
    f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
    em = 1.0 if fp + fn == 0 else 0.0
    metrics['sp_em'] += em
    metrics['sp_f1'] += f1
    metrics['sp_prec'] += prec
    metrics['sp_recall'] += recall
    return em, prec, recall

def eval(prediction_file, gold_file):
    with open(prediction_file) as f:
        prediction = json.load(f)
    with open(gold_file) as f:
        gold = json.load(f)

    metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
        'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
        'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
    for dp in gold:
        cur_id = dp['_id']
        can_eval_joint = True
        if cur_id not in prediction['answer']:
            print('missing answer {}'.format(cur_id))
            can_eval_joint = False
        else:
            em, prec, recall = update_answer(
                metrics, prediction['answer'][cur_id], dp['answer'])
        if cur_id not in prediction['sp']:
            print('missing sp fact {}'.format(cur_id))
            can_eval_joint = False
        else:
            sp_em, sp_prec, sp_recall = update_sp(
                metrics, prediction['sp'][cur_id], dp['supporting_facts'])

        if can_eval_joint:
            joint_prec = prec * sp_prec
            joint_recall = recall * sp_recall
            if joint_prec + joint_recall > 0:
                joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
            else:
                joint_f1 = 0.
            joint_em = em * sp_em

            metrics['joint_em'] += joint_em
            metrics['joint_f1'] += joint_f1
            metrics['joint_prec'] += joint_prec
            metrics['joint_recall'] += joint_recall

    N = len(gold)
    for k in metrics.keys():
        metrics[k] /= N

    print(metrics)

if __name__ == '__main__':
    eval("hotpot_dev_fullwiki_v1_pred_2020730.json","hotpot_dev_fullwiki_v1_merge.json")

{'em': 0.2922659862777795, 'f1': 0.3670088365378471, 'prec': 0.48387104854078733, 'recall': 0.3963252165799103, 'sp_em': 0.19636927285200542, 'sp_f1': 0.5167430116764522, 'sp_prec': 0.5502988632469649, 'sp_recall': 0.5692590852832177, 'joint_em': 0.10283703478544845, 'joint_f1': 0.27465057113200253, 'joint_prec': 0.38052119985128563, 'joint_recall': 0.290684510066337}
