# Run local predictions

Taken ref from: https://www.kaggle.com/code/nulldata/fine-tuning-gpt-2-to-generate-netlfix-descriptions/notebook

Which took ref from: https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475

## Setup

In [1]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

In [2]:
%%capture
if IN_COLAB:
    
    #Remove not needed python versions to free space
    !rm -rf "/usr/local/lib/python2.7"
    !rm -rf "/usr/lib/python2.7"

    # Clone the repo.
    # !git clone ""

    # Change the working directory to the repo root.
    # %cd

    # Add the repo root to the Python path.
    # import sys, os
    # sys.path.append(os.getcwd())
    
    #Install packages not native to colab
    # !pip install python-dotenv
    !pip install python-dotenv
    !pip install transformers
    !pip install transformers[onnx]
    !pip install optimum --upgrade
    !pip install optimum[onnxruntime] --upgrade
    !pip install datasets
    !pip install wandb --upgrade
    !pip install fastapi pyngrok nest_asyncio uvicorn httpx
    # !pip install pandas-profiling --upgrade

    #Mount GDrive to access .env file
    from google.colab import drive
    drive.mount('/content/gdrive')

    #Load env file
    #NOTE: gdrive wont allow you to mount dotfiles
    from dotenv import load_dotenv
    load_dotenv("./gdrive/MyDrive/my_env_file")

In [3]:
import wandb
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33ma-sh0ts[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
_model_conf = {
    
}

In [5]:
project_name = "gpt2-netflix"
run_name = "infer-generative-netflix"
run_type = "inference"

In [6]:
run = wandb.init(
        project=project_name, job_type=run_type, name=run_name, config = _model_conf)

In [7]:
from google.colab import auth
auth.authenticate_user()

# https://cloud.google.com/resource-manager/docs/creating-managing-projects
project_id = 'wandb-growth'
!gcloud config set project {project_id}

Updated property [core/project].


In [8]:
model_art_path = run.use_artifact("generative-netflix:latest").download()
tokenizer_path = run.use_artifact("gpt2-netflix-tokenizer:latest").download()

[34m[1mwandb[0m: Downloading large artifact generative-netflix:latest, 540.41MB. 2 files... Done. 0:0:8.3


In [9]:
import os
model_path = os.path.join(model_art_path, "model.onnx")
config_path = os.path.join(model_art_path, "ort_config.json")

In [11]:
#BUG: name of config must be config.json for now just loading into dict and manually entering the config needed
import json
with open(config_path, 'r') as j:
    config = json.loads(j.read())

In [12]:
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM
from optimum.pipelines import pipeline


model = ORTModelForCausalLM.from_pretrained(model_art_path, file_name="model_quantized.onnx", config=config)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

inference_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]



In [13]:
query = "Weights and Biases is about"

In [14]:
result = inference_pipeline(query)



In [15]:
result

[{'generated_text': 'Weights and Biases is about the people who make the world a better place, and the people'}]

In [16]:
run.log({
    "query": query,
    "result": result
})

In [17]:
run.finish()

VBox(children=(Label(value='0.169 MB of 0.169 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
query,Weights and Biases i...


Let's make a web service!

In [18]:
%%writefile models.py
from pydantic import BaseModel

class Input(BaseModel):
    QUERY: str

Writing models.py


In [19]:
from models import Input

from concurrent.futures import ProcessPoolExecutor

import nest_asyncio
import uvicorn
from fastapi import FastAPI
from pyngrok import ngrok

app = FastAPI()
exc = ProcessPoolExecutor(max_workers=1)  # To handle pred-time logs out of process

# Here is how to load the model in a startup event
# @app.on_event("startup")
# def load_model():
#     global model

#     prod_model_path = get_prod_model_from_wandb(dataset_name, group_id)
#     model_path = Path(prod_model_path)
#     if model_path:
#         print(model_path)
#         model = load_model("./artifacts/credit_model_artifacts:v1/model")
#         # model = load_model(model_path.replace(".pkl", ""))
#     else:
#         print(model_path)

@app.get("/")
def read_root():
    return {"hello": "world"}

@app.post("/api", tags=["prediction"])
async def get_predictions(input_dict: Input):
    try:
        data = input_dict.dict()
        query = data["QUERY"]
        result = inference_pipeline(query)
        return {"result": result}
    except Exception as e:
        print(e.message)
        print(e.args)
        print("Something went wrong!")
        return {"result": "error"}

In [23]:
!ngrok authtoken $NGROK_AUTH

Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml


In [24]:
tunnel = ngrok.connect(8000)
tunnel

<NgrokTunnel: "http://c5f9-34-82-231-49.ngrok.io" -> "http://localhost:8000">

In [None]:
nest_asyncio.apply()
uvicorn.run(app, port=8000)

INFO:     Started server process [69]
INFO:uvicorn.error:Started server process [69]
INFO:     Waiting for application startup.
INFO:uvicorn.error:Waiting for application startup.
INFO:     Application startup complete.
INFO:uvicorn.error:Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO:uvicorn.error:Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


INFO:     88.30.27.250:0 - "GET / HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO:     88.30.27.250:0 - "GET / HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "GET /docs HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "GET /docs HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     88.30.27.250:0 - "POST /api HTTP/