In [1]:
import numpy as np
import json
from collections import Counter
from tqdm import tqdm
import nltk
import re
import string
from nltk.corpus import stopwords
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import os
import json

def word_tokenize(tokens):
    return [token.replace("''", '"').replace("``", '"') for token in nltk.word_tokenize(tokens)]


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc((s))))




def get_2d_spans(text, tokenss):
    spanss = []
    cur_idx = 0
    for tokens in tokenss:
        spans = []
        for token in tokens:
            if text.find(token, cur_idx) < 0:
                print(tokens)
                print("{} {} {}".format(token, cur_idx, text))
                raise Exception()
            cur_idx = text.find(token, cur_idx)
            spans.append((cur_idx, cur_idx + len(token)))
            cur_idx += len(token)
        spanss.append(spans)
    return spanss


def get_word_span(context, wordss, start, stop):
    spanss = get_2d_spans(context, wordss)
    idxs = []
    for sent_idx, spans in enumerate(spanss):
        for word_idx, span in enumerate(spans):
            if not (stop <= span[0] or start >= span[1]):
                idxs.append((sent_idx, word_idx))
    #print(spanss)
    #print(start,stop)
    #print(context[start:stop])
    assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop)
    return idxs[0], (idxs[-1][0], idxs[-1][1] + 1)
def get_word_idx(context, wordss, idx):
    spanss = get_2d_spans(context, wordss)
    return spanss[idx[0]][idx[1]][0]



def process_tokens1(tokens):
   
    
    tokens = [w for w in tokens if w not in set(string.punctuation)]
   # stop_words = set(stopwords.words('english'))
    #tokens = [w for w in tokens if not w in stop_words]
    
    return tokens



def process_tokens(temp_tokens):
    tokens = []
    
    for token in temp_tokens:
        flag = False
        l = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0")
        # \u2013 is en-dash. Used for number to nubmer
        # l = ("-", "\u2212", "\u2014", "\u2013")
        # l = ("\u2013",)
        tokens.extend(re.split("([{}])".format("".join(l)), token))
    
    return tokens

def generatedateset(sentence_list,question_list):
    s_list=[]
    q_list=[]
    a_list=[]
    res=[]
    
    for q in question_list:
        if(len(q)==0):
            continue
        ai=q[0][0]
      
        pi=q[0][1]
        si=q[0][2]
        
        for i in range(len(sentence_list[ai][pi])):
            if(i==si):
                res+=[[q[0][3],sentence_list[ai][pi][i],q[0][4], q[0][5]]]
            else:
                res+=[[q[0][3],sentence_list[ai][pi][i],0, q[0][5]]]
    return res
def create_index_dict(embed,word_counter):
    word2idx={}
    idx2word={}
    i=1
    word2idx['<UNK>']=0
    idx2word[0]='<UNK>'
    vocab=np.array(['<UNK>'])
    for (key,value) in (embed.items()):
        word2idx[key]=i
        idx2word[i]=key
   #     vocab=np.append(vocab,[key])
        i+=1
    
    
    for key,value in (word_counter.items()):
        key=key.lower()
    
        try:
            embed[key]
           
        except: 
            try:
                word2idx[key]
            except:
                word2idx[key]=i
                idx2word[i]=key
        #        vocab=np.append(vocab,[key])
                i+=1
   
   # assert len(word2idx.keys())==len(vocab)
    return word2idx,idx2word,set(word2idx.keys())

def process(data):
    q_trn=[]
    s_trn=[]
    y_trn=[]
    q_id=[]
    for i in range(len(data)):
       # tokens = [w for w in tokens if w not in set(string.punctuation) ]
        
        
        question=list(filter(None, data[i][0]))#filter empty strin
        sentence=list(filter(None, data[i][1]))#filter empty strin
        index=data[i][2]
       # list1+=[[question,sentence,index]]
        
        q_trn+=[question]
        s_trn+=[sentence]
        y_trn+=[index]
        q_id+=[data[i][3]]
    return q_trn,s_trn,y_trn, q_id

def read_embedding(filename):
    embed = {}
    for line in open(filename,encoding='utf8'):
        line = line.strip().split()
        if (len(line)!=51):
            continue
        try:
            embed[(line[0])] = np.array(list(map(float, line[1:])))
        except:
            pass
    
    print('[%s]\n\tEmbedding size: %d' % (filename, len(embed)), end='\n')
    return embed

