## **Medical Report Summarisation using Medical Knowledge**

### **References**

**Main Reference**
- Radiology report generation with medical knowledge and multilevel image-report alignment: A new method and its verification
https://www.sciencedirect.com/science/article/pii/S0933365723002282



## **Data Collection**

### **Collect Datasets**

In [14]:
'''Libraries Installation and Import'''

# installling necessary libraries
!pip -q install --user requests numpy pandas matplotlib tqdm Pillow opencv-python nltk pyspellchecker torch torchvision torchaudio transformers scikit-learn sentence-transformers

# importing required libraries
import os
import re
import csv
import requests
import tarfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from tqdm import tqdm
from PIL import Image
import cv2

import nltk
from nltk.corpus import stopwords
from spellchecker import SpellChecker
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as dist
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import models, transforms

import transformers
from transformers import BertTokenizer, BertModel
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer
from transformers import DistilBertModel, DistilBertTokenizer
from sentence_transformers import SentenceTransformer

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split

In [2]:
'''Setting Paths'''

# project directory
# from google.colab import drive
# drive.mount('/content/drive')
# project_directory = '/content/drive/Othercomputers/My Laptop/CS550_ASMT_MRSMK/datasets'
# project_directory = '/content/drive/MyDrive/Academics/CS550 Machine Learning/CS550 ASMT MRSMK/datasets'

project_directory = "./datasets"
dataset = 'iu_xray/'
iu_xray_dataset = os.path.join(project_directory, dataset)


# input directory
input_directory = os.path.join(iu_xray_dataset, "input")

images_dir = os.path.join(input_directory, "images")
reports_dir = os.path.join(input_directory, "reports")
iu_xray_images = images_dir
iu_xray_reports = os.path.join(reports_dir, 'ecgen-radiology')


# output directory 
output_directory = os.path.join(iu_xray_dataset, "output")
os.makedirs(output_directory, exist_ok=True)

In [13]:
'''Setup - Generalized'''

# setup to download the IU X-Ray Dataset
images_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz"
reports_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz"


# function to check the file size of a given URL
def get_file_size(url):
    response = requests.head(url)
    size_in_bytes = int(response.headers.get('Content-Length', 0))
    size_in_mb = size_in_bytes / (1024 * 1024)
    return size_in_mb


# function to download and extract from a given url to a given directory
def download_and_extract(url, save_dir):
    file_name = url.split('/')[-1]
    file_path = os.path.join(save_dir, file_name)

    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('Content-Length', 0))
    downloaded_size = 0

    with open(file_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
                downloaded_size += len(chunk)
                percent_complete = (downloaded_size / total_size) * 100
                print(f"Downloaded {downloaded_size / (1024*1024):.2f} MB out of {total_size / (1024*1024):.2f} MB: {percent_complete:.2f}% complete", end="\r")

    print("\nDownload complete!")

    with tarfile.open(file_path, 'r:gz') as tar:
        members = tar.getmembers()
        total_files = len(members)

        for idx, member in enumerate(members, start=1):
            tar.extract(member, path=save_dir)
            print(f"Extracting File {idx} out of {total_files}: {member.name}", end="\r")

    os.remove(file_path)


# downloading  IU X-Ray dataset
if not os.path.exists(images_dir):
    images_size = get_file_size(images_url)
    print(f"Downloading {images_url} to: {images_dir} ({images_size:.2f} MB)")
    os.makedirs(images_dir, exist_ok=True)
    download_and_extract(images_url, images_dir)
    print(f"Downloaded {images_url} to: {images_dir}")
else:
    print(f"{images_url} already exists at: {images_dir}")

if not os.path.exists(reports_dir):
    reports_size = get_file_size(reports_url)
    print(f"Downloading {reports_url} to: {reports_dir} ({reports_size:.2f} MB)")
    os.makedirs(reports_dir, exist_ok=True)
    download_and_extract(reports_url, reports_dir)
    print(f"Downloaded {reports_url} to: {reports_dir}")
else:
    print(f"{reports_url} already exists at: {reports_dir}")

https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz already exists at: ./datasets/iu_xray/input/images
https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz already exists at: ./datasets/iu_xray/input/reports


In [21]:
'''Exploring the IU X-Ray Dataset Contents'''

# displaying directory and subdirectory contents
print("\nPath: ", iu_xray_images)
print(f"Directory Contents: {len(os.listdir(iu_xray_images))} Images")

print("\nPath: ", iu_xray_reports)
print(f"Directory Contents: {len(os.listdir(iu_xray_reports))} Reports")


Path:  ./datasets/iu_xray/input/images
Directory Contents: 7471 Images

Path:  ./datasets/iu_xray/input/reports/ecgen-radiology
Directory Contents: 3955 Reports


In [22]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_images_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}", end="\r")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                for parent_image in root.findall('parentImage'):
                    image_file = parent_image.attrib['id'] + ".png"
                    image_path = os.path.join(iu_xray_images, image_file)
                    image = cv2.imread(image_path)

                    if image is not None:
                        height, width, channels = image.shape
                        caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                        data.append([pmc_id, image_file, caption, comparison, indication, findings, impression, height, width])
                    else:
                        print(f"Warning: Unable to read image {image_path}")

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# creating a dataframe and saving it as .csv
iu_xray_images_df_path = os.path.join(output_directory, 'iu_xray_images_df.csv')
if os.path.exists(iu_xray_images_df_path):
    data = save_images_df()
    columns = ['pmc_id', 'image_filename', 'caption', 'comparison', 'indication', 'findings', 'impression', 'height', 'width']
    iu_xray_images_df = pd.DataFrame(data, columns=columns)
    iu_xray_images_df.to_csv(iu_xray_images_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_images_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_images_df_path}")
    iu_xray_images_df = pd.read_csv(iu_xray_images_df_path)


# displaying the stored dataframe
print("\n\nDataframe Shape:", iu_xray_images_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_images_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_images_df.head())

Dataframe saved to ./datasets/iu_xray/output/iu_xray_images_df.csv


Dataframe Shape: (7470, 9)


Dataframe Information:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7470 entries, 0 to 7469
Data columns (total 9 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   pmc_id          7470 non-null   object
 1   image_filename  7470 non-null   object
 2   caption         7468 non-null   object
 3   comparison      6313 non-null   object
 4   indication      7311 non-null   object
 5   findings        6473 non-null   object
 6   impression      7418 non-null   object
 7   height          7470 non-null   int64 
 8   width           7470 non-null   int64 
dtypes: int64(2), object(7)
memory usage: 525.4+ KB


None



Displaying Dataframe:



Unnamed: 0,pmc_id,image_filename,caption,comparison,indication,findings,impression,height,width
0,718,CXR718_IM-2280-1001.png,Xray Chest PA and Lateral,,XXXX for 3 weeks,The lungs are clear. There is no pleural effus...,No acute pulmonary disease.,512,512
1,718,CXR718_IM-2280-3001.png,Xray Chest PA and Lateral,,XXXX for 3 weeks,The lungs are clear. There is no pleural effus...,No acute pulmonary disease.,513,512
2,2793,CXR2793_IM-1226-1001.png,"Radiographs of the chest, 2 views, dated XXXX,...",None.,XXXX-year-old female. Dyspnea.,The cardiomediastinal silhouette is normal in ...,"Mild lung hyperexpansion, otherwise clear.",512,512
3,2793,CXR2793_IM-1226-2001.png,"Radiographs of the chest, 2 views, dated XXXX,...",None.,XXXX-year-old female. Dyspnea.,The cardiomediastinal silhouette is normal in ...,"Mild lung hyperexpansion, otherwise clear.",512,512
4,3773,CXR3773_IM-1891-1001.png,Xray Chest PA and Lateral,None.,"Numbness and tingling in the left arm, nausea ...",The lungs and pleural spaces show no acute abn...,1. No acute pulmonary abnormality.,511,512


In [6]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_reports_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}", end="\r")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                report_data = {
                    'pmc_id': pmc_id,
                    'findings': findings,
                    'impression': impression,
                    'comparison': comparison,
                    'indication': indication,
                }

                parent_images = root.findall('parentImage')
                report_data['image_count'] = len(parent_images)

                for i, parent_image in enumerate(parent_images, start=1):
                    image_file = parent_image.attrib['id'] + ".jpg"
                    caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                    report_data[f'image_{i}'] = f"{image_file}: {caption}" if caption else image_file

                data.append(report_data)

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# creating a dataframe and saving it as .csv
iu_xray_reports_df_path = os.path.join(output_directory, 'iu_xray_reports_df.csv')
if not os.path.exists(iu_xray_reports_df_path):
    data = save_reports_df()
    iu_xray_reports_df = pd.DataFrame(data)
    iu_xray_reports_df.to_csv(iu_xray_reports_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_reports_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_reports_df_path}")
    iu_xray_reports_df = pd.read_csv(iu_xray_reports_df_path)


# displaying the stored dataframe
print("\n\nDataframe Shape:", iu_xray_reports_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_reports_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_reports_df.head())

Dataframe saved to ./datasets/iu_xray/output/iu_xray_reports_df.csv


Dataframe Shape: (3955, 11)


