In [3]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("cbasu/Med-EASi")

# T5 works best with an explicit instruction
prefix = "simplify: "

# Prepare columns (source = Expert, target = Simple)
def preprocess(example):
    return {
        "input_text": prefix + example["Expert"],
        "target_text": example["Simple"]
    }

dataset = dataset.map(preprocess)

# Train/test split
train_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
train_data = train_dataset["train"]
eval_data = train_dataset["test"]


In [4]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Max. parameters
MAX_INPUT = 512
MAX_TARGET = 128

def tokenize_function(example):
    model_inputs = tokenizer(
        example["input_text"], max_length=MAX_INPUT, padding="max_length", truncation=True
    )
    labels = tokenizer(
        example["target_text"], max_length=MAX_TARGET, padding="max_length", truncation=True
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_tokenized = train_data.map(tokenize_function, batched=True)
eval_tokenized = eval_data.map(tokenize_function, batched=True)


In [5]:
from transformers import T5ForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

model = T5ForConditionalGeneration.from_pretrained("t5-small")

training_args = Seq2SeqTrainingArguments(
    output_dir="./t5-med-simplify",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=5e-5,
    weight_decay=0.01,
    predict_with_generate=True,
    logging_dir="./logs",
    logging_steps=50,
    save_total_limit=2,
    fp16=True,
    report_to="none"  # évite les erreurs si wandb n'est pas installé
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=eval_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
)


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

  trainer = Seq2SeqTrainer(


In [None]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
50,6.1057
100,0.9856
150,0.6721
200,0.5654
250,0.5339
300,0.5278
350,0.4781
400,0.489
450,0.4418
500,0.4473


TrainOutput(global_step=790, training_loss=0.8696312747424162, metrics={'train_runtime': 35.7225, 'train_samples_per_second': 175.939, 'train_steps_per_second': 22.115, 'total_flos': 850623222251520.0, 'train_loss': 0.8696312747424162, 'epoch': 5.0})

In [7]:
trainer.save_model("./t5-med-simplify-trainer")
tokenizer.save_pretrained("./t5-med-simplify-trainer")

('./t5-med-simplify-trainer/tokenizer_config.json',
 './t5-med-simplify-trainer/special_tokens_map.json',
 './t5-med-simplify-trainer/spiece.model',
 './t5-med-simplify-trainer/added_tokens.json')

In [19]:
def simplify_text(text, model, tokenizer, max_length=128):
    """
    Simplify medical text using the trained model
    """
    # Add the same prefix used during training
    input_text = "Summarize the medical report: " + text
    
    # Tokenize
    inputs = tokenizer(
        input_text, 
        return_tensors="pt", 
        max_length=512, 
        truncation=True, 
        padding=True
    )
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_length=max_length,
            num_beams=4,
            early_stopping=True,
            do_sample=False
        )
    
    # Decode
    simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return simplified_text

In [20]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch

# Load the model and tokenizer
model_path = "./t5-med-simplify-trainer"  # or "./t5-med-simplify-trainer"
loaded_model = T5ForConditionalGeneration.from_pretrained(model_path)
loaded_tokenizer = T5Tokenizer.from_pretrained(model_path)

print("Model and tokenizer loaded successfully!")

# ===== TESTING THE LOADED MODEL =====

def simplify_text(text, model, tokenizer, max_length=128):
    """
    Simplify medical text using the trained model
    """
    # Add the same prefix used during training
    input_text = "simplify: " + text
    
    # Tokenize
    inputs = tokenizer(
        input_text, 
        return_tensors="pt", 
        max_length=512, 
        truncation=True, 
        padding=True
    )
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_length=max_length,
            num_beams=4,
            early_stopping=True,
            do_sample=False
        )
    
    # Decode
    simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return simplified_text

# Test example
test_text = "CT also is required to accurately assess skull base bony changes, which are less visible on MRI."
simplified = simplify_text(test_text, loaded_model, loaded_tokenizer)

print(f"Original: {test_text}")
print(f"Simplified: {simplified}")


Model and tokenizer loaded successfully!
Original: CT also is required to accurately assess skull base bony changes, which are less visible on MRI.
Simplified: CT also is required to accurately assess skull base bony changes, which are less visible on MRI.
