In [None]:
import sys

import gradio as gr
import pyrootutils
import torch
import yaml

root = pyrootutils.setup_root(__vsc_ipynb_file__, dotenv=True, pythonpath=True)
sys.path.append("/data/deeplearning-all-in-one/src")

from src.models.dist_classify import DistClassify

In [None]:
path = torch.load(
    "/data/deeplearning-all-in-one/logs/train/runs/2022-08-03/18-17-12/checkpoints/model.ckpt"
)

In [None]:
map_location = {"cuda:1"}

In [None]:
model = DistClassify.load_from_checkpoint(
    path,
    hparams=yaml.safe_load(
        open(
            "/data/deeplearning-all-in-one/logs/train/runs/2022-08-03/18-17-12/.hydra/config.yaml",
            "r",
        )
    ),
    map_location=map_location,
)

In [None]:
from transformers import BertTokenizer

In [None]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext", cache_dir="~/.cache")

In [None]:
label = open("/data/clean_raw_text/street_label_map.txt", "r").readlines()

In [None]:
label = [l.strip().split(" ")[0] for l in label]

In [None]:
def fn(text: str):
    encoded = tokenizer([text], return_tensors="pt", truncation=True, max_length=200).to(device)
    with torch.no_grad():
        logit = model(encoded)[0]  # (num_classes)
    prob = torch.nn.functional.softmax(logit, dim=0)
    top_5 = torch.topk(prob, 5).indices.tolist()
    top_5_prob = prob[top_5].tolist()
    label_top_5 = [label[i] for i in top_5]
    return {k: v for k, v in zip(label_top_5, top_5_prob)}

In [None]:
fn("门禁开不了,请问维修电话,谢谢")

In [None]:
demo = gr.Interface(fn, inputs="text", outputs="label")
demo.launch(share=True, debug=True, server_port=6006)

In [None]:
demo.close()