Dataframe Information:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3955 entries, 0 to 3954
Data columns (total 11 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   pmc_id       3955 non-null   object
 1   findings     3425 non-null   object
 2   impression   3921 non-null   object
 3   comparison   3333 non-null   object
 4   indication   3865 non-null   object
 5   image_count  3955 non-null   int64 
 6   image_1      3851 non-null   object
 7   image_2      3405 non-null   object
 8   image_3      197 non-null    object
 9   image_4      16 non-null     object
 10  image_5      1 non-null      object
dtypes: int64(1), object(10)
memory usage: 340.0+ KB


None



Displaying Dataframe:



Unnamed: 0,pmc_id,findings,impression,comparison,indication,image_count,image_1,image_2,image_3,image_4,image_5
0,718,The lungs are clear. There is no pleural effus...,No acute pulmonary disease.,,XXXX for 3 weeks,2,CXR718_IM-2280-1001.jpg: Xray Chest PA and Lat...,CXR718_IM-2280-3001.jpg: Xray Chest PA and Lat...,,,
1,2793,The cardiomediastinal silhouette is normal in ...,"Mild lung hyperexpansion, otherwise clear.",None.,XXXX-year-old female. Dyspnea.,2,CXR2793_IM-1226-1001.jpg: Radiographs of the c...,CXR2793_IM-1226-2001.jpg: Radiographs of the c...,,,
2,3773,The lungs and pleural spaces show no acute abn...,1. No acute pulmonary abnormality.,None.,"Numbness and tingling in the left arm, nausea ...",2,CXR3773_IM-1891-1001.jpg: Xray Chest PA and La...,CXR3773_IM-1891-2001.jpg: Xray Chest PA and La...,,,
3,3785,No pneumothorax or large pleural effusion. Mil...,No acute cardiopulmonary disease.,None available,XXXX-year-old female with one-XXXX history of ...,2,CXR3785_IM-1898-1001.jpg: Chest XXXX and lateral,CXR3785_IM-1898-2001.jpg: Chest XXXX and lateral,,,
4,7,The cardiac contours are normal. XXXX basilar ...,Basilar atelectasis. No confluent lobar consol...,"XXXX, XXXX",Preop lumbar surgery,2,CXR7_IM-2263-1001.jpg: Xray Chest PA and Lateral,CXR7_IM-2263-2001.jpg: Xray Chest PA and Lateral,,,


In [7]:
'''Displaying the Number of Images per Report'''

# displaying the distribution of number of images per report
reports_count = iu_xray_reports_df['image_count'].value_counts().rename_axis('images_qty').reset_index(name='reports_count')
print("\n\nNumber of Images per Report:\n")
display(reports_count)



Number of Images per Report:



Unnamed: 0,images_qty,reports_count
0,2,3208
1,1,446
2,3,181
3,0,104
4,4,15
5,5,1


In [8]:
'''Checking for Duplicates'''

# Check for duplicate values in the 'pmc_id' column
duplicates_in_pmc_id = iu_xray_reports_df['pmc_id'].duplicated()
num_duplicates = duplicates_in_pmc_id.sum()

# Display the duplicated rows
duplicated_rows = iu_xray_reports_df[duplicates_in_pmc_id]
print(f"Number of duplicates in 'pmc_id' column: {num_duplicates}")
print("Duplicated rows in 'pmc_id' column:")
print(duplicated_rows)

Number of duplicates in 'pmc_id' column: 0
Duplicated rows in 'pmc_id' column:
Empty DataFrame
Columns: [pmc_id, findings, impression, comparison, indication, image_count, image_1, image_2, image_3, image_4, image_5]
Index: []


## **Data Preprocessing**

### **Preprocess Images**

In [23]:
'''Preprocessing Images - Resizing, Tensor Conversion and Normalization'''

# function to preprocess and save images
def preprocess_images(input_dir, output_dir):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    os.makedirs(output_dir, exist_ok=True)

    cnt = 0
    for filename in os.listdir(input_dir):
        if filename.endswith('.png'):
            cnt += 1
            print(f"Preprocessing File {cnt} out of {len(os.listdir(input_dir))}: {filename}", end="\r")

            image_path = os.path.join(input_dir, filename)
            image = Image.open(image_path).convert('RGB')
            processed_image = preprocess(image)

            processed_image_path = os.path.join(output_dir, filename)

            processed_image_pil = transforms.ToPILImage()(processed_image)
            processed_image_pil.save(processed_image_path)


# preprocessing images
iu_xray_images_preprocessed = os.path.join(output_directory, 'images_preprocessed')
if not os.path.exists(iu_xray_images_preprocessed):
    print(f"Preprocessing Images to: {iu_xray_images_preprocessed}")
    preprocess_images(iu_xray_images, iu_xray_images_preprocessed)
    print(f"Preprocessed Images saved to: {iu_xray_images_preprocessed}")
else:
    print(f"Preprocessed Images already exist at: {iu_xray_images_preprocessed}")

Preprocessing Images to: ./datasets/iu_xray/output/images_preprocessed
Preprocessed Images saved to: ./datasets/iu_xray/output/images_preprocessed


### **Preprocess Text**

In [24]:
'''Preprocessing Text - Lowercasing, Decontracting, Punctuation Removal, Number Removal, Two-Letter Word Removal, Stop Word Removal, Spell Checking, Extra Space Removal'''

# download nltk resources and initialize spell checker
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
spell = SpellChecker()


# function to convert text to lowercase
def lowercase(text):
    return text.lower() if isinstance(text, str) else text


# function to decontract words
def decontracted(text):
    if not isinstance(text, str):
        return text
    contractions = {
        "won't": "will not", "can't": "can not", "couldn't": "could not",
        "shouldn't": "should not", "wouldn't": "would not", "n't": " not",
        "'re": " are", "'s": " is", "'d": " would", "'ll": " will",
        "'t": " not", "'ve": " have", "'m": " am"
    }
    for contraction, full_form in contractions.items():
        text = text.replace(contraction, full_form)
    return text


# function to remove punctuations
def rem_punctuations(text):
    return re.sub(r'[^\w\s]', ' ', text) if isinstance(text, str) else text


# function to remove numbers
def rem_numbers(text):
    if not isinstance(text, str):
        return text
    text = re.sub(r'[xX]{2,}', '', text)
    return re.sub(r'\d+', '', text)


# function to remove two-letter words except "no" and "ct"
def rem_two_letter_words(text):
    if not isinstance(text, str):
        return text
    return ' '.join(word for word in text.split() if len(word) > 2 or word in ["no", "ct"])


# function to remove stop words
def rem_stop_words(text):
    if not isinstance(text, str):
        return text
    stop_words = set(stopwords.words('english'))
    return ' '.join(word for word in text.split() if word not in stop_words)


# function to correct spelling
def correct_spelling(text):
    if not isinstance(text, str):
        return text
    corrected = []
    for word in text.split():
        corrected_word = list(spell.candidates(word))[0] if spell.candidates(word) else word
        corrected.append(corrected_word)
    return ' '.join(corrected)


# function to remove extra spaces
def rem_extra_spaces(text):
    return ' '.join(text.split()) if isinstance(text, str) else text


# function to handle full stops
def handle_fullstops(text):
    if not isinstance(text, str):
        return text
    text = re.sub(r'\.\.+', '.', text) 
    return re.sub(r'\.', ' . ', text) 


# function to remove apostrophes
def rem_apostrophes(text):
    return re.sub("'", '', text) if isinstance(text, str) else text


# function to preprocess text
def preprocess_text(data):
    preprocessed = []
    for sentence in tqdm(data.values):
        sentence = str(sentence)
        sentence = lowercase(sentence)
        sentence = decontracted(sentence)
        sentence = rem_punctuations(sentence)
        sentence = rem_numbers(sentence)
        sentence = rem_two_letter_words(sentence)
        sentence = rem_stop_words(sentence)
        sentence = correct_spelling(sentence)
        sentence = rem_extra_spaces(sentence)
        sentence = handle_fullstops(sentence)
        sentence = rem_apostrophes(sentence)
        
        preprocessed.append(sentence)

    return preprocessed

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [6]:
'''Preprocessing Text - Lowercasing, Decontracting, Punctuation Removal, Number Removal, Two-Letter Word Removal, Stop Word Removal, Spell Checking, Extra Space Removal'''

# load your DataFrame (replace with actual path)
iu_xray_reports_df = os.path.join(output_directory, 'iu_xray_reports_df.csv')
df = pd.read_csv(iu_xray_reports_df)


# apply preprocessing on specific columns if they exist
preprocess_columns = ['findings']
for column in preprocess_columns:
    if column in df.columns:
        print(f"Preprocessing Column: {column}")
        df[column] = df[column].fillna('none').astype(str)
        df[column] = preprocess_text(df[column])
        output_path = os.path.join(output_directory, f'preprocessed_{column}.csv')
        df.to_csv(output_path, index=False)
        print(f"Saved preprocessed '{column}' column to: {output_path}")
        
        
# split into Train/Validation/Test (70%/10%/20%)
output_path = os.path.join(output_directory, f'preprocessed_findings.csv')
df = pd.read_csv(output_path)        
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=(2/3), random_state=42)


# save the splits
train_path = os.path.join(output_directory, 'train_data.csv')
train_df.to_csv(train_path, index=False)
display(train_df.head())
print(f"Train data saved to: {train_path}")

val_path = os.path.join(output_directory, 'val_data.csv')
val_df.to_csv(val_path, index=False)
display(val_df.head())
print(f"Validation data saved to: {val_path}")

test_path = os.path.join(output_directory, 'test_data.csv')
test_df.to_csv(test_path, index=False)
display(test_df.head())
print(f"Test data saved to: {test_path}")

Unnamed: 0,pmc_id,findings,impression,comparison,indication,image_count,image_1,image_2,image_3,image_4,image_5
3185,755,none,Lungs clear. Heart size normal. Flowing syndes...,"XXXX, XXXX",heart murmur?,2,CXR755_IM-2306-1001.jpg: PA and lateral views ...,CXR755_IM-2306-2001.jpg: PA and lateral views ...,,,
3187,2093,lungs clear cardiomediastinal silhouette withi...,No focal lung infiltrates.,CT chest XXXX,,2,CXR2093_IM-0723-1001.jpg: PA and lateral view...,CXR2093_IM-0723-2001.jpg: PA and lateral view...,,,
111,1464,heart pulmonary mediastinum within normal limi...,No acute cardiopulmonary disease. Evidence of ...,None.,XXXX-year-old with initiation of XXXX medicati...,2,CXR1464_IM-0301-1001.jpg: Xray Chest PA and La...,CXR1464_IM-0301-2001.jpg: Xray Chest PA and La...,,,
2320,2419,none,Normal heart size. Normal pulmonary vasculatur...,XXXX.,PNUEMONIA exsmoker x 10 yrs hx XXXX onset asth...,2,CXR2419_IM-0963-1001.jpg: Xray Chest PA and La...,CXR2419_IM-0963-2001.jpg: Xray Chest PA and La...,,,
1089,1233,enlarged cardiomediastinal silhouette low lung...,Cardiomegaly without heart failure. Minimal XX...,"XXXX, XXXX.",XXXX-year-old. Chest pain.,2,CXR1233_IM-0157-1001.jpg: PA and lateral views.,CXR1233_IM-0157-2001.jpg: PA and lateral views.,,,


