In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [148]:
## self
import config
from preprocess import *
from utils import *

from transformers import BertModel
import torch.nn as nn
import math # for sqrt

In [11]:
dev_path = config.DEV_MA_FILE
print(dev_path)

/work/2020-IIS-NLU-internship/MNLI/data/multinli_1.0/multinli_1.0_dev_matched.jsonl


In [12]:
tokenizer = BertTokenizer.from_pretrained(config.BERT_EMBEDDING)

In [13]:
dev_set = MNLI_Raw_Dataset(dev_path, mode="develop")
print(len(dev_set))

9815


In [14]:
# how to create batch for graph???

In [46]:
dev_loader = DataLoader(
    dataset=dev_set, 
    batch_size=config.BATCH_SIZE,
    shuffle=True,  
    collate_fn=create_mini_batch
)
if config.DEBUG:
    print(next(iter(dev_loader)))

{'sentence1': {'input_ids': tensor([[  101,  5898, 16283,  2015,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101,  3930,  1004, 23713,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101,  2748,  1010, 17319,  1996,  2192,  1997,  2720,  1012,  2829,
           999,  2720,  1012,  5708,  5864,  1012,   102,     0,     

In [47]:
if config.DEBUG:
    print(tensor_to_sent(next(iter(dev_loader))[config.h_field]['input_ids'][0], tokenizer))
    print(type(next(iter(dev_loader))[config.h_field]['input_ids'][0]))

[CLS] nobody expects that the devil would take the form of a lawyer . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
<class 'torch.Tensor'>


In [169]:
class CrossBERTModel(nn.Module):
    """
    bert cross attention model
    h, p go through bert and get their contexulized embedding saparately
    and do soft alignment and prediction as in decomp-att paper
    this is a embedding enhanced version of decomp-att
    """
    def __init__(self, bert_encoder=None, cross_attention_hidden=392):
        super().__init__()
        #bert encoder
        if bert_encoder == None or not isinstance(bert_encoder, BertModel):
            print("unkown bert model choice, init with config.BERT_EMBEDDING")
            bert_encoder = BertModel.from_pretrained(config.BERT_EMBEDDING)
        self.bert_encoder = bert_encoder
        # dropouts
        self.dropout = nn.Dropout(p=bert_encoder.config.hidden_dropout_prob)
        self.activation = nn.ReLU(inplace=True)
        # linear layers for cross attention, with biased?
        self.cross_attention_hidden = cross_attention_hidden
        self.Wq = nn.Parameter(torch.Tensor(bert_encoder.config.hidden_size, self.cross_attention_hidden))
        self.Wk = nn.Parameter(torch.Tensor(bert_encoder.config.hidden_size, self.cross_attention_hidden))
        self.Wv = nn.Parameter(torch.Tensor(bert_encoder.config.hidden_size, bert_encoder.config.hidden_size))
        self.Wo = nn.Parameter(torch.Tensor(bert_encoder.config.hidden_size, bert_encoder.config.hidden_size))
        self.classifier = nn.Linear(bert_encoder.config.hidden_size, 1)
        
        forward_expansion = 1 # can change
        self.fnn = nn.Sequential(
            nn.Linear(bert_encoder.config.hidden_size, forward_expansion*bert_encoder.config.hidden_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*bert_encoder.config.hidden_size, bert_encoder.config.hidden_size),
        )
        # critrion add positive weight
        self.criterion = nn.BCEWithLogitsLoss()
    
    """
    cross attention, similar to Decomp-Att
    but no fowrad nn, use Wk Wq Wv
    input: query vector(b*n*d), content vector(b*m*d)
    ouput: sof aligned content vector to query vector(b*n*d)
    """
    def cross_attention(self, h1, h2):
        Q = torch.matmul(h1, self.Wq)
        K = torch.matmul(h2, self.Wk)
        V = torch.matmul(h2, self.Wv)
        #Kt = torch.matmul(h2, self.Wk).permute(0,2,1)
        #E = torch.matmul(Q, Kt)
        E = torch.einsum("bnd,bmd->bnm", [Q, K]) # batch, n/m, dimension
        A = torch.softmax(E / (math.sqrt(self.cross_attention_hidden)), dim=2) #soft max dim = 2
        # attention shape: (N, heads, query_len, key_len)
        aligned_2_for_1 = torch.einsum("bnm,bmd->bnd", [A, V])
        if(config.DEBUG):
            print(Q.size())
            print(K.size())
            print(E.size())
            print(A.size())
            print(aligned_2_for_1.size())
            
        return aligned_2_for_1
    
    def forward(self, batch):
        """
        'sentence1' : {'input_ids', 'token_type_ids', 'attention_mask'}
        'sentence2'
        'gold_label'
        """
        # get bert contextualized embedding
        # the _ here is the last hidden states
        # q_poolout is a 768-d vector of [CLS]
        # q_poolout = self.dropout(q_poolout), MT Wu : no dropout better, without 
        hidden = [None, None]
        q_poolout = [None, None]
        hidden[0], q_poolout[0] = self.bert_encoder(input_ids=batch[config.h_field]['input_ids'],
                                         token_type_ids=batch[config.h_field]['token_type_ids'],
                                         attention_mask=batch[config.h_field]['attention_mask'])
        hidden[1], q_poolout[1] = self.bert_encoder(input_ids=batch[config.p_field]['input_ids'],
                                         token_type_ids=batch[config.p_field]['token_type_ids'],
                                         attention_mask=batch[config.p_field]['attention_mask'])
        # soft alignment
        aligned_p_for_h = self.cross_attention(hidden[0], hidden[1])
        aligned_h_for_p = self.cross_attention(hidden[1], hidden[0])
        if(config.DEBUG):
            print("hidden[0] (hypothesis) size : " + str(hidden[0].size()))
            print("aligned_p_for_h size : " + str(aligned_p_for_h.size()))
        logits = self.classifier(q_poolout)
        # can apply nn.module.Sigmoid here to convert to p-distribution
        # score is indeed better (and more stable)
        logits = logits.squeeze(-1)
        return logits
    
    # the nn.Module method
    def forward_2(self, batch):
        logits = self.forward_nn(batch)
        batch[3] = batch[3].to(dtype=torch.float)
        loss = self.criterion(logits, batch[3])
        return loss
    
    # return sigmoded score
    def _predict_score(self, batch):
        logits = self.forward_nn(batch)
        scores = torch.sigmoid(logits)
        scores = scores.detach().cpu().numpy().tolist()
        return scores
    
    # return True False based on score + threshold
    def _predict(self, batch, threshold=0.5):
        scores = self._predict_score(batch)
        return [ 1 if score >= threshold else 0 for score in scores]
    
    # return result with assigned threshold, default = 0.5
    def predict_fgc(self, q_batch, threshold=0.5):
        scores = self._predict(q_batch)

        max_i = 0
        max_score = 0
        sp = []
        for i, score in enumerate(scores):
            if score > max_score:
                max_i = i
                max_score = score
            if score >= threshold:
                sp.append(i)

        if not sp:
            sp.append(max_i)

        return {'sp': sp, 'sp_scores': scores}

In [170]:
model = CrossBERTModel()

unkown bert model choice, init with config.BERT_EMBEDDING


In [171]:
test_batch = next(iter(dev_loader))

In [172]:
print(len(test_batch[config.h_field]["input_ids"]))
print(len(test_batch[config.h_field]["input_ids"][5]))
print(len(test_batch[config.p_field]["input_ids"]))
print(len(test_batch[config.p_field]["input_ids"][5]))

8
18
8
54


In [173]:
model(test_batch)

torch.Size([8, 18, 392])
torch.Size([8, 54, 392])
torch.Size([8, 18, 54])
torch.Size([8, 18, 54])
torch.Size([8, 18, 768])
hidden[0] (hypothesis) size : torch.Size([8, 18, 768])


NameError: name 'aligned_h' is not defined