In [3]:
import pandas as pd

# Login using e.g. `huggingface-cli login` to access this dataset
df = pd.read_csv("hf://datasets/vishnukantshukla/medical-complex-to-simple-10k/medical_simplified.csv")

In [4]:
df['Standard_English'][0]


'The lungs are clear of focal consolidation, pleural effusion or pneumothorax. The heart size is normal. The mediastinal contours are normal. Multiple surgical clips project over the left breast, and old left rib fractures are noted.'

In [5]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

In [6]:
from datasets import Dataset

# if your data is a pandas dataframe
dataset = Dataset.from_pandas(df)

def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["Standard_English"],
        padding="max_length",
        truncation=True
    )
    # Tokenize labels
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["Simplified_English"],
            padding="max_length",
            truncation=True
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/8985 [00:00<?, ? examples/s]



In [7]:
pd.DataFrame(tokenized_datasets).head()

Unnamed: 0,Standard_English,Simplified_English,input_ids,attention_mask,labels
0,"The lungs are clear of focal consolidation, pl...","The lungs look healthy, with no signs of infec...","[37, 3, 17454, 33, 964, 13, 15949, 16690, 6, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[37, 3, 17454, 320, 1695, 6, 28, 150, 3957, 13..."
1,Lung volumes remain low. There are innumerable...,The lungs don't have their full capacity. Ther...,"[301, 425, 13548, 2367, 731, 5, 290, 33, 16, 5...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[37, 3, 17454, 278, 31, 17, 43, 70, 423, 2614,..."
2,Lung volumes are low. This results in crowding...,"The lungs don't have their full capacity, whic...","[301, 425, 13548, 33, 731, 5, 100, 772, 16, 43...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[37, 3, 17454, 278, 31, 17, 43, 70, 423, 2614,..."
3,There is mild pulmonary edema with small bilat...,There is some fluid buildup in the lungs. The ...,"[290, 19, 8248, 3, 26836, 3, 15, 1778, 9, 28, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[290, 19, 128, 5798, 918, 413, 16, 8, 3, 17454..."
4,The right costophrenic angle is not imaged. Ot...,The right side of the chest is not fully visib...,"[37, 269, 583, 10775, 60, 2532, 7669, 19, 59, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[37, 269, 596, 13, 8, 5738, 19, 59, 1540, 5183..."


In [8]:
train_test_split = tokenized_datasets.train_test_split(test_size=0.2)

train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

print("Training dataset size:", len(train_dataset))
print("Evaluation dataset size:", len(eval_dataset))

Training dataset size: 7188
Evaluation dataset size: 1797


In [35]:
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig,AutoTokenizer
import torch

model_name = "google/flan-t5-base"

# Load tokenizer and fix pad token


# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,  # or torch.bfloat16
    bnb_4bit_use_double_quant=True,       # Optional: for better compression
    bnb_4bit_quant_type="nf4"             # Optional: normalized float 4-bit
)

# Load model with quantization config
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",  # Optional: for automatic device placement
    torch_dtype=torch.float16  # Optional: for consistency
)

In [36]:
from peft import LoraConfig, get_peft_model
from transformers import  DataCollatorForSeq2Seq
# LoRA configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v", "k", "o"],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM")
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()

trainable params: 1,769,472 || all params: 249,347,328 || trainable%: 0.7096


In [37]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    return_tensors="pt"
)

In [38]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer



training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,      # Reduced for memory safety
    per_device_eval_batch_size=2,       # Reduced for memory safety
    gradient_accumulation_steps=8,      # Increased to maintain effective batch size
    num_train_epochs=2,                 # Slightly more epochs since batch size is smaller
    logging_steps=50,                   # More frequent logging for shorter runs
    save_steps=250,                     # More frequent saves for Colab
    learning_rate=3e-4,                 # Slightly higher LR for smaller batches
    fp16=False,                          # Changed from bf16 (better Colab compatibility)
    optim="adamw_torch",
)

In [39]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    
)

In [42]:
trainer.train()

Step,Training Loss
50,4.8053
100,4.3003
150,4.1393
200,3.9873
250,3.885
300,3.783
350,3.6643
400,3.5967
450,3.5483
500,3.4246


TrainOutput(global_step=900, training_loss=3.0201713053385415, metrics={'train_runtime': 2207.2638, 'train_samples_per_second': 6.513, 'train_steps_per_second': 0.408, 'total_flos': 1.0363940647206912e+16, 'train_loss': 3.0201713053385415, 'epoch': 2.0})

In [43]:
results = trainer.evaluate()
print(results)

{'eval_loss': 0.7353515625, 'eval_runtime': 142.6279, 'eval_samples_per_second': 12.599, 'eval_steps_per_second': 6.303, 'epoch': 2.0}


In [44]:
import pickle
with open("lora_flan_t5_base_medical_simplification.pkl", "wb") as f:
    pickle.dump(model.state_dict(), f)

In [45]:
peft_model = trainer.model

peft_model.save_pretrained("./medical_lora_adapters")
tokenizer.save_pretrained("./medical_lora_adapters")

('./medical_lora_adapters\\tokenizer_config.json',
 './medical_lora_adapters\\special_tokens_map.json',
 './medical_lora_adapters\\spiece.model',
 './medical_lora_adapters\\added_tokens.json',
 './medical_lora_adapters\\tokenizer.json')

In [51]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load your fine-tuned model
model_name = "./medical_lora_adapters" # path or HF repo
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda")

# Input sentence
sentence = "Complex tear of the posterior horn of the medial meniscus with a displaced bucket-handle component."
# Encode & generate
inputs = tokenizer(sentence, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)

# Decode
simplified = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Simplified:", simplified)

Simplified: The horn of the medial meniscus is a complex tear in the lower part of the horn. The lungs are displaced.