Train data saved to: ./datasets/iu_xray/output/train_data.csv


Unnamed: 0,pmc_id,findings,impression,comparison,indication,image_count,image_1,image_2,image_3,image_4,image_5
299,3070,heart size within normal limits focal airspace...,No acute cardiopulmonary findings.,None available.,"XXXX-year-old male, reactive PPD.",2,CXR3070_IM-1432-1001.jpg: PA and lateral views...,CXR3070_IM-1432-1002.jpg: PA and lateral views...,,,
3900,345,cardiomediastinal silhouette pulmonary muscula...,No acute cardiopulmonary findings.,,Costochondral chest pain,2,"CXR345_IM-1672-1001.jpg: Chest, 2 views, XXXX ...","CXR345_IM-1672-2001.jpg: Chest, 2 views, XXXX ...",,,
442,1137,none,,,,2,CXR1137_IM-0093-12012.jpg: Xray Chest PA and L...,CXR1137_IM-0093-4004.jpg: Xray Chest PA and La...,,,
3406,935,none,"Comparison XXXX, XXXX. Well-expanded and clear...",,"found unresponsive,",2,CXR935_IM-2432-1001.jpg: Xray Chest PA and Lat...,CXR935_IM-2432-2001.jpg: Xray Chest PA and Lat...,,,
3023,303,heart normal size mediastinum stable rectal ba...,"Mild bilateral streaky opacities, XXXX atelect...",XXXX,"XXXX, asthma, preop hip replacement",1,CXR303_IM-1404-1001.jpg: CHEST 2V FRONTAL/LATE...,,,,


Validation data saved to: ./datasets/iu_xray/output/val_data.csv


Unnamed: 0,pmc_id,findings,impression,comparison,indication,image_count,image_1,image_2,image_3,image_4,image_5
134,3790,low lung volumes elevation right hemidiaphragm...,Right lower lobe airspace disease. .,PA and lateral views of the chest dated XXXX. ...,"XXXX-year-old female, pain, seen on XXXX for r...",2,CXR3790_IM-1904-0001-0001.jpg: Xray Chest PA a...,CXR3790_IM-1904-0001-0002.jpg: Xray Chest PA a...,,,
3415,2282,lungs clear bilaterally specifically evidence ...,No acute cardiopulmonary abnormality..,"Two-view chest radiograph dated XXXX, XXXX..","XXXX-year-old male, XXXX 2 XXXX ago, rib pain..",2,CXR2282_IM-0869-1001.jpg: PA and lateral chest...,CXR2282_IM-0869-2001.jpg: PA and lateral chest...,,,
921,2841,heart normal size contour lungs clear without ...,No acute cardiopulmonary disease.,None.,XXXX year old mid to lower back pain since XXXX.,1,CXR2841_IM-1253-2001.jpg: Xray Chest PA and La...,,,,
1029,2192,focal lung consolidation pneumothorax pleural ...,No acute cardiopulmonary process.,XXXX performed XXXX/XXXX,"XXXX-year-old with XXXX, history of lung nodules.",2,CXR2192_IM-0802-2002.jpg: Xray Chest PA and La...,CXR2192_IM-0802-3003.jpg: Xray Chest PA and La...,,,
2509,3149,lungs hyperexpanded cardiomediastinal silhouet...,Lung hyperexpansion. No focal air space disease.,"XXXX, XXXX.",XXXX-year-old male with XXXX and asthma.,2,CXR3149_IM-1480-1001.jpg: PA and Lateral views...,CXR3149_IM-1480-2001.jpg: PA and Lateral views...,,,


Test data saved to: ./datasets/iu_xray/output/test_data.csv


In [11]:
'''Indexing the Training Data'''

# indexing the train data
df = pd.read_csv(train_path)     
df = df[['pmc_id', 'findings', 'image_1', 'image_2']]
df['index'] = range(1, len(df) + 1)


# display the modified DataFrame to check the output
print("Shape of the DataFrame:", df.shape)
display(df.head(20))
df.to_csv(train_path, index=False)
print(f"Train data saved to: {train_path}")

Shape of the DataFrame: (2768, 5)


Unnamed: 0,pmc_id,findings,image_1,image_2,index
0,755,none,CXR755_IM-2306-1001.jpg: PA and lateral views ...,CXR755_IM-2306-2001.jpg: PA and lateral views ...,1
1,2093,lungs clear cardiomediastinal silhouette withi...,CXR2093_IM-0723-1001.jpg: PA and lateral view...,CXR2093_IM-0723-2001.jpg: PA and lateral view...,2
2,1464,heart pulmonary mediastinum within normal limi...,CXR1464_IM-0301-1001.jpg: Xray Chest PA and La...,CXR1464_IM-0301-2001.jpg: Xray Chest PA and La...,3
3,2419,none,CXR2419_IM-0963-1001.jpg: Xray Chest PA and La...,CXR2419_IM-0963-2001.jpg: Xray Chest PA and La...,4
4,1233,enlarged cardiomediastinal silhouette low lung...,CXR1233_IM-0157-1001.jpg: PA and lateral views.,CXR1233_IM-0157-2001.jpg: PA and lateral views.,5
5,847,trachea midline heart slightly large low lung ...,CXR847_IM-2369-1001.jpg: CHEST 2V FRONTAL/LATE...,CXR847_IM-2369-1002.jpg: CHEST 2V FRONTAL/LATE...,6
6,2914,small area scarring lateral right mid lower lu...,,,7
7,2637,moderate marked enlargement cardiac silhouette...,CXR2637_IM-1122-1001.jpg: Xray Chest PA and La...,CXR2637_IM-1122-2001.jpg: Xray Chest PA and La...,8
8,2197,heart size normal parities appear right fissur...,"CXR2197_IM-0807-1001.jpg: Chest x-XXXX, 2 view...","CXR2197_IM-0807-2001.jpg: Chest x-XXXX, 2 view...",9
9,2532,stable enlargement cardiac silhouette lateral ...,CXR2532_IM-1046-1001.jpg: Xray Chest PA and La...,CXR2532_IM-1046-2001.jpg: Xray Chest PA and La...,10


Train data saved to: ./datasets/iu_xray/output/train_data.csv


In [13]:
'''Creating Filtered Dataframes'''

# filtering the data
filtered_train_path = os.path.join(output_directory, 'filtered_train_data.csv')
df = pd.read_csv(train_path)  
print("Shape of the DataFrame Before:", df.shape)

filtered_df = df.dropna(subset=['image_1', 'image_2'], how='all')
print("Shape of the DataFrame After:", filtered_df.shape)

filtered_df.to_csv(filtered_train_path, index=False)
print(f"Dataframe saved to: {filtered_train_path}")


# filtering the data
filtered_val_path = os.path.join(output_directory, 'filtered_val_data.csv')
df = pd.read_csv(val_path)  
print("Shape of the DataFrame Before:", df.shape)

filtered_df = df.dropna(subset=['image_1', 'image_2'], how='all')
print("Shape of the DataFrame After:", filtered_df.shape)

filtered_df.to_csv(filtered_val_path, index=False)
print(f"Dataframe saved to: {filtered_val_path}")


# filtering the data
filtered_test_path = os.path.join(output_directory, 'filtered_test_data.csv')
df = pd.read_csv(test_path)  
print("Shape of the DataFrame Before:", df.shape)

filtered_df = df.dropna(subset=['image_1', 'image_2'], how='all')
print("Shape of the DataFrame After:", filtered_df.shape)

filtered_df.to_csv(filtered_test_path, index=False)
print(f"Dataframe saved to: {filtered_test_path}")

Shape of the DataFrame Before: (2768, 5)
Shape of the DataFrame After: (2700, 5)
Dataframe saved to: ./datasets/iu_xray/output/filtered_train_data.csv
Shape of the DataFrame Before: (395, 11)
Shape of the DataFrame After: (385, 11)
Dataframe saved to: ./datasets/iu_xray/output/filtered_val_data.csv
Shape of the DataFrame Before: (792, 11)
Shape of the DataFrame After: (766, 11)
Dataframe saved to: ./datasets/iu_xray/output/filtered_test_data.csv


### **Create Data Loaders**

In [16]:
'''Image Data Loaders to Supply Dataset to Model in Batches'''

# classes in dataset
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image


