# Fine-Tuning Falconsai/medical_summarization Model

This notebook demonstrates how to fine-tune the `Falconsai/medical_summarization` model on a medical dataset. The dataset consists of transcriptions and descriptions, and we will use both ROUGE and BLEU metrics to evaluate the model's performance.

In [ ]:
import pandas as pd
import transformers
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch
from datasets import load_metric, Dataset, DatasetDict
import sacrebleu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load datasets
train_df = pd.read_csv('preprocessed_data/train.csv')
val_df = pd.read_csv('preprocessed_data/val.csv')
test_df = pd.read_csv('preprocessed_data/test.csv')

# Extract transcriptions and descriptions
train_transcriptions = train_df['transcription'].tolist()
train_descriptions = train_df['description'].tolist()
val_transcriptions = val_df['transcription'].tolist()
val_descriptions = val_df['description'].tolist()
test_transcriptions = test_df['transcription'].tolist()
test_descriptions = test_df['description'].tolist()

In [ ]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('Falconsai/medical_summarization')

# Tokenize the datasets
def preprocess_function(examples):
    inputs = tokenizer(examples['transcription'], max_length=512, truncation=True, padding="max_length")
    outputs = tokenizer(examples['description'], max_length=128, truncation=True, padding="max_length")
    inputs['labels'] = outputs['input_ids']
    return inputs

# Create Hugging Face datasets
train_data = Dataset.from_dict({
    'transcription': train_transcriptions,
    'description': train_descriptions
})
val_data = Dataset.from_dict({
    'transcription': val_transcriptions,
    'description': val_descriptions
})
test_data = Dataset.from_dict({
    'transcription': test_transcriptions,
    'description': test_descriptions
})

dataset = DatasetDict({
    'train': train_data.map(preprocess_function, batched=True),
    'validation': val_data.map(preprocess_function, batched=True),
    'test': test_data.map(preprocess_function, batched=True)
})

In [ ]:
# Initialize model
model = AutoModelForSeq2SeqLM.from_pretrained('Falconsai/medical_summarization')
model.to(device)

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir='./logs',
)

# Data collator
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, model=model)

In [ ]:
# Metrics
rouge_metric = load_metric("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # ROUGE scores
    rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    rouge_result = {key: value.mid.fmeasure * 100 for key, value in rouge_result.items()}
    
    # BLEU scores
    bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
    bleu_result = {"bleu": bleu.score}
    
    # Combine both results
    result = {**rouge_result, **bleu_result}
    
    return {k: round(v, 4) for k, v in result.items()}

# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

In [ ]:
# Evaluate the model on test data
results = trainer.evaluate(eval_dataset=dataset['test'])
print(results)