In [1]:
#@title Gemma 3-4B-it installation

# IMPORTANT: set env vars BEFORE importing torch_xla/torch for them to take effect
import os
os.environ.setdefault("PJRT_DEVICE", "TPU")
os.environ.setdefault("XLA_USE_BF16", "1")
# avoid aggressive preallocation of HBM on XLA
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.99")

import gc
import warnings
warnings.filterwarnings("ignore", message=".*deprecated.*will be removed in Transformers v5.*")

import torch
import torch_xla.core.xla_model as xm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

model_id = "google/gemma-3-4b-it"

print("XLA devices:", xm.get_xla_supported_devices())
device = xm.xla_device()
print("Using device:", device)

# ---------------------------
# Load tokenizer & model
# ---------------------------
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

print("Loading model (this can take a while)...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="cpu",
)

# Safety: disable HF caching logic that triggers cache_utils.update() slicing on XLA
model.config.use_cache = False
try:
    model.generation_config.use_cache = False
except Exception:
    pass

# Make sliding_window safe (1 is conservative)
try:
    model.config.sliding_window = 1
    model.generation_config.sliding_window = 1
except Exception:
    pass

# eager attention implementation if available for TPU if available
try:
    model.config.attn_implementation = "eager"
except Exception:
    pass

# Move model to XLA device and eval
model.to(device)
model.eval()

# ---------------------------
# Quick sanity checks
# ---------------------------
first_param = next(model.parameters())
print("Param device:", first_param.device)
print("Param dtype:", first_param.dtype)
print("Model num params (approx):", sum(p.numel() for p in model.parameters()))

# Prepare tiny input
prompt = "Hello"
inputs = tokenizer(prompt, return_tensors="pt")
print("Input ids shape (cpu):", inputs["input_ids"].shape)

# Move inputs to device
inputs = {k: v.to(device) for k, v in inputs.items()}
print("Input device after move:", inputs["input_ids"].device)

# 1) Quick forward to check forward pass (explicitly disable cache)
try:
    print("Running single forward (no generate), passing use_cache=False ...")
    with torch.no_grad():
        out = model(**inputs, use_cache=False)
    xm.mark_step()
    logits = out.logits if hasattr(out, "logits") else out[0]
    print("Forward succeeded. Logits shape:", logits.shape)
except Exception as e:
    print("Forward FAILED:", repr(e))
    # try to free memory and re-raise to surface the stacktrace
    gc.collect()
    xm.mark_step()
    raise

# 2) Minimal generate for test (use_cache=False)
try:
    print("Starting minimal generate with use_cache=False, max_new_tokens=8 ...")
    with torch.no_grad():
        outids = model.generate(
            **inputs,
            max_new_tokens=8,
            do_sample=False,
            num_beams=1,
            use_cache=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=False,
        )
    xm.mark_step()
    print("Decode:", tokenizer.decode(outids[0], skip_special_tokens=True))
except Exception as e:
    print("Generate FAILED or blocked:", repr(e))
    gc.collect()
    xm.mark_step()
    raise

print("Done small test.")

  * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or


XLA devices: ['xla:0']
Using device: xla:0


  device = xm.xla_device()


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Loading model (this can take a while)...


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Param device: xla:0
Param dtype: torch.bfloat16
Model num params (approx): 4971331952
Input ids shape (cpu): torch.Size([1, 2])
Input device after move: xla:0
Running single forward (no generate), passing use_cache=False ...


  xm.mark_step()
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Forward succeeded. Logits shape: torch.Size([1, 2, 262208])
Starting minimal generate with use_cache=False, max_new_tokens=8 ...


  xm.mark_step()


Decode: Hello 2024!

I
Done small test.


In [2]:
#@title Gemma 3-4B-it generation speed on TPU v5e 1

import os, time, gc, statistics
import torch
import torch_xla.core.xla_model as xm
from transformers import AutoModelForCausalLM, AutoTokenizer

