## Gemma 3N Fine Tune Evaluation multimodal

Kaggle submission:

1) Fine tune gemma 3n e4b with conversation
2) Server fine tuned model locally, in case of out of network service
3) Answer first-aid questions with multi-modal capability, with text, image and audio
4) Support batch or stream models

In [1]:
from transformers import AutoProcessor, AutoModelForImageTextToText, TextStreamer
from PIL import Image
import requests
import torch

model_id = "google/gemma-3n-e4b-it" #"alfredcs/gemma-3N-finetune"#

  from .autonotebook import tqdm as notebook_tqdm


pipe = pipeline(
    "image-text-to-text",
    model=model_id,
    device="cuda:3",
    torch_dtype=torch.bfloat16,
)

In [None]:
from IPython.display import Audio, display
Audio("https://www.nasa.gov/wp-content/uploads/2015/01/591240main_JFKmoonspeech.mp3")

!wget -qqq https://www.nasa.gov/wp-content/uploads/2015/01/591240main_JFKmoonspeech.mp3 -O audio.mp3

In [4]:
messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
        "role": "user",
        "content": [
            {"type": "image", "image": "https://img.wattpad.com/cover/296265693-256-k674393.jpg"},
            #{"type": "image", "image": "https://p.turbosquid.com/ts-thumb/R7/OZS3Pv/Uqsz7sMj/rattlesnake_rigged_c4d_00/jpg/1565713390/1920x1080/fit_q87/080f2cb9f6455db4f8bca2483bc3f04446b73a2a/rattlesnake_rigged_c4d_00.jpg"},
            #{"type": "audio", "audio" : "https://www.nasa.gov/wp-content/uploads/2015/01/591240main_JFKmoonspeech.mp3" },
            {"type": "text", "text": "I got a bite by the animal as shown in the pictures, please explain what action I should take. After that, transcribe the audio."}
        ]
    }
]

### Use Pipeline

%%time

output = pipe(text=messages, max_new_tokens=200).to(torch.bfloat16)
print(output[0]["generated_text"][-1]["content"])

### Use Gemma3n conditional gen

In [3]:
processor = AutoProcessor.from_pretrained(model_id, device_map="cuda:3")
model = AutoModelForImageTextToText.from_pretrained(
            model_id, torch_dtype="auto", device_map="cuda:3")

input_ids = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True, return_dict=True,
        return_tensors="pt",
)
input_ids = input_ids.to(model.device, dtype=model.dtype)

outputs = model.generate(**input_ids, max_new_tokens=256)

text = processor.batch_decode(
    outputs,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)
print(text[0])

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.21it/s]


user
You are a helpful assistant.

I got a bite by the animal as shown in the pictures, please explain what action I should take. After that, transcribe the audio.
model
I understand you've been bitten by an animal. It's important to take action immediately to prevent infection. Here's what you should do:

1. **Wash the wound thoroughly** with soap and water for at least 5-10 minutes. This helps remove any bacteria from the bite.
2. **Apply an antiseptic** like povidone-iodine or chlorhexidine solution. This further helps prevent infection.
3. **Cover the wound** with a sterile bandage. This protects it from further contamination.
4. **Seek medical attention immediately.** Animal bites can lead to serious infections like rabies. A doctor will assess the wound and determine if rabies prophylaxis is needed.
5. **Report the bite** to your local animal control agency or health department. This helps them track potentially rabid animals and prevent future bites.
6. **Monitor the wound** for

## Streaming

In [5]:
from transformers import TextStreamer
# Helper function for inference
def do_gemma_3n_inference(messages, max_new_tokens = 128):
    _ = model.generate(
        **processor.apply_chat_template(
            messages,
            add_generation_prompt = True, # Must add for generation
            tokenize = True,
            return_dict = True,
            return_tensors = "pt",
        ).to(model.device, dtype=model.dtype),
        max_new_tokens = max_new_tokens,
        temperature = 0.2, top_p = 0.95, top_k = 64,
        streamer = TextStreamer(processor, skip_prompt = True, skip_special_tokens=True, clean_up_tokenization_spaces=False),
    )

In [6]:
do_gemma_3n_inference(messages, max_new_tokens = 512)

ValueError: Number of images does not match number of special image tokens in the input text. Got 0 image tokens in the text and 256 tokens from image embeddings.