# function to load image data with transformation and batching
def load_preprocessed_images(image_dir, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset = CustomImageDataset(image_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [18]:
'''Text Data Loaders to Supply Dataset to Model in Batches'''

# classes in dataset
class CustomTextDataset(Dataset):
    def __init__(self, text_list, tokenizer, max_length=512):
        self.text_list = text_list
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.text_list)

    def __getitem__(self, idx):
        text = self.text_list[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {key: tensor.squeeze(0) for key, tensor in encoding.items()}


# function to load text data with batching
def load_preprocessed_texts(text_list, tokenizer, batch_size=32, max_length=512):
    dataset = CustomTextDataset(text_list, tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

## **Model Implementation**

### **Visual Extractor**

In [27]:
'''Visual Extractor to Extract Data from Image and Encode it Accordingly'''

# Define paths for saving features
patch_feats_file = os.path.join(output_directory, 'patch_feats.pt')
avg_feats_file = os.path.join(output_directory, 'avg_feats.pt')
final_embeddings_file = os.path.join(output_directory, 'final_embeddings.pt')
iu_xray_images_preprocessed = os.path.join(output_directory, "images_preprocessed")


# Define the transform for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
])


# Define VisualExtractor class
class VisualExtractor(nn.Module):
    def __init__(self, visual_extractor='resnet101', pretrained=True):
        super(VisualExtractor, self).__init__()

        model = getattr(models, visual_extractor)(pretrained=pretrained)
        
        # Remove the last fully connected layer
        modules = list(model.children())[:-2]  
        self.model = nn.Sequential(*modules)
        
        # Average pooling and a fully connected layer to transform the features
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_layer = nn.Linear(model.fc.in_features, 2048)
        
    def forward(self, x):
        x = self.model(x)
        avg_feats = self.avg_pool(x).view(x.size(0), -1)
        patch_feats = self.fc_layer(avg_feats)
        return patch_feats, avg_feats

    
# Load images function that handles the new DataFrame structure
def load_images(report_row, img_folder):
    images = []
    for i in range(1, 3):  # image_1 and image_2
        image = str(report_row[f'image_{i}']).strip()
        img_filename = image.split('.')[0] + '.png'  
        img_path = os.path.join(img_folder, img_filename)
        
        # Check if the image file exists before attempting to open it
        if os.path.exists(img_path):
            img = Image.open(img_path).convert("RGB")
            images.append(transform(img))
        else:
            print(f"Warning: Image file not found: {img_path}")  # Warning if file not found

    # If only one image was loaded, duplicate it
    if len(images) == 1:
        images.append(images[0].clone())  # Duplicate the single available image
    
    return torch.stack(images) if images else torch.tensor([])  # Return empty tensor if no images loaded


# Initialize visual extractor
visual_extractor = VisualExtractor()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visual_extractor.to(device)


# Set learning rates and other parameters
learning_rate = 5e-5  
other_parameter = 1e-4


# Create an optimizer
optimizer = optim.Adam(visual_extractor.parameters(), lr=learning_rate, weight_decay=other_parameter)


# function to extract features
def extract_features(data_loader):
    patch_feats, avg_feats, final_embeddings = [], [], []
    
    for idx, row in data_loader.iterrows():
        images = load_images(row, iu_xray_images_preprocessed).to(device)  # Load up to 2 images
        
        if images.numel() == 0:  # Skip if no images were loaded
            continue
        
        report_patch_feats, report_avg_feats = [], []
        
        # Iterate directly over the images tensor
        for image in images:  # Process each image in the loaded images tensor
            image = image.unsqueeze(0).to(device)  # Move single image to device
            pf, af = visual_extractor(image)  # Extract patch_feats, avg_feats
            print(f"{row['index']}, {row['pmc_id']} Patch Features Shape: {pf.shape}, Avg Features Shape: {af.shape}", end="\r")
            
            # Store the output temporarily, releasing earlier references
            report_patch_feats.append(pf.detach())
            report_avg_feats.append(af.detach())

        # Concatenate patch features and average the avg_feats
        concatenated_patch_feats = torch.cat(report_patch_feats, dim=1)  # Concatenate along feature dimension
        averaged_avg_feats = torch.mean(torch.stack(report_avg_feats), dim=0)  # Average the avg_feats

        # Combine avg_feats with patch_feats to create final embedding
        final_embedding = torch.cat((averaged_avg_feats, concatenated_patch_feats), dim=1)
        patch_feats.append(concatenated_patch_feats)
        avg_feats.append(averaged_avg_feats)
        final_embeddings.append(final_embedding)
        
    return torch.stack(patch_feats), torch.stack(avg_feats), torch.stack(final_embeddings)


# Functions to load and save features
def load_features(file_path):
    return torch.load(file_path)

def save_features(file_path, features):
    torch.save(features, file_path)

    
# Load data
report_data_path = os.path.join(output_directory, 'filtered_train_data.csv')
iu_xray_reports_df = pd.read_csv(report_data_path)
display(iu_xray_reports_df.head(10))


# Check if features are already saved; if not, extract and save them
if not os.path.exists(patch_feats_file) and os.path.exists(avg_feats_file) and os.path.exists(final_embeddings_file):
    print("All features are already precomputed and will be loaded.")
    patch_feats = load_features(patch_feats_file)
    avg_feats = load_features(avg_feats_file)
    final_embeddings = load_features(final_embeddings_file)
else:
    print("Extracting features since they are not precomputed...")
    
    visual_extractor.train()  
    
    patch_feats, avg_feats, final_embeddings = extract_features(iu_xray_reports_df)
    patch_feats = patch_feats.squeeze(1)  # Shape: (num_reports, 4096)
    avg_feats = avg_feats.squeeze(1)      # Shape: (num_reports, 2048)
    final_embeddings = final_embeddings.squeeze(1)  # Shape: (num_reports, 6144)

    save_features(patch_feats_file, patch_feats)
    save_features(avg_feats_file, avg_feats)
    save_features(final_embeddings_file, final_embeddings)

    
# Displaying the shapes of the feature dataframes
print("Patch Features Shape:", patch_feats.shape)
print("Average Features Shape:", avg_feats.shape)
print("Final Embedding Shape:", final_embeddings.shape)



Unnamed: 0,pmc_id,findings,image_1,image_2,index
0,755,none,CXR755_IM-2306-1001.jpg: PA and lateral views ...,CXR755_IM-2306-2001.jpg: PA and lateral views ...,1
1,2093,lungs clear cardiomediastinal silhouette withi...,CXR2093_IM-0723-1001.jpg: PA and lateral view...,CXR2093_IM-0723-2001.jpg: PA and lateral view...,2
2,1464,heart pulmonary mediastinum within normal limi...,CXR1464_IM-0301-1001.jpg: Xray Chest PA and La...,CXR1464_IM-0301-2001.jpg: Xray Chest PA and La...,3
3,2419,none,CXR2419_IM-0963-1001.jpg: Xray Chest PA and La...,CXR2419_IM-0963-2001.jpg: Xray Chest PA and La...,4
4,1233,enlarged cardiomediastinal silhouette low lung...,CXR1233_IM-0157-1001.jpg: PA and lateral views.,CXR1233_IM-0157-2001.jpg: PA and lateral views.,5
5,847,trachea midline heart slightly large low lung ...,CXR847_IM-2369-1001.jpg: CHEST 2V FRONTAL/LATE...,CXR847_IM-2369-1002.jpg: CHEST 2V FRONTAL/LATE...,6
6,2637,moderate marked enlargement cardiac silhouette...,CXR2637_IM-1122-1001.jpg: Xray Chest PA and La...,CXR2637_IM-1122-2001.jpg: Xray Chest PA and La...,8
7,2197,heart size normal parities appear right fissur...,"CXR2197_IM-0807-1001.jpg: Chest x-XXXX, 2 view...","CXR2197_IM-0807-2001.jpg: Chest x-XXXX, 2 view...",9
8,2532,stable enlargement cardiac silhouette lateral ...,CXR2532_IM-1046-1001.jpg: Xray Chest PA and La...,CXR2532_IM-1046-2001.jpg: Xray Chest PA and La...,10
9,3658,cardiomediastinal silhouette normal size conto...,CXR3658_IM-1819-1001.jpg: Xray Chest PA and La...,CXR3658_IM-1819-2001.jpg: Xray Chest PA and La...,11


Extracting features since they are not precomputed...
Patch Features Shape: torch.Size([2700, 4096]), 2048]), Avg Features Shape: torch.Size([1, 2048])
Average Features Shape: torch.Size([2700, 2048])
Final Embedding Shape: torch.Size([2700, 6144])


In [28]:
'''Displaying Tensor Shapes for All Embeddings'''

# Define paths for the saved features
patch_feats_file = os.path.join(output_directory, 'patch_feats.pt')
avg_feats_file = os.path.join(output_directory, 'avg_feats.pt')
final_embeddings_file = os.path.join(output_directory, 'final_embeddings.pt')


# Function to load features and print their shapes
def print_tensor_shapes():
    # Load patch features
    if os.path.exists(patch_feats_file):
        patch_feats = torch.load(patch_feats_file, map_location=torch.device('cpu'))
        print("Patch Features Shape:", patch_feats.shape)
    else:
        print(f"Patch features file not found: {patch_feats_file}")

    # Load average features
    if os.path.exists(avg_feats_file):
        avg_feats = torch.load(avg_feats_file, map_location=torch.device('cpu'))
        print("Average Features Shape:", avg_feats.shape)
    else:
        print(f"Average features file not found: {avg_feats_file}")

    # Load final embeddings
    if os.path.exists(final_embeddings_file):
        final_embeddings = torch.load(final_embeddings_file, map_location=torch.device('cpu'))
        print("Final Embedding Shape:", final_embeddings.shape)
    else:
        print(f"Final embeddings file not found: {final_embeddings_file}")

        
# Call the function to print shapes
print_tensor_shapes()

Patch Features Shape: torch.Size([2700, 4096])
Average Features Shape: torch.Size([2700, 2048])
Final Embedding Shape: torch.Size([2700, 6144])


  patch_feats = torch.load(patch_feats_file, map_location=torch.device('cpu'))
  avg_feats = torch.load(avg_feats_file, map_location=torch.device('cpu'))
  final_embeddings = torch.load(final_embeddings_file, map_location=torch.device('cpu'))


### **Text Encoder**

In [None]:
'''Text Encoder'''

# function to embed text
def embed_text(text_dataloader, model):
    all_embeddings = []
    
    try:
        for batch in text_dataloader:
            if isinstance(batch, str):
                inputs = tokenizer([batch], return_tensors="pt", padding=True, truncation=True).to(device)
            elif isinstance(batch, list):
                inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)
            else:
                raise ValueError("Batch must be of type str or List[str]")

            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            all_embeddings.append(embeddings)

        if all_embeddings: 
            return torch.cat(all_embeddings, dim=0)
        else:
            if isinstance(model, SentenceTransformer):
                return torch.empty(0, model.get_sentence_embedding_dimension()).to(device)
            else:
                return torch.empty(0, model.config.hidden_size).to(device)

    except Exception as e:
        print(f"Error in embedding: {e}")
        if isinstance(model, SentenceTransformer):
            return torch.empty(0, model.get_sentence_embedding_dimension()).to(device)
        else:
            return torch.empty(0, model.config.hidden_size).to(device)


# function to compute cosine similarity
def compute_cosine_similarity(embeddings1, embeddings2):
    embeddings1 = F.normalize(embeddings1, p=2, dim=1)
    embeddings2 = F.normalize(embeddings2, p=2, dim=1)
    cosine = torch.mm(embeddings1, embeddings2.t())
    return cosine


# defining device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# loading tokenizer, bert model and sentence-bert model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
sentence_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)

In [None]:
'''Text Encoder - Medical Knowledge encoder using TextEncoder'''

