In [None]:
import numpy as np
import torch
import pickle
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

- The first 512 tokens of the concatenated evidences as shown in the dataset are used. Substitute this for a sentence formation. 

In [None]:
data = pickle.load(open("../Data/usable_verifiable_fever_data.pickle", "rb"))
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print(str(len(data)) + " data points.")

In [None]:
class Verifier(torch.nn.Module):
    def __init__(self, attention_dim = 512, internal_lstm_dim = 128, num_layers_to_take_from_bert=4):
        super(Verifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.encoder_lstm = torch.nn.LSTM(num_layers_to_take_from_bert*768, 
                                           internal_lstm_dim, batch_first=True)
        self.attention_internal = torch.nn.Linear(internal_lstm_dim, internal_lstm_dim)
        self.softmax = torch.nn.Softmax()
        self.sigmoid = torch.nn.Sigmoid()
        self.output = torch.nn.Linear(internal_lstm_dim*2, 1)
        print("REMEMBER TO USE BCEWithLogitsLoss since the output is not put through sigmoid")
    
    def forward(self, claims, evidences):
        #put both the claims and the evidences thru BERT
        claims_, _ = self.bert(claims)
        claims_ = torch.cat(claims_[-4:], dim=-1)
        evidences_, _ = self.bert(evidences)
        evidences_ = torch.cat(evidences_[-4:], dim=-1)
        #print("OUTPUTS FROM BERT")
        #print(claims_.size(), evidences_.size())
        
        #put the bert output through the same encoder
        claims_, _ = self.encoder_lstm(claims_)
        evidences_, _ = self.encoder_lstm(evidences_)        
        #print("OUTPUTS FROM THE ENCODER LSTM")
        #print(claims_.size(), evidences_.size())
        
        #do attention between the last time step of the claims, and all the timesteps of the evidences
        claims_ = self.attention_internal(claims_[:,-1,:]).unsqueeze(dim=1)
        att = torch.matmul(evidences_, claims_.transpose(1, 2))
        att = self.softmax(att)
        evidences_ = torch.sum(evidences_ * att, dim=1)
        #print("EVIDENCES POST ATTENTION")
        #print(evidences_.size())
        
        #put the evidence through the output_prep vector
        out_ = torch.cat([claims_.squeeze(dim=1), evidences_], dim=-1)
        #print("POST ATTENTION CONCAT")
        #print(out_.size())
        out_ = self.output(out_)
        #print("FINAL OUTPUT SIZE")
        #print(out_.size())
        return self.sigmoid(out_.squeeze(dim=-1))

In [None]:
verifier = Verifier().cuda()

In [None]:
def getBatch(bs = 5, max_len=512, claim_len=30):
    indices = np.random.randint(0, len(data), (bs,))
    batch_evidences = [data[index]["evidence"] for index in indices]
    batch_claims = [data[index]["claim"] for index in indices]
    y = [data[index]["class"] for index in indices]
    return batch_claims, batch_evidences, y

In [None]:
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(verifier.parameters(), lr=0.01)

In [None]:
def train(batch_size=4, total_batches=100):
    losses = []
    for i in range(total_batches):
        claim, evidence, y = getBatch(bs=batch_size)
        claim = torch.LongTensor(claim).cuda()
        evidence = torch.LongTensor(evidence).cuda()
        y = torch.FloatTensor(y).cuda()
        output = verifier.forward(claim, evidence)
        loss = loss_fn(output, y)
        losses.append(loss.data.item())
        loss.backward()
        optimizer.step()
        losses = losses[-100:]
        print("Batch:", str(i), "; Average Loss:", str(np.round(np.mean(losses), 4)), end="\r")

In [None]:
train(total_batches=10000)