In [17]:
# Copyright (c) 2021, Hyunwoong Ko. Modified 2022 by Billy Cao.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import torch
import numpy as np
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from flask import Flask, request, jsonify
import onnx
from onnxruntime import InferenceSession


session = InferenceSession("ctrl-sum/onnx_model/model.onnx", providers=['CUDAExecutionProvider'])
model = onnx.load("ctrl-sum/onnx_model/model.onnx")
tokenizer = AutoTokenizer.from_pretrained('bert-large-cased')

source_text = "the sky is blue"

inputs = tokenizer(source_text, return_tensors="np")
inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
outputs = session.run(output_names=["logits"], input_feed={'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'token_type_ids': inputs['token_type_ids']})

In [18]:
import scipy.special
labels = ["0", "1"]
label_map = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)

def align_predictions(predictions):
    label2id = {label: i for i, label in enumerate(labels)}
    preds = np.argmax(predictions, axis=2)
    predictions = scipy.special.softmax(predictions, axis=2)
    batch_size, seq_len = preds.shape
    preds_prob_list = []
    for j in range(seq_len):
        preds_prob_list.append(predictions[0][j][label2id['1']])

    return preds_prob_list[1:-1]  # remove start and end token


In [19]:
probs = align_predictions(outputs[0])
keywords = []
for i in range(len(probs)):
    if probs[i] > 0.03:
        keywords.append(source_text.split()[i])