In [None]:
!pip install transformers -U
!pip install fastapi nest-asyncio pyngrok uvicorn bitsandbytes huggingface_hub accelerate
!ngrok authtoken xxx  # replace xxx with your ngrok token

# Only need this part for models with restriced access
from huggingface_hub import login
login()

In [1]:
import json
import hashlib

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

from starlette.requests import Request
from pyngrok import ngrok, conf
import nest_asyncio
import uvicorn

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from threading import Thread

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model Stuff

In [None]:
model_name = "google/gemma-7b-it"

quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)


async def model_call(text):
    try:
        inputs = tokenizer(text, return_tensors="pt").to(device)

        generation_kwargs = dict(inputs, streamer=streamer)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        for new_text in streamer:
            yield new_text
    except Exception as e:
        print(e)

# API stuff

In [3]:
def hash_key(key):
    hasher = hashlib.sha256()
    hasher.update(key.encode())
    return hasher.hexdigest()

async def get_hashed_key(api_key: str):
    if not api_key:
        raise HTTPException(status_code=401, detail="Missing X-API-KEY header")
    return hash_key(api_key)

ALLOWED = ""  # replace with the SHA256 hash of your API key

In [None]:
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

@app.get('/')
async def root(request: Request):
    api_key = request.headers.get("X-API-KEY")
    hashed_key = await get_hashed_key(api_key)

    if hashed_key == ALLOWED:
        user_data = await request.body()
        user_data = json.loads(json.loads(user_data))
        msg = user_data.get("content")

        # model_ans = await model_call(msg)
        # return {"message": model_ans}

        return StreamingResponse(model_call(msg), media_type="text/event-stream")
    else:
        return {"message": f"Unauthorized access!"}

ngrok_tunnel = ngrok.connect(8000)
print('Public URL:', ngrok_tunnel.public_url)
nest_asyncio.apply()
uvicorn.run(app, port=8000)
