## **Medical Report Summarisation using Medical Knowledge**

### **References**

**Main Reference**
- Radiology report generation with medical knowledge and multilevel image-report alignment: A new method and its verification
https://www.sciencedirect.com/science/article/pii/S0933365723002282#bib1



## **Data Collection**

### **Collect Datasets**

In [None]:
'''Libraries Installation and Import'''

# install necessary libraries
!pip install Pillow
!pip install torchvision
!pip install nltk
!pip install pyspellchecker
!pip install tqdm
!pip install opencv-python

# importing required libraries
import os
import requests
import tarfile
import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
import cv2
import re
import nltk
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.corpus import stopwords
from spellchecker import SpellChecker

In [None]:
'''Setup - Generalized'''

# setup to download the IU X-Ray Dataset
dataset = 'iu_xray/'
download_path = os.path.join('./datasets', dataset)

# from google.colab import drive
# drive.mount('/content/drive')
# download_path = os.path.join('/content/drive/MyDrive/Academics/CS550 Machine Learning/CS550 ASMT MRSMK/datasets', dataset)

images_dir = os.path.join(download_path, "images")
reports_dir = os.path.join(download_path, "reports")

images_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz"
reports_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz"


# function to check the file size of a given URL
def get_file_size(url):
    response = requests.head(url)
    size_in_bytes = int(response.headers.get('Content-Length', 0))
    size_in_mb = size_in_bytes / (1024 * 1024)
    return size_in_mb


