# Summarization Med-Llama 3.8b

In [None]:
import torch, textwrap
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

### Load model

In [None]:
MODEL_ID = "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"
tokenizer_llama = AutoTokenizer.from_pretrained(
    MODEL_ID,
    return_tensors="pt",
    padding = True,
    truncation = True)
model_llama = AutoModelForCausalLM.from_pretrained(MODEL_ID, 
                                            torch_dtype = torch.float16,
                                            device_map = {"": 0},
                                            trust_remote_code = True)
tokenizer_llama.pad_token = tokenizer_llama.eos_token
model_llama.config.pad_token_id = tokenizer_llama.pad_token_id
model_llama.generation_config.pad_token_id = tokenizer_llama.pad_token_id

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model_llama = torch.compile(model_llama, mode = "reduce-overhead",
                      fullgraph = False,
                      dynamic = True)
model_llama.eval()

In [None]:
gen_cfg = GenerationConfig(
    max_new_tokens = 64,
    temperature = 0.1,
    top_p = 0.9,
    repetition_penalty = 1.1,
    do_sample = True,
    no_repeat_ngram_size = 6,
)

if tokenizer_llama.chat_template is None:
    tokenizer_llama.chat_template = textwrap.dedent("""
    <|im_start|>system
    You are a concise, professional medical writing assistant. <|im_end|>
    {% for m in messages %}
    <|im_start|>{{ m['role'] }}
    {{ m['content'] }}<|im_end|>
    {% endfor %}
    {% if add_generation_prompt %}<|im_start|>assistant
    {% endif %}
    """).strip()

def summarize_medllama(medical_text):
    messages = [
        {
            "role": "user",
            "content": textwrap.dedent(f"""
                Below is an abstract from a medical paper.
    
                ```text
                {medical_text.strip()}
                ```
    
                **Task:** Produce a 20-word summary **and end with a full stop (.) when you are done.**
                Use clear, professional medical language.
                Don't include a greeting or introduction.
            """),
        }
    ]
    
    prompt = tokenizer_llama.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = False
    )

    encoded = tokenizer_llama(
        prompt,
        return_tensors = "pt",
        padding = True).to(model_llama.device)
    
    with torch.no_grad():
        generated = model_llama.generate(
            **encoded,
            generation_config = gen_cfg, 
            return_dict_in_generate = False, 
            max_new_tokens = 64,
        )
    
    summary = tokenizer_llama.decode(generated[0], skip_special_tokens = True)
    response_text = textwrap.fill(summary, 90).split('im_start|>assistant')[-1]
    response_text = response_text.replace('<|im_end|>', '').replace("\n", " ").strip()
    return response_text

In [8]:
text = """
This 25 year old white female.  This young woman presents with complaint of migranes.
She had trying many medications.
She has frequent sinus infections.
She has tried many drugs to help the migranes.
"""
summary = summarize_medllama(text)
summary

'A 25-year-old white female presents with migraines despite trying various medications and experiencing recurrent sinus infections..'

In [9]:
len(text), len(summary)

(202, 131)