# defining radiology dictionary
radiology_dictionary = {
    "pleural": ['hemithorax', 'effusion', 'pneumothorax', 'parenchymal'],
    "lung": ['lungs', 'pulmonary', 'hilar', 'lobe', 'consolidation', 'atelectasis', 'edema', 'opacity', 'pneumonia'],
    "mediastinal": ['mediastinum', 'diaphragm', 'hemidiaphragm'],
    "cardiac": ['heart', 'cardiomegaly', 'cardiomediastinal', 'atrium', 'ventricle', 'retrocardiac'],
    "vascular": ['aorta', 'venous', 'jugular', 'aortic', 'vasculature', 'cabg'],
    "osseous": ['rib', 'sternal', 'subclavian', 'thoracic'],
    "trachea": ['endotrachea'],
    "stomach": [],
    "abdomen": [],
    "tube": ['clips'],
    "spine": ['vertebral', 'degenerative'],
    "nodule": ['mass'],
    "chest": ['small', 'enlarged', 'unchanged', 'stable', 'silhouette', 'contours', 'size', 'focal', 'mild', 'acute']
}


# define the filename and full path
filename = 'radiology_terms.csv'
dictionary_csv = os.path.join(output_directory, filename) 


# Open the CSV file in write mode using the full path
if not os.path.exists(dictionary_csv):    
    with open(dictionary_csv, 'w', newline='') as file:  
        writer = csv.writer(file)  
        writer.writerow(['Category', 'Term']) 
    
        for category, terms in radiology_dictionary.items():
            for term in terms:
                if term: 
                    writer.writerow([category, term]) 
                    
    print(f"Dictionary saved to {dictionary_csv}")  
else:     
    print(f"File already exists at {dictionary_csv}. No changes made.")

text_list = []
with open(dictionary_csv, 'r') as file:
    reader = csv.reader(file)
    next(reader) 
    for row in reader:
        text_list.append(row[1]) 

        
# load the SentenceTransformer model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Generate embeddings
embeddings = model.encode(text_list)

# Display the shape of the embeddings
print(f"Embeddings Shape: {embeddings.shape}")

dictionary_pt = os.path.join(output_directory, 'dictionary_embeddings.pt')
if not os.path.exists(dictionary_pt):    
    torch.save(embeddings, dictionary_pt)
    print(f"Saved at : {dictionary_pt}")
else : print(f"Already saved at path : {dictionary_pt}")

    
# # embedding dictionary and reports using bert, and historical medical reports using sentence transformer
# dictionary_dataloader = load_preprocessed_texts(radiology_dictionary, tokenizer)
# dictionary_embeddings = embed_text(dictionary_dataloader, bert_model)


# # saving embeddings
# embeddings_file_path = os.path.join(output_directory, 'dictionary_embeddings.pt')
# if not os.path.exists(embeddings_file_path):
#     torch.save(dictionary_embeddings.cpu(), embeddings_file_path)
#     print(f"Dictionary embeddings saved to {embeddings_file_path}")


# # displaying shape of embeddings
# print(f"Dictionary Embeddings Shape: {dictionary_embeddings.shape}")

In [None]:
# '''Text Encoder - Medical History encoding using sentence encoder'''

# # reading and preprocessing medical history reports
# iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_df.csv')
# medical_history = pd.read_csv(iu_xray_reports_preprocessed_df_path)["findings"].dropna().tolist()
# medical_history = [str(report) for report in medical_history if isinstance(report, str) or pd.notna(report)]

# batch_size = 32  # Set your desired batch size
# tokenizer = AutoTokenizer.from_pretrained('all-MiniLM-L6-v2')  # Ensure you have your tokenizer defined

# # Load the data into a DataLoader
# medical_reports_dataloader = load_preprocessed_texts(medical_history, tokenizer, batch_size)

# # Step 3: Print the shapes of the batches
# for batch in medical_reports_dataloader:
#     # Printing shapes of input tensors
#     print({key: tensor.shape for key, tensor in batch.items()})
#     break  # Remove this break to see all batches

In [None]:
# '''Text Encoder -  Sentence-Bert Encoder using Medical History'''

# # reading and preprocessing medical history reports
# iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_df.csv')
# medical_history = pd.read_csv(iu_xray_reports_preprocessed_df_path)["findings"].dropna().tolist()
# medical_history = [str(report) for report in medical_history if isinstance(report, str) or pd.notna(report)]

# invalid_entries = [report for report in medical_history if not isinstance(report, (str, list))]

# # Print if there are any invalid entries
# if invalid_entries:
#     print(f"Found {len(invalid_entries)} invalid entries: {invalid_entries}")
# else:
#     print("No invalid entries found.")

# # encoding medical history using Sentence-BERT
# bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# historical_dataloader_b = load_preprocessed_texts(medical_history, tokenizer)
# historical_embeddings_b = embed_text(historical_dataloader, bert_model)

# # # Check for invalid entries after batching
# # invalid_entries_after_batches = [
# #     report for batch in historical_dataloader_b for report in batch.values() if not isinstance(report, (str, list))
# # ]

# # for batch in historical_dataloader:
# #     print(type(batch))  # Check the type
# #     print(batch)  # Print the batch content
# #     break



# # saving embeddings
# historical_pt_b = os.path.join(output_directory, 'historical_embeddings_b.pt')
# if not os.path.exists(historical_pt_b):
#     torch.save(historical_embeddings_b.cpu(), historical_pt_b)
#     print(f"Historical embeddings saved to {historical_pt_b}")
# else : print(f"Already at {historical_pt_b}")


# # displaying shape of embeddings
# print(f"Historical Embeddings Shape: {historical_embeddings_b.shape}")

In [None]:
# '''Text Encoder - Finding Reports Similar to Current Report'''

# # current report embedding
# current_report = 'Heart size pulmonary vascularity appear within normal limits mild tortuosity descending thoracic aorta lungs free focal airspace disease pleural effusion pneumothorax seen discrete nodules adenopathy noted degenerative changes present spine'
# current_embeddings = embed_text(current_report, sentence_model)


# # computing cosine similarity
# historical_embeddings_normalized = historical_embeddings / historical_embeddings.norm(dim=1, keepdim=True)
# current_embeddings_normalized = current_embeddings / current_embeddings.norm(dim=1, keepdim=True)

# similarity_matrix = compute_cosine_similarity(dictionary_embeddings, historical_embeddings)
# print(f"Similarity Matrix Shape: {similarity_matrix.shape}")


# # finding top-k relevant entries
# k = 5
# top_k_indices = similarity_matrix.topk(k=k, dim=1).indices


# # preparing relevant entries based on indices
# relevant_entries = []
# for row in top_k_indices:
#     relevant_entries.append(medical_history[row.item()])


# # printing relevant entries for each report
# print(f"Relevant Entries for the Current Report: {relevant_entries}")


# # update the historical embeddingx
# historical_embeddings.append(current_embeddings)

In [None]:
# import os
# import pandas as pd
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import DataLoader, Dataset
# from transformers import BertTokenizer, BertModel
# import pickle

# # Check if GPU is available and set the device accordingly
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Function to convert findings to list of lists
# def get_findings_list(df):
#     findings_list = df['findings'].fillna("").tolist()  # Replace NaN with empty string
#     return [[finding] for finding in findings_list]  # Convert each finding into a list of strings

# # Dataset class to load the findings from the reports
# class FindingsDataset(Dataset):
#     def __init__(self, findings_list):
#         self.findings_list = findings_list

#     def __len__(self):
#         return len(self.findings_list)

#     def __getitem__(self, idx):
#         return self.findings_list[idx]

# # Define the Sentence Encoder (trainable)
# class SentenceEncoder(nn.Module):
#     def __init__(self, hidden_size=768):
#         super(SentenceEncoder, self).__init__()
#         self.encoder = nn.Linear(hidden_size, hidden_size)
    
#     def forward(self, x):
#         return self.encoder(x)

# # Cosine similarity loss based on the image equation
# def cosine_similarity_loss(H, H_b):
#     cos_sim_H = F.cosine_similarity(H, H.unsqueeze(1))
#     cos_sim_H_b = F.cosine_similarity(H_b, H_b.unsqueeze(1))
#     loss = torch.mean((cos_sim_H_b - cos_sim_H) ** 2)
#     return loss

# # Load pre-trained BERT model and tokenizer
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# bert_model = BertModel.from_pretrained('bert-base-uncased')

# # Move the model to the appropriate device
# bert_model.to(device)

# # Load your dataframe
# # iu_xray_reports_preprocessed_df_path = '/kaggle/input/preprocessed-text/iu_xray_reports_df_preprocessed.csv'
# iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_df.csv')
# medical_history = pd.read_csv(iu_xray_reports_preprocessed_df_path)
# findings_list = get_findings_list(medical_history)

# # Custom collate function to tokenize the findings in batches
# def collate_fn(batch):
#     batch = [item[0] for item in batch]  # Flatten the batch list
#     return tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)  # Move to GPU

# # Create dataset and dataloader
# findings_dataset = FindingsDataset(findings_list)
# dataloader = DataLoader(findings_dataset, batch_size=16, collate_fn=collate_fn)

# # Initialize Sentence Encoder and move to device
# sentence_encoder = SentenceEncoder().to(device)

# # Historical Knowledge Storage
# #historical_knowledge_file = 'historical_knowledge.pkl'  # File to save the historical encodings
# historical_knowledge = []  # List to store historical encodings

# historical_knowledge_file = os.path.join(output_directory, 'historical_knowledge.pkl')
# # if not os.path.exists(historical_pt_b):
# #     torch.save(historical_embeddings_b.cpu(), historical_pt_b)
# #     print(f"Historical embeddings saved to {historical_pt_b}")

# # Load historical knowledge if it exists
# if os.path.exists(historical_knowledge_file):
#     with open(historical_knowledge_file, 'rb') as f:
#         historical_knowledge = pickle.load(f)

# # Optimizer
# optimizer = torch.optim.Adam(sentence_encoder.parameters(), lr=1e-4)
# num_epochs = 10  # Set the number of epochs