def load_data(data_type):
    source_path = "{}.json".format(data_type)
    source_data = json.load(open(source_path, 'r'))
    context_list=[]
    sentence_list=[]
    question_list=[]
    answer_list=[]
    answer_index=[]
    label_index=[]
    label=[]
    idxs=[]
    word_counter = Counter()

    for ai, article in enumerate(tqdm(source_data['data'])):
        s=[]
        p = []
        for pi, para in enumerate(article['paragraphs']):
            # wordss
            context = para['context']
            context = context.replace("''", '" ')
            context = context.replace("``", '" ')

            si = list(map(word_tokenize, sent_tokenize(context)))
           # print(si)
            si = [process_tokens(tokens) for tokens in si]  # process tokens
            for sentence in si:
                for word in sentence:
                     word_counter[word] += 1

            # given xi, add chars
            s.append(si)
            p.append(context)

            indexi = [ai, pi]
            
            for qa in para['qas']:
                    qi = word_tokenize(qa['question'])
                    if qa['is_impossible']:
                        labeli=[0]
                      #  print(labeli)
                    else:
                        labeli=[1]       
                    yi = []
                    yyi=[] 
                    answers = []
                    q_id = qa['id']
                    
                    
                    ans=[]
                    if labeli[0] == 1:
                        ans=qa['answers']
                    else:
                        ans=qa['plausible_answers']
                    for answer in ans:
                        answer_text=answer['text']
                        answers.append(answer_text)
                        answer_start = answer['answer_start']
                        answer_stop = answer_start + len(answer_text)
                       # print(context)
                       # print(si)
                        yi0, yi1 = get_word_span(context, si, answer_start, answer_stop)
                        #w0 = xi[yi0[0]][yi0[1]]            
                        #w1 = xi[yi1[0]][yi1[1]-1]
                        yi.append([ai,pi,yi0[0],qi,labeli[0],q_id]) 
                       # print(yi)
                       # print(yi)
                        yyi.append([answer_start,answer_stop])
                    for qij in qi:

                                word_counter[qij] += 1
                    question_list.append(qi)
                    answer_list.append(answers)
                    answer_index.append(yi)
                    label.append(labeli)
                    label_index.append(indexi)
                   # ids.append(qa['id'])
                    idxs.append(len(idxs))

        sentence_list.append(s)

        context_list.append(p)
        
    
    a = {'word_counter':word_counter,'sentence_list':sentence_list,'answer_index':answer_index}
    with open("{}_saved.json".format(data_type), "w") as fp:
        json.dump(a , fp) 
    print("saved in json")
    return 
def load_from_json(data_type):
    with open("{}_saved.json".format(data_type), "r") as fp:
            a=json.load(fp) 
  
    return a['word_counter'],a['sentence_list'],a['answer_index']
def read_embedding(filename):
    embed = {}
    for line in open(filename,encoding='utf8'):
        line = line.strip().split()
        if (len(line)!=51):
            continue
        try:
            embed[(line[0])] = np.array(list(map(float, line[1:])))
        except:
            pass
    
    print('[%s]\n\tEmbedding size: %d' % (filename, len(embed)), end='\n')
    return embed

def make_output(output):
    if output == 0:
        return torch.tensor([0])
    else:
        return torch.tensor([1])
def create_emb_layer(weights_matrix, num_embeddings,trainable=False):
    _, embedding_dim = weights_matrix.size()
    emb_layer = nn.Embedding(num_embeddings, embedding_dim)
    
    
    #emb_layer.load_state_dict({'weight': weights_matrix})
    
    if not trainable:
        emb_layer.weight.requires_grad = False
    else:
        emb_layer.weight.requires_grad = True
    emb_layer.weight.data.copy_(weights_matrix)
    return emb_layer, embedding_dim


def evaluate_data(preds, correct_mapping):
    
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0
    
    
    for id, pred in correct_mapping.items():
        if pred == 0 and preds[id] == 0:
            true_negatives += 1
        elif pred == 1 and preds[id] == 1:
            true_positives += 1
        elif pred == 0 and preds[id] == 1:
            false_positives += 1
        elif pred == 1 and preds[id] == 0:
            false_negatives += 1
    print( true_positives,true_negatives,false_positives,false_negatives)
    precision = calculate_precision(true_positives, false_positives)
    recall = calculate_recall(true_positives, false_negatives)
    f_1 = calculate_f1(precision, recall)
    return precision, recall, f_1

