In [None]:
import pandas as pd
import csv
df_base = pd.read_csv('./data/mtsamples.csv')
df_base['transcription'] = df_base.transcription.astype(str)
df_base['description'] = df_base.description.astype(str)

# BioBART v2 Large

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

## Load Model

In [None]:
MODEL_ID = "GanjinZero/biobart-v2-large"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, 
    use_fast = True)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_ID, 
    device_map = "auto", 
    torch_dtype = "auto")

## Inference

In [None]:
def summarize_biobart(medical_text):
    try:
        prompt = f"{medical_text.strip()}"
        inputs = tokenizer(
            prompt, 
            return_tensors = "pt", 
            max_length = 1024, 
            truncation = True, 
            padding = True
        ).to(model.device)
        
        with torch.no_grad():
            generated = model.generate(**inputs,
                max_new_tokens = 40, 
                num_beams = 4, 
                length_penalty = 1.5, 
                early_stopping = False, 
                no_repeat_ngram_size = 3, 
                encoder_no_repeat_ngram_size = 3
              )

        return tokenizer.decode(generated[0], skip_special_tokens = True)
    except:
        return ""

In [None]:
df = df_base.copy(deep = True)
df['model-summary'] = df.transcription.apply(summarize_biobart)
df['model-name'] = 'biobart'

## Persistence

In [None]:
df.to_csv('./data/mtsamples_with_biobart.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)

# Med-Gemma 4b

In [None]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

## Load Model

In [None]:
torch.set_float32_matmul_precision('high')

MODEL_ID = "google/medgemma-4b-it"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors = "pt",
    padding = True,
    truncation = True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map = {"": 1},
    torch_dtype = "auto",
    trust_remote_code = True
)
model.config.pad_token_id = tokenizer.pad_token_id
model = torch.compile(model, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model.eval()

## Inference

In [None]:
gen_cfg = GenerationConfig(
    max_new_tokens = 128,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)    

def summarize_medgemma(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]

    encoded = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = True,
        padding = True,
        max_length = 1024,
        truncation = True,
        return_tensors = "pt",
    ).to(model.device)

    with torch.no_grad():
        generated = model.generate(
            encoded,  
            generation_config = gen_cfg, 
            return_dict_in_generate = False,
        )
    
    summary = tokenizer.decode(generated[0], skip_special_tokens = True)
    summary = summary.split('\nmodel\n')[-1]
    
    return summary

In [None]:
df = df_base.copy(deep = True)
df['model-summary'] = df.transcription.apply(summarize_medgemma)
df['model-name'] = 'med-gemma'

## Persistence

In [None]:
df.to_csv('./data/mtsamples_with_gemma.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)

# Med-Llama 3.8b

In [None]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

## Load model

In [None]:
MODEL_ID = "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors="pt",
    padding = True,
    truncation = True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, 
                                            torch_dtype = torch.float16,
                                            device_map = {"": 1},
                                            trust_remote_code = True)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model = torch.compile(model, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model.eval()

In [None]:
gen_cfg = GenerationConfig(
    max_new_tokens = 64,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)

if tokenizer.chat_template is None:
    tokenizer.chat_template = textwrap.dedent("""
    <|im_start|>system
    You are a concise, professional medical writing assistant. <|im_end|>
    {% for m in messages %}
    <|im_start|>{{ m['role'] }}
    {{ m['content'] }}<|im_end|>
    {% endfor %}
    {% if add_generation_prompt %}<|im_start|>assistant
    {% endif %}
    """).strip()

def summarize_medllama(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]
    
    prompt = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = False
    )

    encoded = tokenizer(
        prompt,
        return_tensors = "pt",
        padding = True).to(model.device)
    
    with torch.no_grad():
        generated = model.generate(
            **encoded,
            generation_config = gen_cfg, 
            return_dict_in_generate = False, 
            max_new_tokens = 64,
        )
    
    summary = tokenizer.decode(generated[0], skip_special_tokens = True)
    response_text = textwrap.fill(summary, 90).split('im_start|>assistant')[-1]
    response_text = response_text.replace('<|im_end|>', '').replace("\n", " ").strip()
    return response_text

In [None]:
df = df_base.copy(deep = True)
df['model-summary'] = df.transcription.apply(summarize_medllama)
df['model-name'] = 'med-llama'

## Persistence

In [None]:
df.to_csv('./data/mtsamples_with_llama.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)