# # Training loop
# for epoch in range(num_epochs):
#     for batch in dataloader:
#         input_ids = batch['input_ids'].to(device)  # Move input_ids to GPU
#         attention_mask = batch['attention_mask'].to(device)  # Move attention_mask to GPU

#         # Get historical encodings from BERT
#         with torch.no_grad():
#             bert_output = bert_model(input_ids=input_ids, attention_mask=attention_mask)
#             H_b = bert_output.last_hidden_state.mean(dim=1).to(device)  # Average pooling of BERT embeddings and move to GPU

#         # Get encodings from the trainable Sentence Encoder
#         H = sentence_encoder(H_b)  # Feeding the pre-trained embeddings to the trainable encoder

#         # final_H = H.detach().cpu()
#         # Store the current encodings in historical knowledge
#         historical_knowledge.append(H.detach().cpu())  # Detach tensor and move to CPU

#         # Calculate cosine similarity loss
#         loss = cosine_similarity_loss(H, H_b)

#         # Backpropagation and optimization
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#     print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# # final_H_file = 'final_H.pkl'
# # with open(final_H_file, 'wb') as f:
# #     pickle.dump(final_H.numpy(), f)


# # Now you can utilize historical_knowledge for further inference or report generation.
# with open(historical_knowledge_file, 'wb') as f:
#     # Convert tensors to NumPy arrays before saving
#     historical_knowledge_np = [h.numpy() for h in historical_knowledge]
#     pickle.dump(historical_knowledge_np, f)


### **Sentence Encoder**

In [None]:
'''Sentence Encoder - Classes'''

# class for Sentence Encoder
class SentenceEncoder(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=512, output_dim=512):
        super(SentenceEncoder, self).__init__()
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Networks to generate mean and variance as mentioned in paper
        self.mean_layer = nn.Linear(hidden_dim, output_dim)
        self.logvar_layer = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        encoded = self.encoder(x)
        mean = self.mean_layer(encoded)
        logvar = self.logvar_layer(encoded)
        return mean, logvar
        
    def encode(self, x):
        """Get only the mean for inference"""
        mean, _ = self.forward(x)
        return mean


# class for Sentence BERT
class SentenceBERT:
    def __init__(self):
        # Initialize the BERT tokenizer and model
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

    def encode_reports(self, reports):
        embeddings = []
        for report in reports:
            # Tokenize and encode the report
            inputs = self.tokenizer(report, return_tensors='pt', padding=True, truncation=True)

            # Get BERT outputs
            with torch.no_grad():  # Disable gradient calculation for inference
                outputs = self.bert_model(**inputs)
                
            # Use the pooled output as the embedding
            embeddings.append(outputs.pooler_output)

        # Stack all embeddings into a single tensor
        return torch.stack(embeddings)

In [None]:
'''Sentence Encoder - Generating Sentence Embeddings'''

# initialising paths
input_train_path = os.path.join(output_directory, 'filtered_train_data.csv')
input_test_path = os.path.join(output_directory, 'filtered_test_data.csv')


# reading csv files
train_df = pd.read_csv(input_train_path)
test_df = pd.read_csv(input_test_path)


# initialize models
sentence_bert = SentenceBERT()
sentence_encoder = SentenceEncoder()


# Parameters for batching
df = train_df
batch_size = 64
num_batches = (len(df) + batch_size - 1) // batch_size  


# Initialize matrices to store embeddings
sbert_matrix = []
sentence_encoder_matrix = []


# Loop through batches with progress bar
for i in tqdm(range(num_batches), desc="Processing Batches", unit="batch"):
    batch_reports = df['findings'].iloc[i * batch_size:(i + 1) * batch_size].tolist()
    
    # Compute SBERT embeddings
    sbert_embeddings = sentence_bert.encode_reports(batch_reports)
    
    # Compute Sentence Encoder embeddings
    sentence_embeddings = sentence_encoder.encode(sbert_embeddings)
    
    # Store embeddings in the matrices
    sbert_matrix.append(sbert_embeddings)
    sentence_encoder_matrix.append(sentence_embeddings)


# Stack the matrices to create final output tensors
sbert_matrix = torch.cat(sbert_matrix, dim=0).squeeze(1)
sentence_encoder_matrix = torch.cat(sentence_encoder_matrix, dim=0).squeeze(1)


# Print shapes of the final matrices
print(f"SBERT Matrix Shape: {sbert_matrix.shape}")
print(f"Sentence Encoder Matrix Shape: {sentence_encoder_matrix.shape}")


# define file paths for each embedding
embeddings_dir = output_directory
sentence_encoder_embeddings_path = os.path.join(embeddings_dir, "sentence_encoder_embeddings.pt")
sbert_embeddings_path = os.path.join(embeddings_dir, "sbert_embeddings.pt")

# save the embeddings
torch.save(sentence_encoder_matrix, sentence_encoder_embeddings_path)
torch.save(sbert_matrix, sbert_embeddings_path)
print("Embeddings saved successfully.")


# load the embeddings
loaded_sentence_encoder_matrix = torch.load(sentence_encoder_embeddings_path)
loaded_sbert_matrix = torch.load(sbert_embeddings_path)
print("Embeddings loaded successfully.")

In [None]:
'''Sentence Encoder - Training the Sentence Encoder using Sentence BERT'''

# function to compute kl divergence loss
def kl_divergence_loss(mean1: torch.Tensor, logvar1: torch.Tensor, mean2: torch.Tensor, logvar2: torch.Tensor) -> torch.Tensor:
    normal1 = dist.Normal(mean1, torch.exp(0.5 * logvar1))  # Standard deviation is sqrt of variance
    normal2 = dist.Normal(mean2, torch.exp(0.5 * logvar2))
    kl_loss = dist.kl.kl_divergence(normal1, normal2).mean()
    return kl_loss