def get_acc(preds,gold_labels):
   correct=0;
   total=len(preds)
   for i in range(total):
       if preds[i]==gold_labels[i]:
           correct+=1
   return correct/total


def calculate_precision(true_positives, false_positives):
    return true_positives / (true_positives + false_positives)


def calculate_recall(true_positives, false_negatives):
    return true_positives / (true_positives + false_negatives)


def calculate_f1(precision, recall):
    return (2 * precision * recall) / (precision + recall)


def get_we_slist(slist, word2idx):
    res_slist=[]
    for sent in slist:
        s=[]
        for word in sent:
            try:
                s.append(word2idx[word.lower()])
            except:
                s.append(word2idx['<UNK>'])
        res_slist.append(s)
                
    return res_slist
 
def create_prediction_file(preds, q_ids, fname):
    if(len(preds) != len(q_ids)):
        print("Error in pred len")
        return
    result={}
    for i in range(len(preds)):
        result[q_ids[i]] = preds[i]
        
    with open(fname, 'w') as outfile:
        json.dump(result, outfile)
    
def get_prediction_dict(preds, q_ids):
    if(len(preds) != len(q_ids)):
        print("Error in pred len")
        return
    result={}
    for i in range(len(preds)):
        result[q_ids[i]] = preds[i]
        
    return result
#create_embedding_matrix
def create_embeddingmatrix(vocab):
    
    initW = torch.nn.init.xavier_normal_(torch.randn([len(vocab), 301]))
    for i,v in enumerate(vocab):
        try:
            initW[i]=embed[v]
        except:
            continue
    embededing_matrix=initW
    return embededing_matrix



In [None]:
embed_file="glove.6B.300d.txt"
embed=read_embedding(embed_file)
print("finish loading embedding file")

split=True
nltk.data.path.append("/afs/cs.stanford.edu/u/tianzhao/tz/nltk_data")
sent_tokenize = nltk.sent_tokenize

data_type='training'
word_counter,sentence_list,answer_index=load_from_json(data_type)
data_type='development'
d_word_counter,d_sentence_list,d_answer_index=load_from_json(data_type)

data_type='training'
X=generatedateset(sentence_list,answer_index)   
word2idx,idx2word,vocab=create_index_dict(embed,word_counter)
q_trn,s_trn,y_trn, q_id=process(X)
q_trn_i=[ [word2idx[j.lower()] for j in i] for i in q_trn]
s_trn_i=[ [word2idx[j.lower()] for j in i] for i in s_trn]
vocab_size=len(vocab)
print(vocab_size)
embedding_matrix=create_embeddingmatrix(vocab)



X_Tst=generatedateset(d_sentence_list,d_answer_index)   
q_tst,s_tst,y_tst, q_id_tst=process(X_Tst)
q_tst_i=get_we_slist(q_tst, word2idx)
s_tst_i=get_we_slist(s_tst, word2idx)


y_trn_d = get_prediction_dict(y_trn, q_id)
y_tst_d = get_prediction_dict(y_tst, q_id_tst)

In [None]:

class model(nn.Module):

    def __init__(self,weights_matrix,vocab_size,hidden_dim=128, trainable=False):
        super(model, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size=vocab_size
        self.trainable=trainable
        self.embeds,self.embedding_dim = create_emb_layer(weights_matrix, self.vocab_size,trainable=self.trainable)
        self.encoder1 = nn.LSTM( self.embedding_dim,  self.hidden_dim)
        self.encoder2 = nn.LSTM( self.embedding_dim,  self.hidden_dim)
        
        self.loss = nn.CrossEntropyLoss()
        self.out1 = nn.Linear(2*self.hidden_dim,2)
        self.out2 = nn.Linear(1024,2048)
        
        self.out3 = nn.Linear(2048,2)
        
        self.softmax=nn.Softmax(dim=0)


    def compute_Loss(self, pred_vec, gold_seq):
        return self.loss(pred_vec, gold_seq)
        
    def forward(self,input_question, input_sen):
        question_vectors = self.embeds(torch.tensor(input_question))
        sen_vectors = self.embeds(torch.tensor(input_sen))
        
        
        encoder_outputs1, hidden = self.encoder1(question_vectors.view(len(question_vectors),1,-1))
        encoder_outputs2, hidden = self.encoder2(sen_vectors.view(len(sen_vectors),1,-1))
        
        
        combined=torch.cat((encoder_outputs1[-1], encoder_outputs2[-1]), 1)
        
        prediction = self.out1(combined)
        prediction=torch.nn.functional.relu(prediction)
        
        #prediction = self.out2(prediction)
      #  prediction=torch.nn.functional.relu(prediction)
        
        #prediction = self.out3(prediction)
       # prediction=torch.nn.functional.relu(prediction)
        
        prediction = prediction.squeeze()
        val, idx = torch.max(prediction, 0)
       # print(val)
       
        return prediction, idx.item()


In [None]:
#main
trainable=True

m = model(weights_matrix=embedding_matrix, vocab_size=vocab_size,hidden_dim = 128,trainable = trainable)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), lr=0.01)

use_gpu = torch.cuda.is_available()
if use_gpu:
    m = m.cuda()
    print ('USE GPU')
else:
    print ('USE CPU')



minibatch_size = 2
    
num_minibatches = len(q_trn_i) // minibatch_size 

for epoch in (range(3)):
        # Training
        print("Training")
        # Put the model in evaluation mode
        m.cuda()
        m.train()
        start_train = time.time()
        for group in tqdm(range(num_minibatches)):
            predictions = None
            gold_outputs = None
            loss = 0
            optimizer.zero_grad()
            for i in range(group * minibatch_size, (group + 1) * minibatch_size):
               # print(i)
                question_vectors = q_trn_i[i]
                sen_vectors = s_trn_i[i]
                gold_output = make_output(y_trn[i])
             
                if use_gpu:
                
                    question_vectors  = torch.cuda.LongTensor(question_vectors)
                    sen_vectors = torch.cuda.LongTensor(sen_vectors)
                    gold_output = gold_output.cuda()
                prediction_vec, prediction = m(question_vectors,sen_vectors)
             
                if predictions is None:
                    predictions = [prediction_vec]
                    gold_outputs = [gold_output] 
                else:
                    predictions.append(prediction_vec)
                    gold_outputs.append(gold_output)
            #print(gold_outputs)
            if(minibatch_size>1):
               # print(torch.stack(predictions),torch.stack(gold_outputs).squeeze())
                loss = m.compute_Loss(torch.stack(predictions), torch.stack(gold_outputs).squeeze())
              
            else:

                loss = m.compute_Loss(prediction_vec.view(1,-1), gold_output)
            
           # print(gold_outputs)
           # if(i%100==1):
            #    print(loss)
            
            loss.backward()
            optimizer.step()
        print("Training time: {} for epoch {}".format(time.time() - start_train, epoch))
        
        
        print("Evaluation")
        # Put the model in evaluation mode
     
        m.eval()
        start_eval = time.time()

        #Inference phase
       
        predictions = None
        for q_vec, s_vec  in zip(q_tst_i, s_tst_i):
            if use_gpu:

                q_vec  = torch.cuda.LongTensor(q_vec)
                s_vec = torch.cuda.LongTensor(s_vec)
               
            _, predicted_output = m(q_vec,s_vec)
            if predictions is None:
                predictions = [predicted_output]
                   
            else:
                predictions.append(predicted_output)
   
        #print(predictions)
        accu=get_acc(predictions,y_tst)
        #preds_d = get_prediction_dict(predictions, q_id)
        #print(preds_d)
        #print(y_te)
       # print(preds_d)
       # p, r, f = evaluate_data(preds_d, y_te)
        
        
        #print("Evaluation time: {} for epoch {}, presion: {}, recall: {}, F1: {},".format(time.time() - start_eval, epoch, p,r,f))
        print("Evaluation time: {} for epoch {}, acc ：{},".format(time.time() - start_eval, epoch,accu))

print("Done training")

PATH="model_1_%s.pth"%(str(trainable))
torch.save(m.state_dict(), PATH)
print("save model completed")