# BioBART v2 Large

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

## Load Model

In [None]:
MODEL_ID = "GanjinZero/biobart-v2-large"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, 
    use_fast = True)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_ID, 
    device_map = "auto", 
    torch_dtype = "auto")

## Some Medical Text

In [None]:
medical_text = """
Non-steroidal anti-inflammatory drugs are not only potent analgesics and antipyretics but also nephrotoxins, and may cause 
electrolyte disarray. In addition to the commonly expected effects, including hyperkalemia, hyponatremia, acute renal injury, 
renal cortical necrosis, and volume retention, glomerular disease with or without nephrotic syndrome or nephritis can occur as 
well including after years of seemingly safe administration. Minimal change disease, secondary membranous glomerulonephritis, 
and acute interstitial nephritis are all reported glomerular lesions seen with non-steroidal anti-inflammatory use. We report a 
patient who used non-steroidal anti-inflammatory drugs for years without diabetes, chronic kidney disease, or proteinuria; he 
then developed severe nephrotic range proteinuria with 7 g of daily urinary protein excretion. Renal biopsy showed minimal 
change nephropathy, a likely secondary membranous glomerulonephritis, and acute interstitial nephritis present simultaneously
in one biopsy. Cessation of non-steroidal anti-inflammatory drug use along with steroid treatment resulted in a moderate
improvement in renal function, though residual impairment remained. Urine heavy metal screen returned with elevated levels of
urine copper, but with normal ceruloplasmin level. Workup suggested that the elevated copper levels were due to cirrhosis from
non-alcoholic fatty liver disease. The membranous glomerulonephritis is possibly linked to non-steroidal anti-inflammatory drug
exposure, and possibly to heavy metal exposure, and is clinically and pathologically much less likely to be a primary membranous
glomerulonephritis with negative serological markers.

Keywords: Minimal change disease, podocytopathy, secondary membranous glomerulonephritis, acute interstitial nephritis, non-steroidal anti-inflammatory drugs
"""

## Inference

In [None]:
def summarize(medical_text):
    try:
        prompt = f"{medical_text.strip()}"
        inputs = tokenizer(
            prompt, 
            return_tensors = "pt", 
            max_length = 1024, 
            truncation = True, 
            padding = True
        ).to(model.device)
        
        with torch.no_grad():
            generated = model.generate(**inputs,
                max_new_tokens = 40, 
                num_beams = 4, 
                length_penalty = 1.5, 
                early_stopping = False, 
                no_repeat_ngram_size = 3, 
                encoder_no_repeat_ngram_size = 3
              )

        return tokenizer.decode(generated[0], skip_special_tokens = True)
    except:
        return ""

In [None]:
summary = summarize(medical_text)
print(summary)

In [None]:
import pandas as pd
import csv
df = pd.read_csv('./data/mtsamples.csv')
df['transcription'] = df.transcription.astype(str)
df['description'] = df.description.astype(str)

In [None]:
df['model-summary'] = df.transcription.apply(summarize)
df['model-name'] = 'biobart'

In [None]:
df.to_csv('./data/mtsamples_with_biobart.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)

# Med-Gemma 4b

In [None]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

In [None]:
torch.set_float32_matmul_precision('high')

MODEL_ID = "google/medgemma-4b-it"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors = "pt",
    padding = True,
    truncation = True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map = {"": 1},
    torch_dtype = "auto",
    trust_remote_code = True
)
model.config.pad_token_id = tokenizer.pad_token_id
model = torch.compile(model, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model.eval()

In [None]:
medical_text = """
Non-steroidal anti-inflammatory drugs are not only potent analgesics and antipyretics but also nephrotoxins, and may cause 
electrolyte disarray. In addition to the commonly expected effects, including hyperkalemia, hyponatremia, acute renal injury, 
renal cortical necrosis, and volume retention, glomerular disease with or without nephrotic syndrome or nephritis can occur as 
well including after years of seemingly safe administration. Minimal change disease, secondary membranous glomerulonephritis, 
and acute interstitial nephritis are all reported glomerular lesions seen with non-steroidal anti-inflammatory use. We report a 
patient who used non-steroidal anti-inflammatory drugs for years without diabetes, chronic kidney disease, or proteinuria; he 
then developed severe nephrotic range proteinuria with 7 g of daily urinary protein excretion. Renal biopsy showed minimal 
change nephropathy, a likely secondary membranous glomerulonephritis, and acute interstitial nephritis present simultaneously
in one biopsy. 

"""

In [None]:
gen_cfg = GenerationConfig(
    max_new_tokens = 128,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)    

def summarize(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]

    encoded = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = True,
        padding = True,
        max_length = 1024,
        truncation = True,
        return_tensors = "pt",
    ).to(model.device)

    with torch.no_grad():
        generated = model.generate(
            encoded,  
            generation_config = gen_cfg, 
            return_dict_in_generate = False,
        )
    
    summary = tokenizer.decode(generated[0], skip_special_tokens = True)
    summary = summary.split('\nmodel\n')[-1]
    
    return summary

