# MedGemma Image-Text-to-Text Batch Inference Jobs

This notebook demonstrates how to perform batch inference jobs using MedGemma multimodal model in Snowflake. It walks through the complete workflow from model deployment to analyzing medical images with AI-generated descriptions.

## Overview

MedGemma is a specialized medical AI model that can process both images and text to provide medical insights. This notebook shows how to:

1. Establish a Snowflake connection
2. Load and log the MedGemma image-text-to-text model from Hugging Face
3. Upload medical images (e.g., X-rays) to a Snowflake stage
4. Prepare input data with OpenAI-compatible chat format
5. Run batch inference on GPU compute pools using vLLM engine
6. Inspect the inference output with medical descriptions

## Prerequisites

- Snowflake account with appropriate privileges
- `snowflake-ml-python>=1.26.0` (for batch inference support)
- A valid Snowflake connection configuration
- If you need to log MedGemma, a Hugging Face token with access to the MedGemma model
- (Optional) Access to GPU compute pools
- (Optional) A Snowflake database, schema, and stage already created
- (Optional) A MedGemma model already logged in Snowflake registry

## Running the Notebook

Run the cells in order to follow the quickstart guide. The notebook includes steps to log the model, prepare medical image data, and run batch inference. Make sure to update the Hugging Face token and database/schema names as needed.

# Make Sure the snowflake-ml-python has the Right Version

In [1]:
from importlib.metadata import version
# batch inference PuPr in snowflake-ml-python>=1.26.0
print(version('snowflake-ml-python'))

1.26.0


# Establish a Connection

In [27]:
from snowflake.snowpark import Session
session = Session.builder.config("connection_name", "preprod8").create()

# Create Resources

If you already have a stage, a compute pool and a secret, please fill in the DB_NAME, SCHEMA_NAME, STAGE_NAME, COMPUTE_POOL_NAME, and SECRET_NAME and skip the following cell.

Otherwise, please leave them as they are and run the cells to create the resources.

In [None]:
DB_NAME = "MEDGEMMA_DB"
SCHEMA_NAME = "PUBLIC"
STAGE_NAME = "MEDGEMMA_STAGE"
COMPUTE_POOL_NAME = "MEDGEMMA_COMPUTE_POOL"
HF_TOKEN = "hf_xxxx" # Please fill in your Hugging Face token if you need to log MedGemma
SECRET_NAME = "HF_TOKEN_SECRET"

In [32]:
# Create database
session.sql(f"CREATE DATABASE IF NOT EXISTS {DB_NAME}").collect()

# Create schema
session.sql(f"CREATE SCHEMA IF NOT EXISTS {DB_NAME}.{SCHEMA_NAME}").collect()

# Create stage with SSE encryption
session.sql(f"""
    CREATE STAGE IF NOT EXISTS {DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}
    ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')
""").collect()

# Create compute pool with smallest CPU tier
# use GPU_NV_S for GPU workloads
session.sql(f"""
    CREATE COMPUTE POOL IF NOT EXISTS {COMPUTE_POOL_NAME}
    MIN_NODES = 1
    MAX_NODES = 2
    INSTANCE_FAMILY = GPU_NV_S
""").collect()

# Set the session to use the newly created database and schema
session.use_database(DB_NAME)
session.use_schema(SCHEMA_NAME)

print(f"Created database: {DB_NAME}")
print(f"Created schema: {DB_NAME}.{SCHEMA_NAME}")
print(f"Created stage: {DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}")
print(f"Created compute pool: {COMPUTE_POOL_NAME}")

session.sql(f"""
    CREATE OR REPLACE SECRET {DB_NAME}.{SCHEMA_NAME}.{SECRET_NAME}
    TYPE = GENERIC_STRING
    SECRET_STRING = '{HF_TOKEN}'
""").collect()

print(f"Created secret: {DB_NAME}.{SCHEMA_NAME}.{SECRET_NAME}")