# function to train the model based on embeddings
def train_embeddings(sentence_encoder, projection_layer, sentence_encoder_matrix, sbert_matrix, 
                    optimizer, batch_size, num_epochs):
    """
    Training loop with explicit memory management and graph cleanup
    """
    device = next(sentence_encoder.parameters()).device
    
    for epoch in range(num_epochs):
        total_loss = 0
        sentence_encoder.train()
        projection_layer.train()
        
        n_batches = sentence_encoder_matrix.size(0) // batch_size
        
        with tqdm(total=n_batches, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
            for i in range(0, sentence_encoder_matrix.size(0), batch_size):
                try:
                    # Clear gradients
                    optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
                    
                    # Get batch data with detached copies
                    batch_embedding1 = sentence_encoder_matrix[i:i+batch_size].to(device).detach().clone()
                    batch_embedding2 = sbert_matrix[i:i+batch_size].to(device).detach().clone()
                    
                    # Project SBERT embeddings
                    with torch.set_grad_enabled(True):
                        batch_embedding2_projected = projection_layer(batch_embedding2)
                        
                        # Compute means and variances
                        mean1 = batch_embedding1.mean(dim=0, keepdim=True)
                        var1 = batch_embedding1.var(dim=0, keepdim=True, unbiased=False)
                        mean2 = batch_embedding2_projected.mean(dim=0, keepdim=True)
                        var2 = batch_embedding2_projected.var(dim=0, keepdim=True, unbiased=False)
                        
                        # KL divergence loss
                        kl_loss = kl_divergence_loss(mean1, var1, mean2, var2)
                        
                        # Similarity loss
                        similarities = F.cosine_similarity(batch_embedding1, batch_embedding2_projected, dim=1)
                        sim_loss = 1 - similarities.mean()
                        
                        # Combined loss
                        loss = kl_loss + sim_loss
                    
                    # Backward pass with no graph retention
                    loss.backward(retain_graph=False)
                    
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(sentence_encoder.parameters(), max_norm=1.0)
                    torch.nn.utils.clip_grad_norm_(projection_layer.parameters(), max_norm=1.0)
                    
                    # Optimizer step
                    optimizer.step()
                    
                    # Update metrics
                    current_loss = loss.item()
                    total_loss += current_loss
                    
                    # Clean up tensors explicitly
                    del batch_embedding1, batch_embedding2, batch_embedding2_projected
                    del mean1, var1, mean2, var2
                    del kl_loss, similarities, sim_loss, loss
                    torch.cuda.empty_cache() if torch.cuda.is_available() else None
                    
                    # Update progress bar
                    pbar.update(1)
                    pbar.set_postfix(loss=f"{current_loss:.4f}")
                    
                except RuntimeError as e:
                    print(f"\nError in batch {i//batch_size}:")
                    print(str(e))
                    # Try to recover
                    torch.cuda.empty_cache() if torch.cuda.is_available() else None
                    continue
        
        # Compute average loss for the epoch
        avg_loss = total_loss / n_batches
        print(f"\nEpoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")


# Set up training
batch_size = 64
num_epochs = 20


# Initialize optimizer with parameters from both models
optimizer = optim.Adam([
    {'params': sentence_encoder.parameters()},
    {'params': projection_layer.parameters()}
], lr=1e-3)


# Move models to the same device if using GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    sentence_encoder = sentence_encoder.to(device)
    projection_layer = projection_layer.to(device)
    sentence_encoder_matrix = sentence_encoder_matrix.to(device)
    sbert_matrix = sbert_matrix.to(device)
else:
    device = torch.device('cpu')


# Run training
train_embeddings(
    sentence_encoder=sentence_encoder,
    projection_layer=projection_layer,
    sentence_encoder_matrix=sentence_encoder_matrix,
    sbert_matrix=sbert_matrix,
    optimizer=optimizer,
    batch_size=batch_size,
    num_epochs=num_epochs
)

In [None]:
'''Sentence Encoder - Compute Similarity Scores'''

# function to compute cosine similarity
def compute_cosine_similarities(batch_embeddings: torch.Tensor, all_embeddings: torch.Tensor) -> torch.Tensor:
    # Normalize embeddings
    batch_norm = F.normalize(batch_embeddings, p=2, dim=1)
    all_norm = F.normalize(all_embeddings, p=2, dim=1)
    
    # Compute similarities
    similarities = torch.mm(batch_norm, all_norm.t())
    return similarities


# Compute similarity matrices
sbert_similarities = compute_cosine_similarities(sbert_matrix, sbert_matrix)
sentence_encoder_similarities = compute_cosine_similarities(sentence_encoder_matrix, sentence_encoder_matrix)


# Print similarity matrix shapes
print(f"SBERT Similarity Matrix Shape: {sbert_similarities.shape}")  # Should be [2705, 2705]
print(f"Sentence Encoder Similarity Matrix Shape: {sentence_encoder_similarities.shape}")  # Should be [2705, 2705]

In [None]:
'''Sentence Encoder - Saving/Loading our Model'''

# define model path
model_path = os.path.join(output_directory, "model.pt")  


# save `sentence_encoder` model instance
torch.save(sentence_encoder.state_dict(), model_path)
print(f"Model saved to {model_path}")


# load model instance
sentence_encoder = SentenceEncoder()
sentence_encoder.load_state_dict(torch.load(model_path))
sentence_encoder.eval()
print("Model loaded successfully.")

In [None]:
'''Sentence Encoder - Finding Top K Similar Reports for each Report'''

# function to compute similarity score
def compute_similarity_score(sim1, sim2):
    return torch.mean((sim1 - sim2) ** 2)


# function to find top k similar reports
def find_top_k_similar_reports(custom_similarities, sbert_similarities, k):
    num_reports = custom_similarities.shape[0]
    
    similarity_scores = []

    for i in range(num_reports):
        for j in range(i + 1, num_reports):
            custom_score = compute_similarity_score(custom_similarities[i], custom_similarities[j])
            sbert_score = compute_similarity_score(sbert_similarities[i], sbert_similarities[j])
            combined_score = (custom_score**2 - sbert_score**2)

            similarity_scores.append((combined_score.item(), i, j))

    similarity_scores.sort(key=lambda x: x[0], reverse=True)

    return similarity_scores[:k]


# find top k similar reports for each report 
k_most_similar_reports = find_top_k_similar_reports(sentence_encoder_similarities, sbert_similarities, 7)


# store and display the results
ans = []
for report_id, data in k_most_similar_reports.items():
    ans.append(data['most_similar_reports'])
    print(data['most_similar_reports'])

In [None]:
'''Sentence Encoder - Finding Top N Similar Indices'''

# convert to numpy array
ans = np.array(ans)
print(type(ans))


# function to return n similar indexes
def return_n_sim_indexes(a, n):
    indexes = []
    flattened_array = a.flatten()
    count = Counter(flattened_array)
    most_common_numbers = count.most_common(n)
    indexes = [num for num, _ in most_common_numbers]
    
    return indexes


# display n similar indexes
print(return_n_sim_indexes(ans, 7))

### **Multilevel Alignment**

In [None]:
'''Multilevel Alignment based on BLIP Architecture'''

# function to extract image embeddings of a given file name
def extract_image_embeddings(image_name):
    image_embeddings = torch.load(output_directory + "image_embeddings.pt")
    if image_name in image_embeddings:
        return image_embeddings[image_name]
    else:
        return None
    

# function to extract text embeddings of a given text
def extract_historical_text_embeddings(text):
    historical_text_embeddings = torch.load(output_directory + "historical_embeddings.pt")
    if text in historical_text_embeddings:
        return historical_text_embeddings[text]
    else :
        return None

def extract_dictionary_text_embeddings(text):
    dictionary_text_embeddings = torch.load(output_directory + "dictionary_embeddings.pt")
    if text in dictionary_text_embeddings:
        return dictionary_text_embeddings[text]
    else :
        return None
    

# function to find cosine similarity between a text and an image 
def cosine_similarity(embedding1, embedding2):
    embedding1 = np.array(embedding1)
    embedding2 = np.array(embedding2)

    dot_product = np.dot(embedding1, embedding2)
    norm_embedding1 = np.linalg.norm(embedding1)
    norm_embedding2 = np.linalg.norm(embedding2)
    
    if norm_embedding1 == 0 or norm_embedding2 == 0:
        return None 
    
    return dot_product / (norm_embedding1 * norm_embedding2)
    

# function to compute Image-Text-Contrastive Loss
def batch_itc_loss(image_embeddings, text_embeddings, temperature=0.1):
    image_embeddings = F.normalize(image_embeddings, dim=1)
    text_embeddings = F.normalize(text_embeddings, dim=1)

    similarity_matrix = torch.matmul(image_embeddings, text_embeddings.T) / temperature

    labels = torch.arange(len(image_embeddings), device=image_embeddings.device)

    loss = F.cross_entropy(similarity_matrix, labels)
    return labels, loss


# function to compute Image-Text-Matching loss
def batch_itm_loss(image_embeddings, text_embeddings, match_labels):
    logits = torch.matmul(image_embeddings, text_embeddings.t())
    probabilities = torch.sigmoid(logits) 
    positive_probs = probabilities[torch.arange(len(match_labels)), match_labels]

    loss = -torch.log(positive_probs + 1e-12).mean()
    return loss


# function to get embeddings()
def get_embeddings(file_path):
    df = pd.read_csv(file_path)
    report_embeddings = []
    image_embeddings = []

    for _, row in df.iterrows():
        report_embedding = extract_historical_text_embeddings(row['findings'])

        for i in range(1, 6):
            image_col = f'image_{i}'
            if image_col in row and pd.notnull(row[image_col]):
                image_embedding = extract_image_embeddings(row[image_col])
                report_embeddings.append(report_embedding)
                image_embeddings.append(image_embedding)

    return report_embeddings, image_embeddings


# function to do training step
def train_step(file, optimizer):
    blip_model.train()
    text_embeddings, image_embeddings = get_embeddings(file_path)
    
    match_labels, itc_loss_value = batch_itc_loss(image_embeddings, text_embeddings)
    itm_loss_value = batch_itm_loss(image_embeddings, text_embeddings, match_labels)
    
    total_loss = itc_loss_value + itm_loss_value
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    return total_loss.item()

In [None]:
'''Multilevel Alignment based on BLIP Architecture - Coarse Grained Alignment (Image Text Contrastive and Image Text Matching)'''

# BLIP architecture
processor = BlipProcessor.from_pretrained("Salesforce/blip-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
blip_model.to(device)
blip_model.eval()


# training loop
num_epochs = 10  
optimizer = torch.optim.Adam(blip_model.parameters(), lr=1e-5)


# using a data loader defined that provides (image, text) pairs
for epoch in range(num_epochs):
    loss = train_step(file_path, optimizer)
    print(f"Epoch {epoch}, Loss: {loss}")

In [None]:
'''Multilevel Alignment based on BLIP Architecture - Fine Grained Alignment'''

In [None]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist


output_aligned_features_dir = output_directory
if(os.path.exists(output_aligned_features_dir + "aligned_outputs.pt")):
    print(f"Already at {historical_pt_b}")

batch_size = 32 

final_embeddings = torch.load(output_directory + "/final_embeddings.pt")  # Shape: [7470, 50, 2048]
#patch_features = torch.load(output_directory + 'patch_feats.pt')          # Shape: [7470, 49, 2048]
dictionary_embeddings = torch.load(output_directory + "/dictionary_embeddings.pt")  # Shape: [47, 384]
dictionary_embeddings = torch.tensor(dictionary_embeddings)

# Project dictionary embeddings (V) to match the dimensionality of the image embeddings (2048)
projection_layer = nn.Linear(384, 2048)
V_projected = projection_layer(dictionary_embeddings)  # Shape: [47, 2048]
# V_projected = F.normalize(V_projected, p=2, dim=1)

# Normalize final embeddings (this is already being done)
# I_prime = F.normalize(final_embeddings, p=2, dim=1)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # weight matrices for query, key, value
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        # Compute QK^T / sqrt(d_k)
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k).float())
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

    def forward(self, V, I_prime):
        # V:  projected dictionary embeddings
        # I_prime: final_embeddings
        
        Q = self.W_q(V)  # Dictionary embeddings
        K = self.W_k(I_prime)  # Image embeddings 
        V = self.W_v(I_prime)

        Q = Q.view(Q.size(0), Q.size(1), self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(K.size(0), K.size(1), self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(V.size(0), V.size(1), self.num_heads, self.d_k).transpose(1, 2)
       
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V) # apply scaled dot-product attention

        attn_output = attn_output.transpose(1, 2).contiguous().view(V.size(0), -1, self.d_model) # concatenate heads and apply final linear transformation
        output = self.W_o(attn_output)
        
        return output, attn_weights

# Feed-forward network
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# KL Divergence Loss (for training)
# def kl_divergence(mu1, logvar1, mu2, logvar2):
#     kl_loss = 0.5 * torch.sum(logvar2 - logvar1 + (torch.exp(logvar1) + (mu1 - mu2)**2) / torch.exp(logvar2) - 1)
#     return kl_loss
def kl_divergence(mu1, logvar1, mu2, logvar2):
    # kl_loss = 0.5 * torch.sum(
    #     logvar2 - logvar1 +
    #     (torch.exp(logvar1) + (mu1 - mu2) ** 2) / torch.exp(logvar2) - 1
    # )
    # def compute_kl_loss(mu1, logvar1, mu2, logvar2):
    # Create normal distributions from the means (mu) and log-variances (logvar)
    normal1 = dist.Normal(mu1, torch.exp(0.5 * logvar1))  # exp(0.5 * logvar) gives standard deviation
    normal2 = dist.Normal(mu2, torch.exp(0.5 * logvar2))
    
    # Compute the KL divergence between the two distributions
    kl_loss = dist.kl.kl_divergence(normal1, normal2).mean()  # Mean over the batch
    return kl_loss

    # return kl_loss
    
# Model setup
d_model = 2048
num_heads = 8
d_ff = 4096  

mha = MultiHeadAttention(d_model, num_heads)
ffn = FeedForwardNetwork(d_model, d_ff)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mha = mha.to(device)
ffn = ffn.to(device)
V_projected = V_projected.to(device)
final_embeddings = final_embeddings.to(device)
#patch_features = patch_features.to(device)

dataset_size = final_embeddings.size(0)  # 7470
aligned_outputs = []

params = list(mha.parameters()) + list(ffn.parameters())  
optimizer = torch.optim.Adam(params, lr=1e-4)  

for i in range(0, dataset_size, batch_size):
    
    batch_final_embeddings = final_embeddings[i:i+batch_size]  # Shape: [batch_size, 50, 2048]
    #batch_patch_features = patch_features[i:i+batch_size]  # Shape: [batch_size, 49, 2048]

    #batch_I_prime = torch.cat((batch_final_embeddings, batch_patch_features), dim=1)  # Shape: [batch_size, 99, 2048]
    batch_I_prime = batch_final_embeddings
    batch_V_repeated = V_projected.unsqueeze(0).repeat(batch_final_embeddings.size(0), 1, 1)  # Shape: [batch_size, 47, 2048]

    batch_I_prime = batch_I_prime.to(device)
    batch_V_repeated = batch_V_repeated.to(device)

    aligned_output, _ = mha(batch_V_repeated, batch_I_prime) #multi-head attention with dictionary embeddings and image embeddings

    aligned_output_ffn = ffn(aligned_output) #feed-forward network

    mu1, logvar1 = torch.randn_like(aligned_output_ffn,requires_grad=True), torch.randn_like(aligned_output_ffn,requires_grad=True)  # Priors for V'
    mu2, logvar2 = torch.randn_like(batch_V_repeated,requires_grad=True), torch.randn_like(batch_V_repeated,requires_grad=True)  # Priors for V_label
    kl_loss = kl_divergence(mu1, logvar1, mu2, logvar2) # KL divergence

    print(f"KL Loss for batch {i // batch_size}: {kl_loss.item()}")

    # If training, accumulate gradients and perform optimization steps
    optimizer.zero_grad()
    kl_loss.backward()
    optimizer.step()
    
    aligned_outputs.append(aligned_output_ffn.cpu()) 

#single tensor
aligned_outputs_tensor = torch.cat(aligned_outputs, dim=0)  # Shape: [7470, 50, 2048]

torch.save(aligned_outputs_tensor, os.path.join(output_aligned_features_dir, "aligned_outputs.pt"))

print("Aligned features saved successfully.")

    


In [None]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from transformers import BertModel, BertTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TextEncoder(nn.Module):
    def __init__(self, bert_model="bert-base-uncased", output_dim=384):
        super(TextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)
        self.projection = nn.Linear(self.bert.config.hidden_size, output_dim)
        
    def encode_dictionary(self, dictionary):
        """
        Encodes dictionary entries using BERT and combines key-value pairs
        dictionary: Dict with medical terms as keys and list of related terms as values
        Returns: Tensor of shape [num_entries, output_dim]
        """
        encoded_entries = []
        
        for key, values in dictionary.items():
            # Combine key with its values into a single text
            if values:  # If values list is not empty
                text = key + ": " + ", ".join(values)
            else:
                text = key
                
            # Tokenize and encode
            inputs = self.tokenizer(text, 
                                  padding=True, 
                                  truncation=True, 
                                  return_tensors="pt")
            
            with torch.no_grad():
                outputs = self.bert(**inputs)
                # Use [CLS] token embedding
                embedding = outputs.last_hidden_state[:, 0, :]
                
            # Project to desired dimension
            projected = self.projection(embedding)
            encoded_entries.append(projected)
            
        return torch.cat(encoded_entries, dim=0)

# Your original alignment code
output_aligned_features_dir = output_directory
if(os.path.exists(output_aligned_features_dir + "aligned_outputs.pt")):
    print(f"Already at {historical_pt_b}")

batch_size = 32 

# Load the medical dictionary
medical_dict = {
    "pleural": ["hemithorax", "effusion", "pneumothorax", "parenchymal"],
    "lung": ["lungs", "pulmonary", "hilar", "lobe", "consolidation", 
             "atelectasis", "edema", "opacity", "pneumonia"],
    "mediastinal": ["mediastinum", "diaphragm", "hemidiaphragm"],
    "cardiac": ["heart", "cardiomegaly", "cardiomediastinal", "atrium",
                "ventricle", "retrocardiac"],
    "vascular": ["aorta", "venous", "jugular", "aortic", "vasculature", "cabg"],
    "osseous": ["rib", "sternal", "subclavian", "thoracic"],
    "trachea": ["endotrachea"],
    "stomach": [],
    "abdomen": [],
    "tube": ["clips"],
    "spine": ["vertebral", "degenerative"],
    "nodule": ["mass"],
    "chest": ["small", "enlarged", "unchanged", "stable", "silhouette",
              "contours", "size", "focal", "mild", "acute"]
}

# Initialize text encoder and encode dictionary
text_encoder = TextEncoder()
dictionary_embeddings = text_encoder.encode_dictionary(medical_dict)

# final_embeddings = torch.load(output_directory + "/final_embeddings.pt")
final_embeddings = torch.load(output_directory + "/final_embeddings.pt", map_location=torch.device('cpu'))# Shape: [7470, 50, 2048]
dictionary_embeddings = dictionary_embeddings.to(torch.float32)  # Convert to float32

# Project dictionary embeddings (V) to match the dimensionality of the image embeddings (2048)
projection_layer = nn.Linear(384, 6144)
V_projected = projection_layer(dictionary_embeddings)  # Shape: [47, 2048]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # weight matrices for query, key, value
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        # Compute QK^T / sqrt(d_k)
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k).float())
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

    def forward(self, V, I_prime):
        # V:  projected dictionary embeddings
        # I_prime: final_embeddings
        # print("Shape of I_prime:", I_prime.shape)
        I_prime = I_prime.unsqueeze(1) 
        Q = self.W_q(V)  # Dictionary embeddings
        K = self.W_k(I_prime)  # Image embeddings 
        V = self.W_v(I_prime)
        batch_size, seq_len, d_model = K.size()
        Q = Q.view(Q.size(0), Q.size(1), self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(K.size(0), K.size(1), self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(V.size(0), V.size(1), self.num_heads, self.d_k).transpose(1, 2)
       
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(V.size(0), -1, self.d_model)
        output = self.W_o(attn_output)
        
        return output, attn_weights

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def kl_divergence(mu1, logvar1, mu2, logvar2):
    normal1 = dist.Normal(mu1, torch.exp(0.5 * logvar1))
    normal2 = dist.Normal(mu2, torch.exp(0.5 * logvar2))
    kl_loss = dist.kl.kl_divergence(normal1, normal2).mean()
    return kl_loss

class Piror(nn.Module):
    """Fully connected layer to convert encodings to mean and variance"""
    def __init__(self, input_dim=6144, hidden_dim=512):
        super(Piror, self).__init__()
        self.fc_mu = nn.Linear(input_dim, hidden_dim)
        self.fc_var = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x):
        # Generate mean and log variance
        mu = self.fc_mu(x)
        logvar = self.fc_var(x)
        return mu, logvar

# def compute_kl_loss(mu1, logvar1, mu2, logvar2):
#     """
#     Compute KL divergence between two normal distributions N1(mu1, sigma1) and N2(mu2, sigma2)
#     as per equation (12) in the paper
#     """
#     # Convert log variance to variance
#     var1 = torch.exp(logvar1)
#     var2 = torch.exp(logvar2)
    
#     # Compute KL divergence according to equation (12)
#     kl_div = 0.5 * torch.sum(
#         logvar2 - logvar1 + 
#         (var1 + (mu1 - mu2).pow(2)) / var2 - 1
#     )
    
#     return kl_div    
# Model setup
d_model = 6144
num_heads = 8
d_ff = 4096  

mha = MultiHeadAttention(d_model, num_heads)
ffn = FeedForwardNetwork(d_model, d_ff)
piror = Piror(d_model)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mha = mha.to(device)
ffn = ffn.to(device)
V_projected = V_projected.to(device)
final_embeddings = final_embeddings.to(device)
piror = piror.to(device)
dataset_size = final_embeddings.size(0)
aligned_outputs = []

params = list(mha.parameters()) + list(ffn.parameters())  
optimizer = torch.optim.Adam(params, lr=1e-4)  

for i in range(0, dataset_size, batch_size):
    optimizer.zero_grad()
    batch_final_embeddings = final_embeddings[i:i+batch_size]
    batch_I_prime = batch_final_embeddings
    batch_V_repeated = V_projected.unsqueeze(0).repeat(batch_final_embeddings.size(0), 1, 1)

    batch_I_prime = batch_I_prime.to(device)
    batch_V_repeated = batch_V_repeated.to(device)

    aligned_output, _ = mha(batch_V_repeated, batch_I_prime)
    aligned_output_ffn = ffn(aligned_output)
    V_prime = aligned_output_ffn
    V_label = batch_V_repeated 
    mu1, logvar1 = piror(V_prime)
    mu2, logvar2 = piror(V_label)
    
    # Compute KL divergence loss according to equation (12)
    # kl_loss = compute_kl_loss(mu1, logvar1, mu2, logvar2)

    # print(f"KL Loss for batch {i // batch_size}: {kl_loss.item()}")

    #     mu1, logvar1 = torch.randn_like(aligned_output_ffn,requires_grad=True), torch.randn_like(aligned_output_ffn,requires_grad=True)
    # mu2, logvar2 = torch.randn_like(batch_V_repeated,requires_grad=True), torch.randn_like(batch_V_repeated,requires_grad=True)
    kl_loss = kl_divergence(mu1, logvar1, mu2, logvar2)

    print(f"KL Loss for batch {i // batch_size}: {kl_loss.item()}")

    
    kl_loss.backward(retain_graph=True)
    optimizer.step()
    
    aligned_outputs.append(aligned_output_ffn.cpu())

aligned_outputs_tensor = torch.cat(aligned_outputs, dim=0)
torch.save(aligned_outputs_tensor, os.path.join(output_aligned_features_dir, "aligned_outputs.pt"))
# Saving each model's state_dict separately
torch.save(mha.state_dict(), "mha.pth")
torch.save(ffn.state_dict(), "ffn.pth")
torch.save(piror.state_dict(), "piror.pth")

# Or saving all models together
torch.save({
    'mha': mha.state_dict(),
    'ffn': ffn.state_dict(),
    'piror': piror.state_dict(),
}, "model_parameters.pth")

print("Aligned features saved successfully.")

### **Report Generator**

### **Complete Model**

## **Training**

### **Training**

## **Testing**

### **Testing**