
# Streaming Multimodal Outputs with Hugging Face Transformers
**Based on:** *How to Stream Multimodal Outputs with Hugging Face Transformers (Bonus: Gradio Integration)* — Youssef Ghaoui (Medium)


This notebook reproduces the core examples from the article: a blocking (traditional) generation example, a streaming generation example using `TextIteratorStreamer`, and a Gradio demo integration. Replace `HF_TOKEN` and model/device settings as needed for your environment.



## Environment / Install
Run the following in a cell (or your terminal) to install required packages. Use the correct CUDA/torch combo for your machine.
```bash
# Example (adjust torch / cuda versions for your setup)
%pip install transformers==4.50.1 accelerate==0.26.0 gradio==4.44.1 Pillow requests
```
**Note:** The notebook uses `google/gemma-3-4b-it` as an example; ensure you have access and sufficient GPU memory. Consider using a smaller model for testing.


In [None]:

# Imports and Hugging Face token setup
import os
import dotenv
dotenv.load_dotenv()  # load HF_TOKEN from .env if present

HF_TOKEN = os.getenv("HF_TOKEN", None)
if HF_TOKEN is None:
    print("Warning: HF_TOKEN not found in environment. Set HF_TOKEN in your environment or .env to access private models if needed.")

# Common imports
import torch
from PIL import Image
from io import BytesIO
import requests


## 1) Traditional (blocking) generation — example

In [None]:

# Traditional generation (blocks until finished)
from transformers import Gemma3ForConditionalGeneration, AutoProcessor

model_id = "google/gemma-3-4b-it"  # example
# NOTE: for quick testing, replace with a small text-only model like 'gpt2' or similar
try:
    model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto", token=HF_TOKEN, torch_dtype=torch.bfloat16).eval()
    processor = AutoProcessor.from_pretrained(model_id)
except Exception as e:
    print("Model load error (expected in small/demo env):", e)
    model = None
    processor = None

# Prepare a sample multimodal message (image url + instruction)
messages = [
    {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
    {"role": "user", "content": [
        {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
        {"type": "text", "text": "Describe this image in detail."}
    ]}
]

if model is not None and processor is not None:
    inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors='pt', use_fast=True, do_sample=False)
    inputs = {k: (v.to(model.device, dtype=torch.bfloat16) if v.is_floating_point() else v.to(model.device)) for k, v in inputs.items()}
    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
    decoded = processor.decode(generation[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
    print('Decoded output:\n', decoded)
else:
    print('Model not loaded. Replace model_id with a test-friendly model or run this on a GPU instance with enough memory.')


## 2) Streaming generation with `TextIteratorStreamer`

In [None]:

# Streaming generation example using TextIteratorStreamer
from transformers import TextIteratorStreamer, AutoTokenizer
from threading import Thread

# Ensure tokenizer is available for streaming decoding
try:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
except Exception as e:
    tokenizer = None
    print("Tokenizer load warning:", e)

# Load a real image into PIL
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = None
try:
    image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
    print("Loaded image size:", image.size)
except Exception as e:
    print("Could not load image from URL:", e, "— continuing with example code.")

messages = [
    {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
    {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image in detail."}]}
]

if model is not None and processor is not None and tokenizer is not None:
    inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors='pt', use_fast=True, do_sample=False)
    # Move tensors to correct device/dtype
    inputs = {k: (v.to(model.device, dtype=torch.bfloat16) if v.is_floating_point() else v.to(model.device)) for k, v in inputs.items()}

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    thread = Thread(target=model.generate, kwargs=dict(**inputs, streamer=streamer, max_new_tokens=256, temperature=0.7))
    thread.start()

    print('Streaming output:')
    for token in streamer:
        print(token, end='', flush=True)
    print('\n--- stream finished ---')
else:
    print('Streaming demo skipped because model/processor/tokenizer not loaded in this environment.')


## 3) Gradio integration (real-time UI)

In [None]:

# Gradio streaming integration example (blocks when launched)
import gradio as gr
import numpy as np

# IMPORTANT: Running this cell will launch a Gradio server. Stop it to continue using the notebook.
# The code below follows the structure in the original article and uses streaming via TextIteratorStreamer.

stop_flag = False
max_size = 1024

def process_image(image_input, image_url, stop=False):
    global stop_flag
    stop_flag = False

    def stop_inference():
        global stop_flag
        stop_flag = True
        print("Inference stopped. (stop_flag True)")

    if stop == 'Stop':
        stop_inference()
        return "Inference stopped.", None

    # Load image from URL or the uploaded numpy image
    if image_url:
        try:
            response = requests.get(image_url)
            image = Image.open(BytesIO(response.content)).convert('RGB')
        except Exception as e:
            return f"Error loading image from URL: {e}", None
    elif image_input is None:
        return "Cleared output", None
    else:
        image = Image.fromarray(image_input.astype('uint8'))

    # Resize if too large
    width, height = image.size
    if width > max_size or height > max_size:
        if width > height:
            new_width = max_size
            new_height = int((max_size / width) * height)
        else:
            new_height = max_size
            new_width = int((max_size / height) * width)
        image = image.resize((new_width, new_height), Image.ANTIALIAS)

    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
        {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image in detail."}]}
    ]

    if model is None or processor is None or tokenizer is None:
        return "Model not loaded in this environment. See notebook instructions.", image

    inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors='pt', use_fast=True, do_sample=False)
    inputs = {k: (v.to(model.device, dtype=torch.bfloat16) if v.is_floating_point() else v.to(model.device)) for k, v in inputs.items()}

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=150)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    report = ''
    for new_text in streamer:
        if stop_flag:
            report += '\n ## Inference stopped.'
            break
        report += new_text
        # Yield partial result to allow streaming in Gradio
        yield report, image
    yield report, image

def stopper():
    global stop_flag
    stop_flag = True

with gr.Blocks() as demo:
    gr.Markdown('# Streaming Multi Model Generation')
    gr.Markdown('Upload an image or provide an image URL for VLM analysis.')
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type='numpy', label='Upload Image')
            image_url_input = gr.Textbox(label='Image URL')
            stop_button = gr.Button(value='Stop')
            clear_button = gr.ClearButton([image_input, image_url_input])
        with gr.Column():
            output_text = gr.Markdown(label='Image Analysis Result')
            output_image = gr.Image(label='Processed Image')

    stop_button.click(fn=stopper)
    image_input.change(fn=process_image, inputs=[image_input, image_url_input, stop_button], outputs=[output_text, output_image])
    image_url_input.change(fn=process_image, inputs=[image_input, image_url_input, stop_button], outputs=[output_text, output_image])

# Launching the demo in this notebook will block the kernel until stopped.
# Uncomment the next line to run locally (not recommended in ephemeral cloud notebooks):
# demo.launch(share=True, debug=True)



## Final notes & tips
- Streaming is valuable for responsive UIs (chatbots, assistants).
- For development, test with smaller models to save time and memory.
- If you plan to run the full `google/gemma-3-4b-it`, ensure you have access and a GPU with enough memory (or use device_map/accelerate to offload).
- Consider adding a `requirements.txt` and a short README in the repo linking back to your Medium article.


---

