In [1]:
import os
os.chdir('../')

In [2]:
import yaml
import gluonnlp as nlp
import torch

from kobert import get_pytorch_kobert_model
from kobert.utils import get_tokenizer

from models import create_model

import ipywidgets as widgets
from IPython.display import display, clear_output
import time

output_model = widgets.Output()

@output_model.capture()
def on_click_callback_model(b: widgets.Button) -> None:
    global model
    global dataset
    
    _, vocab = get_pytorch_kobert_model(cachedir=".cache")
    tokenizer = nlp.data.BERTSPTokenizer(get_tokenizer(), vocab, lower=False)
    
    cfg = yaml.load(
        open(f'./configs/{model_select.value}/{model_select.value}-test.yaml','r'), 
        Loader = yaml.FullLoader
    )
    
    model = create_model(
        modelname  = cfg['MODEL']['modelname'], 
        hparams    = cfg['MODEL']['PARAMETERS'],
        tokenizer  = tokenizer, 
        checkpoint_path = cfg['MODEL']['CHECKPOINT']['checkpoint_path']
    )
    
    dataset = __import__('dataset').__dict__[f"{cfg['DATASET']['name']}Dataset"](
        tokenizer       = tokenizer,
        vocab           = vocab,
        **cfg['DATASET']['PARAMETERS']
    )
    
    model.eval()


output = widgets.Output(layout=widgets.Layout(width='500px', border='1px solid black'))

# 데코레이터로써 사용하면 기본 출력처가 된다.
@output.capture()
def on_click_callback(b: widgets.Button) -> None:
    clear_output(wait=True)
    
    doc = [s.strip() for s in text.value.split('.')]
    inputs = dataset.single_preprocessor(doc=doc)

    end = time.time()
    outputs = model(**inputs)
    run_time = time.time() - end
    
    outputs = torch.nn.functional.softmax(outputs,dim=1)
    preds = outputs.argmax(dim=-1)
    
    print(f'Run time: {time.strftime("%H:%M:%S", time.gmtime(run_time))}\n')
    
    for i, t in enumerate(doc):
        if preds[i-1] == 1 and i != 0:
            print()
            print(f'\n--- fake ({outputs[i-1,1]:.2%}) ---\n')

        print(t, end=' ')


# ======================
# layout
# ======================
    
title = widgets.HTML(
    value="<h1>Select Model</h1>"
)
    
model_select = widgets.Dropdown(
    options=['BTS','KoBERTSeg'],
    value='KoBERTSeg',
    description='Model:',
    disabled=False,
)
button_model = widgets.Button(description='Select Model')
button_model.on_click(on_click_callback_model)    
    
display(title)
display(widgets.HBox([model_select,button_model]))

text = widgets.Textarea(
    value=' ',
    placeholder='Type something',
    description='Text:',
    layout = widgets.Layout(width='500px', height='500px')
)

button = widgets.Button(description='Run')

run = widgets.HBox([text,button])

display(run)



result = widgets.HTML(
    value="<h1>Result</h1>"
)
        
button.on_click(on_click_callback)
display(widgets.VBox([result,output]))

HTML(value='<h1>Select Model</h1>')

HBox(children=(Dropdown(description='Model:', index=1, options=('BTS', 'KoBERTSeg'), value='KoBERTSeg'), Butto…

HBox(children=(Textarea(value=' ', description='Text:', layout=Layout(height='500px', width='500px'), placehol…

VBox(children=(HTML(value='<h1>Result</h1>'), Output(layout=Layout(border='1px solid black', width='500px'))))