## ContextEvaluator

In [1]:
import torch
import numpy as np
from transformers import BertTokenizer, BertForNextSentencePrediction

class ContextEvaluator:
    def __init__(self, nsp_limit):
        self.nsp_limit = nsp_limit
        self.tokenizer = BertTokenizer.from_pretrained('klue/bert-base')
        self.model = BertForNextSentencePrediction.from_pretrained('klue/bert-base')
        
    def evaluate_context(self, text1, text2):
        #0: IsNSP, 1: NotNSP
        
        def cal_softmax(x):
            #get_softmax value
            softmax_x = np.exp(x - np.max(x))
            return softmax_x / softmax_x.sum()
        
        input_tensor = self.tokenizer(text1, text2, return_tensors='pt')
        predict = self.model(**input_tensor)
        predict = predict.logits.detach().numpy()[0]   #tensor2numpy
        
        softmax = cal_softmax(predict)
        return softmax[0]   #softmax[0] == IsNSP probability

In [2]:
#Load Class
contextEvaluator = ContextEvaluator(0.5)

Downloading:   0%|          | 0.00/243k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/289 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/425 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/424M [00:00<?, ?B/s]

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForNextSentencePrediction: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Test

In [18]:
contextEvaluator.evaluate_context("오늘은 뭐 먹고 왔어?", "너무 춥다")

0.34925595