In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    T5ForConditionalGeneration,
    BartForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
from sklearn.model_selection import train_test_split
import evaluate
import nltk
from nltk.tokenize import sent_tokenize
import time
import re

# Download required NLTK data
nltk.download('punkt')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2025-04-16 03:15:12.208031: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 03:15:12.219710: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744753512.233693   25489 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744753512.238122   25489 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744753512.251040   25489 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Using device: cuda


[nltk_data] Downloading package punkt to /home/geetheswar/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
dataset = load_dataset("cbasu/Med-EASi")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 1397
    })
    validation: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 196
    })
    test: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 300
    })
})

In [4]:
def preprocess_dataset(dataset):
    print("Preprocessing dataset...")
    # Convert dataset to pandas DataFrame for easier manipulation
    train_df = pd.DataFrame(dataset['train'])
    test_df = pd.DataFrame(dataset['test'])
    val_df = pd.DataFrame(dataset['validation'])
    
    print(f"Train set: {len(train_df)} examples")
    print(f"Validation set: {len(val_df)} examples")
    print(f"Test set: {len(test_df)} examples")
    
    return train_df, val_df, test_df

train_df, val_df, test_df = preprocess_dataset(dataset)

Preprocessing dataset...
Train set: 1397 examples
Validation set: 196 examples
Test set: 300 examples


In [5]:
def prepare_dataset_for_model(train_df, val_df, test_df, tokenizer, max_input_length=256, max_target_length=256):
    # Process individual examples directly
    train_encodings = []
    val_encodings = []
    test_encodings = []
    
    # Process training data
    for _, row in train_df.iterrows():
        model_inputs = tokenizer(
            row["Expert"], 
            max_length=max_input_length,
            padding="max_length",
            truncation=True,
        )
        
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                row["Simple"],
                max_length=max_target_length,
                padding="max_length",
                truncation=True,
            )
        
        model_inputs["labels"] = labels["input_ids"]
        train_encodings.append(model_inputs)
    
    # Process validation data
    for _, row in val_df.iterrows():
        model_inputs = tokenizer(
            row["Expert"], 
            max_length=max_input_length,
            padding="max_length",
            truncation=True,
        )
        
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                row["Simple"],
                max_length=max_target_length,
                padding="max_length",
                truncation=True,
            )
        
        model_inputs["labels"] = labels["input_ids"]
        val_encodings.append(model_inputs)
    
    # Process test data
    for _, row in test_df.iterrows():
        model_inputs = tokenizer(
            row["Expert"], 
            max_length=max_input_length,
            padding="max_length",
            truncation=True,
        )
        
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                row["Simple"],
                max_length=max_target_length,
                padding="max_length",
                truncation=True,
            )
        
        model_inputs["labels"] = labels["input_ids"]
        test_encodings.append(model_inputs)
    
    # Convert to PyTorch datasets with simplified __getitem__
    class TextSimplificationDataset(torch.utils.data.Dataset):
        def __init__(self, encodings):
            self.encodings = encodings
        
        def __getitem__(self, idx):
            # Simply return the pre-processed encodings
            return self.encodings[idx]
        
        def __len__(self):
            return len(self.encodings)
    
    train_dataset = TextSimplificationDataset(train_encodings)
    val_dataset = TextSimplificationDataset(val_encodings)
    test_dataset = TextSimplificationDataset(test_encodings)
    
    return train_dataset, val_dataset, test_dataset

In [6]:
def compute_metrics(eval_pred, tokenizer):
    predictions, labels = eval_pred
    
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Clip predictions to valid token range to prevent overflow errors
    # Get tokenizer vocabulary size
    vocab_size = tokenizer.vocab_size
    
    # Clip predictions to be within valid token range
    predictions = np.clip(predictions, 0, vocab_size - 1)
    
    try:
        # Decode generated summaries and reference texts
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        # Rouge expects newline after each sentence
        decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
        decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
        
        # Compute ROUGE scores
        rouge_metric = evaluate.load("rouge")
        result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        
        # Extract ROUGE f1 scores
        result = {k: round(v * 100, 2) for k, v in result.items()}
        
        # Compute BLEU score
        bleu_metric = evaluate.load("bleu")
        bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)
        result["bleu"] = round(bleu_result["bleu"] * 100, 2)
        
        # Compute readability metrics
        result["fk_grade_diff"] = calculate_readability_improvement(decoded_preds, decoded_labels)
        
    except Exception as e:
        print(f"Error in compute_metrics: {e}")
        # Return default metrics in case of failure
        result = {
            "rouge1": 0.0,
            "rouge2": 0.0,
            "rougeL": 0.0,
            "bleu": 0.0,
            "fk_grade_diff": 0.0
        }
    
    return result

# Calculate readability improvement using Flesch-Kincaid Grade Level
def calculate_readability_improvement(simplified_texts, original_texts):
    def flesch_kincaid_grade(text):
        # Simple implementation of Flesch-Kincaid Grade Level
        sentences = sent_tokenize(text)
        num_sentences = len(sentences)
        if num_sentences == 0:
            return 0
            
        words = re.findall(r'\b\w+\b', text.lower())
        num_words = len(words)
        if num_words == 0:
            return 0
            
        syllables = 0
        for word in words:
            syllables += count_syllables(word)
            
        fk_grade = 0.39 * (num_words / num_sentences) + 11.8 * (syllables / num_words) - 15.59
        return max(0, fk_grade)  # Grade level should not be negative
    
    # Calculate the average grade level difference
    grade_diffs = []
    for original, simplified in zip(original_texts, simplified_texts):
        original_grade = flesch_kincaid_grade(original)
        simplified_grade = flesch_kincaid_grade(simplified)
        grade_diff = original_grade - simplified_grade  # Positive value means simplified text is easier to read
        grade_diffs.append(grade_diff)
    
    return sum(grade_diffs) / len(grade_diffs) if grade_diffs else 0