# ---------- Config ----------
MODEL_ID = "google/gemma-3-4b-it"   # used only if model/tokenizer/device are not present
REPEATS = 2                         # repeat measurements and report average+std
WARMUP_NEW_TOKENS = 1               # short warm-up generation
USE_CACHE_FOR_MEASUREMENT = False   # default False to avoid XLA cache slicing issues
VERBOSE = True
# ----------------------------

# --- New: explicit prompts requested by the user
PROMPTS = [
    "Hello, how are you",
    "Hello. What do you think about Greek philosophy?"
]

# Ensure model/tokenizer/device available (use existing if in globals)
globals_ref = globals()
if "model" in globals_ref and "tokenizer" in globals_ref and "device" in globals_ref:
    model = globals_ref["model"]
    tokenizer = globals_ref["tokenizer"]
    device = globals_ref["device"]
    if VERBOSE:
        print("Using existing model/tokenizer/device from session.")
else:
    if VERBOSE:
        print("Loading model/tokenizer (this will take time)...")
    # ensure XLA envs set externally earlier in session; just load
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map="cpu",
    )
    # Make safe for XLA
    model.config.use_cache = False
    try:
        model.generation_config.use_cache = False
    except Exception:
        pass
    try:
        model.config.sliding_window = 1
        model.generation_config.sliding_window = 1
    except Exception:
        pass
    device = xm.xla_device()
    model.to(device)
    model.eval()

# Safety: ensure model on device
first_param = next(model.parameters())
if first_param.device != device:
    model.to(device)

print("Model device:", next(model.parameters()).device, "dtype:", next(model.parameters()).dtype)

results = []

for prompt in PROMPTS:
    # tokenize to find actual prompt token length
    input_enc = tokenizer(prompt, return_tensors="pt")
    input_ids = input_enc["input_ids"]
    actual_prompt_tokens = int(input_ids.shape[-1])

    if VERBOSE:
        print(f"\n=== Experiment: prompt (text)='{prompt}'")
        print(f"    actual prompt tokens={actual_prompt_tokens}")

    # prepare inputs on device
    inputs = {k: v.to(device) for k, v in input_enc.items()}
    input_len = int(inputs["input_ids"].shape[-1])

    # warm-up (short generation) to reduce first-call compilation overhead
    try:
        if VERBOSE:
            print("Warm-up generate (short)...")
        with torch.no_grad():
            _ = model.generate(
                **inputs,
                max_new_tokens=WARMUP_NEW_TOKENS,
                do_sample=False,
                num_beams=1,
                use_cache=False,   # keep warm-up safe
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=False,
            )
        xm.mark_step()
    except Exception as e:
        print("Warm-up failed:", e)
        # continue — we will still attempt timed runs

    # measurement runs
    gen_new_tokens = actual_prompt_tokens  # mirror original behavior: generate same number as input tokens
    times = []
    for rep in range(REPEATS):
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        t0 = time.perf_counter()
        try:
            with torch.no_grad():
                out = model.generate(
                    **inputs,
                    max_new_tokens=gen_new_tokens,
                    do_sample=False,
                    num_beams=1,
                    use_cache=USE_CACHE_FOR_MEASUREMENT,  # False by default for TPU safety
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    return_dict_in_generate=False,
                )
            xm.mark_step()
        except Exception as e:
            # if a runtime error occurs (e.g. cache issue), fallback to safe mode (use_cache=False)
            if USE_CACHE_FOR_MEASUREMENT:
                print("Runtime error with use_cache=True; falling back to use_cache=False for this run:", repr(e))
                with torch.no_grad():
                    out = model.generate(
                        **inputs,
                        max_new_tokens=gen_new_tokens,
                        do_sample=False,
                        num_beams=1,
                        use_cache=False,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        return_dict_in_generate=False,
                    )
                xm.mark_step()
            else:
                raise
        t1 = time.perf_counter()

        # compute number of output tokens (difference between output length and input length)
        out_len = int(out.shape[-1])
        generated = max(0, out_len - input_len)
        elapsed = t1 - t0
        tps = generated / elapsed if elapsed > 0 else float("inf")
        times.append(tps)
        if VERBOSE:
            print(f"Run {rep+1}/{REPEATS}: generated={generated} tokens, elapsed={elapsed:.3f}s, tps={tps:.2f}")

    mean_tps = statistics.mean(times)
    std_tps = statistics.stdev(times) if len(times) > 1 else 0.0

    results.append({
        "prompt_text": prompt,
        "prompt_actual_tokens": actual_prompt_tokens,
        "generated_tokens_requested": gen_new_tokens,
        "mean_tps": mean_tps,
        "std_tps": std_tps,
        "raw_tps": times,
    })