Created database: MEDGEMMA_DB
Created schema: MEDGEMMA_DB.PUBLIC
Created stage: MEDGEMMA_DB.PUBLIC.MEDGEMMA_STAGE
Created compute pool: MEDGEMMA_COMPUTE_POOL
Created secret: MEDGEMMA_DB.PUBLIC.HF_TOKEN_SECRET


### If you don't already have a model logged, load and log the MedGemma model

In [34]:
from snowflake.ml.model import openai_signatures
from snowflake.ml.model.models import huggingface
from snowflake.ml.registry import registry
reg = registry.Registry(session=session)

model = huggingface.TransformersPipeline(
    model="google/medgemma-4b-it",
    task="image-text-to-text",
    token_or_secret=f"{DB_NAME}.{SCHEMA_NAME}.{SECRET_NAME}",
)

mv = reg.log_model(
    model=model,
    model_name="medgemma_demo",
    target_platforms=["SNOWPARK_CONTAINER_SERVICES"],
)

Model logged successfully.: 100%|██████████| 6/6 [08:02<00:00, 80.50s/it]                      


### If you already have a model logged, please fill in MODEL_NAME and VERSION_NAME

In [None]:
from snowflake.ml.registry import registry

MODEL_NAME = "MODEL_NAME"
VERSION_NAME = "V1"

reg = registry.Registry(session=session, database_name=DB_NAME, schema_name=SCHEMA_NAME)
mv = reg.get_model(MODEL_NAME).get_version(VERSION_NAME)

# Inspect the Functions of the Model

In [14]:
mv.show_functions()