# Syllable counting helper function
def count_syllables(word):
    # Simple syllable counting - this is a basic implementation
    word = word.lower()
    if len(word) <= 3:
        return 1
    
    # Remove ending e
    if word.endswith('e'):
        word = word[:-1]
        
    # Count vowel groups
    vowels = "aeiouy"
    count = 0
    prev_is_vowel = False
    
    for char in word:
        is_vowel = char in vowels
        if is_vowel and not prev_is_vowel:
            count += 1
        prev_is_vowel = is_vowel
        
    return max(1, count)  # Every word has at least one syllable

# Fine-tune model
def fine_tune_model(model_name, tokenizer, train_dataset, val_dataset, output_dir):
    print(f"Fine-tuning {model_name}...")
    
    # Load pre-trained model and tokenizer
    if "t5" in model_name.lower():
        model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
    elif "bart" in model_name.lower():
        model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    
    # Set up the data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True,
    )
    
    # Define training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=3e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        weight_decay=0.01,
        save_total_limit=3,
        num_train_epochs=3,
        predict_with_generate=True,
        generation_max_length=512,
        report_to="none",  # Disable wandb reporting
        save_strategy="epoch",
    )
    
    # Create Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics= lambda p: compute_metrics(p, tokenizer),
    )
    
    # Train model
    trainer.train()
    
    # Save best model
    trainer.save_model(output_dir)
    
    return model, trainer

In [7]:
model1_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model1_name)
train_dataset, val_dataset, test_dataset = prepare_dataset_for_model(train_df, val_df, test_df, tokenizer)



In [8]:
flan_t5_model, flan_t5_trainer = fine_tune_model(model1_name, tokenizer, train_dataset, val_dataset, "flan_t5_model")

Fine-tuning google/flan-t5-small...


  trainer = Seq2SeqTrainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Bleu,Fk Grade Diff
1,No log,1.625526,0.55,0.04,0.55,0.54,0.0,-33.689968
2,6.900900,0.552572,1.15,0.18,0.99,1.08,0.0,-6.22476
3,0.628100,0.464969,1.63,0.38,1.42,1.49,0.09,-66.285841


Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

In [7]:
model2_name = "facebook/bart-base"
tokenizer2 = AutoTokenizer.from_pretrained(model2_name)
 
train_dataset2, val_dataset2, test_dataset2 = prepare_dataset_for_model(
    train_df, val_df, test_df, tokenizer2
)

model2, trainer2 = fine_tune_model(
    model2_name, tokenizer2, train_dataset2, val_dataset2, "./results_bart"
)



Fine-tuning facebook/bart-base...


  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Bleu,Fk Grade Diff
1,No log,0.337914,26.66,13.29,23.34,23.97,8.1,4.996868
2,1.072300,0.329956,32.03,17.28,28.44,28.92,11.01,4.114712
3,0.150500,0.331769,33.36,18.53,29.49,30.41,14.22,3.291828




In [11]:
def test_model(model, test_dataset, model_name, tokenizer):
    print(f"Testing {model_name}...")
    
    # Set up the data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True,
    )
    
    # Define test arguments
    test_args = Seq2SeqTrainingArguments(
        output_dir=f"./test_results_{model_name}",
        per_device_eval_batch_size=4,
        predict_with_generate=True,
        generation_max_length=512,
        report_to="none",
    )
    
    # Create Trainer for evaluation
    trainer = Seq2SeqTrainer(
        model=model,
        args=test_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=lambda p: compute_metrics(p, tokenizer),
    )
    
    # Evaluate on test dataset
    test_results = trainer.evaluate(test_dataset)
    
    return test_results

In [12]:
test_results2 = test_model(model2, test_dataset2, "bart", tokenizer2)

Testing bart...


  trainer = Seq2SeqTrainer(


In [13]:
test_results2

{'eval_loss': 0.27668920159339905,
 'eval_model_preparation_time': 0.0016,
 'eval_rouge1': 32.35,
 'eval_rouge2': 18.08,
 'eval_rougeL': 28.72,
 'eval_rougeLsum': 29.35,
 'eval_bleu': 15.45,
 'eval_fk_grade_diff': 2.3486921533359393,
 'eval_runtime': 54.1928,
 'eval_samples_per_second': 5.536,
 'eval_steps_per_second': 1.384}

In [14]:
bart = model2

In [15]:
def simplify(model, tokenizer, expert_text):
    inputs = tokenizer(expert_text, return_tensors="pt", padding=True, truncation=True).to(device)
    
    with torch.no_grad():
        outputs = model.generate(**inputs)
    
    simplified_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return simplified_text

In [18]:
expert_text = "The patient was diagnosed with a severe case of pneumonia, requiring immediate hospitalization and intravenous antibiotics."
simplified_text = simplify(bart, tokenizer2, expert_text)
simplified_text

'The patient was diagnosed with pneumonia, requiring immediate hospitalization and intravenous antibiotics.'

In [None]:
expert_text = "Desmopressin"
simplified_text = simplify(bart, tokenizer2, expert_text)
simplified_text

Sometimes, the drug desmopressin


## Problems
- Due to small input context length, the text is masking

In [24]:
expert_text = "Some patients have weight loss, rarely enough to become underweight. Anemia, glossitis, angular stomatitis, and aphthous ulcers are usually seen in these patients."
simplified_text = simplify(bart, tokenizer2, expert_text)
simplified_text

'Some people have weight loss, rarely enough to become underweight. Anemia, glossitis'