# Summarization

In [2]:
import pandas as pd
import csv
df_base = pd.read_csv('./data/mtsamples.csv')
df_base['transcription'] = df_base.transcription.astype(str)
df_base['description'] = df_base.description.astype(str)

---
## BioBART v2 Large

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

  from .autonotebook import tqdm as notebook_tqdm
2025-08-07 10:18:38.532350: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-07 10:18:38.563076: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-07 10:18:38.563101: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-07 10:18:38.563943: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-07 10:18:38.5

### Load Model

In [4]:
MODEL_ID = "GanjinZero/biobart-v2-large"
tokenizer_biobart = AutoTokenizer.from_pretrained(
    MODEL_ID, 
    use_fast = True)
model_biobart = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_ID, 
    device_map = "auto", 
    torch_dtype = "auto")
model_biobart.eval()

We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.


BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(85401, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(85401, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

### Inference

In [5]:
def summarize_biobart(medical_text):
    try:
        prompt = f"{medical_text.strip()}"
        inputs = tokenizer_biobart(
            prompt, 
            return_tensors = "pt", 
            max_length = 1024, 
            truncation = True, 
            padding = True
        ).to(model_biobart.device)
        
        with torch.no_grad():
            generated = model_biobart.generate(**inputs,
                max_new_tokens = 40, 
                num_beams = 4, 
                length_penalty = 1.5, 
                early_stopping = False, 
                no_repeat_ngram_size = 3, 
                encoder_no_repeat_ngram_size = 3
              )

        return tokenizer_biobart.decode(generated[0], skip_special_tokens = True)
    except Exception as e:
        return ""

In [6]:
df = df_base.copy(deep = True)
df['model-summary'] = df.transcription.apply(summarize_biobart)
df['model-name'] = 'biobart'

### Persistence

In [7]:
df.to_csv('./data/mtsamples_with_biobart.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)
torch.cuda.empty_cache()

---
## Med-Gemma 4b

In [8]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

### Load Model

In [9]:
torch.set_float32_matmul_precision('high')

MODEL_ID = "google/medgemma-4b-it"
tokenizer_gemma = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors = "pt",
    padding = True,
    truncation = True)
model_gemma = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map = {"": 1},
    torch_dtype = "auto",
    trust_remote_code = True
)
model_gemma.config.pad_token_id = tokenizer_gemma.pad_token_id
model_gemma = torch.compile(model_gemma, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model_gemma.eval()

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.76it/s]


OptimizedModule(
  (_orig_mod): Gemma3ForConditionalGeneration(
    (model): Gemma3Model(
      (vision_tower): SiglipVisionModel(
        (vision_model): SiglipVisionTransformer(
          (embeddings): SiglipVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
            (position_embedding): Embedding(4096, 1152)
          )
          (encoder): SiglipEncoder(
            (layers): ModuleList(
              (0-26): 27 x SiglipEncoderLayer(
                (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                (self_attn): SiglipAttention(
                  (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
                )
  

### Inference

In [10]:
gen_cfg = GenerationConfig(
    max_new_tokens = 128,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)    

def summarize_medgemma(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]

    encoded = tokenizer_gemma.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = True,
        padding = True,
        max_length = 1024,
        truncation = True,
        return_tensors = "pt",
    ).to(model_gemma.device)

    with torch.no_grad():
        generated = model_gemma.generate(
            encoded,  
            generation_config = gen_cfg, 
            return_dict_in_generate = False,
        )
    
    summary = tokenizer_gemma.decode(generated[0], skip_special_tokens = True)
    summary = summary.split('\nmodel\n')[-1]
    
    return summary

In [11]:
df = df_base.copy(deep = True)
df['model-summary'] = df.transcription.apply(summarize_medgemma)
df['model-name'] = 'med-gemma'

`generation_config` default values have been modified to match model-specific defaults: {'pad_token_id': 0, 'bos_token_id': 2, 'eos_token_id': [1, 106]}. If this is not desired, please set these values explicitly.


### Persistence

In [12]:
df.to_csv('./data/mtsamples_with_gemma.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)
torch.cuda.empty_cache()

---
## Med-Llama 3.8b

In [13]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

### Load model

In [14]:
MODEL_ID = "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"
tokenizer_llama = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors="pt",
    padding = True,
    truncation = True)
model_llama = AutoModelForCausalLM.from_pretrained(MODEL_ID, 
                                            torch_dtype = torch.float16,
                                            device_map = {"": 0},
                                            trust_remote_code = True)
tokenizer_llama.pad_token = tokenizer_llama.eos_token
model_llama.config.pad_token_id = tokenizer_llama.pad_token_id
model_llama.generation_config.pad_token_id = tokenizer_llama.pad_token_id

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model_llama = torch.compile(model_llama, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model_llama.eval()

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.13it/s]


OptimizedModule(
  (_orig_mod): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
   

In [15]:
gen_cfg = GenerationConfig(
    max_new_tokens = 64,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)

if tokenizer_llama.chat_template is None:
    tokenizer_llama.chat_template = textwrap.dedent("""
    <|im_start|>system
    You are a concise, professional medical writing assistant. <|im_end|>
    {% for m in messages %}
    <|im_start|>{{ m['role'] }}
    {{ m['content'] }}<|im_end|>
    {% endfor %}
    {% if add_generation_prompt %}<|im_start|>assistant
    {% endif %}
    """).strip()

def summarize_medllama(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]
    
    prompt = tokenizer_llama.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = False
    )

    encoded = tokenizer_llama(
        prompt,
        return_tensors = "pt",
        padding = True).to(model_llama.device)
    
    with torch.no_grad():
        generated = model_llama.generate(
            **encoded,
            generation_config = gen_cfg, 
            return_dict_in_generate = False, 
            max_new_tokens = 64,
        )
    
    summary = tokenizer_llama.decode(generated[0], skip_special_tokens = True)
    response_text = textwrap.fill(summary, 90).split('im_start|>assistant')[-1]
    response_text = response_text.replace('<|im_end|>', '').replace("\n", " ").strip()
    return response_text

In [16]:
df = df_base.copy(deep = True)
df['model-summary'] = df.transcription.apply(summarize_medllama)
df['model-name'] = 'med-llama'

`generation_config` default values have been modified to match model-specific defaults: {'use_cache': False, 'pad_token_id': 128001, 'bos_token_id': 128000, 'eos_token_id': 128001}. If this is not desired, please set these values explicitly.


### Persistence

In [17]:
df.to_csv('./data/mtsamples_with_llama.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)
torch.cuda.empty_cache()