In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM

In [3]:
!pip install fastapi nest-asyncio pyngrok uvicorn

Collecting fastapi
  Downloading fastapi-0.115.11-py3-none-any.whl.metadata (27 kB)
Collecting pyngrok
  Downloading pyngrok-7.2.3-py3-none-any.whl.metadata (8.7 kB)
Collecting uvicorn
  Downloading uvicorn-0.34.0-py3-none-any.whl.metadata (6.5 kB)
Collecting starlette<0.47.0,>=0.40.0 (from fastapi)
  Downloading starlette-0.46.1-py3-none-any.whl.metadata (6.2 kB)
Downloading fastapi-0.115.11-py3-none-any.whl (94 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.9/94.9 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyngrok-7.2.3-py3-none-any.whl (23 kB)
Downloading uvicorn-0.34.0-py3-none-any.whl (62 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.3/62.3 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading starlette-0.46.1-py3-none-any.whl (71 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.0/72.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: uvicorn, pyngrok, s

In [4]:
TOP_K = 10

class Predictor:
    def __init__(self, model_path = None):
        self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()

    def decode(self,tokenizer, pred_idx, top_clean):
        ignore_tokens = string.punctuation + '[PAD]'
        tokens = []
        for w in pred_idx:
            token = ''.join(tokenizer.decode(w).split())
            if token not in ignore_tokens:
                tokens.append(token.replace('##', ''))
        return '\n'.join(tokens[:top_clean])

    def encode(self, tokenizer, text_sentence, add_special_tokens=True):
        text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
        if tokenizer.mask_token == text_sentence.split()[-1]:
            text_sentence += ' .'
        input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
        mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
        return input_ids, mask_idx

    def get_all_predictions(self, text_sentence, top_clean=5):
        input_ids, mask_idx = self.encode(self.bert_tokenizer, text_sentence)
        with torch.no_grad():
            predict = self.bert_model(input_ids)[0]
        predicted_words = self.decode(self.bert_tokenizer, predict[0, mask_idx, :].topk(TOP_K).indices.tolist(), top_clean)
        return predicted_words

    def gen_m_words_n_predictions(self, m, n, input_text):
        output = []
        res = self.get_all_predictions(input_text + ' <mask>', top_clean=n).split('\n')
        input = input_text
        for i in res:
            input_text = input+' '+i
            for i in range(m-1):
                word = self.get_all_predictions(input_text + ' <mask>', top_clean=1).split('\n')
                input_text = input_text+ ' ' + word[0]
            output.append(input_text)
        return output

In [6]:
from fastapi import HTTPException
from pydantic import BaseModel

# Create an instance of Predictor
nextWord = Predictor()

# Define a Pydantic model for input data validation
class NextWordInput(BaseModel):
    text: str
    predictions: int
    tokens: int

# GET method to check service status
def get_service_status():
    return {"status": "success", "message": "Service is running"}

# POST method to generate next word predictions
def get_next_words(data: NextWordInput):
    try:
        result = nextWord.gen_m_words_n_predictions(data.tokens, data.predictions, data.text)
        return {"status": "success", "words": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Something went wrong: {type(e).__name__} {e}")


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

In [10]:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import nest_asyncio
from pyngrok import ngrok
import uvicorn

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins = ["*"],
    allow_credentials = True,
    allow_headers = ["*"],
    allow_methods = ["*"],
)

@app.get("/")
def read_root():
    return get_service_status()

@app.post("/predict")
async def post_next_word(data: NextWordInput):
    return get_next_words(data)


In [12]:
!ngrok help

NAME:
  ngrok - tunnel local ports to public URLs and inspect traffic

USAGE:
  ngrok [command] [flags]

DESCRIPTION: 
  ngrok exposes local networked services behinds NATs and firewalls to the
  public internet over a secure tunnel. Share local websites, build/test
  webhook consumers and self-host personal services.
  Detailed help for each command is available by adding '--help' to any command or with
  the 'ngrok help' command.
  Open https://dashboard.ngrok.com/obs/traffic-inspector to inspect traffic.


TERMS OF SERVICE: https://ngrok.com/tos

EXAMPLES: 
  ngrok http 80                           # secure public URL for port 80 web server
  ngrok http --url baz.ngrok.dev 8080     # port 8080 available at baz.ngrok.dev
  ngrok http foo.dev:80                   # tunnel to host:port instead of localhost
  ngrok http https://localhost            # expose a local https server
  ngrok tcp 22                            # tunnel arbitrary TCP traffic to port 22
  ngrok tls --url=foo.com 

In [13]:
!ngrok config add-authtoken 2Bc5ezPv3N4qiggg9ZMSsP0oS21_bbJAVhavuUapKCjRubtz

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [14]:
port = 8000
ngrok_tunnel = ngrok.connect(port)
print('Public URL:', ngrok_tunnel.public_url)

nest_asyncio.apply()

uvicorn.run(app, port=port)

Public URL: https://8cb5-34-32-170-53.ngrok-free.app


INFO:     Started server process [31]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


INFO:     105.115.0.175:0 - "POST /predict HTTP/1.1" 422 Unprocessable Entity
INFO:     105.115.0.175:0 - "POST /predict HTTP/1.1" 422 Unprocessable Entity
INFO:     105.115.0.175:0 - "POST /predict HTTP/1.1" 200 OK
INFO:     105.115.0.175:0 - "POST /predict HTTP/1.1" 200 OK
INFO:     105.115.0.175:0 - "POST /predict HTTP/1.1" 200 OK
INFO:     105.115.0.175:0 - "POST /predict HTTP/1.1" 200 OK
INFO:     105.115.0.175:0 - "POST /predict HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [31]
