- `MODEL` is the huggingface repo that will be downloaded from.
- `CUDA` defines whether CUDA will be used for `input_ids` (significant speedup on NVIDIA GPUs).
- `USER` defines the user tag (usage defined in `ROLE`).
- `ASSISTANT` defines the assistant tag (usage defined in `ROLE`).
- `START_HEADER` defines the start of the header tag (usage defined in `ROLE`).
- `END_HEADER` defines the end of the header tag (usage defined in `ROLE`).
- `ROLE` is a function for concatenating the header tags.
- `START_TEXT` defines the start of the text tag (usage defined in `TEXT`).
- `END_TEXT` defines the end of the text tag (usage defined in `TEXT`).
- `TEXT` is a function for concatenating the text tags.
- `IMAGE` defines the image tag, leave blank for non-vision models.
- `prompt` defines the beginning of the prompt, leave blank for no start tag.
- `image` should be untouched.

In [None]:
MODEL="unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit"
CUDA=True
USER="user"
ASSISTANT="assistant"
START_HEADER="<|start_header_id|>"
END_HEADER="<|end_header_id|>\n\n"
ROLE = lambda str : START_HEADER + str + END_HEADER
START_TEXT=""
END_TEXT="<|eot_id|>"
TEXT = lambda str : START_TEXT + str + END_TEXT
IMAGE="<|image|>"
prompt = "<|begin_of_text|>"
image=[]

In [None]:
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import re

In [None]:
processor = AutoProcessor.from_pretrained(MODEL)
model = AutoModelForImageTextToText.from_pretrained(MODEL)

In [None]:
def generate(msg, images):
    global prompt
    global image

    prompt += ROLE(USER)
    for imgobj in images:
        prompt += IMAGE
        image.append(imgobj)
    prompt += TEXT(msg)
    prompt += ROLE(ASSISTANT)

    inputs = None
    if len(image) > 0:
        inputs = processor(text=prompt, images=image, return_tensors="pt")
    else:
        inputs = processor(text=prompt, return_tensors="pt")

    if CUDA:
        inputs = inputs.to("cuda")

    output = model.generate(**inputs, max_new_tokens=256)

    prompt_len = inputs.input_ids.shape[-1]
    generated_ids = output[:, prompt_len:]
    generated_text = processor.batch_decode(generated_ids, clean_up_tokenization_spaces=False)

    prompt += generated_text[-1]

    return generated_text[-1]

In [None]:
REGEXPR = r"\"?[ABCDEFGH]:[\\/].+?\..+?[\" ]"

while True:
    msg = input()

    imagepaths = re.findall(REGEXPR, msg)
    images = []
    for path in imagepaths:
        images.append(Image.open(path.replace("\"","")))
    msg = re.sub(REGEXPR, "", msg)

    output = generate(msg, images)

    print(output)