```
bokeh serve   shaji_deploy_4137_xgboost_200_sample_per_class.ipynb  --allow-websocket-origin=34.217.44.96:5007 --port 5007
```

In [1]:
import panel as pn
import holoviews as hv
import pandas as pd


import random
from bokeh.palettes import Category20b_20
colors = list(Category20b_20)

In [2]:
import sys
pn.extension()
import pickle

### Create a pipe stream

In [3]:
pipe = hv.streams.Pipe(data=[])
tb = pn.pane.HTML("")

### Create some random buttons, text boxes

In [4]:
submit_button = pn.widgets.Button(name='Enter Query', button_type='primary')
submit_text = pn.widgets.TextInput(value=None)

def b_submit(event):
    pipe.send((submit_text.value))
    
submit_button.on_click(b_submit)

reset_button = pn.widgets.Button(name='Clear Query', button_type='primary')

def b_reset(event):
    submit_text.value = ""
    pipe.send((submit_text.value))
    
reset_button.on_click(b_reset)

### Make infrastructure for inference

In [5]:
dataset = hv.Dataset([])

In [6]:
import json
import re
sys.path.insert(0, '.')
import TransformerModel
import train
import torch
from tokenizers import Tokenizer
import numpy as np


config = json.loads(re.sub(r'#.*?\n', '', open('config.json', 'r').read()))
model  = train.HTSClassifier.load_from_checkpoint(config['lm_save_file']).eval()
tokenizer = Tokenizer.from_file(config['token_config'])
padding_length = int(config['padding_length'])
hts_map = pd.read_csv("hts_train.csv", dtype={'hs': str, 'desc' : str})
with open(config['save_dir'] + '/label_enc.pkl', 'rb') as f: label_enc = pickle.load(f)


Using configuration file : config.json


In [7]:
pd.set_option('display.max_colwidth', None)

In [8]:
hts_map = pd.read_csv("hts_train.csv", dtype={'hs': str, 'desc' : str})

In [9]:
def get_sample_prediction(text, num_samples=10):
    with torch.no_grad():
         enc = tokenizer.encode(text)
         ids = np.array(enc.ids[:padding_length])
         ids = np.vectorize(lambda x : 1 if not x else x)(ids)
         mask  = (torch.from_numpy(np.array(ids)) == 0)
         ids = torch.from_numpy(ids)
         y = model.forward(ids.reshape(1, padding_length), mask.reshape(1, padding_length))
         logits = torch.softmax(y, dim=1)
         sorted_prob, indices = torch.sort(logits, descending=True)
         indices = label_enc.inverse_transform(indices[0].numpy()[:num_samples])
         sorted_prob = sorted_prob[0].numpy()[:num_samples]
         df_rank = (pd.DataFrame([{'hs' : c, 'probablity' : p} for c, p in zip(indices, sorted_prob)])
                      .merge(hts_map, on='hs', how='left').fillna('No description')
                   )
    return df_rank

In [10]:
css = '''
@import url("https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta2/dist/css/bootstrap.min.css");

.table {
   font-size: 18px;
}

'''
pn.extension(raw_css=[css])

In [11]:
def select_data(ds, query): 
    df = ds.data
    if isinstance(query, list) or not str(query):
        tb.object = ""
    else :
        dtable    = (get_sample_prediction(str(query), 10)
                  .to_html()
                  .replace("dataframe", "table table-bordered table-hover thead-light"))
        
        random.shuffle(colors)

        rep = [(s, '<span style="color:'+ colors[i]  + ';font-weight: bold;text-decoration: underline;">'+s+'</span>') for i, s in enumerate(str(query).split())]
        
        for x, y in rep : dtable = dtable.replace(x, y)
            
        tb.object = dtable

        
    return hv.Div("")

filtered_ds = dataset.apply(select_data,  query=pipe.param.data)

### Layout data

In [12]:
pn.config.sizing_mode="stretch_width"

MAX_WIDTH=1140

spacer = pn.Spacer(height=30, margin=0)

main_area = pn.Column(
    spacer,
    pn.Column('# Do Inference'),
    pn.Column(
      submit_text, 
      pn.Row(submit_button, reset_button)
    ),
    spacer,
    pn.Column(tb),
    pn.Column(filtered_ds, width=1, height=1),
    sizing_mode="stretch_both",
)

main_area.servable(title='inf')