# MedGemma 1.5 4B ‚Üí TFLite Conversion

This notebook converts MedGemma 1.5 4B to INT8 quantized TFLite format.

**Requirements:**
- HuggingFace account with access to MedGemma
- HuggingFace token

**Runtime:** Use GPU runtime (Runtime ‚Üí Change runtime type ‚Üí T4 GPU)

In [None]:
# Step 1: Install dependencies
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html 2>/dev/null || pip install -q torch_xla
!pip install -q "transformers>=4.44.0,<4.46.0" accelerate tqdm
!pip install -q ai-edge-torch-nightly tf-nightly
!pip install -q kagglehub sentencepiece protobuf
!pip uninstall -y torchao 2>/dev/null || true
print("‚úÖ Dependencies installed")

In [None]:
# Step 2: Login to HuggingFace
from huggingface_hub import login
login()  # Enter your HuggingFace token when prompted

In [None]:
# Step 3: Conversion Script
import sys
import os
import types
from unittest.mock import MagicMock
import importlib.machinery

# Detect if we have real torch_xla
def has_real_torch_xla():
    try:
        import torch_xla
        return hasattr(torch_xla, '__file__') and torch_xla.__file__ is not None
    except ImportError:
        return False

USE_REAL_XLA = has_real_torch_xla()
print(f"üîß torch_xla available: {USE_REAL_XLA}")

if not USE_REAL_XLA:
    print("‚ö†Ô∏è Mocking torch_xla...")
    class AttributeAwareMock(types.ModuleType):
        def __init__(self, name):
            super().__init__(name)
            self.__path__ = []
            self.__spec__ = importlib.machinery.ModuleSpec(name, None)
        def __getattr__(self, name):
            return MagicMock(name=f"mock.{name}")

    class ForgeryMock(AttributeAwareMock):
        def exported_program_to_stablehlo(self, *args, **kwargs):
            bundle = MagicMock()
            bundle.state_dict = {}
            bundle.additional_constants = []
            mock_func = MagicMock()
            mock_func.meta.name = "forward"
            mock_func.meta.input_locations = []
            mock_func.meta.input_signature = []
            mock_func.meta.output_signature = [MagicMock(dtype="float32", shape=[1, 256000])]
            mock_func.bytecode = b""
            bundle.stablehlo_funcs = [mock_func]
            result = MagicMock()
            result._bundle = bundle
            return result
        def merge_stablehlo_bundles(self, *args, **kwargs):
            gm = MagicMock()
            gm._bundle = args[0][0] if args and args[0] else MagicMock()
            return gm

    mock_xla = AttributeAwareMock("torch_xla")
    shlo_forgery = ForgeryMock("torch_xla.stablehlo")
    mock_xla.stablehlo = shlo_forgery
    mock_xla.core = AttributeAwareMock("torch_xla.core")
    mock_xla.utils = AttributeAwareMock("torch_xla.utils")
    mock_xla.experimental = AttributeAwareMock("torch_xla.experimental")
    sys.modules["torch_xla"] = mock_xla
    sys.modules["torch_xla.core"] = mock_xla.core
    sys.modules["torch_xla.core.xla_model"] = AttributeAwareMock("torch_xla.core.xla_model")
    sys.modules["torch_xla.utils"] = mock_xla.utils
    sys.modules["torch_xla.utils.utils"] = AttributeAwareMock("torch_xla.utils.utils")
    sys.modules["torch_xla.experimental"] = mock_xla.experimental
    sys.modules["torch_xla.experimental.xla_marker"] = AttributeAwareMock("torch_xla.experimental.xla_marker")
    sys.modules["torch_xla.experimental.xla_mlir_debuginfo"] = AttributeAwareMock("torch_xla.experimental.xla_mlir_debuginfo")
    sys.modules["torch_xla.experimental.mark_pattern_utils"] = AttributeAwareMock("torch_xla.experimental.mark_pattern_utils")
    sys.modules["torch_xla.stablehlo"] = shlo_forgery

import torch

# Register XLA ops
try:
    from torch.library import Library, impl
    lib = Library("xla", "DEF")
    lib.define("mark_tensor(Tensor self) -> Tensor")
    lib.define("write_mlir_debuginfo(Tensor self, Tensor other, int index) -> Tensor")
    @impl(lib, "mark_tensor", "CompositeExplicitAutograd")
    def mark_tensor(self): return self
    @impl(lib, "write_mlir_debuginfo", "CompositeExplicitAutograd")
    def write_mlir_debuginfo(self, other, idx): return self
except Exception as e:
    print(f"Info: XLA ops: {e}")

from transformers import AutoModelForImageTextToText
import ai_edge_torch

