In [1]:
# !pip install ratsnlp

In [2]:
from google.colab import drive

drive.mount('/gdrive', force_remount = True)

Mounted at /gdrive


In [3]:
from ratsnlp.nlpbook.classification import ClassificationDeployArguments

args = ClassificationDeployArguments(
    pretrained_model_name = 'beomi/kcbert-base',
    downstream_model_dir = '/gdrive/MyDrive/nlpbook/checkpoint-doccls',
    max_seq_length = 128,
)

downstream_model_checkpoint_fpath: /gdrive/MyDrive/nlpbook/checkpoint-doccls/epoch=0-val_loss=0.27.ckpt


In [4]:
import torch
from transformers import BertConfig, BertForSequenceClassification

fine_tuned_model_ckpt = torch.load(
    args.downstream_model_checkpoint_fpath,
    map_location = torch.device('cpu')
)

pretrained_model_config = BertConfig.from_pretrained(
    args.pretrained_model_name,
    num_labels = fine_tuned_model_ckpt['state_dict']['model.classifier.bias'].shape.numel(),
)

model = BertForSequenceClassification(pretrained_model_config)

model.load_state_dict({k.replace('model.', '') : v for k , v in fine_tuned_model_ckpt['state_dict'].items()})

<All keys matched successfully>

In [5]:
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30000, 768, padding_idx=0)
      (position_embeddings): Embedding(300, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [6]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained(
    args.pretrained_model_name,
    do_lower_case = False
)

In [7]:
def inference_fn(sentence) :
  inputs = tokenizer(
      [sentence],
      max_length = args.max_seq_length,
      padding = 'max_length',
      truncation = True,
  )

  with torch.no_grad() :
    outputs = model(**{k: torch.tensor(v) for k, v in inputs.items()})
    prob = outputs.logits.softmax(dim = 1)
    positive_prob = round(prob[0][1].item(), 4)
    negative_prob = round(prob[0][0].item(), 4)
    pred = '긍정 (positive)' if torch.argmax(prob) == 1 else '부정 (negtive)'

  return {
      'sentence' : sentence,
      'prediction' : pred,
      'positive_data' : f'긍정 {positive_prob}',
      'negative_data' : f'부정 {negative_prob}',
      'positive_width' : f'{positive_prob * 100}%',
      'negative_width' : f'{round(negative_prob * 100.2)}%'
  }

In [8]:
sentence = '기린이 맞다!'

inference_fn(sentence)

{'sentence': '기린이 맞다!',
 'prediction': '긍정 (positive)',
 'positive_data': '긍정 0.6876',
 'negative_data': '부정 0.3124',
 'positive_width': '68.76%',
 'negative_width': '31%'}

In [9]:
!mkdir /root/.ngrok2 && echo 'authtoken: 2T5dj5MGKLVe9gdeTmRn96rqUGK_6bLTw98KhQnQENd1C9FCt' > /root/.ngrok2/ngrok.yml

mkdir: cannot create directory ‘/root/.ngrok2’: File exists


In [10]:
from ratsnlp.nlpbook.classification import get_web_service_app

app = get_web_service_app(inference_fn)

app.run()

 * Serving Flask app 'ratsnlp.nlpbook.classification.deploy'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m


 * Running on http://4e69-34-87-82-132.ngrok-free.app
 * Traffic stats available on http://127.0.0.1:4040


INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:07] "GET / HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:08] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:16] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:17] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:17] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:22] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:25] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:28] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:31] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:30:35] "POST /api HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [28/Jul/2023 00:32:43] "POST /api HTTP/1.1" 200 -
