In [None]:
import torch
import json
import pickle
import numpy as np
from pytorch_pretrained_bert import BertTokenizer

In [None]:
processed_data = pickle.load(open("../fever_processed.pickle", "rb"))
testing_data = np.asarray(processed_data[-len(processed_data)//10:])
training_data = np.asarray(processed_data[:-len(processed_data)//10])

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
def getSamples(data):
    classes = [line["label"] for line in data]
    positive_samples = np.asarray(classes) =="SUPPORTS"
    positive_samples = np.asarray([i for i in range(len(positive_samples)) if positive_samples[i] == True])
    negative_samples = np.asarray(classes) =="REFUTES"
    negative_samples = np.asarray([i for i in range(len(negative_samples)) if negative_samples[i] == True])
    print(len(positive_samples), len(negative_samples))
    return positive_samples, negative_samples

In [None]:
tr_p, tr_n = getSamples(training_data)
te_p, te_n = getSamples(testing_data)

In [None]:
max_length = 300
max_claim_length = 15

In [None]:
def getBatch(bs = 64, validation = False):
    source = training_data
    positive_samples = tr_p
    negative_samples = tr_n
    
    if (validation):
        source = testing_data
        positive_samples = te_p
        negative_samples = te_n
    
    n_samples = bs // 2
    p_samples = bs - n_samples
    positives = np.random.randint(0, len(positive_samples), (p_samples,))
    negatives = np.random.randint(0, len(negative_samples), (n_samples,))
    positives = positive_samples[positives]
    negatives = negative_samples[negatives]
    
    all_indices = []
    all_indices.extend(positives)
    all_indices.extend(negatives)
    
    _t = []
    _s = []
    _a = []
    _c = []
    
    for index in all_indices:
        _dp = ["[CLS]"]
        _dp.extend(source[index]["processed"]["claim"])
        _dp.append("[SEP]")
        for evid in source[index]["processed"]["evidentiary"]:
            _dp.extend(evid)
        _dp.append("[SEP]")
        
        while (len(_dp) < max_length):
            _dp.append("[PAD]")
        _dp = _dp[:max_length]
        _dp = np.asarray(tokenizer.convert_tokens_to_ids(_dp))
        segments = np.ones((max_length,))
        segments[:len(source[index]["processed"]["claim"]) + 2] = 0
        _class = 1 if source[index]["label"] == "SUPPORTS" else 0
        att_mask = [1 if _dp[index] >0 else 0 for index in range(len(_dp))]
        _t.append(_dp)
        _s.append(segments)
        _a.append(att_mask)
        _c.append(_class)
    
    text = torch.LongTensor(_t).cuda()
    segments = torch.LongTensor(_s).cuda()
    att = torch.LongTensor(_a).cuda()
    classes = torch.FloatTensor(_c).cuda()
    
    return text, segments, att, classes
    #np.random.shuffle(_data)
t, s, a, c = getBatch(bs = 5, validation = False)
print(t.size(), s.size(), a.size(), c.size())

In [None]:
import torch
import torch.nn.functional as F
from pytorch_pretrained_bert import BertModel
from QA_Attentions import *

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class FaVer(torch.nn.Module):
    def __init__(self, bert_model = "bert-base-uncased"):
        super(FaVer, self).__init__()
        self.bert_model = bert_model
        self.bert_width = 768
        if ("-large-" in self.bert_model):
            self.bert_width = 1024
        self.bert = BertModel.from_pretrained(bert_model)
        self.wd = torch.nn.Parameter(torch.FloatTensor(np.random.uniform(0, 1, (3*self.bert_width,))))
        #self.innerAttQuery = torch.nn.Parameter(torch.FloatTensor(np.random.uniform(0, 1, (self.bert_width, 512))))
        self.innerAttDoc = torch.nn.Parameter(torch.FloatTensor(np.random.uniform(0, 1, (self.bert_width*4, 512))))
        self.out = torch.nn.Linear((self.bert_width*4),1)
        self.dropout = torch.nn.Dropout(0.1)
    
    def forward(self, t, s, a):
        text, pooled = self.bert(t,
                        token_type_ids=s, 
                        attention_mask=a, 
                        output_all_encoded_layers=False)
        
        text = self.dropout(text)
        cl_ = s == 0
        ev_ = s == 1
        claims = text * cl_.unsqueeze(-1).float()
        claims = claims[:,:max_claim_length, :]
        evidences = text * ev_.unsqueeze(-1).float()
        evidences = evidences * a.unsqueeze(-1).float()
        bdaf, ad2q, aq2d = biDAF(evidences, claims, self.wd)
        _f = self.out(InnerAttention(bdaf, self.innerAttDoc))
        return _f

In [None]:
lossFn = torch.nn.BCEWithLogitsLoss()
def getLoss(pred, actual, lossFn, e_weight=0.6, ne_weight=0.4):
    loss = lossFn(pred.squeeze(-1), actual)
    return loss

In [None]:
network = FaVer().cuda()
lr = 3e-5
optimizer = torch.optim.Adam(network.parameters(), lr=lr)

In [None]:
with torch.no_grad():
    t, s, a, y = getBatch(5)
    y_ = network.forward(t, s, a)

In [None]:
epoch_losses = []
epoch_vals = []
epoch_accs = []
epoch_evid = []

In [None]:
def _save(cause, network):
    print("\tSaving Model for Cause:", cause)
    torch.save(network.state_dict(), "./FaVer_" + cause + "_save.h5")
    with open("./" + cause + "_training_cycle.json", "w") as f:            
        f.write(json.dumps(
            {
                "training_losses":epoch_losses,
                "validation_losses":epoch_vals,
                "validation_accuracy":epoch_accs,
                "evidence_accuracy":epoch_evid        
            }
        ))
        f.close()
    
def chooseModelSave(network):
    save = False
    if (np.min(epoch_vals) == epoch_vals[-1]):
        cause = "BestValidationLoss"
        _save(cause, network)
    
    if (np.max(epoch_accs) == epoch_accs[-1]):
        cause = "BestValidationOverallAccuracy"
        _save(cause, network)
    
    if (np.max(epoch_evid) == epoch_evid[-1]):
        cause = "BestValidationEvidentiaryAccuracy"
        _save(cause, network)

In [None]:
def validate(network, bs=100, num_batches=5):
    
    classes = torch.FloatTensor([]).cuda()
    preds = torch.FloatTensor([]).cuda()
    with torch.no_grad():
        for i in range(num_batches):
            et, es, ea, classes_ = getBatch(bs=bs, validation=True)
            y_ = network.forward(et, es, ea)
            classes = torch.cat([classes, classes_], dim=0)
            preds = torch.cat([preds, y_], dim=0)
        
        evidences = classes >= 1
        f_loss = getLoss(preds, classes, lossFn)
        pred = torch.round(torch.sigmoid(preds)).squeeze(-1)
        acc = torch.sum(pred == classes)
        acc = acc.cpu().numpy()/(bs*num_batches)
        positives = torch.sum(pred[evidences] == classes[evidences])
        
        return f_loss.data.item(), acc, positives.cpu().numpy()/torch.sum(evidences).cpu().numpy()

In [None]:
def Train(network, bs = 5, epochs = 10, batches_per_epoch = 20):

    val_min = 1000
    for k in range(epochs):
        losses = []
        for i in range(batches_per_epoch):
            t, s, a, y = getBatch(bs)
            y_ = network.forward(t, s, a)
            loss = getLoss(y_, y, lossFn)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.data.item())
            print("epoch:", k+1, "batch:", i+1, "loss:", np.round(np.mean(losses),5), end="\r")
        epoch_losses.append(np.mean(losses))
        val_loss, acc, evid_acc = validate(network, num_batches=10)
        epoch_vals.append(val_loss)
        epoch_accs.append(acc)
        epoch_evid.append(evid_acc)
        
        print("\n\tValidation Loss:", np.round(val_loss,5))
        print("\tOverall Validation Accuracy:", np.round(acc,2), "; and for evidence only:", np.round(evid_acc,2))
        
        if (val_loss < val_min):
            val_min = val_loss
            
        chooseModelSave(network)
        
        with open("./FaVer_training_cycle.json", "w") as f:            
            f.write(json.dumps(
                {
                    "training_losses":epoch_losses,
                    "validation_losses":epoch_vals,
                    "validation_accuracy":epoch_accs,
                    "evidence_accuracy":epoch_evid        
                }
            ))
            f.close()


In [None]:
Train(network, bs=16, batches_per_epoch = 1000, epochs=15)