## **Medical Report Summarisation using Medical Knowledge**

### **References**

**Main Reference**
- Radiology report generation with medical knowledge and multilevel image-report alignment: A new method and its verification
https://www.sciencedirect.com/science/article/pii/S0933365723002282



## **Data Collection**

### **Collect Datasets**

In [2]:
'''Libraries Installation and Import'''

# installling necessary libraries
!pip -q install --user requests numpy pandas matplotlib tqdm Pillow opencv-python nltk pyspellchecker torch torchvision torchaudio transformers scikit-learn sentence-transformers

# importing required libraries
import os
import re
import csv
import requests
import tarfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from tqdm import tqdm
from PIL import Image
import cv2

import nltk
from nltk.corpus import stopwords
from spellchecker import SpellChecker

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn.functional as F

import transformers
from transformers import BertTokenizer, BertModel
from sentence_transformers import SentenceTransformer
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

In [3]:
'''Setting Paths'''

# project directory
# from google.colab import drive
# drive.mount('/content/drive')
# project_directory = '/content/drive/Othercomputers/My Laptop/CS550_ASMT_MRSMK/datasets'
# project_directory = '/content/drive/MyDrive/Academics/CS550 Machine Learning/CS550 ASMT MRSMK/datasets'

project_directory = "./datasets"
dataset = 'iu_xray/'
iu_xray_dataset = os.path.join(project_directory, dataset)


# input directory
input_directory = os.path.join(iu_xray_dataset, "input")

images_dir = os.path.join(input_directory, "images")
reports_dir = os.path.join(input_directory, "reports")
iu_xray_images = images_dir
iu_xray_reports = os.path.join(reports_dir, 'ecgen-radiology')


# output directory 
output_directory = os.path.join(iu_xray_dataset, "output")
os.makedirs(output_directory, exist_ok=True)

In [7]:
'''Setup - Generalized'''

# setup to download the IU X-Ray Dataset
images_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz"
reports_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz"


# function to check the file size of a given URL
def get_file_size(url):
    response = requests.head(url)
    size_in_bytes = int(response.headers.get('Content-Length', 0))
    size_in_mb = size_in_bytes / (1024 * 1024)
    return size_in_mb


