Prepare:

In [None]:
# install libraries
!pip install flask
!pip install pyngrok
!pip install --upgrade -q accelerate bitsandbytes
!pip install git+https://github.com/huggingface/transformers.git
!pip install python-dotenv

# Load model:
from transformers import BitsAndBytesConfig, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
import torch

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"

model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id, torch_dtype="float16", device_map='auto',
    low_cpu_mem_usage=True,
    use_flash_attention_2=True)
processor = LlavaOnevisionProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "left" # set to 'left' for generation and 'right' for training (default in 'right')


# process image
def comment_on_image(images, conversations):
  generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9}
  prompts = [processor.apply_chat_template(conversation, add_generation_prompt=True) for conversation in conversations]
  inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to(model.device, torch.float16)
  output = model.generate(**inputs, **generate_kwargs)
  generated_text = processor.batch_decode(output, skip_special_tokens=True)
  return generated_text[0]

Run server:

In [None]:
import threading
from flask import Flask, request, jsonify
from pyngrok import ngrok
import numpy as np
import os
from dotenv import load_dotenv

load_dotenv()
NGROK_AUTH_TOKEN = os.getenv("NGROK_AUTH_TOKEN")
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Flask app setup
app = Flask(__name__)
if 'port' not in locals():
  port = 5000

# Open a ngrok tunnel
public_url = ngrok.connect(port).public_url
print(f"Public URL: {public_url}")

# Flask route
@app.route('/data', methods=['POST'])
def process_data():
    print("processing")
    data = request.json
    print(f"Received data: {data}")
    images = [np.array(image, dtype=np.uint8) for image in data['images']]
    if not images:
      images = None
    reply = comment_on_image(images, data['conversations'])
    # reply = f"{image.shape} {image.dtype}"
    return jsonify({"response": reply})

# Start the Flask server in a new thread
thread = threading.Thread(target=app.run, kwargs={"port": port, "use_reloader": False})
thread.start()

Stop server:

In [None]:
ngrok.disconnect(public_url)
port += 1
print(f"Tunnel {public_url} closed.")