In [7]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load model
https://huggingface.co/luqh/ClinicalT5-large

In [None]:
MODEL_ID = "luqh/ClinicalT5-large"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID, model_max_length = 1024)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID, from_flax = True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
All Flax model weights were used when initializing T5ForConditionalGeneration.

Some weights of T5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['decoder.embed_tokens.weight', 'lm_head.weight', 'encoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Some medical text

Not that this will fail if there are too many tokens. We'll have to do some chunking or RAG in order to
deal with the small context window.

In [9]:
medical_text = """
After induction of general anesthesia, the patient was placed prone on the operating room table 
resting on chest rolls.  Her face was resting in a pink foam headrest.  Extreme care was taken positioning her because she 
weighs 92 kg.  There was a lot of extra padding for her limbs and her limbs were positioned comfortably.  The arms were not 
hyperextended.  Great care was taken with positioning of the head and making sure there was no pressure on her eyes especially 
since she already has visual disturbance.  A Foley catheter was in place.  She received IV Cipro 400 mg because she is 
allergic to most antibiotics.,Fluoroscopy was used to locate the lower end of the fractured catheter and the skin was marked.  
It was also marked where we would try to insert the new catheter at the L4 or L3 interspinous space.,
"""

In [10]:
len(medical_text)

825

# Inference

In [11]:
prompt = "what procedure was performed in: " + medical_text.strip()
inputs = tokenizer(prompt, return_tensors = "pt", truncation = False)

with torch.no_grad():
    summary_ids = model.generate(
        **inputs,
        max_new_tokens = 20,
        min_length = 60,
        num_beams = 8,
        length_penalty = 0.1,
        early_stopping = False,      # let all beams finish
        no_repeat_ngram_size = 2     # reduce verbatim copying
    )



In [12]:
for summary_id in summary_ids:
    summary = tokenizer.decode(summary_id, skip_special_tokens = True)
    print(summary)

fluoroscopy was used to locate the lower end of the fractured catheter.,