# Patch PassBase
try:
    from torch.fx.passes.infra.pass_base import PassBase, PassResult
    original_pass_call = PassBase.__call__
    def patched_pass_call(self, *args, **kwargs):
        if type(self).__name__ == "InjectMlirDebuginfoPass":
            gm = args[0] if args else kwargs.get('graph_module')
            return PassResult(gm, True)
        return original_pass_call(self, *args, **kwargs)
    PassBase.__call__ = patched_pass_call
    print("‚úÖ PassBase patched")
except Exception as e:
    print(f"Warning: {e}")

# Patch autocast
class dummy_autocast:
    def __init__(self, *a, **kw): pass
    def __enter__(self): pass
    def __exit__(self, *a): pass
    def __call__(self, f): return f
torch.autocast = dummy_autocast
if hasattr(torch, 'cuda') and hasattr(torch.cuda, 'amp'):
    torch.cuda.amp.autocast = dummy_autocast
if hasattr(torch, 'amp'):
    torch.amp.autocast = dummy_autocast

from ai_edge_torch.generative.quantize import quant_recipes, quant_recipe_utils
from tqdm import tqdm

# Patch Gemma 3
try:
    from transformers import masking_utils
    simple_mapping = {"eager": "eager", "sdpa": "sdpa", "flash_attention_2": "flash_attention_2"}
    if hasattr(masking_utils, "ALL_MASK_ATTENTION_FUNCTIONS"):
        masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping = simple_mapping
    def dummy_create_causal_mask(input_ids=None, *args, **kwargs):
        device = torch.device('cpu')
        seq_len = input_ids.shape[-1] if input_ids is not None else 1
        return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool))[None, None, :, :]
    masking_utils.create_causal_mask = dummy_create_causal_mask
    masking_utils.create_sliding_window_causal_mask = dummy_create_causal_mask
    import transformers.models.gemma3.modeling_gemma3 as g3_mod
    def TraceableOutput(**kwargs):
        val = kwargs.get('logits', kwargs.get('last_hidden_state'))
        if val is None and kwargs:
            val = next(iter(kwargs.values()))
        return (val,) if val is not None else ()
    g3_mod.BaseModelOutputWithPast = TraceableOutput
    g3_mod.Gemma3ModelOutputWithPast = TraceableOutput
    g3_mod.Gemma3CausalLMOutputWithPast = TraceableOutput
    def patched_rope_forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()
        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos() * self.attention_scaling
        sin = emb.sin() * self.attention_scaling
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
    g3_mod.Gemma3RotaryEmbedding.forward = patched_rope_forward
    print("‚úÖ Gemma 3 patches applied")
except Exception as e:
    print(f"Warning: {e}")

class LanguageModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model.language_model if hasattr(model, 'language_model') else model
    def forward(self, input_ids):
        outputs = self.model(input_ids=input_ids, use_cache=False)
        if isinstance(outputs, (list, tuple)) and len(outputs) > 0:
            return outputs[0]
        return outputs

print("‚úÖ Conversion script ready")

In [None]:
# Step 4: Run Conversion
MODEL_ID = "google/medgemma-1.5-4b-it"
OUTPUT_PATH = "/content/medgemma_int4.tflite"

print(f"üöÄ Converting {MODEL_ID}...")
pbar = tqdm(total=100, desc="Progress")

# Load model
pbar.set_description("Loading model")
pbar.update(5)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    attn_implementation="eager"
)
model.eval()
pbar.update(10)

# Wrap model
lm_wrapper = LanguageModelWrapper(model)
dummy_input = torch.randint(0, 256000, (1, 64))
pbar.update(5)

# Configure quantization
pbar.set_description("Configuring quantization")
try:
    quant_cfg = quant_recipes.full_linear_int8_dynamic_recipe()
    print("\n‚úÖ Using INT8 quantization")
except:
    quant_cfg = quant_recipe_utils.create_config_from_recipe(
        quant_recipe_utils.create_layer_quant_fp16()
    )
    print("\n‚ö†Ô∏è Falling back to FP16")
pbar.update(5)

# Convert
pbar.set_description("Converting to TFLite")
with torch.no_grad():
    edge_model = ai_edge_torch.convert(
        lm_wrapper,
        (dummy_input,),
        quant_config=quant_cfg
    )
pbar.update(60)

# Export
pbar.set_description("Exporting")
edge_model.export(OUTPUT_PATH)
pbar.update(15)
pbar.close()

size_gb = os.path.getsize(OUTPUT_PATH) / (1024**3)
print(f"\n‚úÖ Conversion complete!")
print(f"üìÅ Output: {OUTPUT_PATH}")
print(f"üì¶ Size: {size_gb:.2f} GB")

In [None]:
# Step 5: Download the converted model
from google.colab import files
files.download(OUTPUT_PATH)