In [None]:
summarize(medical_text)

In [None]:
import pandas as pd
import csv
df = pd.read_csv('./data/mtsamples.csv')
df['transcription'] = df.transcription.astype(str)
df['description'] = df.description.astype(str)

In [None]:
df['model-summary'] = df.transcription.apply(summarize)
df['model-name'] = 'med-gemma'

In [None]:
df.to_csv('./data/mtsamples_with_gemma.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)

# Med-Llama 3.8b

In [None]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

## Load model

In [7]:
MODEL_ID = "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors="pt",
    padding = True,
    truncation = True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, 
                                            torch_dtype = torch.float16,
                                            device_map = {"": 1},
                                            trust_remote_code = True)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model = torch.compile(model, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model.eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
W0806 21:41:56.799000 67140 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
Downloading shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [02:50<00:00, 85.23s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:17<00:00,  8.88s/it]


OptimizedModule(
  (_orig_mod): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): L

## Some medical text

In [None]:
medical_text = """
Non-steroidal anti-inflammatory drugs are not only potent analgesics and antipyretics but also nephrotoxins, and may cause 
electrolyte disarray. In addition to the commonly expected effects, including hyperkalemia, hyponatremia, acute renal injury, 
renal cortical necrosis, and volume retention, glomerular disease with or without nephrotic syndrome or nephritis can occur as 
well including after years of seemingly safe administration. Minimal change disease, secondary membranous glomerulonephritis, 
and acute interstitial nephritis are all reported glomerular lesions seen with non-steroidal anti-inflammatory use. We report a 
patient who used non-steroidal anti-inflammatory drugs for years without diabetes, chronic kidney disease, or proteinuria; he 
then developed severe nephrotic range proteinuria with 7 g of daily urinary protein excretion. Renal biopsy showed minimal 
change nephropathy, a likely secondary membranous glomerulonephritis, and acute interstitial nephritis present simultaneously
in one biopsy. 

"""

In [None]:
gen_cfg = GenerationConfig(
    max_new_tokens = 64,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)

if tokenizer.chat_template is None:
    tokenizer.chat_template = textwrap.dedent("""
    <|im_start|>system
    You are a concise, professional medical writing assistant. <|im_end|>
    {% for m in messages %}
    <|im_start|>{{ m['role'] }}
    {{ m['content'] }}<|im_end|>
    {% endfor %}
    {% if add_generation_prompt %}<|im_start|>assistant
    {% endif %}
    """).strip()

def summarize(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]
    
    prompt = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = False
    )

    encoded = tokenizer(
        prompt,
        return_tensors = "pt",
        padding = True).to(model.device)
    
    with torch.no_grad():
        generated = model.generate(
            **encoded,
            generation_config = gen_cfg, 
            return_dict_in_generate = False, 
            max_new_tokens = 64,
        )
    
    summary = tokenizer.decode(generated[0], skip_special_tokens = True)
    response_text = textwrap.fill(summary, 90).split('im_start|>assistant')[-1]
    response_text = response_text.replace('<|im_end|>', '').replace("\n", " ").strip()
    return response_text

## Inference

In [None]:
summarize(medical_text)

## Persistence

In [None]:
import pandas as pd
import csv
df = pd.read_csv('./data/mtsamples.csv')
df['transcription'] = df.transcription.astype(str)
df['description'] = df.description.astype(str)

In [None]:
df['model-summary'] = df.transcription.apply(summarize)
df['model-name'] = 'med-llama'

In [None]:
df.to_csv('./data/mtsamples_with_llama.csv', index = False, quoting = csv.QUOTE_NONNUMERIC)