In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from transformers_model import transformers_model
from pytorch_pretrained_bert import BertTokenizer

import fire
import json
from collections import defaultdict

from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from nltk.translate.nist_score import sentence_nist
from nltk.util import ngrams

ModuleNotFoundError: No module named 'transformers_model'

In [22]:
def bleu(predict, target, n):
    return sentence_bleu([target], predict, weights=tuple(1 / n for i in range(n)))


def nist(predict, target, n):
    if len(predict) < n or len(target) < n:
        return 0
    return sentence_nist([target], predict, n)


def cal_entropy(generated):
    etp_score = [0.0, 0.0, 0.0, 0.0]
    div_score = [0.0, 0.0, 0.0, 0.0]
    counter = [defaultdict(int), defaultdict(int),
               defaultdict(int), defaultdict(int)]
    for gg in generated:
        g = gg.rstrip().split()
        for n in range(4):
            for idx in range(len(g)-n):
                ngram = ' '.join(g[idx:idx+n+1])
                counter[n][ngram] += 1
    for n in range(4):
        total = sum(counter[n].values()) + 1e-10
        for v in counter[n].values():
            etp_score[n] += - (v+0.0) / total * (np.log(v+0.0) - np.log(total))
        div_score[n] = (len(counter[n].values())+0.0) / total
    return etp_score, div_score


def cal_length(sentences):
    sen_length = [len(s.split()) for s in sentences]
    return np.mean(sen_length), np.var(sen_length)


def calculate_metrics(predict, reference):
    reference_len = len(reference)
    predict_len = len(predict)

    #-------------------bleu----------
    bleu_2 = bleu(predict, reference, 2)
    bleu_4 = bleu(predict, reference, 4)
    #-------------------nist----------
    nist_2 = nist(predict, reference, 2)
    nist_4 = nist(predict, reference, 4)
    #-------------------meteor----------
    predict = " ".join(predict)
    reference = " ".join(reference)
    meteor_scores = meteor_score([reference], predict)
    return bleu_2, bleu_4, nist_2, nist_4, meteor_scores


def top_k_logits(logits, k):
    """Mask logits so that only top-k logits remain
    """
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1].unsqueeze(1).repeat(1, logits.shape[-1])
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)



In [14]:
top_k = 50
temperature = 1.0
decoder_path='best_model.pth'
gpu_id=0

In [15]:
# load model
print('load the model....')
device = torch.device(f"cuda:{gpu_id}")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

model = transformers_model()
model.load_state_dict(torch.load(decoder_path))

device = torch.device(f"cuda:0")
model.to(device)
model.eval()

print('load success')

load the model....
load success


In [16]:
# load test data
test_data = torch.load("../test_data.pth")
test_dataset = TensorDataset(*test_data)
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=1)

In [23]:
# start generate samples
update_count = 0

bleu_2scores = 0
bleu_4scores = 0
nist_2scores = 0
nist_4scores = 0

meteor_scores = 0
sentences = []
print('start generating....')

for batch in test_dataloader:
        with torch.no_grad():
            batch = [item.to(device) for item in batch]
            encoder_input, decoder_input, mask_encoder_input, _ = batch
            past, _ = model.encoder(encoder_input, mask_encoder_input)

            prev_pred = decoder_input[:, :1]
            sentence = prev_pred

            # decoding loop
            for i in range(100):
                logits, _ = model.decoder(sentence, encoder_hidden_states=past)
                logits = model.linear(logits)
                logits = logits[:, -1]
                logits = logits.squeeze(1) / temperature
                
                logits = top_k_logits(logits, k=top_k)
                probs = F.softmax(logits, dim=-1)
                prev_pred = torch.multinomial(probs, num_samples=1)
                sentence= torch.cat([sentence, prev_pred], dim=-1)
                if prev_pred[0][0] == 102:
                    break

            predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist())

            encoder_input = encoder_input.squeeze(dim=0)
            encoder_input_num = (encoder_input != 0).sum()
            inputs = tokenizer.convert_ids_to_tokens(encoder_input[:encoder_input_num].tolist())

            decoder_input = decoder_input.squeeze(dim=0)
            decoder_input_num = (decoder_input != 0).sum()

            reference = tokenizer.convert_ids_to_tokens(decoder_input[:decoder_input_num].tolist())
            
            temp_bleu_2, temp_bleu_4, temp_nist_2, temp_nist_4, temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1])

            bleu_2scores += temp_bleu_2
            bleu_4scores += temp_bleu_4
            nist_2scores += temp_nist_2
            nist_4scores += temp_nist_4

            meteor_scores += temp_meteor_scores
            sentences.append(" ".join(predict[1:-1]))
            
            # print some samples
            if update_count % 10 == 0:
                patient = ' '.join(inputs[1:-1])
                reference = ' '.join(reference[1:-1])
                predict = ' '.join(predict[1:-1])
                
                print('-'*20 + f"example {update_count}" + '-'*20)
                print(f"Patient: {patient.replace(' ##','')}")
                print(f"Reference: {reference.replace(' ##','')}")
                print(f"Predict: {predict.replace(' ##','')}")
            
            update_count += 1

start generating....


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


--------------------example 0--------------------
Patient: i have had mild chest pain for over a week . it now seems more persistent and pronounced . i don ' t have shortness of breath or any other covid - 19 symptoms , except some fatigue . i have been traveling a lot in high risk areas . should i get tested ?
Reference: brief opinion : yes i would advise screening due to your exposure . fever is very commonly associated with covid - 19 . stay at home , rest , drink fluids and monitor your temperature . arrange the testing which also may include a chest image with your pcp . since your have been traveling , a pulmonary embolism is another possible cause of your shortness of breath . would you like to video or text chat with me ?
Predict: possible . are your corona
--------------------example 10--------------------
Patient: i have mild irritation in my chest but i am not coughing , i just feel a tingling sensation on my thought . i have no fever , no aches should i be worried about cov

In [24]:
entro, dist = cal_entropy(sentences)
mean_len, var_len = cal_length(sentences)
print(f'avg: {mean_len}, var: {var_len}')
print(f'entro: {entro}')
print(f'dist: {dist}')
print(f'test bleu_2scores: {bleu_2scores / update_count}')
print(f'test bleu_4scores: {bleu_4scores / update_count}')
print(f'test nist_2scores: {nist_2scores / update_count}')
print(f'test nist_4scores: {nist_4scores / update_count}')
print(f'test meteor_scores: {meteor_scores / update_count}')

avg: 30.345864661654137, var: 330.9630844027362
entro: [4.160529352343808, 6.247525325712235, 7.326991431963631, 7.765953666454775]
dist: [0.036917740336966376, 0.3141173456315574, 0.6714399363563863, 0.8318219291013786]
test bleu_2scores: 0.024377008970693708
test bleu_4scores: 0.006792492305409994
test nist_2scores: 0.27158298558775
test nist_4scores: 0.26783892938356935
test meteor_scores: 0.0991544472086127
