* This notebook performs data cleaning, chunked summarization, and evaluation
* On the MTSamples dataset using BART-based transformer models.

In [None]:
# Install and Import Dependencies
!pip install -q pandas transformers nltk datasets evaluate rouge_score bert_score

In [None]:
# Import Libraries
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
import evaluate
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import display

nltk.download('punkt')

In [None]:
# Load Dataset
df = pd.read_csv("mtsamples.csv")
print("Original record count:", len(df))
print("Missing values per column:\n", df.isnull().sum())

In [None]:
# Data Cleaning
def clean_text(text):
    text = text.replace('\n', ' ').replace('\r', ' ')
    text = ' '.join(text.split())
    return text.strip()

def preprocess_data(df):
    df = df.dropna(subset=['transcription', 'description']).copy()
    df['cleaned_transcription'] = df['transcription'].astype(str).apply(clean_text)
    return df

df_cleaned = preprocess_data(df)
print("Cleaned record count:", len(df_cleaned))
print(df_cleaned[['description', 'cleaned_transcription']].head(3))


In [None]:
# Load Summarization Model
MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
device = 0 if torch.cuda.is_available() else -1
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)

In [None]:
# Text Chuncking
MAX_INPUT_LENGTH = 1024
CHUNK_OVERLAP = 200

def chunk_text(text, tokenizer, max_length=MAX_INPUT_LENGTH, overlap=CHUNK_OVERLAP):
    tokens = tokenizer.encode(text, add_special_tokens=True)
    chunks = []
    start = 0
    while start < len(tokens):
        end = min(start + max_length, len(tokens))
        chunk = tokenizer.decode(tokens[start:end], skip_special_tokens=True)
        chunks.append(chunk)
        start += max_length - overlap
    return chunks

In [None]:
# Summarization Function
SUMMARY_MAX = 150
SUMMARY_MIN = 40

def generate_medical_summary(text):
    if not text.strip():
        return ""
    input_len = len(tokenizer.encode(text))
    if input_len < MAX_INPUT_LENGTH:
        try:
            max_len = min(SUMMARY_MAX, input_len)
            return summarizer(text, max_length=max_len, min_length=SUMMARY_MIN, do_sample=False)[0]['summary_text']
        except Exception as e:
            print("Single-pass error:", e)
            return ""
    summaries = []
    for chunk in chunk_text(text, tokenizer):
        try:
            input_len = len(tokenizer.encode(chunk))
            max_len = min(SUMMARY_MAX, input_len)
            summary = summarizer(chunk, max_length=max_len, min_length=SUMMARY_MIN, do_sample=False)[0]['summary_text']
            summaries.append(summary)
        except Exception as e:
            print("Chunk error:", e)
            continue
    return summarizer(" ".join(summaries), max_length=SUMMARY_MAX*2, min_length=SUMMARY_MIN*2, do_sample=False)[0]['summary_text']


In [None]:
# Batch Summarization of 30 Reports
df_batch = df_cleaned.head(30).copy()
summaries = []
print("\nGenerating summaries for 30 reports...")
for i, text in tqdm(enumerate(df_batch['cleaned_transcription']), total=30):
    print(f"Summarizing index {i}...")
    summaries.append(generate_medical_summary(text))
df_batch['generated_summary'] = summaries
df_batch.to_csv("batch_30_summaries.csv", index=False)
print("Saved to 'batch_30_summaries.csv'")

In [None]:
# Display all 30 descriptions and their generated summaries
df_batch[['description', 'generated_summary']].head(30)

In [None]:
# Evaluation (ROUGE + BERTScore)
rouge = evaluate.load('rouge')
bertscore = evaluate.load('bertscore')

refs = df_batch['description'].tolist()
preds = df_batch['generated_summary'].tolist()
valid = [(p, r) for p, r in zip(preds, refs) if p and r]

if valid:
    preds_valid, refs_valid = zip(*valid)
    rouge_results = rouge.compute(predictions=preds_valid, references=refs_valid)
    bert_results = bertscore.compute(predictions=preds_valid, references=refs_valid, model_type="bert-base-uncased")
    print("\n ROUGE:", rouge_results)
    print("BERTScore F1 (mean):", sum(bert_results['f1']) / len(bert_results['f1']))
    pd.DataFrame([{
        "rouge1": rouge_results['rouge1'],
        "rouge2": rouge_results['rouge2'],
        "rougeL": rouge_results['rougeL'],
        "bertscore_f1": sum(bert_results['f1']) / len(bert_results['f1'])
    }]).to_csv("evaluation_metrics.csv", index=False)
else:
    print("No valid references/predictions to evaluate.")

In [None]:
# Evaluation Matrix Result