# function to download and extract from a given url to a given directory
def download_and_extract(url, save_dir):
    file_name = url.split('/')[-1]
    file_path = os.path.join(save_dir, file_name)

    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('Content-Length', 0))
    downloaded_size = 0

    with open(file_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
                downloaded_size += len(chunk)
                percent_complete = (downloaded_size / total_size) * 100
                print(f"Downloaded {downloaded_size / (1024*1024):.2f} MB out of {total_size / (1024*1024):.2f} MB: {percent_complete:.2f}% complete")

    print("\nDownload complete!")

    with tarfile.open(file_path, 'r:gz') as tar:
        members = tar.getmembers()
        total_files = len(members)

        for idx, member in enumerate(members, start=1):
            tar.extract(member, path=save_dir)
            print(f"Extracting File {idx} out of {total_files}: {member.name}")

    os.remove(file_path)


# downloading  IU X-Ray dataset
if not os.path.exists(images_dir):
    images_size = get_file_size(images_url)
    print(f"Downloading {images_url} to: {images_dir} ({images_size:.2f} MB)")
    os.makedirs(images_dir, exist_ok=True)
    download_and_extract(images_url, images_dir)
    print(f"Downloaded {images_url} to: {images_dir}")
else:
    print(f"{images_url} already exists at: {images_dir}")

if not os.path.exists(reports_dir):
    reports_size = get_file_size(reports_url)
    print(f"Downloading {reports_url} to: {reports_dir} ({reports_size:.2f} MB)")
    os.makedirs(reports_dir, exist_ok=True)
    download_and_extract(reports_url, reports_dir)
    print(f"Downloaded {reports_url} to: {reports_dir}")
else:
    print(f"{reports_url} already exists at: {reports_dir}")

https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz already exists at: ./datasets/iu_xray/input/images
https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz already exists at: ./datasets/iu_xray/input/reports


In [8]:
'''Exploring the IU X-Ray Dataset Contents'''

# displaying directory and subdirectory contents
print("\nPath: ", iu_xray_images)
print(f"Directory Contents: {len(os.listdir(iu_xray_images))} Images")

print("\nPath: ", iu_xray_reports)
print(f"Directory Contents: {len(os.listdir(iu_xray_reports))} Reports")


Path:  ./datasets/iu_xray/input/images
Directory Contents: 7471 Images

Path:  ./datasets/iu_xray/input/reports/ecgen-radiology
Directory Contents: 3955 Reports


In [9]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_images_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                for parent_image in root.findall('parentImage'):
                    image_file = parent_image.attrib['id'] + ".png"
                    image_path = os.path.join(iu_xray_images, image_file)
                    image = cv2.imread(image_path)

                    if image is not None:
                        height, width, channels = image.shape
                        caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                        data.append([pmc_id, image_file, caption, comparison, indication, findings, impression, height, width])
                    else:
                        print(f"Warning: Unable to read image {image_path}")

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# creating a dataframe and saving it as .csv
iu_xray_images_df_path = os.path.join(output_directory, 'iu_xray_images_df.csv')
if not os.path.exists(iu_xray_images_df_path):
    data = save_images_df()
    columns = ['pmc_id', 'image_filename', 'caption', 'comparison', 'indication', 'findings', 'impression', 'height', 'width']
    iu_xray_images_df = pd.DataFrame(data, columns=columns)
    iu_xray_images_df.to_csv(iu_xray_images_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_images_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_images_df_path}")
    iu_xray_images_df = pd.read_csv(iu_xray_images_df_path)


# displaying the stored dataframe
print("\n\nDataframe Shape:", iu_xray_images_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_images_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_images_df.head())

Dataframe already exists at ./datasets/iu_xray/output/iu_xray_images_df.csv


Dataframe Shape: (7470, 9)


Dataframe Information:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7470 entries, 0 to 7469
Data columns (total 9 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   pmc_id          7470 non-null   int64 
 1   image_filename  7470 non-null   object
 2   caption         7468 non-null   object
 3   comparison      5210 non-null   object
 4   indication      7311 non-null   object
 5   findings        6473 non-null   object
 6   impression      7418 non-null   object
 7   height          7470 non-null   int64 
 8   width           7470 non-null   int64 
dtypes: int64(3), object(6)
memory usage: 525.4+ KB


None



Displaying Dataframe:



Unnamed: 0,pmc_id,image_filename,caption,comparison,indication,findings,impression,height,width
0,1,CXR1_1_IM-0001-3001.png,Xray Chest PA and Lateral,None.,Positive TB test,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,624,512
1,1,CXR1_1_IM-0001-4001.png,Xray Chest PA and Lateral,None.,Positive TB test,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,420,512
2,10,CXR10_IM-0002-1001.png,PA and lateral chest x-XXXX XXXX.,Chest radiographs XXXX.,"XXXX-year-old male, chest pain.",The cardiomediastinal silhouette is within nor...,No acute cardiopulmonary process.,624,512
3,10,CXR10_IM-0002-2001.png,PA and lateral chest x-XXXX XXXX.,Chest radiographs XXXX.,"XXXX-year-old male, chest pain.",The cardiomediastinal silhouette is within nor...,No acute cardiopulmonary process.,420,512
4,100,CXR100_IM-0002-1001.png,"CHEST 2V FRONTAL/LATERAL XXXX, XXXX XXXX PM",None.,,Both lungs are clear and expanded. Heart and m...,No active disease.,420,512


In [10]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_reports_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                report_data = {
                    'pmc_id': pmc_id,
                    'findings': findings,
                    'impression': impression,
                    'comparison': comparison,
                    'indication': indication,
                }

                parent_images = root.findall('parentImage')
                report_data['image_count'] = len(parent_images)

                for i, parent_image in enumerate(parent_images, start=1):
                    image_file = parent_image.attrib['id'] + ".jpg"
                    caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                    report_data[f'image_{i}'] = f"{image_file}: {caption}" if caption else image_file

                data.append(report_data)

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# creating a dataframe and saving it as .csv
iu_xray_reports_df_path = os.path.join(output_directory, 'iu_xray_reports_df.csv')
if not os.path.exists(iu_xray_reports_df_path):
    data = save_reports_df()
    iu_xray_reports_df = pd.DataFrame(data)
    iu_xray_reports_df.to_csv(iu_xray_reports_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_reports_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_reports_df_path}")
    iu_xray_reports_df = pd.read_csv(iu_xray_reports_df_path)


# displaying the stored dataframe
print("\n\nDataframe Shape:", iu_xray_reports_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_reports_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_reports_df.head())

Dataframe already exists at ./datasets/iu_xray/output/iu_xray_reports_df.csv


Dataframe Shape: (3955, 11)


Dataframe Information:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3955 entries, 0 to 3954
Data columns (total 11 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   pmc_id       3955 non-null   int64 
 1   findings     3425 non-null   object
 2   impression   3921 non-null   object
 3   comparison   2757 non-null   object
 4   indication   3865 non-null   object
 5   image_count  3955 non-null   int64 
 6   image_1      3851 non-null   object
 7   image_2      3405 non-null   object
 8   image_3      197 non-null    object
 9   image_4      16 non-null     object
 10  image_5      1 non-null      object
dtypes: int64(2), object(9)
memory usage: 340.0+ KB


None



Displaying Dataframe:



Unnamed: 0,pmc_id,findings,impression,comparison,indication,image_count,image_1,image_2,image_3,image_4,image_5
0,1,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,None.,Positive TB test,2,CXR1_1_IM-0001-3001.jpg: Xray Chest PA and Lat...,CXR1_1_IM-0001-4001.jpg: Xray Chest PA and Lat...,,,
1,10,The cardiomediastinal silhouette is within nor...,No acute cardiopulmonary process.,Chest radiographs XXXX.,"XXXX-year-old male, chest pain.",2,CXR10_IM-0002-1001.jpg: PA and lateral chest x...,CXR10_IM-0002-2001.jpg: PA and lateral chest x...,,,
2,100,Both lungs are clear and expanded. Heart and m...,No active disease.,None.,,2,CXR100_IM-0002-1001.jpg: CHEST 2V FRONTAL/LAT...,CXR100_IM-0002-2001.jpg: CHEST 2V FRONTAL/LAT...,,,
3,1000,There is XXXX increased opacity within the rig...,1. Increased opacity in the right upper lobe w...,XXXX PA and lateral chest radiographs,"XXXX-year-old male, XXXX.",3,CXR1000_IM-0003-1001.jpg: PA and lateral chest...,CXR1000_IM-0003-2001.jpg: PA and lateral chest...,CXR1000_IM-0003-3001.jpg: PA and lateral chest...,,
4,1001,Interstitial markings are diffusely prominent ...,Diffuse fibrosis. No visible focal acute disease.,,"dyspnea, subjective fevers, arthritis, immigra...",2,CXR1001_IM-0004-1001.jpg: CHEST 2V FRONTAL/LAT...,CXR1001_IM-0004-1002.jpg: CHEST 2V FRONTAL/LAT...,,,


In [11]:
'''Displaying the Number of Images per Report'''

# displaying the distribution of number of images per report
reports_count = iu_xray_reports_df['image_count'].value_counts().rename_axis('images_qty').reset_index(name='reports_count')
print("\n\nNumber of Images per Report:\n")
display(reports_count)



Number of Images per Report:



Unnamed: 0,images_qty,reports_count
0,2,3208
1,1,446
2,3,181
3,0,104
4,4,15
5,5,1


## **Data Preprocessing**

### **Preprocess Images**

In [12]:
'''Preprocessing Images - Resizing, Tensor Conversion and Normalization'''

# function to preprocess and save images
def preprocess_images(input_dir, output_dir):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    os.makedirs(output_dir, exist_ok=True)

    cnt = 0
    for filename in os.listdir(input_dir):
        if filename.endswith('.png'):
            cnt += 1
            print(f"Preprocessing File {cnt} out of {len(os.listdir(input_dir))}: {filename}")

            image_path = os.path.join(input_dir, filename)
            image = Image.open(image_path).convert('RGB')
            processed_image = preprocess(image)

            processed_image_path = os.path.join(output_dir, filename)

            processed_image_pil = transforms.ToPILImage()(processed_image)
            processed_image_pil.save(processed_image_path)


# preprocessing images
iu_xray_images_preprocessed = os.path.join(output_directory, 'images_preprocessed')
if not os.path.exists(iu_xray_images_preprocessed):
    print(f"Preprocessing Images to: {iu_xray_images_preprocessed}")
    preprocess_images(iu_xray_images, iu_xray_images_preprocessed)
    print(f"Preprocessed Images saved to: {iu_xray_images_preprocessed}")
else:
    print(f"Preprocessed Images already exist at: {iu_xray_images_preprocessed}")

Preprocessed Images already exist at: ./datasets/iu_xray/output/images_preprocessed


### **Preprocess Text**

In [10]:
'''Preprocessing Text - Lowercasing, Decontracting, Punctuation Removal, Number Removal, Two-Letter Word Removal, Stop Word Removal, Spell Checking, Extra Space Removal'''

# download nltk resources and initialize spell checker
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
spell = SpellChecker()


# function to convert text to lowercase
def lowercase(text):
    return text.lower() if isinstance(text, str) else text


# function to decontract words
def decontracted(text):
    if not isinstance(text, str):
        return text
    contractions = {
        "won't": "will not", "can't": "can not", "couldn't": "could not",
        "shouldn't": "should not", "wouldn't": "would not", "n't": " not",
        "'re": " are", "'s": " is", "'d": " would", "'ll": " will",
        "'t": " not", "'ve": " have", "'m": " am"
    }
    for contraction, full_form in contractions.items():
        text = text.replace(contraction, full_form)
    return text


# function to remove punctuations
def rem_punctuations(text):
    return re.sub(r'[^\w\s]', ' ', text) if isinstance(text, str) else text


# function to remove numbers
def rem_numbers(text):
    return re.sub(r'\d+', ' ', text) if isinstance(text, str) else text


# function to remove two-letter words except "no" and "ct"
def rem_two_letter_words(text):
    if not isinstance(text, str):
        return text
    return ' '.join(word for word in text.split() if len(word) > 2 or word in ["no", "ct"])


# function to remove stop words
def rem_stop_words(text):
    if not isinstance(text, str):
        return text
    stop_words = set(stopwords.words('english'))
    return ' '.join(word for word in text.split() if word not in stop_words)


# function to correct spelling
def correct_spelling(text):
    if not isinstance(text, str):
        return text
    corrected = []
    for word in text.split():
        corrected_word = list(spell.candidates(word))[0] if spell.candidates(word) else word
        corrected.append(corrected_word)
    return ' '.join(corrected)


# function to remove extra spaces
def rem_extra_spaces(text):
    return ' '.join(text.split()) if isinstance(text, str) else text


# function to preprocess text
def preprocess_text(data):
    preprocessed = []
    for sentence in tqdm(data.values):
        sentence = str(sentence)
        sentence = lowercase(sentence)
        sentence = decontracted(sentence)
        sentence = rem_punctuations(sentence)
        sentence = rem_numbers(sentence)
        sentence = rem_two_letter_words(sentence)
        sentence = rem_stop_words(sentence)
        sentence = correct_spelling(sentence)
        sentence = rem_extra_spaces(sentence)
        
        preprocessed.append(sentence)

    return preprocessed


# path to the preprocessed dataframe
iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_df.csv')
iu_xray_reports_preprocessed_df = iu_xray_reports_df.copy()


# preprocessing text columns in the dataframe
if not os.path.exists(iu_xray_reports_preprocessed_df_path):
    print(f"Preprocessing Text of DataFrame {iu_xray_reports_df_path} to: {iu_xray_reports_preprocessed_df_path}")
    
    preprocess_caption = True
    preprocess_comparison = True
    preprocess_indication = True
    preprocess_findings = True
    preprocess_impression = True
    
    if preprocess_caption and 'caption' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: caption")
        iu_xray_reports_preprocessed_df['caption'] = iu_xray_reports_preprocessed_df['caption'].fillna('unknown').astype(str)
        iu_xray_reports_preprocessed_df['caption'] = preprocess_text(iu_xray_reports_preprocessed_df['caption'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'caption' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_comparison and 'comparison' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: comparison")
        iu_xray_reports_preprocessed_df['comparison'] = iu_xray_reports_preprocessed_df['comparison'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['comparison'] = preprocess_text(iu_xray_reports_preprocessed_df['comparison'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'comparison' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_indication and 'indication' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: indication")
        iu_xray_reports_preprocessed_df['indication'] = iu_xray_reports_preprocessed_df['indication'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['indication'] = preprocess_text(iu_xray_reports_preprocessed_df['indication'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'indication' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_findings and 'findings' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: findings")
        iu_xray_reports_preprocessed_df['findings'] = iu_xray_reports_preprocessed_df['findings'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['findings'] = preprocess_text(iu_xray_reports_preprocessed_df['findings'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'findings' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_impression and 'impression' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: impression")
        iu_xray_reports_preprocessed_df['impression'] = iu_xray_reports_preprocessed_df['impression'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['impression'] = preprocess_text(iu_xray_reports_preprocessed_df['impression'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'impression' column to: {iu_xray_reports_preprocessed_df_path}")
else:
    print(f"Preprocessed Text of DataFrame {iu_xray_reports_df_path} already exists at: {iu_xray_reports_preprocessed_df_path}")
    

# displaying the preprocessed dataframe
iu_xray_reports_preprocessed_df = pd.read_csv(iu_xray_reports_preprocessed_df_path)
display(iu_xray_reports_preprocessed_df.head())

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Preprocessed Text of DataFrame ./datasets/iu_xray/output/iu_xray_reports_df.csv already exists at: ./datasets/iu_xray/output/iu_xray_reports_preprocessed_df.csv


Unnamed: 0,pmc_id,findings,impression,comparison,indication,image_count,image_1,image_2,image_3,image_4,image_5
0,3967,cardiomediastinal silhouette within normal lim...,acute radiographic cardiopulmonary process,none,chest pain,2,CXR3967_IM-2028-1001.jpg: PA and lateral chest...,CXR3967_IM-2028-2001.jpg: PA and lateral chest...,,,
1,3332,heart normal size contour mediastinal widening...,acute cardiopulmonary abnormalities,radiograph chest lateral xxxx xxxx,weakness,3,CXR3332_IM-1596-1001.jpg: Radiograph Chest PA ...,CXR3332_IM-1596-2001.jpg: Radiograph Chest PA ...,CXR3332_IM-1596-3001.jpg: Radiograph Chest PA ...,,
2,30,lungs clear without focal consolidation effusi...,negative acute cardiopulmonary abnormality,none,xxxx year old male chest pain,2,CXR30_IM-1385-1001.jpg: Chest x-XXXX XXXX and ...,CXR30_IM-1385-2001.jpg: Chest x-XXXX XXXX and ...,,,
3,2593,mild cardiomegaly unchanged stable superior me...,mild cardiomegaly interstitial prominence coul...,none,back pain,2,"CXR2593_IM-1084-1001.jpg: Chest, 2 views, XXXX...","CXR2593_IM-1084-2001.jpg: Chest, 2 views, XXXX...",,,
4,1165,frontal lateral views chest show normal size c...,acute active cardiac pulmonary pleural disease,none,chest pain shortness breath patient lower abdo...,1,CXR1165_IM-0110-1001.jpg: Xray Chest PA and La...,,,,


### **Create Data Loaders**

In [13]:
'''Image Data Loaders to Supply Dataset to Model in Batches'''

# classes in dataset
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image


# function to load image data with transformation and batching
def load_preprocessed_images(image_dir, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset = CustomImageDataset(image_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [14]:
'''Text Data Loaders to Supply Dataset to Model in Batches'''

# classes in dataset
class CustomTextDataset(Dataset):
    def __init__(self, text_list, tokenizer, max_length=512):
        self.text_list = text_list
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.text_list)

    def __getitem__(self, idx):
        text = self.text_list[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {key: tensor.squeeze(0) for key, tensor in encoding.items()}


# function to load text data with batching
def load_preprocessed_texts(text_list, tokenizer, batch_size=32, max_length=512):
    dataset = CustomTextDataset(text_list, tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

## **Model Implementation**

### **Visual Extractor**

In [18]:
'''Visual Extractor to Extract Data from Image and Encode it Accordingly'''

# defining device for gpu support
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# defining the visual extractor model using ResNet101 
class VisualExtractor(nn.Module):
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.visual_extractor = args.visual_extractor
        weights = models.ResNet101_Weights.DEFAULT if args.visual_extractor_pretrained else None  
        model = getattr(models, self.visual_extractor)(weights=weights)
        modules = list(model.children())[:-2]  
        self.model = nn.Sequential(*modules)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_layer = nn.Linear(model.fc.in_features, 2048) 
        
    def forward(self, images):
        patch_feats = self.model(images)
        avg_feats = self.avg_pool(patch_feats).squeeze() 
        avg_feats = self.fc_layer(avg_feats)

        batch_size, feat_size, _, _ = patch_feats.shape
        patch_feats = patch_feats.view(batch_size, feat_size, -1).permute(0, 2, 1)
        
        final_embedding = torch.cat((avg_feats.unsqueeze(1), patch_feats), dim=1) 
        
        return patch_feats, avg_feats, final_embedding


# arguments for the visual extractor
class Args:
    visual_extractor = 'resnet101' 
    visual_extractor_pretrained = True


# initializing the model
args = Args()
visual_extractor = VisualExtractor(args).to(device)


# function to extract features from images
def extract_features(images_dataloader):
    all_patch_feats = []
    all_avg_feats = []
    all_final_embeddings = []

    visual_extractor.eval() 
    with torch.no_grad():
        for images in tqdm(images_dataloader):
            images = images.to(device)
            patch_feats, avg_feats, final_embedding = visual_extractor(images)
            all_patch_feats.append(patch_feats.cpu()) 
            all_avg_feats.append(avg_feats.cpu())
            all_final_embeddings.append(final_embedding.cpu())

    all_patch_feats = torch.cat(all_patch_feats, dim=0)
    all_avg_feats = torch.cat(all_avg_feats, dim=0)
    all_final_embeddings = torch.cat(all_final_embeddings, dim=0)

    return all_patch_feats, all_avg_feats, all_final_embeddings


# function to save extracted features
def save_features(file_path, features):
    print(f"Saving features to {file_path}")
    torch.save(features, file_path)


# function to lead the extracted features
def load_features(file_path):
    if os.path.exists(file_path):
        print(f"Loading features from {file_path}")
        return torch.load(file_path, weights_only=True)
    return None


# initializing paths
feature_dir = output_directory
patch_feats_file = os.path.join(output_directory, 'patch_feats.pt')
avg_feats_file = os.path.join(output_directory, 'avg_feats.pt')
image_embeddings_file = os.path.join(output_directory, 'image_embeddings.pt')
images_dataloader = load_preprocessed_images(iu_xray_images_preprocessed)


# extracting and saving the extracted features
if os.path.exists(patch_feats_file) and os.path.exists(avg_feats_file) and os.path.exists(image_embeddings_file):
    print("All features are already precomputed and will be loaded.")
    patch_feats = load_features(patch_feats_file)
    avg_feats = load_features(avg_feats_file)
    image_embeddings = load_features(image_embeddings_file)
else:
    print("Extracting features since they are not precomputed...")
    patch_feats, avg_feats, image_embeddings = extract_features(images_dataloader) 
        
    os.makedirs(feature_dir, exist_ok=True)
    save_features(patch_feats_file, patch_feats)
    save_features(avg_feats_file, avg_feats)
    save_features(image_embeddings_file, image_embeddings)


# displaying sizes of the feature dataframes
print("Patch Features Shape:", patch_feats.shape)
print("Average Features Shape:", avg_feats.shape)
print("Final Embedding Shape:", image_embeddings.shape)

Extracting features since they are not precomputed...


  4%|█▋                                         | 9/234 [00:21<09:09,  2.44s/it]


KeyboardInterrupt: 

In [14]:
'''Visualizing Extracted Features using Plots'''

# function to visualize features using PCA and t-SNE, including K-Means clustering
def visualize_features(features, title):
    print(f"Original feature shape: {features.shape}")
    
    features = features.reshape(features.shape[0], -1) 
    
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(features)

    tsne = TSNE(n_components=2)
    tsne_result = tsne.fit_transform(features)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.scatter(pca_result[:, 0], pca_result[:, 1], alpha=0.5)
    plt.title(f"PCA - {title}")
    plt.xlabel("PCA Component 1")
    plt.ylabel("PCA Component 2")

    plt.subplot(1, 2, 2)
    plt.scatter(tsne_result[:, 0], tsne_result[:, 1], alpha=0.5)
    plt.title(f"t-SNE - {title}")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")

    plt.show()

    num_clusters = 3
    kmeans = KMeans(n_clusters=num_clusters)
    labels = kmeans.fit_predict(features)
    
    plt.figure(figsize=(8, 8))
    plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=labels, cmap='viridis', alpha=0.5)
    plt.title('t-SNE Result with K-Means Clusters')
    plt.colorbar()
    plt.show()
    plt.figure(figsize=(8, 8))
    plt.scatter(pca_result[:, 0], pca_result[:, 1], c=labels, cmap='viridis', alpha=0.5)
    plt.title('PCA Result with K-Means Clusters')
    plt.colorbar()
    plt.show()


# function to visualize features using PCA and t-SNE, without clustering
def visualize_features_2(features, title):
    print(f"Original feature shape: {features.shape}")
    
    flattened_features = features.reshape(features.shape[0], -1) 
    
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(flattened_features)

    tsne = TSNE(n_components=2)
    tsne_result = tsne.fit_transform(flattened_features)

    plt.scatter(pca_result[:, 0], pca_result[:, 1], alpha=0.7)
    plt.title(f'PCA: {title}')
    plt.show()
    
    plt.scatter(tsne_result[:, 0], tsne_result[:, 1], alpha=0.7)
    plt.title(f't-SNE: {title}')
    plt.show()


# visualizing average features
# visualize_features(avg_feats.numpy(), "Average Features")
# visualize_features(patch_feats.numpy(), "Patch Features")
# visualize_features(image_embeddings.numpy(), "Image Embeddings")

# visualize_features_2(avg_feats.numpy(), "Average Features")
# visualize_features_2(patch_feats.numpy(), "Patch Features")
# visualize_features_2(image_embeddings.numpy(), "Image Embeddings")

### **Text Encoder**

In [15]:
'''Text Encoder'''

# function to embed text
def embed_text(text_dataloader, model):
    all_embeddings = []
    
    try:
        for batch in text_dataloader:
            if isinstance(batch, str):
                inputs = tokenizer([batch], return_tensors="pt", padding=True, truncation=True).to(device)
            elif isinstance(batch, list):
                inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)
            else:
                raise ValueError("Batch must be of type str or List[str]")

            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            all_embeddings.append(embeddings)

        if all_embeddings: 
            return torch.cat(all_embeddings, dim=0)
        else:
            if isinstance(model, SentenceTransformer):
                return torch.empty(0, model.get_sentence_embedding_dimension()).to(device)
            else:
                return torch.empty(0, model.config.hidden_size).to(device)

    except Exception as e:
        print(f"Error in embedding: {e}")
        if isinstance(model, SentenceTransformer):
            return torch.empty(0, model.get_sentence_embedding_dimension()).to(device)
        else:
            return torch.empty(0, model.config.hidden_size).to(device)


# function to compute cosine similarity
def compute_cosine_similarity(embeddings1, embeddings2):
    embeddings1 = F.normalize(embeddings1, p=2, dim=1)
    embeddings2 = F.normalize(embeddings2, p=2, dim=1)
    cosine = torch.mm(embeddings1, embeddings2.t())
    return cosine


# defining device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# loading tokenizer, bert model and sentence-bert model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
sentence_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [18]:
'''Text Encoder - Medical Knowledge encoder using TextEncoder -> note shape'''

# defining radiology dictionary
radiology_dictionary = {
    "pleural": ['hemithorax', 'effusion', 'pneumothorax', 'parenchymal'],
    "lung": ['lungs', 'pulmonary', 'hilar', 'lobe', 'consolidation', 'atelectasis', 'edema', 'opacity', 'pneumonia'],
    "mediastinal": ['mediastinum', 'diaphragm', 'hemidiaphragm'],
    "cardiac": ['heart', 'cardiomegaly', 'cardiomediastinal', 'atrium', 'ventricle', 'retrocardiac'],
    "vascular": ['aorta', 'venous', 'jugular', 'aortic', 'vasculature', 'cabg'],
    "osseous": ['rib', 'sternal', 'subclavian', 'thoracic'],
    "trachea": ['endotrachea'],
    "stomach": [],
    "abdomen": [],
    "tube": ['clips'],
    "spine": ['vertebral', 'degenerative'],
    "nodule": ['mass'],
    "chest": ['small', 'enlarged', 'unchanged', 'stable', 'silhouette', 'contours', 'size', 'focal', 'mild', 'acute']
}


# Ensure the output directory exists
os.makedirs(output_directory, exist_ok=True)


# define the filename and full path
filename = 'radiology_terms.csv'
dictionary_csv = os.path.join(output_directory, filename) 


# Open the CSV file in write mode using the full path
if not os.path.exists(dictionary_csv):    
    with open(dictionary_csv, 'w', newline='') as file:  
        writer = csv.writer(file)  
        writer.writerow(['Category', 'Term']) 
    
        for category, terms in radiology_dictionary.items():
            for term in terms:
                if term: 
                    writer.writerow([category, term]) 
                    
    print(f"Dictionary saved to {dictionary_csv}")  
else:     
    print(f"File already exists at {dictionary_csv}. No changes made.")

text_list = []
with open(dictionary_csv, 'r') as file:
    reader = csv.reader(file)
    next(reader) 
    for row in reader:
        text_list.append(row[1]) 

        
# load the SentenceTransformer model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Generate embeddings
embeddings = model.encode(text_list)

# Display the shape of the embeddings
print(f"Embeddings Shape: {embeddings.shape}")

dictionary_pt = os.path.join(output_directory, 'dictionary_embeddings.pt')
if not os.path.exists(dictionary_pt):    
    torch.save(embeddings, dictionary_pt)
    print(f"Saved at : {dictionary_pt}")
else : print(f"Already saved at path : {dictionary_pt}")

    
# # embedding dictionary and reports using bert, and historical medical reports using sentence transformer
# dictionary_dataloader = load_preprocessed_texts(radiology_dictionary, tokenizer)
# dictionary_embeddings = embed_text(dictionary_dataloader, bert_model)


# # saving embeddings
# embeddings_file_path = os.path.join(output_directory, 'dictionary_embeddings.pt')
# if not os.path.exists(embeddings_file_path):
#     torch.save(dictionary_embeddings.cpu(), embeddings_file_path)
#     print(f"Dictionary embeddings saved to {embeddings_file_path}")


# # displaying shape of embeddings
# print(f"Dictionary Embeddings Shape: {dictionary_embeddings.shape}")

File already exists at ./datasets/iu_xray/output/radiology_terms.csv. No changes made.


StopIteration: 

In [86]:
'''Text Encoder - Medical History encoding using sentence encoder'''

# reading and preprocessing medical history reports
iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_df.csv')
medical_history = pd.read_csv(iu_xray_reports_preprocessed_df_path)["findings"].dropna().tolist()
medical_history = [str(report) for report in medical_history if isinstance(report, str) or pd.notna(report)]

batch_size = 32  # Set your desired batch size
tokenizer = AutoTokenizer.from_pretrained('all-MiniLM-L6-v2')  # Ensure you have your tokenizer defined

# Load the data into a DataLoader
medical_reports_dataloader = load_preprocessed_texts(medical_history, tokenizer, batch_size)

# Step 3: Print the shapes of the batches
for batch in medical_reports_dataloader:
    # Printing shapes of input tensors
    print({key: tensor.shape for key, tensor in batch.items()})
    break  # Remove this break to see all batches

OSError: all-MiniLM-L6-v2 is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

In [120]:
'''Text Encoder -  Sentence-Bert Encoder using Medical History'''

# reading and preprocessing medical history reports
iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_df.csv')
medical_history = pd.read_csv(iu_xray_reports_preprocessed_df_path)["findings"].dropna().tolist()
medical_history = [str(report) for report in medical_history if isinstance(report, str) or pd.notna(report)]

invalid_entries = [report for report in medical_history if not isinstance(report, (str, list))]

# Print if there are any invalid entries
if invalid_entries:
    print(f"Found {len(invalid_entries)} invalid entries: {invalid_entries}")
else:
    print("No invalid entries found.")

# encoding medical history using Sentence-BERT
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
historical_dataloader_b = load_preprocessed_texts(medical_history, tokenizer)
historical_embeddings_b = embed_text(historical_dataloader, bert_model)

# # Check for invalid entries after batching
# invalid_entries_after_batches = [
#     report for batch in historical_dataloader_b for report in batch.values() if not isinstance(report, (str, list))
# ]

# for batch in historical_dataloader:
#     print(type(batch))  # Check the type
#     print(batch)  # Print the batch content
#     break



# saving embeddings
historical_pt_b = os.path.join(output_directory, 'historical_embeddings_b.pt')
if not os.path.exists(historical_pt_b):
    torch.save(historical_embeddings_b.cpu(), historical_pt_b)
    print(f"Historical embeddings saved to {historical_pt_b}")
else : print(f"Already at {historical_pt_b}")


# displaying shape of embeddings
print(f"Historical Embeddings Shape: {historical_embeddings_b.shape}")

No invalid entries found.

Error in embedding: Batch must be of type str or List[str]

Already at ./datasets/iu_xray/output/historical_embeddings_b.pt

Historical Embeddings Shape: torch.Size([0, 768])


In [25]:
'''Text Encoder - Finding Reports Similar to Current Report'''

# current report embedding
current_report = 'Heart size pulmonary vascularity appear within normal limits mild tortuosity descending thoracic aorta lungs free focal airspace disease pleural effusion pneumothorax seen discrete nodules adenopathy noted degenerative changes present spine'
current_embeddings = embed_text(current_report, sentence_model)


# computing cosine similarity
historical_embeddings_normalized = historical_embeddings / historical_embeddings.norm(dim=1, keepdim=True)
current_embeddings_normalized = current_embeddings / current_embeddings.norm(dim=1, keepdim=True)

similarity_matrix = compute_cosine_similarity(dictionary_embeddings, historical_embeddings)
print(f"Similarity Matrix Shape: {similarity_matrix.shape}")


# finding top-k relevant entries
k = 5
top_k_indices = similarity_matrix.topk(k=k, dim=1).indices


# preparing relevant entries based on indices
relevant_entries = []
for row in top_k_indices:
    relevant_entries.append(medical_history[row.item()])


# printing relevant entries for each report
print(f"Relevant Entries for the Current Report: {relevant_entries}")


# update the historical embeddingx
historical_embeddings.append(current_embeddings)

Error in embedding: SentenceTransformer.forward() missing 1 required positional argument: 'input'


RuntimeError: mat1 and mat2 shapes cannot be multiplied (0x768 and 384x0)

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
import pickle

# Check if GPU is available and set the device accordingly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Function to convert findings to list of lists
def get_findings_list(df):
    findings_list = df['findings'].fillna("").tolist()  # Replace NaN with empty string
    return [[finding] for finding in findings_list]  # Convert each finding into a list of strings

# Dataset class to load the findings from the reports
class FindingsDataset(Dataset):
    def __init__(self, findings_list):
        self.findings_list = findings_list

    def __len__(self):
        return len(self.findings_list)

    def __getitem__(self, idx):
        return self.findings_list[idx]

# Define the Sentence Encoder (trainable)
class SentenceEncoder(nn.Module):
    def __init__(self, hidden_size=768):
        super(SentenceEncoder, self).__init__()
        self.encoder = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x):
        return self.encoder(x)

# Cosine similarity loss based on the image equation
def cosine_similarity_loss(H, H_b):
    cos_sim_H = F.cosine_similarity(H, H.unsqueeze(1))
    cos_sim_H_b = F.cosine_similarity(H_b, H_b.unsqueeze(1))
    loss = torch.mean((cos_sim_H_b - cos_sim_H) ** 2)
    return loss

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Move the model to the appropriate device
bert_model.to(device)

# Load your dataframe
# iu_xray_reports_preprocessed_df_path = '/kaggle/input/preprocessed-text/iu_xray_reports_df_preprocessed.csv'
iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_df.csv')
medical_history = pd.read_csv(iu_xray_reports_preprocessed_df_path)
findings_list = get_findings_list(medical_history)

# Custom collate function to tokenize the findings in batches
def collate_fn(batch):
    batch = [item[0] for item in batch]  # Flatten the batch list
    return tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)  # Move to GPU

# Create dataset and dataloader
findings_dataset = FindingsDataset(findings_list)
dataloader = DataLoader(findings_dataset, batch_size=16, collate_fn=collate_fn)

# Initialize Sentence Encoder and move to device
sentence_encoder = SentenceEncoder().to(device)

# Historical Knowledge Storage
#historical_knowledge_file = 'historical_knowledge.pkl'  # File to save the historical encodings
historical_knowledge = []  # List to store historical encodings

historical_knowledge_file = os.path.join(output_directory, 'historical_knowledge.pkl')
# if not os.path.exists(historical_pt_b):
#     torch.save(historical_embeddings_b.cpu(), historical_pt_b)
#     print(f"Historical embeddings saved to {historical_pt_b}")

# Load historical knowledge if it exists
if os.path.exists(historical_knowledge_file):
    with open(historical_knowledge_file, 'rb') as f:
        historical_knowledge = pickle.load(f)

# Optimizer
optimizer = torch.optim.Adam(sentence_encoder.parameters(), lr=1e-4)
num_epochs = 10  # Set the number of epochs

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)  # Move input_ids to GPU
        attention_mask = batch['attention_mask'].to(device)  # Move attention_mask to GPU

        # Get historical encodings from BERT
        with torch.no_grad():
            bert_output = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            H_b = bert_output.last_hidden_state.mean(dim=1).to(device)  # Average pooling of BERT embeddings and move to GPU

        # Get encodings from the trainable Sentence Encoder
        H = sentence_encoder(H_b)  # Feeding the pre-trained embeddings to the trainable encoder

        # final_H = H.detach().cpu()
        # Store the current encodings in historical knowledge
        historical_knowledge.append(H.detach().cpu())  # Detach tensor and move to CPU

        # Calculate cosine similarity loss
        loss = cosine_similarity_loss(H, H_b)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# final_H_file = 'final_H.pkl'
# with open(final_H_file, 'wb') as f:
#     pickle.dump(final_H.numpy(), f)


# Now you can utilize historical_knowledge for further inference or report generation.
with open(historical_knowledge_file, 'wb') as f:
    # Convert tensors to NumPy arrays before saving
    historical_knowledge_np = [h.numpy() for h in historical_knowledge]
    pickle.dump(historical_knowledge_np, f)


Epoch 1, Loss: 0.2381143867969513


### **Multilevel Alignment**

In [21]:
'''Multilevel Alignment based on BLIP Architecture'''

# function to extract image embeddings of a given file name
def extract_image_embeddings(image_name):
    image_embeddings = torch.load(output_directory + "image_embeddings.pt")
    if image_name in image_embeddings:
        return image_embeddings[image_name]
    else:
        return None
    

# function to extract text embeddings of a given text
def extract_historical_text_embeddings(text):
    historical_text_embeddings = torch.load(output_directory + "historical_embeddings.pt")
    if text in historical_text_embeddings:
        return historical_text_embeddings[text]
    else :
        return None

def extract_dictionary_text_embeddings(text):
    dictionary_text_embeddings = torch.load(output_directory + "dictionary_embeddings.pt")
    if text in dictionary_text_embeddings:
        return dictionary_text_embeddings[text]
    else :
        return None
    

# function to find cosine similarity between a text and an image 
def cosine_similarity(embedding1, embedding2):
    embedding1 = np.array(embedding1)
    embedding2 = np.array(embedding2)

    dot_product = np.dot(embedding1, embedding2)
    norm_embedding1 = np.linalg.norm(embedding1)
    norm_embedding2 = np.linalg.norm(embedding2)
    
    if norm_embedding1 == 0 or norm_embedding2 == 0:
        return None 
    
    return dot_product / (norm_embedding1 * norm_embedding2)
    

# function to compute Image-Text-Contrastive Loss
def batch_itc_loss(image_embeddings, text_embeddings, temperature=0.1):
    image_embeddings = F.normalize(image_embeddings, dim=1)
    text_embeddings = F.normalize(text_embeddings, dim=1)

    similarity_matrix = torch.matmul(image_embeddings, text_embeddings.T) / temperature

    labels = torch.arange(len(image_embeddings), device=image_embeddings.device)

    loss = F.cross_entropy(similarity_matrix, labels)
    return labels, loss


# function to compute Image-Text-Matching loss
def batch_itm_loss(image_embeddings, text_embeddings, match_labels):
    logits = torch.matmul(image_embeddings, text_embeddings.t())
    probabilities = torch.sigmoid(logits) 
    positive_probs = probabilities[torch.arange(len(match_labels)), match_labels]

    loss = -torch.log(positive_probs + 1e-12).mean()
    return loss


# function to get embeddings()
def get_embeddings(file_path):
    df = pd.read_csv(file_path)
    report_embeddings = []
    image_embeddings = []

    for _, row in df.iterrows():
        report_embedding = extract_historical_text_embeddings(row['findings'])

        for i in range(1, 6):
            image_col = f'image_{i}'
            if image_col in row and pd.notnull(row[image_col]):
                image_embedding = extract_image_embeddings(row[image_col])
                report_embeddings.append(report_embedding)
                image_embeddings.append(image_embedding)

    return report_embeddings, image_embeddings


# function to do training step
def train_step(file, optimizer):
    blip_model.train()
    text_embeddings, image_embeddings = get_embeddings(file_path)
    
    match_labels, itc_loss_value = batch_itc_loss(image_embeddings, text_embeddings)
    itm_loss_value = batch_itm_loss(image_embeddings, text_embeddings, match_labels)
    
    total_loss = itc_loss_value + itm_loss_value
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    return total_loss.item()

In [22]:
'''Multilevel Alignment based on BLIP Architecture - Coarse Grained Alignment (Image Text Contrastive and Image Text Matching)'''

# BLIP architecture
processor = BlipProcessor.from_pretrained("Salesforce/blip-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
blip_model.to(device)
blip_model.eval()


# training loop
num_epochs = 10  
optimizer = torch.optim.Adam(blip_model.parameters(), lr=1e-5)


# using a data loader defined that provides (image, text) pairs
for epoch in range(num_epochs):
    loss = train_step(file_path, optimizer)
    print(f"Epoch {epoch}, Loss: {loss}")

OSError: Salesforce/blip-base is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

In [None]:
'''Multilevel Alignment based on BLIP Architecture - Fine Grained Alignment'''

In [None]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist


output_aligned_features_dir = output_directory
if(os.path.exists(output_aligned_features_dir + "aligned_outputs.pt")):
    print(f"Already at {historical_pt_b}")

batch_size = 32 

final_embeddings = torch.load(output_directory + "/final_embeddings.pt")  # Shape: [7470, 50, 2048]
#patch_features = torch.load(output_directory + 'patch_feats.pt')          # Shape: [7470, 49, 2048]
dictionary_embeddings = torch.load(output_directory + "/dictionary_embeddings.pt")  # Shape: [47, 384]
dictionary_embeddings = torch.tensor(dictionary_embeddings)

# Project dictionary embeddings (V) to match the dimensionality of the image embeddings (2048)
projection_layer = nn.Linear(384, 2048)
V_projected = projection_layer(dictionary_embeddings)  # Shape: [47, 2048]
# V_projected = F.normalize(V_projected, p=2, dim=1)

# Normalize final embeddings (this is already being done)
# I_prime = F.normalize(final_embeddings, p=2, dim=1)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # weight matrices for query, key, value
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        # Compute QK^T / sqrt(d_k)
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k).float())
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

    def forward(self, V, I_prime):
        # V:  projected dictionary embeddings
        # I_prime: final_embeddings
        
        Q = self.W_q(V)  # Dictionary embeddings
        K = self.W_k(I_prime)  # Image embeddings 
        V = self.W_v(I_prime)

        Q = Q.view(Q.size(0), Q.size(1), self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(K.size(0), K.size(1), self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(V.size(0), V.size(1), self.num_heads, self.d_k).transpose(1, 2)
       
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V) # apply scaled dot-product attention

        attn_output = attn_output.transpose(1, 2).contiguous().view(V.size(0), -1, self.d_model) # concatenate heads and apply final linear transformation
        output = self.W_o(attn_output)
        
        return output, attn_weights

# Feed-forward network
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# KL Divergence Loss (for training)
# def kl_divergence(mu1, logvar1, mu2, logvar2):
#     kl_loss = 0.5 * torch.sum(logvar2 - logvar1 + (torch.exp(logvar1) + (mu1 - mu2)**2) / torch.exp(logvar2) - 1)
#     return kl_loss
def kl_divergence(mu1, logvar1, mu2, logvar2):
    # kl_loss = 0.5 * torch.sum(
    #     logvar2 - logvar1 +
    #     (torch.exp(logvar1) + (mu1 - mu2) ** 2) / torch.exp(logvar2) - 1
    # )
    # def compute_kl_loss(mu1, logvar1, mu2, logvar2):
    # Create normal distributions from the means (mu) and log-variances (logvar)
    normal1 = dist.Normal(mu1, torch.exp(0.5 * logvar1))  # exp(0.5 * logvar) gives standard deviation
    normal2 = dist.Normal(mu2, torch.exp(0.5 * logvar2))
    
    # Compute the KL divergence between the two distributions
    kl_loss = dist.kl.kl_divergence(normal1, normal2).mean()  # Mean over the batch
    return kl_loss

    # return kl_loss
    
# Model setup
d_model = 2048
num_heads = 8
d_ff = 4096  

mha = MultiHeadAttention(d_model, num_heads)
ffn = FeedForwardNetwork(d_model, d_ff)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mha = mha.to(device)
ffn = ffn.to(device)
V_projected = V_projected.to(device)
final_embeddings = final_embeddings.to(device)
#patch_features = patch_features.to(device)

dataset_size = final_embeddings.size(0)  # 7470
aligned_outputs = []

params = list(mha.parameters()) + list(ffn.parameters())  
optimizer = torch.optim.Adam(params, lr=1e-4)  

for i in range(0, dataset_size, batch_size):
    
    batch_final_embeddings = final_embeddings[i:i+batch_size]  # Shape: [batch_size, 50, 2048]
    #batch_patch_features = patch_features[i:i+batch_size]  # Shape: [batch_size, 49, 2048]

    #batch_I_prime = torch.cat((batch_final_embeddings, batch_patch_features), dim=1)  # Shape: [batch_size, 99, 2048]
    batch_I_prime = batch_final_embeddings
    batch_V_repeated = V_projected.unsqueeze(0).repeat(batch_final_embeddings.size(0), 1, 1)  # Shape: [batch_size, 47, 2048]

    batch_I_prime = batch_I_prime.to(device)
    batch_V_repeated = batch_V_repeated.to(device)

    aligned_output, _ = mha(batch_V_repeated, batch_I_prime) #multi-head attention with dictionary embeddings and image embeddings

    aligned_output_ffn = ffn(aligned_output) #feed-forward network
    
    # print("mu1 (alignment result):", mu1.min().item(), mu1.max().item())
    # print("logvar1 (alignment result):", logvar1.min().item(), logvar1.max().item())
    # print("mu2 (text encoding):", mu2.min().item(), mu2.max().item())
    # print("logvar2 (text encoding):", logvar2.min().item(), logvar2.max().item())

    
    mu1, logvar1 = torch.randn_like(aligned_output_ffn,requires_grad=True), torch.randn_like(aligned_output_ffn,requires_grad=True)  # Priors for V'
    mu2, logvar2 = torch.randn_like(batch_V_repeated,requires_grad=True), torch.randn_like(batch_V_repeated,requires_grad=True)  # Priors for V_label
    kl_loss = kl_divergence(mu1, logvar1, mu2, logvar2) # KL divergence

    print(f"KL Loss for batch {i // batch_size}: {kl_loss.item()}")

    # If training, accumulate gradients and perform optimization steps
    optimizer.zero_grad()
    kl_loss.backward()
    optimizer.step()
    
    aligned_outputs.append(aligned_output_ffn.cpu()) 

#single tensor
aligned_outputs_tensor = torch.cat(aligned_outputs, dim=0)  # Shape: [7470, 50, 2048]

torch.save(aligned_outputs_tensor, os.path.join(output_aligned_features_dir, "aligned_outputs.pt"))

print("Aligned features saved successfully.")

    


KL Loss for batch 0: 2.50960111618042
KL Loss for batch 1: 2.5075905323028564
KL Loss for batch 2: 2.5028076171875
KL Loss for batch 3: 2.509035348892212
KL Loss for batch 4: 2.510709762573242
KL Loss for batch 5: 2.5108795166015625
KL Loss for batch 6: 2.510974645614624
KL Loss for batch 7: 2.5042688846588135
KL Loss for batch 8: 2.5072593688964844
KL Loss for batch 9: 2.500169277191162
KL Loss for batch 10: 2.5004351139068604


### **Report Generator**

### **Complete Model**

## **Training**

### **Training**

## **Testing**

### **Testing**