In [1]:
# !pip install -U mlx mlx-lm
# !pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

In [2]:
from types import SimpleNamespace

import ipywidgets as widgets
from IPython.display import clear_output, display
from mlx_lm import generate, load
from mlx_lm.chat import load, make_prompt_cache, make_sampler, stream_generate

In [3]:
model_name = "mlx-community/gemma-3-4b-it-8bit"

# Chat Template

In [4]:
# Create args namespace with custom values
args = SimpleNamespace(
    model=model_name,
    adapter_path=None,
    temp=0.7,
    top_p=0.9,
    seed=None,
    max_kv_size=None,
    max_tokens=2_000
)

In [5]:
# Initialize model and tokenizer
model, tokenizer = load(
    args.model,
    adapter_path=args.adapter_path,
    tokenizer_config={"trust_remote_code": True}
)

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

In [6]:
prompt_cache = make_prompt_cache(model, args.max_kv_size)

In [7]:
def get_response(query):
    response_text = ""
    for response in stream_generate(
        model,
        tokenizer,
        tokenizer.apply_chat_template([{"role": "user", "content": query}], add_generation_prompt=True),
        max_tokens=args.max_tokens,
        sampler=make_sampler(args.temp, args.top_p),
        prompt_cache=prompt_cache
    ):
        print(response.text, end='', flush=True)
        response_text += response.text

    return response_text

In [8]:
def chat_interface():
    # Textarea for better paste support
    text_input = widgets.Textarea(
        placeholder='Type or paste your question here...',
        description='Query:',
        layout=widgets.Layout(
            width='50%',
            height='100px'
        )
    )
    
    # Add a Submit button since Textarea doesn't have on_submit
    submit_button = widgets.Button(
        description='Submit',
        layout=widgets.Layout(width='100px')
    )
    
    output_area = widgets.Output()
    conversation_history = []
    
    def on_submit(b):
        with output_area:
            query = text_input.value
            if not query.strip():
                return
                
            clear_output()
            
            for _, exchange in enumerate(conversation_history, 1):
                print(f"\nQ: {exchange['Q']}\n")
                print(f"A: {exchange['A']}\n")
                print("-" * 50)
            
            print(f"\nQ: {query}\n")
            print("A: ", end='', flush=True)
            response_text = get_response(query)
            print("\n" + "-" * 50 + "\n")
            conversation_history.append({"Q": query, "A": response_text})
            
        text_input.value = ''
    
    submit_button.on_click(on_submit)
    
    # Create input container with textarea and button side by side
    input_container = widgets.HBox([text_input, submit_button])
    
    # Stack output above input
    vbox = widgets.VBox([output_area, input_container])
    display(vbox)

In [None]:
chat_interface()

VBox(children=(Output(), HBox(children=(Textarea(value='', description='Query:', layout=Layout(height='100px',…