# Summary table
print("\n=== SUMMARY (tokens/sec measured for GENERATED tokens only) ===")
for r in results:
    print(f"Prompt (text)='{r['prompt_text']}' "
          f"(actual {r['prompt_actual_tokens']} tokens), "
          f"generated={r['generated_tokens_requested']} tokens -> "
          f"{r['mean_tps']:.2f} ± {r['std_tps']:.2f} tokens/sec (repeats={len(r['raw_tps'])})")

Using existing model/tokenizer/device from session.
Model device: xla:0 dtype: torch.bfloat16

=== Experiment: prompt (text)='Hello, how are you'
    actual prompt tokens=6
Warm-up generate (short)...


  xm.mark_step()
  xm.mark_step()


Run 1/2: generated=6 tokens, elapsed=199.303s, tps=0.03
Run 2/2: generated=6 tokens, elapsed=1.809s, tps=3.32

=== Experiment: prompt (text)='Hello. What do you think about Greek philosophy?'
    actual prompt tokens=11
Warm-up generate (short)...
Run 1/2: generated=11 tokens, elapsed=647.963s, tps=0.02
Run 2/2: generated=11 tokens, elapsed=3.796s, tps=2.90

=== SUMMARY (tokens/sec measured for GENERATED tokens only) ===
Prompt (text)='Hello, how are you' (actual 6 tokens), generated=6 tokens -> 1.67 ± 2.32 tokens/sec (repeats=2)
Prompt (text)='Hello. What do you think about Greek philosophy?' (actual 11 tokens), generated=11 tokens -> 1.46 ± 2.04 tokens/sec (repeats=2)


In [3]:
first_param = next(model.parameters())
print("First param device:", first_param.device)
print("First param dtype:", first_param.dtype)
# compter params par dtype
from collections import Counter
dtypes = Counter([p.dtype for p in model.parameters()])
print("Param dtypes distribution:", dtypes)

First param device: xla:0
First param dtype: torch.bfloat16
Param dtypes distribution: Counter({torch.bfloat16: 884})


In [4]:
for n, m in model.named_modules():
     print(n, type(m).__name__)

 Gemma3ForConditionalGeneration
model Gemma3Model
model.vision_tower SiglipVisionModel
model.vision_tower.vision_model SiglipVisionTransformer
model.vision_tower.vision_model.embeddings SiglipVisionEmbeddings
model.vision_tower.vision_model.embeddings.patch_embedding Conv2d
model.vision_tower.vision_model.embeddings.position_embedding Embedding
model.vision_tower.vision_model.encoder SiglipEncoder
model.vision_tower.vision_model.encoder.layers ModuleList
model.vision_tower.vision_model.encoder.layers.0 SiglipEncoderLayer
model.vision_tower.vision_model.encoder.layers.0.layer_norm1 LayerNorm
model.vision_tower.vision_model.encoder.layers.0.self_attn SiglipAttention
model.vision_tower.vision_model.encoder.layers.0.self_attn.k_proj Linear
model.vision_tower.vision_model.encoder.layers.0.self_attn.v_proj Linear
model.vision_tower.vision_model.encoder.layers.0.self_attn.q_proj Linear
model.vision_tower.vision_model.encoder.layers.0.self_attn.out_proj Linear
model.vision_tower.vision_model.e