# Evaluation results
metrics = {
    "ROUGE-1": 0.2328,
    "ROUGE-2": 0.1759,
    "ROUGE-L": 0.2176,
    "BERTScore F1": 0.5597
}

# Plot
plt.figure(figsize=(8, 6))
bars = plt.bar(metrics.keys(), metrics.values(), color=["#3498db", "#9b59b6", "#2ecc71", "#e67e22"])
plt.ylim(0, 1)
plt.title("Evaluation Metrics for Medical Report Summarization")
plt.ylabel("Score")
plt.xlabel("Metric")

# Annotate bars
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval + 0.02, f"{yval:.3f}", ha='center', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
def generate_medical_summary(text, age=None, gender=None, allergies=None, diagnosis=None, medications=None):
    if not text.strip():
        return ""

    # Normalize and validate gender
    if gender:
        gender = gender.lower()
        if gender not in ["male", "female", "other"]:
            gender = "unspecified"

    # Build demographic + clinical context
    demographics = []
    if gender: demographics.append(f"{gender}")
    if age: demographics.append(f"{age}-year-old")
    demographic_info = " ".join(demographics)

    structured_info = []
    if diagnosis: structured_info.append(f"Diagnosis: {diagnosis}")
    if medications: structured_info.append(f"Medications: {medications}")
    if allergies: structured_info.append(f"Allergies: {allergies}")

    # Final prompt construction
    preamble = f"Patient is a {demographic_info}."
    if structured_info:
        preamble += " " + " ".join(structured_info)
    prompt = f"{preamble}\nSummarize the following medical report:\n{text}"

    # Token limit check
    input_len = len(tokenizer.encode(prompt))
    if input_len < MAX_INPUT_LENGTH:
        return summarizer(prompt, max_length=SUMMARY_MAX, min_length=SUMMARY_MIN, do_sample=False)[0]['summary_text']

    # Chunking with context
    summaries = []
    for chunk in chunk_text(text, tokenizer):
        chunk_prompt = f"{preamble}\n{chunk}"
        summary = summarizer(chunk_prompt, max_length=SUMMARY_MAX, min_length=SUMMARY_MIN, do_sample=False)[0]['summary_text']
        summaries.append(summary)

    # Combine and summarize
    final_summary = summarizer(" ".join(summaries), max_length=SUMMARY_MAX*2, min_length=SUMMARY_MIN*2, do_sample=False)[0]['summary_text']
    return final_summary


In [None]:
# Original Medical Report vs Summary
text = df_cleaned.iloc[0]['cleaned_transcription']
summary = generate_medical_summary(text, age="23", gender="female")

print("\n Original Report:\n", text[:1000], "...\n")  # Limit to 1000 chars for readability
print(" Generated Summary:\n", summary)

In [None]:
def generate_medical_summary(text, age=None, gender=None):
    if not text.strip():
        return ""

    # Better demographic prompt integration
    if age and gender:
        prompt = f"Summarize the following medical report written for a {gender} aged {age}:\n{text}"
    else:
        prompt = f"Summarize the following medical report:\n{text}"

    try:
        input_ids = tokenizer.encode(prompt, truncation=False)
    except Exception as e:
        print("Encoding error:", e)
        return ""

    if len(input_ids) < MAX_INPUT_LENGTH:
        try:
            return summarizer(prompt, max_length=min(SUMMARY_MAX, len(input_ids)), min_length=SUMMARY_MIN, do_sample=False)[0]['summary_text']
        except Exception as e:
            print("Single pass error:", e)
            return ""

    summaries = []
    for chunk in chunk_text(text, tokenizer):
        try:
            if age and gender:
                chunk_prompt = f"Summarize this report for a {gender} aged {age}:\n{chunk}"
            else:
                chunk_prompt = f"Summarize this medical report:\n{chunk}"
            summary = summarizer(chunk_prompt, max_length=SUMMARY_MAX, min_length=SUMMARY_MIN, do_sample=False)[0]['summary_text']
            summaries.append(summary)
        except Exception as e:
            print("Chunk error:", e)
            continue

    try:
        final_prompt = " ".join(summaries)
        return summarizer(final_prompt, max_length=SUMMARY_MAX*2, min_length=SUMMARY_MIN*2, do_sample=False)[0]['summary_text']
    except Exception as e:
        print("Final pass error:", e)
        return " ".join(summaries)


In [None]:
#  Comparison Table for Multiple Records
# To evaluate or showcase the summarization pipeline across multiple samples.
df_display = df_cleaned.head(30).copy()
df_display['generated_summary'] = df_display['cleaned_transcription'].apply(
    lambda x: generate_medical_summary(x, age="50", gender="female")
)
display(df_display[['description', 'cleaned_transcription', 'generated_summary']])
