In [None]:
import torch, textwrap
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig,
)

# Load model

In [None]:
MODEL_ID = "Henrychur/MMed-Llama-3-8B-EnIns"    # instruction-tuned
DEVICE   = "cuda" if torch.cuda.is_available() else "CPU"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, load_in_8bit = True, device_map = "auto")

# Some medical text

In [None]:
medical_text = """
Non-steroidal anti-inflammatory drugs are not only potent analgesics and antipyretics but also nephrotoxins, and may cause 
electrolyte disarray. In addition to the commonly expected effects, including hyperkalemia, hyponatremia, acute renal injury, 
renal cortical necrosis, and volume retention, glomerular disease with or without nephrotic syndrome or nephritis can occur as 
well including after years of seemingly safe administration. Minimal change disease, secondary membranous glomerulonephritis, 
and acute interstitial nephritis are all reported glomerular lesions seen with non-steroidal anti-inflammatory use. We report a 
patient who used non-steroidal anti-inflammatory drugs for years without diabetes, chronic kidney disease, or proteinuria; he 
then developed severe nephrotic range proteinuria with 7 g of daily urinary protein excretion. Renal biopsy showed minimal 
change nephropathy, a likely secondary membranous glomerulonephritis, and acute interstitial nephritis present simultaneously
in one biopsy. 

"""

In [None]:
SYSTEM = "You are a helpful medical AI that produces clear, accurate summaries."
TEXT = medical_text
USER = f"Summarise the following passage in no more than three sentences:\n{TEXT.strip()}"

# Inference

In [None]:
inputs = tok(prompt, return_tensors="pt").to(model.device)

gen_cfg = GenerationConfig(
    max_new_tokens      = 200,
    temperature         = 0.7,
    top_p               = 0.9,
    repetition_penalty  = 1.1,
)

with torch.no_grad():
    out_ids = model.generate(**inputs, **gen_cfg.to_dict())

summary = tok.decode(out_ids[0], skip_special_tokens=True)
print(textwrap.fill(summary, 90))