# function to download and extract from a given url to a given directory
def download_and_extract(url, save_dir):
    file_name = url.split('/')[-1]
    file_path = os.path.join(save_dir, file_name)

    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('Content-Length', 0))
    downloaded_size = 0

    with open(file_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
                downloaded_size += len(chunk)
                percent_complete = (downloaded_size / total_size) * 100
                print(f"Downloaded {downloaded_size / (1024*1024):.2f} MB out of {total_size / (1024*1024):.2f} MB: {percent_complete:.2f}% complete")

    print("\nDownload complete!")

    with tarfile.open(file_path, 'r:gz') as tar:
        members = tar.getmembers()
        total_files = len(members)

        for idx, member in enumerate(members, start=1):
            tar.extract(member, path=save_dir)
            print(f"Extracting File {idx} out of {total_files}: {member.name}")

    os.remove(file_path)


# downloading  IU X-Ray dataset
if not os.path.exists(images_dir):
    images_size = get_file_size(images_url)
    print(f"Downloading {images_url} to: {images_dir} ({images_size:.2f} MB)")
    os.makedirs(images_dir, exist_ok=True)
    download_and_extract(images_url, images_dir)
    print(f"Downloaded {images_url} to: {images_dir}")
else:
    print(f"{images_url} already exists at: {images_dir}")

if not os.path.exists(reports_dir):
    reports_size = get_file_size(reports_url)
    print(f"Downloading {reports_url} to: {reports_dir} ({reports_size:.2f} MB)")
    os.makedirs(reports_dir, exist_ok=True)
    download_and_extract(reports_url, reports_dir)
    print(f"Downloaded {reports_url} to: {reports_dir}")
else:
    print(f"{reports_url} already exists at: {reports_dir}")

In [None]:
'''Exploring the IU X-Ray Dataset Contents'''

# displaying directory and subdirectory contents
iu_xray = download_path
print("\nPath: ", iu_xray)
print(f"Directory Contents: {os.listdir(iu_xray)}")

iu_xray_images = images_dir
print("\nPath: ", iu_xray_images)
print(f"Directory Contents: {len(os.listdir(iu_xray_images))} Images")

iu_xray_reports = os.path.join(reports_dir, 'ecgen-radiology')
print("\nPath: ", iu_xray_reports)
print(f"Directory Contents: {len(os.listdir(iu_xray_reports))} Reports")

In [None]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_images_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                for parent_image in root.findall('parentImage'):
                    image_file = parent_image.attrib['id'] + ".png"
                    image_path = os.path.join(iu_xray_images, image_file)
                    image = cv2.imread(image_path)

                    if image is not None:
                        height, width, channels = image.shape
                        caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                        data.append([pmc_id, image_file, caption, comparison, indication, findings, impression, height, width])
                    else:
                        print(f"Warning: Unable to read image {image_path}")

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# create a dataframe and save it as csv
iu_xray_images_df_path = os.path.join(iu_xray, 'iu_xray_images_df.csv')
if not os.path.exists(iu_xray_images_df_path):
    data = save_images_df()
    columns = ['pmc_id', 'image_filename', 'caption', 'comparison', 'indication', 'findings', 'impression', 'height', 'width']
    iu_xray_images_df = pd.DataFrame(data, columns=columns)
    iu_xray_images_df.to_csv(iu_xray_images_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_images_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_images_df_path}")
    iu_xray_images_df = pd.read_csv(iu_xray_images_df_path)


# display the stored dataframe
print("\n\nDataframe Shape:", iu_xray_images_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_images_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_images_df.head())

In [None]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_reports_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                report_data = {
                    'pmc_id': pmc_id,
                    'findings': findings,
                    'impression': impression,
                    'comparison': comparison,
                    'indication': indication,
                }

                parent_images = root.findall('parentImage')
                report_data['image_count'] = len(parent_images)

                for i, parent_image in enumerate(parent_images, start=1):
                    image_file = parent_image.attrib['id'] + ".jpg"
                    caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                    report_data[f'image_{i}'] = f"{image_file}: {caption}" if caption else image_file

                data.append(report_data)

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# create a dataframe and save it as csv
iu_xray_reports_df_path = os.path.join(iu_xray, 'iu_xray_reports_df.csv')
if not os.path.exists(iu_xray_reports_df_path):
    data = save_reports_df()
    iu_xray_reports_df = pd.DataFrame(data)
    iu_xray_reports_df.to_csv(iu_xray_reports_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_reports_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_reports_df_path}")
    iu_xray_reports_df = pd.read_csv(iu_xray_reports_df_path)


# display the stored dataframe
print("\n\nDataframe Shape:", iu_xray_reports_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_reports_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_reports_df.head())

In [None]:
'''Displaying the Number of Images per Report'''

# displaying the distribution of number of images per report
reports_count = iu_xray_reports_df['image_count'].value_counts().rename_axis('images_qty').reset_index(name='reports_count')
print("\n\nNumber of Images per Report:\n")
display(reports_count)

## **Data Preprocessing**

### **Preprocess Images**

In [None]:
'''Preprocessing Images - Resizing, Tensor Conversion and Normalization'''

# function to find minimum dimensions of given set of images
def find_min_dimensions(image_dir):
    min_width = float('inf')
    min_height = float('inf')

    for filename in os.listdir(image_dir):
        if filename.endswith('.png'):
            img_path = os.path.join(image_dir, filename)
            with Image.open(img_path) as img:
                width, height = img.size
                min_width = min(min_width, width)
                min_height = min(min_height, height)

    return min_width, min_height


# function to preprocess and save images
def preprocess_images(input_dir, output_dir):
    min_width, min_height = find_min_dimensions(iu_xray_images)
    print(f'Minimum Width: {min_width}, Minimum Height: {min_height}\n')

    preprocess = transforms.Compose([
        transforms.Resize((min_width, min_height)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    os.makedirs(output_dir, exist_ok=True)

    cnt = 0
    for filename in os.listdir(input_dir):
        if filename.endswith('.png'):
            cnt += 1
            print(f"Preprocessing File {cnt} out of {len(os.listdir(input_dir))}: {filename}")

            image_path = os.path.join(input_dir, filename)
            image = Image.open(image_path).convert('RGB')
            processed_image = preprocess(image)

            processed_image_path = os.path.join(output_dir, filename)

            processed_image_pil = transforms.ToPILImage()(processed_image)
            processed_image_pil.save(processed_image_path)


# preprocessing images
iu_xray_images_preprocessed = os.path.join(iu_xray, 'images_preprocessed')
if not os.path.exists(iu_xray_images_preprocessed):
    print(f"Preprocessing Images to: {iu_xray_images_preprocessed}")
    preprocess_images(iu_xray_images, iu_xray_images_preprocessed)
    print(f"Preprocessed Images saved to: {iu_xray_images_preprocessed}")
else:
    print(f"Preprocessed Images already exist at: {iu_xray_images_preprocessed}")

### **Preprocess Text**

In [None]:
'''Preprocessing Text - Lowercasing, Decontracting, Punctuation Removal, Number Removal, Two-Letter Word Removal, Stop Word Removal, Negation Handling, Spell Checking, Extra Space Removal'''

# download nltk resources and initialize spell checker
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
spell = SpellChecker()


# function to convert text to lowercase
def lowercase(text):
    return text.lower() if isinstance(text, str) else text


# function to decontract words
def decontracted(text):
    if not isinstance(text, str):
        return text
    contractions = {
        "won't": "will not", "can't": "can not", "couldn't": "could not",
        "shouldn't": "should not", "wouldn't": "would not", "n't": " not",
        "'re": " are", "'s": " is", "'d": " would", "'ll": " will",
        "'t": " not", "'ve": " have", "'m": " am"
    }
    for contraction, full_form in contractions.items():
        text = text.replace(contraction, full_form)
    return text


# function to remove punctuations
def rem_punctuations(text):
    return re.sub(r'[^\w\s]', '', text) if isinstance(text, str) else text


# function to remove numbers
def rem_numbers(text):
    return re.sub(r'\d+', '', text) if isinstance(text, str) else text


# function to remove two-letter words except "no" and "ct"
def rem_two_letter_words(text):
    if not isinstance(text, str):
        return text
    return ' '.join(word for word in text.split() if len(word) > 2 or word in ["no", "ct"])


# function to remove stop words
def rem_stop_words(text):
    if not isinstance(text, str):
        return text
    stop_words = set(stopwords.words('english'))
    return ' '.join(word for word in text.split() if word not in stop_words)


# function to handle negations
def handle_negations(text):
    if not isinstance(text, str):
        return text
    negations = {"no": "not", "not": "not"}
    return ' '.join(negations.get(word, word) for word in text.split())


# function to correct spelling
def correct_spelling(text):
    if not isinstance(text, str):
        return text
    corrected = []
    for word in text.split():
        corrected_word = list(spell.candidates(word))[0] if spell.candidates(word) else word
        corrected.append(corrected_word)
    return ' '.join(corrected)


# function to remove extra spaces
def rem_extra_spaces(text):
    return ' '.join(text.split()) if isinstance(text, str) else text


# function to preprocess text
def preprocess_text(data):
    preprocessed = []
    for sentence in tqdm(data.values):
        sentence = str(sentence) 
        sentence = lowercase(sentence)
        sentence = decontracted(sentence)
        sentence = rem_punctuations(sentence)
        sentence = rem_numbers(sentence)
        sentence = rem_two_letter_words(sentence)
        sentence = rem_stop_words(sentence)
        sentence = handle_negations(sentence)
        sentence = correct_spelling(sentence)
        sentence = rem_extra_spaces(sentence)

        preprocessed.append(sentence)

    return preprocessed


# function to preprocess text and save the corresponding dataframe
def preprocess_and_save_dataframe(dataframe, path):
    columns_to_preprocess = {
        'caption': 'unknown',
        'comparison': 'no comparison',
        'indication': 'no indication',
        'findings': 'no findings',
        'impression': 'no impression'
    }

    for column, fill_value in columns_to_preprocess.items():
        if column in dataframe.columns:
            print(f"Preprocessing Column: {column}")
            dataframe.loc[:, column] = dataframe[column].fillna(fill_value).astype(str)
            dataframe.loc[:, column] = preprocess_text(dataframe[column])

    dataframe.to_csv(path, index=False)

    return dataframe


# save and display the preprocessed dataframe for images
# iu_xray_images_preprocessed_df_path = os.path.join(iu_xray, 'iu_xray_images_preprocessed_df.csv')
# if not os.path.exists(iu_xray_images_preprocessed_df_path):
#     print(f"Preprocessing Text DataFrame {iu_xray_images_df_path} to: {iu_xray_images_preprocessed_df_path}")
#     iu_xray_images_preprocessed_df = preprocess_and_save_dataframe(iu_xray_images_df, iu_xray_images_preprocessed_df_path)
#     print(f"Preprocessed Text DataFrame {iu_xray_images_df_path} saved to: {iu_xray_images_preprocessed_df_path}")
# else:
#     print(f"Preprocessed Text DataFrame {iu_xray_images_df_path} already exists at: {iu_xray_images_preprocessed_df_path}")
#     iu_xray_images_preprocessed_df = pd.read_csv(iu_xray_images_preprocessed_df_path)
# display(iu_xray_images_preprocessed_df.head())


# save and display the preprocessed dataframe for reports
iu_xray_reports_preprocessed_df_path = os.path.join(iu_xray, 'iu_xray_reports_preprocessed_df.csv')
if not os.path.exists(iu_xray_reports_preprocessed_df_path):
    print(f"Preprocessing Text DataFrame {iu_xray_reports_df_path} to: {iu_xray_reports_preprocessed_df_path}")
    iu_xray_reports_preprocessed_df = preprocess_and_save_dataframe(iu_xray_reports_df, iu_xray_reports_preprocessed_df_path)
    print(f"Preprocessed Text DataFrame {iu_xray_reports_df_path} saved to: {iu_xray_reports_preprocessed_df_path}")
else:
    print(f"Preprocessed Text DataFrame {iu_xray_reports_df_path} already exists at: {iu_xray_reports_preprocessed_df_path}")
    iu_xray_reports_preprocessed_df = pd.read_csv(iu_xray_reports_preprocessed_df_path)
display(iu_xray_reports_preprocessed_df.head())

### **Create Data Loaders**

## **Model Implementation**

### **Visual Extractor**

### **Text Encoder**

### **Multilevel Alignment**

### **Report Generator**

### **Complete Model**

## **Training**

### **Training**

## **Testing**

### **Testing**

## **Dataset Download as Zip File**

In [None]:
'''Downloading Dataset Directory with all Changes'''

# importing required libraries
import shutil
import os
from google.colab import files


# zipping and downloading the archive
zip_filename = 'IUXR.zip'
shutil.make_archive(zip_filename[:-4], 'zip', download_path)
files.download(zip_filename)