In [179]:
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, BertForTokenClassification, BertModel
import torch, tqdm, json
import numpy as np
from torch.utils.data import DataLoader, Dataset 
from torch import nn
import torch.nn.functional as F
import random
# from scorer impbort fever_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_EPOCHS = 2

In [180]:
def process_data(fname, train=True):
    X = []
    mask = []
    token_type_ids = []
    y = []

    label_dict = {}
    label_dict['UNK'] = -1
    label_dict['NOT ENOUGH INFO'] = 0
    label_dict['SUPPORTS'] = 1
    label_dict['REFUTES'] = 2
    claim_ids = []

    predicted_evidence = []
    f = open(fname, encoding='utf8')
    f.readline()
    for line in f:
        line = json.loads(line)
        claim_ids.append(line['id'])
        predicted_evidence.append([line['doc'], line['sid']])

        emb = tokenizer.encode_plus(line['claim'], line["sentence"], pad_to_max_length=True)
        input_ids, sent_ids, m = emb['input_ids'], emb['token_type_ids'], emb['attention_mask']
        
        X.append(input_ids[:128])
        mask.append(m[:128])
        token_type_ids.append(sent_ids[:128])
        
        y.append(label_dict[line['label']])
    f.close()
        
    return torch.LongTensor(X), torch.LongTensor(y), torch.LongTensor(mask), torch.LongTensor(token_type_ids), claim_ids, predicted_evidence

In [138]:
X_train, y_train, mask_train, token_type_ids_train, ids_train, predicted_evidence_train = process_data("NN-NLP-Project-Data/train_sent_results.txt")
X_dev, y_dev, mask_dev, token_type_ids_dev, ids_dev, predicted_evidence_dev = process_data("NN-NLP-Project-Data/dev_sent_results.txt")
# X_test, y_test, mask_test, ids_test, predicted_evidence_test = process_data("NN-NLP-Project-Data/test_sent_results.txt")

train_dataset = TensorDataset(X_train, y_train, mask_train, token_type_ids)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, num_workers=8)

dev_dataset = TensorDataset(X_dev, y_dev, mask_dev, token_type_ids)
dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=32, num_workers=8)

FileNotFoundError: [Errno 2] No such file or directory: 'NN-NLP-Project-Data/train_sent_results.txt'

In [217]:
class LinearSelfAttn(nn.Module):
    """Self attention over a sequence:
    * o_i = softmax(Wx_i) for x_i in X.
    """
    def __init__(self, input_size, drop_rate=0.1):
        super(LinearSelfAttn, self).__init__()
        self.linear = nn.Linear(input_size, 1)
        self.dropout = torch.nn.Dropout(drop_rate)

    def forward(self, x, x_mask):
        x = self.dropout(x)
        x_flat = x.contiguous().view(-1, x.size(-1))
        scores = self.linear(x_flat).view(x.size(0), x.size(1))
        scores.data.masked_fill_(x_mask.data, -float('inf'))
        alpha = F.softmax(scores, 1)
        return alpha.unsqueeze(1).bmm(x).squeeze(1)

In [218]:
class BilinearFlatSim(nn.Module):
    """A bilinear attention layer over a sequence X w.r.t y:
    * o_i = x_i'Wy for x_i in X.
    """
    def __init__(self, x_size, y_size, drop_rate=0.1):
        super(BilinearFlatSim, self).__init__()
  
        self.linear = nn.Linear(y_size, x_size)
        self.dropout = torch.nn.Dropout(drop_rate)

    def forward(self, x, y, x_mask):
        """
        x = batch * len * h1
        y = batch * h2
        x_mask = batch * len
        """
        x = self.dropout(x)
        y = self.dropout(y)

        Wy = self.linear(y)
        xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
        xWy.data.masked_fill_(x_mask.data, -float('inf'))
        return xWy

In [219]:
class Classifier(nn.Module):
    def __init__(self, x_size, y_size, drop_rate=0.1):
        super(Classifier, self).__init__()
        
        self.dropout = torch.nn.Dropout(drop_rate)
        self.proj = nn.Linear(x_size * 4, y_size)

    def forward(self, x1, x2, mask=None):
        x = torch.cat([x1, x2, (x1 - x2).abs(), x1 * x2], 1)
        x = self.dropout(x)
        scores = self.proj(x)
        return scores


In [220]:
def generate_mask(new_data, dropout_p=0.0, is_training=False):
    if not is_training: dropout_p = 0.0
    new_data = (1-dropout_p) * (new_data.zero_() + 1)
    for i in range(new_data.size(0)):
        one = random.randint(0, new_data.size(1)-1)
        new_data[i][one] = 1
    mask = 1.0/(1 - dropout_p) * torch.bernoulli(new_data)
    mask.requires_grad = False
    return mask

In [221]:
class BertSAN(nn.Module):
    """BERT model with SAN for entailment.
    """
    def __init__(self, K, x_size, h_size, drop_rate=0.0, num_labels=3):
        super(BertSAN, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        self.query_wsum = LinearSelfAttn(x_size, drop_rate=0.1)
        self.attn = BilinearFlatSim(x_size, h_size)
        self.rnn = torch.nn.GRUCell(input_size=768, hidden_size=768)

        self.K = K
        
        self.dropout = torch.nn.Dropout(drop_rate)
        self.classifier = Classifier(x_size, num_labels)
        
        self.alpha = nn.Parameter(torch.zeros(1, 1), requires_grad=False)

    def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
        
        hidden_states, pooled_output = self.bert(input_ids)
        
        x_mask = torch.BoolTensor((token_type_ids==0) + (attention_mask==0))
        h_mask = torch.BoolTensor((token_type_ids==1) + (attention_mask==0))
        
        h0 = x = hidden_states
        
        h0 = self.query_wsum(h0, h_mask)
        scores_list = []

        for turn in range(self.K):
          att_scores = self.attn(x, h0, x_mask)
          x_sum = torch.bmm(F.softmax(att_scores, 1).unsqueeze(1), x).squeeze(1)
          scores = self.classifier(x_sum, h0)
          scores_list.append(scores)

          h0 = self.dropout(h0)
          h0 = self.rnn(x_sum, h0)

        mask = generate_mask(self.alpha.data.new(x.size(0), self.K), 0.1, self.training)
        mask = [m.contiguous() for m in torch.unbind(mask, 1)]
        tmp_scores_list = [mask[idx].view(x.size(0), 1).expand_as(inp) * F.softmax(inp, 1) for idx, inp in enumerate(scores_list)]
        scores = torch.stack(tmp_scores_list, 2)
        scores = torch.mean(scores, 2)
        scores = torch.log(scores)
        
#         scores = scores_list[-1]

        return scores

In [222]:
# sent_ip = torch.zeros(size=(1, 512), dtype=torch.long)

# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# emb = tokenizer.encode(concat, add_special_tokens=True)[:128]

# dic = tokenizer.encode_plus(claim, hyp, pad_to_max_length=True)
# print(dic['attention_mask'])

In [223]:
model = BertSAN(10, 768, 768)

claim = "The world is big"
hyp = "The world is big"

emb = tokenizer.encode_plus(claim, hyp, pad_to_max_length=True)
input_ids, sent_ids, mask = torch.LongTensor(emb['input_ids'][:128]).unsqueeze(0), torch.LongTensor(emb['token_type_ids'][:128]).unsqueeze(0), torch.LongTensor(emb['attention_mask'][:128]).unsqueeze(0)

out = model(input_ids, sent_ids, mask)    

In [224]:
print(out)

tensor([[-1.5544, -1.2000, -0.9772]], grad_fn=<LogBackward>)
