In [None]:
!pip install -U mlx-vlm gradio

In [12]:
import mlx.core as mx
from mlx_vlm import load, generate
from mlx_vlm.utils import generate_step, sample, prepare_inputs, load_config, load_image_processor, get_model_path

In [2]:
model_path = "mlx-community/nanoLLaVA"
model_path = get_model_path(model_path)
model, processor = load(model_path)
config = load_config(model_path)
image_processor = load_image_processor(config)

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
prompt = processor.apply_chat_template(
    [{"role": "user", "content": f"<image>\nWhat's so funny about this image?"}],
    tokenize=False,
    add_generation_prompt=True,
)

output = generate(model, processor, "./assets/image.png", prompt, image_processor, verbose=False)

Image: ./assets/image.png 

Prompt: <|im_start|>system
Answer the questions.<|im_end|><|im_start|>user
<image>
What's so funny about this image?<|im_end|><|im_start|>assistant

This image is quite amusing. It's a humorous drawing of a man holding a large, yellow, and black balloon. The man is wearing a white shirt and has a white beard. The balloon is so large that it almost covers the man's entire body. The man's hand is also visible, and it seems like he's holding the balloon with both hands.
Prompt: 77.345 tokens-per-sec
Generation: 46.756 tokens-per-sec


"This image is quite amusing. It's a humorous drawing of a man holding a large, yellow, and black balloon. The man is wearing a white shirt and has a white beard. The balloon is so large that it almost covers the man's entire body. The man's hand is also visible, and it seems like he's holding the balloon with both hands."

In [14]:
import time
from typing import Optional

def generate(
    model,
    processor,
    image: str,
    prompt: str,
    image_processor = None,
    temp: float = 0.0,
    max_tokens: int = 100,
    repetition_penalty: Optional[float] = None,
    repetition_context_size: Optional[int] = None,
    top_p: float = 1.0,
):

    if image_processor is not None:
        tokenizer = processor
    else:
        tokenizer = processor.tokenizer

    input_ids, pixel_values = prepare_inputs(image_processor, processor, image, prompt)
    logits, cache = model(input_ids, pixel_values)
    logits = logits[:, -1, :]
    y, _ = sample(logits, temp, top_p)


    tic = time.perf_counter()
    detokenizer = processor.detokenizer
    detokenizer.reset()

    detokenizer.add_token(y.item())

    for (token, _), n in zip(
        generate_step(
            model.language_model,
            logits,
            cache,
            temp,
            repetition_penalty,
            repetition_context_size,
            top_p,
        ),
        range(max_tokens),
    ):
        token = token.item()

        if token == tokenizer.eos_token_id:
            break

        detokenizer.add_token(token)
        detokenizer.finalize()
        yield detokenizer.last_segment

In [19]:
import gradio as gr

def generate_response(message, history):
    prompt = message["text"]

    prompt = processor.apply_chat_template(
        [{"role": "user", "content": f"<image>\n{prompt}"}],
        tokenize=False,
        add_generation_prompt=True,
    )
    response = ""
    for chunk in generate(model, processor, message["files"][0], prompt, image_processor):
        response+=chunk
        yield response

demo = gr.ChatInterface(fn=generate_response, title="MLX-VLM Bot", multimodal=True)

demo.launch()

Running on local URL:  http://127.0.0.1:7869

To create a public link, set `share=True` in `launch()`.


