In [8]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

In [9]:
torch.set_float32_matmul_precision('high')

MODEL_ID = "google/medgemma-4b-it"
DEVICE   = "cuda" if torch.cuda.is_available() else "CPU"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors = "pt",
    padding = True,
    truncation = True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map = {"": 1},
    torch_dtype = "auto",
    trust_remote_code = True
)
#tokenizer.pad_token = tokenizer.eos_token
#model.config.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

import os
os.environ['TORCH_LOGS'] = "recompiles"

model = torch.compile(model, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model.eval()

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.78it/s]


OptimizedModule(
  (_orig_mod): Gemma3ForConditionalGeneration(
    (model): Gemma3Model(
      (vision_tower): SiglipVisionModel(
        (vision_model): SiglipVisionTransformer(
          (embeddings): SiglipVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
            (position_embedding): Embedding(4096, 1152)
          )
          (encoder): SiglipEncoder(
            (layers): ModuleList(
              (0-26): 27 x SiglipEncoderLayer(
                (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                (self_attn): SiglipAttention(
                  (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
                )
  

In [10]:
medical_text = """
Non-steroidal anti-inflammatory drugs are not only potent analgesics and antipyretics but also nephrotoxins, and may cause 
electrolyte disarray. In addition to the commonly expected effects, including hyperkalemia, hyponatremia, acute renal injury, 
renal cortical necrosis, and volume retention, glomerular disease with or without nephrotic syndrome or nephritis can occur as 
well including after years of seemingly safe administration. Minimal change disease, secondary membranous glomerulonephritis, 
and acute interstitial nephritis are all reported glomerular lesions seen with non-steroidal anti-inflammatory use. We report a 
patient who used non-steroidal anti-inflammatory drugs for years without diabetes, chronic kidney disease, or proteinuria; he 
then developed severe nephrotic range proteinuria with 7 g of daily urinary protein excretion. Renal biopsy showed minimal 
change nephropathy, a likely secondary membranous glomerulonephritis, and acute interstitial nephritis present simultaneously
in one biopsy. 

"""

In [11]:
gen_cfg = GenerationConfig(
    max_new_tokens = 128,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)    

def summarize(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]

    encoded = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = True,
        padding = True,
        max_length = 1024,
        truncation = True
    )

    batch = tokenizer.pad(
        [{"input_ids": encoded}],
        return_tensors = "pt",
        padding = True).to(model.device)

    with torch.no_grad():
        generated = model.generate(
            input_ids = batch['input_ids'],  
            generation_config = gen_cfg, 
            return_dict_in_generate = False,
            attention_mask  = batch["attention_mask"])
    
    summary = tokenizer.decode(generated[0], skip_special_tokens = True)
    summary = summary.split('\nmodel\n')[-1]
    
    return summary

In [12]:
summarize(medical_text)

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
V0712 16:20:57.093000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/1] [__recompiles] Recompiling function forward in /home/kilnaar/anaconda3/envs/ai574-pocs/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py:1275
V0712 16:20:57.093000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0712 16:20:57.093000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/1] [__recompiles]     - 0/0: tensor 'attention_mask' size mismatch at index 3. expected 1151, actual 402


'NSAID use can trigger diverse glomerular diseases, including minimal change disease, membranous glomerulonephritis and acute interstitial nephritis, even in patients without pre-existing conditions.\n'

In [13]:
import pandas as pd
import csv
df = pd.read_csv('./mtsamples.csv')
df['transcription'] = df.transcription.astype(str)
df['description'] = df.description.astype(str)

In [14]:
df['med-gemma-summary'] = df.transcription.apply(summarize)

V0712 16:21:20.517000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/2] [__recompiles] Recompiling function forward in /home/kilnaar/anaconda3/envs/ai574-pocs/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py:1275
V0712 16:21:20.517000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/2] [__recompiles]     triggered by the following guard failure(s):
V0712 16:21:20.517000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/2] [__recompiles]     - 0/1: ___check_obj_id(past_key_values.key_cache[0], 136810302758672)
V0712 16:21:20.517000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/2] [__recompiles]     - 0/0: tensor 'attention_mask' size mismatch at index 3. expected 1151, actual 526
V0712 16:21:29.067000 180773 site-packages/torch/_dynamo/guards.py:3006] [0/3] [__recompiles] Recompiling function forward in /home/kilnaar/anaconda3/envs/ai574-pocs/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py:1275
V0712 16:21:29.0

In [15]:
df.to_csv('mtsamples_with_gemma.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)