In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread

# Load model and tokenizer
model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)

# Generating content

In [None]:
# Prepare the input prompt
prompt = "What do you think of AIUK?"
inputs = tokenizer(prompt, return_tensors="pt").to("mps")  # Use "cpu" if no GPU

# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# Generate tokens in a separate thread to allow streaming
generation_kwargs = dict(
    inputs=inputs.input_ids,
    max_new_tokens=100,  # Increase this for longer output
    temperature=0.7,
    do_sample=True,  # Allows creative output
    top_k=50,        # Limits sampling to top 50 tokens
    top_p=0.95,      # Nucleus sampling for diversity`
    streamer=streamer
)

# Start generation in a new thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

# Print the prompt
print(prompt, end="", flush=True)

# Stream output token by token
for new_text in streamer:
    print(new_text, end="", flush=True)

# Wait for the thread to finish
thread.join()