[{'name': '__CALL__',
  'target_method': '__call__',
  'target_method_function_type': 'FUNCTION',
  'signature': ModelSignature(
      inputs=[
          FeatureGroupSpec(
              name='messages',
              specs=[
                  FeatureGroupSpec(
                      name='content',
                      specs=[
                          FeatureSpec(dtype=DataType.STRING, name='type', nullable=True),
                          FeatureSpec(dtype=DataType.STRING, name='text', nullable=True),
                          FeatureGroupSpec(
                              name='image_url',
                              specs=[
                                  FeatureSpec(dtype=DataType.STRING, name='url', nullable=True),
                                  FeatureSpec(dtype=DataType.STRING, name='detail', nullable=True)
                              ]
                          ),
                          FeatureGroupSpec(
                              name='video_url',
            

# Prepare Images in the Stages

In [20]:
import requests
import os

# List of image URLs to download
image_urls = [
    "https://www.shutterstock.com/shutterstock/photos/2508050577/display_1500/stock-photo-x-ray-image-of-hand-showing-fifth-metacarpal-fracture-2508050577.jpg",
    "https://www.radiologyinfo.org/-/media/radinfo/gallery-items/images/u-to-z/y-cxr.jpg",
    "https://media.istockphoto.com/id/1212285302/photo/human-thoracic-cavity-x-ray-film.jpg?s=1024x1024&w=is&k=20&c=v6nMYROFPp3HN7jya10S51fl4etb9QkhLtLHQZfIrGs=",
    "https://media.istockphoto.com/id/684931694/photo/bones-of-hands.jpg?s=1024x1024&w=is&k=20&c=oQSAM1-DwJsSvzZOoYZLHEtsD0OeYjiQ3Bpkj9cQ2us=",
    "https://as1.ftcdn.net/jpg/02/16/47/88/1000_F_216478859_vKIowhAdEDU9vAmI9ZetrxbkK9HTvpIh.jpg",
]

stage_location = f"@{DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}"
image_stage_paths = []

# Download and upload each image
for idx, image_url in enumerate(image_urls):
    print(f"Processing image {idx + 1}/{len(image_urls)}: {image_url}")
    
    # Download the image
    response = requests.get(image_url)
    response.raise_for_status()
    
    # Save to /tmp which is always writable
    local_image_path = f"/tmp/medical_image_{idx}.png"
    with open(local_image_path, "wb") as f:
        f.write(response.content)
    
    # Upload to Snowflake stage
    put_result = session.file.put(
        local_file_name=local_image_path,
        stage_location=stage_location,
        auto_compress=False,
        overwrite=True
    )
    
    # Clean up local file
    os.remove(local_image_path)
    
    # Store the stage path
    image_stage_path = f"{stage_location}/medical_image_{idx}.png"
    image_stage_paths.append(image_stage_path)
    print(f"  Uploaded to: {image_stage_path}")

print(f"\nAll {len(image_stage_paths)} images uploaded successfully!")

# Show all files in stage
session.sql(f"LS {stage_location}").show()

Processing image 1/5: https://www.shutterstock.com/shutterstock/photos/2508050577/display_1500/stock-photo-x-ray-image-of-hand-showing-fifth-metacarpal-fracture-2508050577.jpg
  Uploaded to: @MEDGEMMA_DB.PUBLIC.MEDGEMMA_STAGE/medical_image_0.png
Processing image 2/5: https://www.radiologyinfo.org/-/media/radinfo/gallery-items/images/u-to-z/y-cxr.jpg
  Uploaded to: @MEDGEMMA_DB.PUBLIC.MEDGEMMA_STAGE/medical_image_1.png
Processing image 3/5: https://media.istockphoto.com/id/1212285302/photo/human-thoracic-cavity-x-ray-film.jpg?s=1024x1024&w=is&k=20&c=v6nMYROFPp3HN7jya10S51fl4etb9QkhLtLHQZfIrGs=
  Uploaded to: @MEDGEMMA_DB.PUBLIC.MEDGEMMA_STAGE/medical_image_2.png
Processing image 4/5: https://media.istockphoto.com/id/684931694/photo/bones-of-hands.jpg?s=1024x1024&w=is&k=20&c=oQSAM1-DwJsSvzZOoYZLHEtsD0OeYjiQ3Bpkj9cQ2us=
  Uploaded to: @MEDGEMMA_DB.PUBLIC.MEDGEMMA_STAGE/medical_image_3.png
Processing image 5/5: https://as1.ftcdn.net/jpg/02/16/47/88/1000_F_216478859_vKIowhAdEDU9vAmI9Zetrxbk

# Prepare Data in OpenAI Chat Template

In [23]:
import json

messages_list = []
for jpg_file in image_stage_paths:
    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]},
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Describe the x-ray to me"},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": jpg_file,
                    },
                },
            ],
        },
    ]
    messages_list.append(messages)

schema = ["MESSAGES"]
data = [(json.dumps(m)) for m in messages_list]
input_df = session.create_dataframe(data, schema=schema)

print(f"Created input dataframe with {len(messages_list)} messages")
input_df.show(1,max_width=300)

Created input dataframe with 5 messages
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"MESSAGES"                                                                                                                                                                                                                                                                            |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|[{"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]}, {"role": "user", "content":

# Run the Batch Inference Job

In [24]:
from snowflake.ml.model import JobSpec, OutputSpec, SaveMode, InputSpec
from snowflake.ml.model.inference_engine import InferenceEngine


job = mv.run_batch(
    compute_pool = "SYSTEM_COMPUTE_POOL_GPU",
    X=input_df,
    input_spec=InputSpec(params={"temperature": 0.5, }), 
    output_spec=OutputSpec(stage_location=f"@{DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}/demo/output/", mode=SaveMode.OVERWRITE),
    job_spec=JobSpec(gpu_requests="1"),
    inference_engine_options={
        "engine": InferenceEngine.VLLM,
        "engine_args_override": [
            "--max-model-len=7048",
            "--gpu-memory-utilization=0.9",
        ]
    }
)

job.wait()

'DONE'

# Inspect the Inference Output

In [25]:
session.read.option("pattern", ".*\\.parquet").parquet(f"@{DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}/demo/output/").select('"id"').show(100, max_width=2000)

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Clean Up

In [None]:
# uncomment to clean up the database and compute pool
# session.sql(f'DROP DATABASE IF EXISTS {DB_NAME}').collect()
#session.sql(f'DROP COMPUTE POOL IF EXISTS {COMPUTE_POOL_NAME}').collect()

[Row(status='MEDGEMMA_COMPUTE_POOL successfully dropped.')]