## **Medical Report Summarisation using Medical Knowledge**

### **References**

**Main Reference**
- Radiology report generation with medical knowledge and multilevel image-report alignment: A new method and its verification
https://www.sciencedirect.com/science/article/pii/S0933365723002282



## **Data Collection**

### **Collect Datasets**

In [15]:
'''Libraries Installation and Import'''

# installling necessary libraries
!pip -q install --user requests numpy pandas matplotlib tqdm Pillow opencv-python nltk pyspellchecker torch torchvision torchaudio transformers scikit-learn sentence-transformers

# importing required libraries
import os
import re
import csv
import requests
import tarfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from tqdm import tqdm
from PIL import Image
import cv2

import nltk
from nltk.corpus import stopwords
from spellchecker import SpellChecker
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as dist
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import models, transforms

import transformers
from transformers import BertTokenizer, BertModel
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer
from transformers import DistilBertModel, DistilBertTokenizer
from sentence_transformers import SentenceTransformer

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optims
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertModel
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
import torch.distributions as dist
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
from transformers import DistilBertModel, DistilBertTokenizer
from collections import Counter
import os
import re
import pandas as pd
from tqdm import tqdm
import nltk
from spellchecker import SpellChecker
from sklearn.model_selection import train_test_split
from nltk.corpus import stopwords


In [2]:
'''Setting Paths'''

# project directory
# from google.colab import drive
# drive.mount('/content/drive')
# project_directory = '/content/drive/Othercomputers/My Laptop/CS550_ASMT_MRSMK/datasets'
# project_directory = '/content/drive/MyDrive/Academics/CS550 Machine Learning/CS550 ASMT MRSMK/datasets'

project_directory = "./datasets"
dataset = 'iu_xray/'
iu_xray_dataset = os.path.join(project_directory, dataset)


# input directory
input_directory = os.path.join(iu_xray_dataset, "input")

images_dir = os.path.join(input_directory, "images")
reports_dir = os.path.join(input_directory, "reports")
iu_xray_images = images_dir
iu_xray_reports = os.path.join(reports_dir, 'ecgen-radiology')


# output directory 
output_directory = os.path.join(iu_xray_dataset, "output")
os.makedirs(output_directory, exist_ok=True)

In [3]:
'''Setup - Generalized'''

# setup to download the IU X-Ray Dataset
images_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz"
reports_url = "https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz"


# function to check the file size of a given URL
def get_file_size(url):
    response = requests.head(url)
    size_in_bytes = int(response.headers.get('Content-Length', 0))
    size_in_mb = size_in_bytes / (1024 * 1024)
    return size_in_mb


# function to download and extract from a given url to a given directory
def download_and_extract(url, save_dir):
    file_name = url.split('/')[-1]
    file_path = os.path.join(save_dir, file_name)

    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('Content-Length', 0))
    downloaded_size = 0

    with open(file_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
                downloaded_size += len(chunk)
                percent_complete = (downloaded_size / total_size) * 100
                print(f"Downloaded {downloaded_size / (1024*1024):.2f} MB out of {total_size / (1024*1024):.2f} MB: {percent_complete:.2f}% complete", end="\r")

    print("\nDownload complete!")

    with tarfile.open(file_path, 'r:gz') as tar:
        members = tar.getmembers()
        total_files = len(members)

        for idx, member in enumerate(members, start=1):
            tar.extract(member, path=save_dir)
            print(f"Extracting File {idx} out of {total_files}: {member.name}", end="\r")

    os.remove(file_path)


# downloading  IU X-Ray dataset
if not os.path.exists(images_dir):
    images_size = get_file_size(images_url)
    print(f"Downloading {images_url} to: {images_dir} ({images_size:.2f} MB)")
    os.makedirs(images_dir, exist_ok=True)
    download_and_extract(images_url, images_dir)
    print(f"Downloaded {images_url} to: {images_dir}")
else:
    print(f"{images_url} already exists at: {images_dir}")

if not os.path.exists(reports_dir):
    reports_size = get_file_size(reports_url)
    print(f"Downloading {reports_url} to: {reports_dir} ({reports_size:.2f} MB)")
    os.makedirs(reports_dir, exist_ok=True)
    download_and_extract(reports_url, reports_dir)
    print(f"Downloaded {reports_url} to: {reports_dir}")
else:
    print(f"{reports_url} already exists at: {reports_dir}")

https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz already exists at: ./datasets/iu_xray/input/images
https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz already exists at: ./datasets/iu_xray/input/reports


In [4]:
'''Exploring the IU X-Ray Dataset Contents'''

# displaying directory and subdirectory contents
print("\nPath: ", iu_xray_images)
print(f"Directory Contents: {len(os.listdir(iu_xray_images))} Images")

print("\nPath: ", iu_xray_reports)
print(f"Directory Contents: {len(os.listdir(iu_xray_reports))} Reports")


Path:  ./datasets/iu_xray/input/images
Directory Contents: 7471 Images

Path:  ./datasets/iu_xray/input/reports/ecgen-radiology
Directory Contents: 3955 Reports


In [None]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_images_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}", end="\r")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                for parent_image in root.findall('parentImage'):
                    image_file = parent_image.attrib['id'] + ".png"
                    image_path = os.path.join(iu_xray_images, image_file)
                    image = cv2.imread(image_path)

                    if image is not None:
                        height, width, channels = image.shape
                        caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                        data.append([pmc_id, image_file, caption, comparison, indication, findings, impression, height, width])
                    else:
                        print(f"Warning: Unable to read image {image_path}")

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# creating a dataframe and saving it as .csv
iu_xray_images_df_path = os.path.join(output_directory, 'iu_xray_images_df.csv')
if os.path.exists(iu_xray_images_df_path):
    data = save_images_df()
    columns = ['pmc_id', 'image_filename', 'caption', 'comparison', 'indication', 'findings', 'impression', 'height', 'width']
    iu_xray_images_df = pd.DataFrame(data, columns=columns)
    iu_xray_images_df.to_csv(iu_xray_images_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_images_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_images_df_path}")
    iu_xray_images_df = pd.read_csv(iu_xray_images_df_path)


# displaying the stored dataframe
print("\n\nDataframe Shape:", iu_xray_images_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_images_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_images_df.head())

In [None]:
'''Processing Textual Data from each .xml Report File and Storing it in a .csv File'''

# function to iterate through all .xml report files and storing them in a dataframe
def save_reports_df():
    data = []
    cnt = 0
    for file in os.listdir(iu_xray_reports):
        if file.endswith(".xml"):
            cnt += 1
            print(f"Processing .xml File {cnt} out of {len(os.listdir(iu_xray_reports))}: {file}", end="\r")

            file_path = os.path.join(iu_xray_reports, file)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()

                pmc_id = root.find('.//pmcId').attrib.get('id')

                comparison = indication = findings = impression = None

                for abstract in root.findall('.//AbstractText'):
                    if abstract.attrib.get('Label') == 'COMPARISON':
                        comparison = abstract.text
                    elif abstract.attrib.get('Label') == 'INDICATION':
                        indication = abstract.text
                    elif abstract.attrib.get('Label') == 'FINDINGS':
                        findings = abstract.text
                    elif abstract.attrib.get('Label') == 'IMPRESSION':
                        impression = abstract.text

                report_data = {
                    'pmc_id': pmc_id,
                    'findings': findings,
                    'impression': impression,
                    'comparison': comparison,
                    'indication': indication,
                }

                parent_images = root.findall('parentImage')
                report_data['image_count'] = len(parent_images)

                for i, parent_image in enumerate(parent_images, start=1):
                    image_file = parent_image.attrib['id'] + ".jpg"
                    caption = parent_image.find('caption').text if parent_image.find('caption') is not None else None
                    report_data[f'image_{i}'] = f"{image_file}: {caption}" if caption else image_file

                data.append(report_data)

            except Exception as e:
                print(f"Error processing file {file}: {e}")

    return data


# creating a dataframe and saving it as .csv
iu_xray_reports_df_path = os.path.join(output_directory, 'iu_xray_reports_df.csv')
if not os.path.exists(iu_xray_reports_df_path):
    data = save_reports_df()
    iu_xray_reports_df = pd.DataFrame(data)
    iu_xray_reports_df.to_csv(iu_xray_reports_df_path, index=False)
    print(f"Dataframe saved to {iu_xray_reports_df_path}")
else:
    print(f"Dataframe already exists at {iu_xray_reports_df_path}")
    iu_xray_reports_df = pd.read_csv(iu_xray_reports_df_path)


# displaying the stored dataframe
print("\n\nDataframe Shape:", iu_xray_reports_df.shape)

print("\n\nDataframe Information:\n")
display(iu_xray_reports_df.info())

print("\n\nDisplaying Dataframe:\n")
display(iu_xray_reports_df.head())

In [None]:
'''Displaying the Number of Images per Report'''

# displaying the distribution of number of images per report
reports_count = iu_xray_reports_df['image_count'].value_counts().rename_axis('images_qty').reset_index(name='reports_count')
print("\n\nNumber of Images per Report:\n")
display(reports_count)

In [None]:
'''Checking for Duplicates'''

# Check for duplicate values in the 'pmc_id' column
duplicates_in_pmc_id = iu_xray_reports_df['pmc_id'].duplicated()
num_duplicates = duplicates_in_pmc_id.sum()

# Display the duplicated rows
duplicated_rows = iu_xray_reports_df[duplicates_in_pmc_id]
print(f"Number of duplicates in 'pmc_id' column: {num_duplicates}")
print("Duplicated rows in 'pmc_id' column:")
print(duplicated_rows)

## **Data Preprocessing**

### **Preprocess Images**

In [5]:
'''Preprocessing Images - Resizing, Tensor Conversion and Normalization'''

# function to preprocess and save images
def preprocess_images(input_dir, output_dir):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    os.makedirs(output_dir, exist_ok=True)

    cnt = 0
    for filename in os.listdir(input_dir):
        if filename.endswith('.png'):
            cnt += 1
            print(f"Preprocessing File {cnt} out of {len(os.listdir(input_dir))}: {filename}", end="\r")

            image_path = os.path.join(input_dir, filename)
            image = Image.open(image_path).convert('RGB')
            processed_image = preprocess(image)

            processed_image_path = os.path.join(output_dir, filename)

            processed_image_pil = transforms.ToPILImage()(processed_image)
            processed_image_pil.save(processed_image_path)


# preprocessing images
iu_xray_images_preprocessed = os.path.join(output_directory, 'images_preprocessed')
if not os.path.exists(iu_xray_images_preprocessed):
    print(f"Preprocessing Images to: {iu_xray_images_preprocessed}")
    preprocess_images(iu_xray_images, iu_xray_images_preprocessed)
    print(f"Preprocessed Images saved to: {iu_xray_images_preprocessed}")
else:
    print(f"Preprocessed Images already exist at: {iu_xray_images_preprocessed}")

Preprocessed Images already exist at: ./datasets/iu_xray/output/images_preprocessed


### **Preprocess Text**

In [None]:
'''Preprocessing Text - Lowercasing, Decontracting, Punctuation Removal, Number Removal, Two-Letter Word Removal, Stop Word Removal, Spell Checking, Extra Space Removal'''

# download nltk resources and initialize spell checker
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
spell = SpellChecker()


# function to convert text to lowercase
def lowercase(text):
    return text.lower() if isinstance(text, str) else text


# function to decontract words
def decontracted(text):
    if not isinstance(text, str):
        return text
    contractions = {
        "won't": "will not", "can't": "can not", "couldn't": "could not",
        "shouldn't": "should not", "wouldn't": "would not", "n't": " not",
        "'re": " are", "'s": " is", "'d": " would", "'ll": " will",
        "'t": " not", "'ve": " have", "'m": " am"
    }
    for contraction, full_form in contractions.items():
        text = text.replace(contraction, full_form)
    return text


# function to remove punctuations
def rem_punctuations(text):
    return re.sub(r'[^\w\s]', ' ', text) if isinstance(text, str) else text


# function to remove numbers
def rem_numbers(text):
    if not isinstance(text, str):
        return text
    text = re.sub(r'[xX]{2,}', '', text)
    return re.sub(r'\d+', '', text)


# function to remove two-letter words except "no" and "ct"
def rem_two_letter_words(text):
    if not isinstance(text, str):
        return text
    return ' '.join(word for word in text.split() if len(word) > 2 or word in ["no", "ct"])


# function to remove stop words
def rem_stop_words(text):
    if not isinstance(text, str):
        return text
    stop_words = set(stopwords.words('english'))
    return ' '.join(word for word in text.split() if word not in stop_words)


# function to correct spelling
def correct_spelling(text):
    if not isinstance(text, str):
        return text
    corrected = []
    for word in text.split():
        corrected_word = list(spell.candidates(word))[0] if spell.candidates(word) else word
        corrected.append(corrected_word)
    return ' '.join(corrected)


# function to remove extra spaces
def rem_extra_spaces(text):
    return ' '.join(text.split()) if isinstance(text, str) else text


# function to handle full stops
def handle_fullstops(text):
    if not isinstance(text, str):
        return text
    text = re.sub(r'\.\.+', '.', text) 
    return re.sub(r'\.', ' . ', text) 


# function to remove apostrophes
def rem_apostrophes(text):
    return re.sub("'", '', text) if isinstance(text, str) else text


# function to preprocess text
def preprocess_text(data):
    preprocessed = []
    for sentence in tqdm(data.values):
        sentence = str(sentence)
        sentence = lowercase(sentence)
        sentence = decontracted(sentence)
        sentence = rem_punctuations(sentence)
        sentence = rem_numbers(sentence)
        sentence = rem_two_letter_words(sentence)
        sentence = rem_stop_words(sentence)
        sentence = correct_spelling(sentence)
        sentence = rem_extra_spaces(sentence)
        sentence = handle_fullstops(sentence)
        sentence = rem_apostrophes(sentence)
        
        preprocessed.append(sentence)

    return preprocessed

In [None]:
'''Preprocessing Text - Lowercasing, Decontracting, Punctuation Removal, Number Removal, Two-Letter Word Removal, Stop Word Removal, Spell Checking, Extra Space Removal'''

# load your DataFrame (replace with actual path)
iu_xray_reports_df = os.path.join(output_directory, 'iu_xray_reports_df.csv')
df = pd.read_csv(iu_xray_reports_df)


# apply preprocessing on specific columns if they exist
preprocess_columns = ['findings']
for column in preprocess_columns:
    if column in df.columns:
        print(f"Preprocessing Column: {column}")
        df[column] = df[column].fillna('none').astype(str)
        df[column] = preprocess_text(df[column])
        output_path = os.path.join(output_directory, f'preprocessed_{column}.csv')
        df.to_csv(output_path, index=False)
        print(f"Saved preprocessed '{column}' column to: {output_path}")
        
        
# split into Train/Validation/Test (70%/10%/20%)
output_path = os.path.join(output_directory, f'preprocessed_findings.csv')
df = pd.read_csv(output_path)        
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=(2/3), random_state=42)


# save the splits
train_path = os.path.join(output_directory, 'train_data.csv')
train_df.to_csv(train_path, index=False)
display(train_df.head())
print(f"Train data saved to: {train_path}")

val_path = os.path.join(output_directory, 'val_data.csv')
val_df.to_csv(val_path, index=False)
display(val_df.head())
print(f"Validation data saved to: {val_path}")

test_path = os.path.join(output_directory, 'test_data.csv')
test_df.to_csv(test_path, index=False)
display(test_df.head())
print(f"Test data saved to: {test_path}")

In [None]:
'''Indexing the Training Data'''

# indexing the train data
df = pd.read_csv(train_path)     
df = df[['pmc_id', 'findings', 'image_1', 'image_2']]
df['index'] = range(1, len(df) + 1)


# display the modified DataFrame to check the output
print("Shape of the DataFrame:", df.shape)
display(df.head(20))
df.to_csv(train_path, index=False)
print(f"Train data saved to: {train_path}")

In [None]:
'''Creating Filtered Dataframes'''

# filtering the data
filtered_train_path = os.path.join(output_directory, 'filtered_train_data.csv')
df = pd.read_csv(train_path)  
print("Shape of the DataFrame Before:", df.shape)

filtered_df = df.dropna(subset=['image_1', 'image_2'], how='all')
print("Shape of the DataFrame After:", filtered_df.shape)

filtered_df.to_csv(filtered_train_path, index=False)
print(f"Dataframe saved to: {filtered_train_path}")


# filtering the data
filtered_val_path = os.path.join(output_directory, 'filtered_val_data.csv')
df = pd.read_csv(val_path)  
print("Shape of the DataFrame Before:", df.shape)

filtered_df = df.dropna(subset=['image_1', 'image_2'], how='all')
print("Shape of the DataFrame After:", filtered_df.shape)

filtered_df.to_csv(filtered_val_path, index=False)
print(f"Dataframe saved to: {filtered_val_path}")


# filtering the data
filtered_test_path = os.path.join(output_directory, 'filtered_test_data.csv')
df = pd.read_csv(test_path)  
print("Shape of the DataFrame Before:", df.shape)

filtered_df = df.dropna(subset=['image_1', 'image_2'], how='all')
print("Shape of the DataFrame After:", filtered_df.shape)

filtered_df.to_csv(filtered_test_path, index=False)
print(f"Dataframe saved to: {filtered_test_path}")

In [None]:
'''Preprocessing Text - Lowercasing, Decontracting, Punctuation Removal, Number Removal, Two-Letter Word Removal, Stop Word Removal, Spell Checking, Extra Space Removal'''

# download nltk resources and initialize spell checker
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
spell = SpellChecker()


# function to convert text to lowercase
def lowercase(text):
    return text.lower() if isinstance(text, str) else text


# function to decontract words
def decontracted(text):
    if not isinstance(text, str):
        return text
    contractions = {
        "won't": "will not", "can't": "can not", "couldn't": "could not",
        "shouldn't": "should not", "wouldn't": "would not", "n't": " not",
        "'re": " are", "'s": " is", "'d": " would", "'ll": " will",
        "'t": " not", "'ve": " have", "'m": " am"
    }
    for contraction, full_form in contractions.items():
        text = text.replace(contraction, full_form)
    return text


# function to remove punctuations
def rem_punctuations(text):
    return re.sub(r'[^\w\s]', ' ', text) if isinstance(text, str) else text


# function to remove numbers
def rem_numbers(text):
    return re.sub(r'\d+', ' ', text) if isinstance(text, str) else text


# function to remove two-letter words except "no" and "ct"
def rem_two_letter_words(text):
    if not isinstance(text, str):
        return text
    return ' '.join(word for word in text.split() if len(word) > 2 or word in ["no", "ct"])


# function to remove stop words
def rem_stop_words(text):
    if not isinstance(text, str):
        return text
    stop_words = set(stopwords.words('english'))
    return ' '.join(word for word in text.split() if word not in stop_words)


# function to correct spelling
def correct_spelling(text):
    if not isinstance(text, str):
        return text
    corrected = []
    for word in text.split():
        corrected_word = list(spell.candidates(word))[0] if spell.candidates(word) else word
        corrected.append(corrected_word)
    return ' '.join(corrected)


# function to remove extra spaces
def rem_extra_spaces(text):
    return ' '.join(text.split()) if isinstance(text, str) else text


# function to preprocess text
def preprocess_text(data):
    preprocessed = []
    for sentence in tqdm(data.values):
        sentence = str(sentence)
        sentence = lowercase(sentence)
        sentence = decontracted(sentence)
        sentence = rem_punctuations(sentence)
        sentence = rem_numbers(sentence)
        sentence = rem_two_letter_words(sentence)
        sentence = rem_stop_words(sentence)
        sentence = correct_spelling(sentence)
        sentence = rem_extra_spaces(sentence)
        
        preprocessed.append(sentence)

    return preprocessed


# path to the preprocessed dataframe
iu_xray_reports_preprocessed_df_path = os.path.join(output_directory, 'iu_xray_reports_preprocessed_sorted_df.csv')
# report_data_path = os.path.join(output_directory,'iu_xray_reports_preprocessed_sorted_df.csv')
iu_xray_reports_preprocessed_df = iu_xray_reports_df.copy()


# preprocessing text columns in the dataframe
if os.path.exists(iu_xray_reports_preprocessed_df_path):
    print(f"Preprocessing Text of DataFrame {iu_xray_reports_df_path} to: {iu_xray_reports_preprocessed_df_path}")
    
    preprocess_caption = True
    preprocess_comparison = True
    preprocess_indication = True
    preprocess_findings = True
    preprocess_impression = True
    
    if preprocess_caption and 'caption' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: caption")
        iu_xray_reports_preprocessed_df['caption'] = iu_xray_reports_preprocessed_df['caption'].fillna('unknown').astype(str)
        iu_xray_reports_preprocessed_df['caption'] = preprocess_text(iu_xray_reports_preprocessed_df['caption'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'caption' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_comparison and 'comparison' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: comparison")
        iu_xray_reports_preprocessed_df['comparison'] = iu_xray_reports_preprocessed_df['comparison'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['comparison'] = preprocess_text(iu_xray_reports_preprocessed_df['comparison'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'comparison' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_indication and 'indication' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: indication")
        iu_xray_reports_preprocessed_df['indication'] = iu_xray_reports_preprocessed_df['indication'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['indication'] = preprocess_text(iu_xray_reports_preprocessed_df['indication'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'indication' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_findings and 'findings' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: findings")
        iu_xray_reports_preprocessed_df['findings'] = iu_xray_reports_preprocessed_df['findings'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['findings'] = preprocess_text(iu_xray_reports_preprocessed_df['findings'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'findings' column to: {iu_xray_reports_preprocessed_df_path}")
    
    if preprocess_impression and 'impression' in iu_xray_reports_preprocessed_df.columns:
        print("Preprocessing Column: impression")
        iu_xray_reports_preprocessed_df['impression'] = iu_xray_reports_preprocessed_df['impression'].fillna('none').astype(str)
        iu_xray_reports_preprocessed_df['impression'] = preprocess_text(iu_xray_reports_preprocessed_df['impression'])
        iu_xray_reports_preprocessed_df.to_csv(iu_xray_reports_preprocessed_df_path, index=False)
        print(f"Saved preprocessed 'impression' column to: {iu_xray_reports_preprocessed_df_path}")
else:
    print(f"Preprocessed Text of DataFrame {iu_xray_reports_df_path} already exists at: {iu_xray_reports_preprocessed_df_path}")
    

# displaying the preprocessed dataframe
iu_xray_reports_preprocessed_df = pd.read_csv(iu_xray_reports_preprocessed_df_path)
display(iu_xray_reports_preprocessed_df.head())

In [None]:
import os
import re
import pandas as pd
from tqdm import tqdm
import nltk
from spellchecker import SpellChecker
from sklearn.model_selection import train_test_split
from nltk.corpus import stopwords

# Download necessary NLTK resources
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')

# Initialize spell checker
spell = SpellChecker()

# Combined preprocessing functions
def lowercase(text):
    """Convert text to lowercase."""
    return text.lower() if isinstance(text, str) else text

def decontracted(text):
    """Decontract phrases in the text."""
    if not isinstance(text, str):
        return text
    contractions = {
        "won't": "will not", "can't": "can not", "couldn't": "could not",
        "shouldn't": "should not", "wouldn't": "would not", "n't": " not",
        "'re": " are", "'s": " is", "'d": " would", "'ll": " will",
        "'t": " not", "'ve": " have", "'m": " am"
    }
    for contraction, full_form in contractions.items():
        text = text.replace(contraction, full_form)
    return text

def rem_punctuations(text):
    """Remove punctuations except for full stops."""
    return re.sub(r'[^\w\s.]', '', text) if isinstance(text, str) else text

def rem_numbers(text):
    """Remove numbers and irrelevant text like 'XXXX'."""
    if not isinstance(text, str):
        return text
    text = re.sub(r'[xX]{2,}', '', text)  # Removes sequences like 'XXXX'
    return re.sub(r'\d+', '', text)

def rem_two_letter_words(text):
    """Remove words with fewer than 2 characters except 'no' and 'ct'."""
    if not isinstance(text, str):
        return text
    return ' '.join(word for word in text.split() if len(word) > 2 or word in ["no", "ct"])

def rem_stop_words(text):
    """Remove stop words."""
    if not isinstance(text, str):
        return text
    stop_words = set(stopwords.words('english'))
    return ' '.join(word for word in text.split() if word not in stop_words)

def correct_spelling(text):
    """Correct spelling using a spell checker."""
    if not isinstance(text, str):
        return text
    corrected = []
    for word in text.split():
        corrected_word = list(spell.candidates(word))[0] if spell.candidates(word) else word
        corrected.append(corrected_word)
    return ' '.join(corrected)

def rem_extra_spaces(text):
    """Remove extra spaces."""
    return ' '.join(text.split()) if isinstance(text, str) else text

def handle_fullstops(text):
    """Handle full stops, spacing around them, and remove multiple consecutive stops."""
    if not isinstance(text, str):
        return text
    text = re.sub(r'\.\.+', '.', text)  # Convert multiple full stops to single
    return re.sub(r'\.', ' . ', text)  # Add space around full stops

def rem_apostrophes(text):
    """Remove apostrophes."""
    return re.sub("'", '', text) if isinstance(text, str) else text

# Combined text preprocessing function
def preprocess_text(data):
    """Apply combined preprocessing steps."""
    preprocessed = []
    for sentence in tqdm(data.values):
        sentence = str(sentence)
        sentence = lowercase(sentence)
        sentence = decontracted(sentence)
        sentence = rem_punctuations(sentence)
        sentence = rem_numbers(sentence)
        sentence = rem_two_letter_words(sentence)
        sentence = rem_stop_words(sentence)
        sentence = correct_spelling(sentence)
        sentence = rem_apostrophes(sentence)
        sentence = handle_fullstops(sentence)
        sentence = rem_extra_spaces(sentence)
        
        preprocessed.append(sentence)

    return preprocessed

# Load your DataFrame (replace with actual path)
input_path = os.path.join(output_directory, 'iu_xray_reports_sorted_df.csv')
# output_directory = os.path.join(output_directory, 'iu_xray_reports_sorted_preprocessed_df.csv')
df = pd.read_csv(input_path)

# Select only the columns 'pmc_id', 'findings', 'image_1', and 'image_2'

# Apply preprocessing on specific columns if they exist
preprocess_columns = ['findings']
for column in preprocess_columns:
    if column in df.columns:
        print(f"Preprocessing Column: {column}")
        df[column] = df[column].fillna('none').astype(str)
        df[column] = preprocess_text(df[column])
        output_path = os.path.join(output_directory, f'preprocessed_{column}.csv')
        df.to_csv(output_path, index=False)
        print(f"Saved preprocessed '{column}' column to: {output_path}")

# Split into Train/Validation/Test (70%/10%/20%)
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=(2/3), random_state=42)

# Save the splits
train_path = os.path.join(output_directory, 'train_data.csv')
val_path = os.path.join(output_directory, 'val_data.csv')
test_path = os.path.join(output_directory, 'test_data.csv')

train_df.to_csv(train_path, index=False)
val_df.to_csv(val_path, index=False)
test_df.to_csv(test_path, index=False)

print(f"Train data saved to: {train_path}")
print(f"Validation data saved to: {val_path}")
print(f"Test data saved to: {test_path}")


In [None]:
df = pd.read_csv("./datasets/iu_xray/output/test_data.csv")     
df = df[['pmc_id', 'findings', 'image_1', 'image_2']]

# Display the modified DataFrame to check the output

df['index'] = range(1, len(df) + 1)
print("Shape of the DataFrame:", df.shape)
print(df.head())
df.to_csv(test_path, index=False)

print(f"Train data saved to: {test_path}")


In [None]:
filtered_train_path = os.path.join(output_directory, 'filtered_train_data.csv')
filtered_val_path = os.path.join(output_directory, 'filtered_val_data.csv')
filtered_test_path = os.path.join(output_directory, 'filtered_test_data.csv')



df = pd.read_csv("./datasets/iu_xray/output/val_data.csv")  
print("Shape of the DataFrame:", df.shape)
filtered_df = df.dropna(subset=['image_1', 'image_2'], how='all')
print("Shape of the DataFrame:", filtered_df.shape)
filtered_df.to_csv(filtered_val_path, index=False)

print(f"Train data saved to: {filtered_val_path}")

### **Create Data Loaders**

In [6]:
'''Image Data Loaders to Supply Dataset to Model in Batches'''

# classes in dataset
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image


# function to load image data with transformation and batching
def load_preprocessed_images(image_dir, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset = CustomImageDataset(image_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [None]:
'''Text Data Loaders to Supply Dataset to Model in Batches'''

# classes in dataset
class CustomTextDataset(Dataset):
    def __init__(self, text_list, tokenizer, max_length=512):
        self.text_list = text_list
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.text_list)

    def __getitem__(self, idx):
        text = self.text_list[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {key: tensor.squeeze(0) for key, tensor in encoding.items()}


# function to load text data with batching
def load_preprocessed_texts(text_list, tokenizer, batch_size=32, max_length=512):
    dataset = CustomTextDataset(text_list, tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

## **Model Implementation**

### **Visual Extractor**

In [13]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import pandas as pd
import torch.optim as optim

output_directory="./datasets/iu_xray/output/"
# Define paths for saving features
patch_feats_file = os.path.join(output_directory, 'final_test_patch_feats.pt')
avg_feats_file = os.path.join(output_directory, 'final_test_avg_feats.pt')
final_embeddings_file = os.path.join(output_directory, 'final_test_final_embeddings.pt')
# iu_xray_images_preprocessed="/kaggle/working/extracted_images/images_preprocessed"
# # Define paths for saving features
# patch_feats_file = os.path.join(output_directory, 'patch_feats.pt')
# avg_feats_file = os.path.join(output_directory, 'avg_feats.pt')
# final_embeddings_file = os.path.join(output_directory, 'final_embeddings.pt')

# Define the transform for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust size based on model input requirement
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization for ResNet
])

# Define VisualExtractor class
class VisualExtractor(nn.Module):
    def __init__(self, visual_extractor='resnet101', pretrained=True):
        super(VisualExtractor, self).__init__()

        model = getattr(models, visual_extractor)(pretrained=pretrained)
        
        # Remove the last fully connected layer
        modules = list(model.children())[:-2]  
        self.model = nn.Sequential(*modules)
        
        # Average pooling and a fully connected layer to transform the features
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_layer = nn.Linear(model.fc.in_features, 512)
        
    def forward(self, x):
        x = self.model(x)
        avg_feats = self.avg_pool(x).view(x.size(0), -1)
        patch_feats = self.fc_layer(avg_feats)
        return patch_feats, avg_feats

# Load images function that handles the new DataFrame structure
def load_images(report_row, img_folder):
    images = []
    for i in range(1, 3):  # For image_1 and image_2
        img_value = report_row[f'image_{i}']
        
        # Check if img_value is a string before splitting
        if isinstance(img_value, str):
            img_filename = img_value.split('.')[0] + '.png'  # Replace with .png
            img_path = os.path.join(img_folder, img_filename)
            
            # Check if the image file exists before attempting to open it
            if os.path.exists(img_path):
                img = Image.open(img_path).convert("RGB")
                images.append(transform(img))
            else:
                print(f"Warning: Image file not found: {img_path}")  # Warning if file not found
        else:
            print(f"Warning: Missing or invalid image reference in {f'image_{i}'} column.")
    
    # If only one image was loaded, duplicate it
    if len(images) == 1:
        images.append(images[0].clone())  # Duplicate the single available image
    
    return torch.stack(images) if images else torch.tensor([])  # Return empty tensor if no images loaded

# Initialize visual extractor
visual_extractor = VisualExtractor()

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visual_extractor.to(device)

# Set learning rates and other parameters
learning_rate = 5e-5  # 5 × 10^−5
other_parameter = 1e-4  # Example of another parameter (like weight decay)

# Create an optimizer
optimizer = optim.Adam(visual_extractor.parameters(), lr=learning_rate, weight_decay=other_parameter)

def extract_features(data_loader):
    patch_feats, avg_feats, final_embeddings = [], [], []
    
    for idx, row in data_loader.iterrows():
        images = load_images(row, iu_xray_images_preprocessed).to(device)  # Load up to 2 images
        
        if images.numel() == 0:  # Skip if no images were loaded
            continue
        
        report_patch_feats, report_avg_feats = [], []
        
        # Iterate directly over the images tensor
        for image in images:  # Process each image in the loaded images tensor
            image = image.unsqueeze(0).to(device)  # Move single image to device
            pf, af = visual_extractor(image)  # Extract patch_feats, avg_feats
            print(f"{row['index']}, {row['pmc_id']} Patch Features Shape: {pf.shape}, Avg Features Shape: {af.shape}")
            
            # Store the output temporarily, releasing earlier references
            report_patch_feats.append(pf.detach())
            report_avg_feats.append(af.detach())

        # Concatenate patch features and average the avg_feats
        concatenated_patch_feats = torch.cat(report_patch_feats, dim=1)  # Concatenate along feature dimension
        averaged_avg_feats = torch.mean(torch.stack(report_avg_feats), dim=0)  # Average the avg_feats

        # Combine avg_feats with patch_feats to create final embedding
        final_embedding = torch.cat((averaged_avg_feats, concatenated_patch_feats), dim=1)
        patch_feats.append(concatenated_patch_feats)
        avg_feats.append(averaged_avg_feats)
        final_embeddings.append(final_embedding)
        
    return torch.stack(patch_feats), torch.stack(avg_feats), torch.stack(final_embeddings)

# Functions to load and save features
def load_features(file_path):
    return torch.load(file_path)

def save_features(file_path, features):
    torch.save(features, file_path)

# # Load data
# report_data_path = os.path.join(output_directory, 'filtered_train_data.csv')
# iu_xray_reports_df = pd.read_csv(report_data_path)
# display(iu_xray_reports_df.head(10))
# # Load data
report_data_path =  os.path.join(output_directory, "final_filtered_test_data.csv")
iu_xray_reports_df = pd.read_csv(report_data_path)
display(iu_xray_reports_df.head(10))

# Check if features are already saved; if not, extract and save them
if not os.path.exists(patch_feats_file) and os.path.exists(avg_feats_file) and os.path.exists(final_embeddings_file):
    print("All features are already precomputed and will be loaded.")
    patch_feats = load_features(patch_feats_file)
    avg_feats = load_features(avg_feats_file)
    final_embeddings = load_features(final_embeddings_file)
else:
    print("Extracting features since they are not precomputed...")
    
    # Set the model to training mode
    visual_extractor.train()  
    
    patch_feats, avg_feats, final_embeddings = extract_features(iu_xray_reports_df)
    patch_feats = patch_feats.squeeze(1)  
    avg_feats = avg_feats.squeeze(1)      
    final_embeddings = final_embeddings.squeeze(1) 

    save_features(patch_feats_file, patch_feats)
    save_features(avg_feats_file, avg_feats)
    save_features(final_embeddings_file, final_embeddings)

# Displaying the shapes of the feature dataframes
print("Patch Features Shape:", patch_feats.shape)
print("Average Features Shape:", avg_feats.shape)
print("Final Embedding Shape:", final_embeddings.shape)


Unnamed: 0,pmc_id,findings,image_1,image_2,index
0,3790,low lung volumes . elevation the right hemidi...,CXR3790_IM-1904-0001-0001.jpg: Xray Chest PA a...,CXR3790_IM-1904-0001-0002.jpg: Xray Chest PA a...,1
1,2282,the lungs are clear bilaterally . specificall...,CXR2282_IM-0869-1001.jpg: PA and lateral chest...,CXR2282_IM-0869-2001.jpg: PA and lateral chest...,2
2,2841,the heart normal size and contour . the lungs...,CXR2841_IM-1253-2001.jpg: Xray Chest PA and La...,,3
3,2192,no focal lung consolidation . no pneumothora ...,CXR2192_IM-0802-2002.jpg: Xray Chest PA and La...,CXR2192_IM-0802-3003.jpg: Xray Chest PA and La...,4
4,3149,the lungs are hyperepanded . cardiomediastina...,CXR3149_IM-1480-1001.jpg: PA and Lateral views...,CXR3149_IM-1480-2001.jpg: PA and Lateral views...,5
5,3673,the heart normal size . the mediastinal conto...,CXR3673_IM-1828-1001.jpg: CHEST 2V FRONTAL/LAT...,CXR3673_IM-1828-1002.jpg: CHEST 2V FRONTAL/LAT...,6
6,471,the heart normal size . the mediastinum unrem...,CXR471_IM-2099-2002.jpg: Xray Chest PA and Lat...,CXR471_IM-2099-3003.jpg: Xray Chest PA and Lat...,7
7,1226,the heart size within normal limits . trachea...,CXR1226_IM-0150-1001.jpg: The chest 2 views PA...,CXR1226_IM-0150-1002.jpg: The chest 2 views PA...,8
8,1697,lungs are clear bilaterally . cardiac and med...,CXR1697_IM-0458-1001.jpg: PA and lateral chest...,CXR1697_IM-0458-2001.jpg: PA and lateral chest...,9
9,845,minimal subsegmental atelectasis posteriorly ....,CXR845_IM-2367-1001.jpg: PA and lateral views ...,CXR845_IM-2367-1002.jpg: PA and lateral views ...,10


Extracting features since they are not precomputed...
1, 3790 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
1, 3790 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
2, 2282 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
2, 2282 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
3, 2841 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
3, 2841 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
4, 2192 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
4, 2192 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
5, 3149 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
5, 3149 Patch Features Shape: torch.Size([1, 512]), Avg Features Shape: torch.Size([1, 2048])
6, 367

In [None]:
patch_feats = patch_feats.squeeze(1)  # Shape: (3851, 4096)
avg_feats = avg_feats.squeeze(1)      # Shape: (3851, 2048)
final_embeddings = final_embeddings.squeeze(1)  # Shape: (3851, 6144)


In [None]:
import os
import torch

# Define paths for the saved features
patch_feats_file = os.path.join(output_directory, 'patch_feats.pt')
avg_feats_file = os.path.join(output_directory, 'avg_feats.pt')
final_embeddings_file = os.path.join(output_directory, 'final_embeddings.pt')

# Function to load features and print their shapes
def print_tensor_shapes():
    # Load patch features
    if os.path.exists(patch_feats_file):
        patch_feats = torch.load(patch_feats_file, map_location=torch.device('cpu'))
        print("Patch Features Shape:", patch_feats.shape)
    else:
        print(f"Patch features file not found: {patch_feats_file}")

    # Load average features
    if os.path.exists(avg_feats_file):
        avg_feats = torch.load(avg_feats_file, map_location=torch.device('cpu'))
        print("Average Features Shape:", avg_feats.shape)
    else:
        print(f"Average features file not found: {avg_feats_file}")

    # Load final embeddings
    if os.path.exists(final_embeddings_file):
        final_embeddings = torch.load(final_embeddings_file, map_location=torch.device('cpu'))
        print("Final Embedding Shape:", final_embeddings.shape)
    else:
        print(f"Final embeddings file not found: {final_embeddings_file}")

# Call the function to print shapes
print_tensor_shapes()


In [None]:
'''Displaying Tensor Shapes for All Embeddings'''

# Define paths for the saved features
patch_feats_file = os.path.join(output_directory, 'patch_feats.pt')
avg_feats_file = os.path.join(output_directory, 'avg_feats.pt')
final_embeddings_file = os.path.join(output_directory, 'final_embeddings.pt')


# Function to load features and print their shapes
def print_tensor_shapes():
    # Load patch features
    if os.path.exists(patch_feats_file):
        patch_feats = torch.load(patch_feats_file, map_location=torch.device('cpu'))
        print("Patch Features Shape:", patch_feats.shape)
    else:
        print(f"Patch features file not found: {patch_feats_file}")

    # Load average features
    if os.path.exists(avg_feats_file):
        avg_feats = torch.load(avg_feats_file, map_location=torch.device('cpu'))
        print("Average Features Shape:", avg_feats.shape)
    else:
        print(f"Average features file not found: {avg_feats_file}")

    # Load final embeddings
    if os.path.exists(final_embeddings_file):
        final_embeddings = torch.load(final_embeddings_file, map_location=torch.device('cpu'))
        print("Final Embedding Shape:", final_embeddings.shape)
    else:
        print(f"Final embeddings file not found: {final_embeddings_file}")

        
# Call the function to print shapes
print_tensor_shapes()

### **Text Encoder**

In [2]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from transformers import BertModel, BertTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TextEncoder(nn.Module):
    def __init__(self, bert_model="bert-base-uncased", output_dim=384):
        super(TextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)
        self.projection = nn.Linear(self.bert.config.hidden_size, output_dim)
        
    def encode_dictionary(self, dictionary):
        """
        Encodes dictionary entries using BERT and combines key-value pairs
        dictionary: Dict with medical terms as keys and list of related terms as values
        Returns: Tensor of shape [num_entries, output_dim]
        """
        encoded_entries = []
        
        for key, values in dictionary.items():
            # Combine key with its values into a single text
            if values:  # If values list is not empty
                text = key + ": " + ", ".join(values)
            else:
                text = key
                
            # Tokenize and encode
            inputs = self.tokenizer(text, 
                                  padding=True, 
                                  truncation=True, 
                                  return_tensors="pt")
            
            with torch.no_grad():
                outputs = self.bert(**inputs)
                # Use [CLS] token embedding
                embedding = outputs.last_hidden_state[:, 0, :]
                
            # Project to desired dimension
            projected = self.projection(embedding)
            encoded_entries.append(projected)
            
        return torch.cat(encoded_entries, dim=0)
        
    def encode_reports(self, reports):
            """
            Encodes medical reports using BERT
            reports: List of report texts
            Returns: Tensor of shape [batch_size, output_dim]
            """
            # Tokenize all reports in batch
            inputs = self.tokenizer(reports,
                                  padding=True,
                                  truncation=True,
                                  return_tensors="pt")
            
            with torch.no_grad():
                outputs = self.bert(**inputs)
                # Use [CLS] token embeddings
                embeddings = outputs.last_hidden_state[:, 0, :]
                
            # Project to desired dimension
            projected = self.projection(embeddings)
            
            return projected

medical_dict = {
    "pleural": ["hemithorax", "effusion", "pneumothorax", "parenchymal"],
    "lung": ["lungs", "pulmonary", "hilar", "lobe", "consolidation", 
             "atelectasis", "edema", "opacity", "pneumonia"],
    "mediastinal": ["mediastinum", "diaphragm", "hemidiaphragm"],
    "cardiac": ["heart", "cardiomegaly", "cardiomediastinal", "atrium",
                "ventricle", "retrocardiac"],
    "vascular": ["aorta", "venous", "jugular", "aortic", "vasculature", "cabg"],
    "osseous": ["rib", "sternal", "subclavian", "thoracic"],
    "trachea": ["endotrachea"],
    "stomach": [],
    "abdomen": [],
    "tube": ["clips"],
    "spine": ["vertebral", "degenerative"],
    "nodule": ["mass"],
    "chest": ["small", "enlarged", "unchanged", "stable", "silhouette",
              "contours", "size", "focal", "mild", "acute"]
}

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # weight matrices for query, key, value
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        # Compute QK^T / sqrt(d_k)
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k).float())
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

    def forward(self, V, I_prime):
        I_prime = I_prime.unsqueeze(1)
        Q = self.W_q(V)  # Dictionary embeddings
        K = self.W_k(I_prime)  # Image embeddings 
        V = self.W_v(I_prime)
        batch_size, seq_len, d_model = K.size()
        Q = Q.view(Q.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(K.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(V.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
        Q = Q.repeat(64 // 13 + 1, 1, 1, 1)  # Repeat enough times
        Q = Q[:64]
        
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(V.size(0), -1, self.d_model)
        output = self.W_o(attn_output)
        
        return output, attn_weights

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def kl_divergence(mu1, logvar1, mu2, logvar2):
    normal1 = dist.Normal(mu1, torch.exp(0.5 * logvar1))
    normal2 = dist.Normal(mu2, torch.exp(0.5 * logvar2))
    kl_loss = dist.kl.kl_divergence(normal1, normal2).mean()
    return kl_loss

class Piror(nn.Module):
    """Fully connected layer to convert encodings to mean and variance"""
    def __init__(self, input_dim=3072, hidden_dim=512):
        super(Piror, self).__init__()
        self.fc_mu = nn.Linear(input_dim, hidden_dim)
        self.fc_var = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x):
        # Generate mean and log variance
        mu = self.fc_mu(x)
        logvar = self.fc_var(x)
        return mu, logvar



### **Sentence Encoder**

In [3]:
class SentenceEncoder(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=512, output_dim=512):
        super(SentenceEncoder, self).__init__()
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Networks to generate mean and variance as mentioned in paper
        self.mean_layer = nn.Linear(hidden_dim, output_dim)
        self.logvar_layer = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        encoded = self.encoder(x)
        mean = self.mean_layer(encoded)
        logvar = self.logvar_layer(encoded)
        return mean, logvar
        
    def encode(self, x):
        """Get only the mean for inference"""
        mean, _ = self.forward(x)
        return mean

class SentenceBERT:
    def __init__(self):
        # Initialize the BERT tokenizer and model
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

    def encode_reports(self, reports):
        embeddings = []
        for report in reports:
            # Tokenize and encode the report
            inputs = self.tokenizer(report, return_tensors='pt', padding=True, truncation=True)

            # Get BERT outputs
            with torch.no_grad():  # Disable gradient calculation for inference
                outputs = self.bert_model(**inputs)
                
            # Use the pooled output as the embedding
            embeddings.append(outputs.pooler_output)

        # Stack all embeddings into a single tensor
        return torch.stack(embeddings)
def kl_divergence_loss(mean1: torch.Tensor, logvar1: torch.Tensor, mean2: torch.Tensor, logvar2: torch.Tensor) -> torch.Tensor:
    normal1 = dist.Normal(mean1, torch.exp(0.5 * logvar1))  # Standard deviation is sqrt of variance
    normal2 = dist.Normal(mean2, torch.exp(0.5 * logvar2))
    kl_loss = dist.kl.kl_divergence(normal1, normal2).mean()
    return kl_loss


In [4]:
def compute_cosine_similarities(batch_embeddings: torch.Tensor, all_embeddings: torch.Tensor) -> torch.Tensor:
    # Normalize embeddings
    batch_norm = F.normalize(batch_embeddings, p=2, dim=1)
    all_norm = F.normalize(all_embeddings, p=2, dim=1)
    
    # Compute similarities
    similarities = torch.mm(batch_norm, all_norm.t())
    return similarities



In [5]:
import torch

def screen_historical_knowledge(current_embedding, historical_embeddings, top_k=5):
    """
    Screen historical knowledge by selecting the top-K most similar historical embeddings.
    
    Parameters:
    - current_embedding: Tensor, the current report embedding (1, d_model).
    - historical_embeddings: deque, containing historical embeddings of shape (num_historical, d_model).
    - top_k: int, number of most similar historical embeddings to select.

    Returns:
    - screened_knowledge: Tensor, containing the top-K most similar historical embeddings.
    """
    # Calculate cosine similarities between the current embedding and each historical embedding
    similarities = [
        F.cosine_similarity(current_embedding, hist_embedding.unsqueeze(0), dim=1)
        for hist_embedding in historical_embeddings
    ]
    similarities = torch.stack(similarities).squeeze()  # Shape: (num_historical,)

    # Get indices of top-K most similar historical embeddings
    top_k_indices = torch.topk(similarities, top_k, largest=True).indices
    # print(top_k_indices)
    # Select the top-K most similar historical embeddings
    screened_knowledge = torch.stack([historical_embeddings[idx] for idx in top_k_indices])

    return screened_knowledge

# Contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.2):
        super().__init__()
        self.margin = margin

    def forward(self, pos_sim, neg_sim):
        return F.relu(self.margin - pos_sim + neg_sim).mean()


In [6]:
# Define projection layers
avg_projection = nn.Linear(2048, 512).to(device)
hist_projection = nn.Identity().to(device)

# Initialize contrastive loss and classifier
itc_loss_fn = ContrastiveLoss().to(device)
itm_classifier = nn.Linear(512 * 2, 1).to(device)

# Coarse alignment loop
itc_loss_total = 0
itm_loss_total = 0

# Project embeddings
# proj_avg_embeddings = avg_projection(average_embeddings)      # (2700, 512)
# proj_hist_embeddings = hist_projection(historical_embeddings)  # (2700, 512)

### **Multilevel Alignment**

### **Decoder**

In [7]:

class ReportDecoder(nn.Module):
    def __init__(self, input_dim, d_model, vocab_size, num_layers, num_heads, d_ff, dropout=0.1):
        super(ReportDecoder, self).__init__()
        
        # Add a linear projection to convert concatenated features to model dimension
        self.input_projection = nn.Linear(input_dim, d_model)
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        # Project input memory to model dimension
        # memory shape: [batch_size, total_features]
        memory = self.input_projection(memory)  # [batch_size, d_model]
        
        # Add sequence dimension and transpose for transformer
        memory = memory.unsqueeze(0)  # [1, batch_size, d_model]
        
        # Handle target sequence
        if len(tgt.shape) == 2:
            tgt = tgt.transpose(0, 1)  # [seq_len, batch_size]
            
        # Embed target
        tgt_embed = self.embedding(tgt)  # [seq_len, batch_size, d_model]
        
        # Apply transformer decoder
        output = self.transformer_decoder(tgt_embed, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
        
        # Project to vocabulary size
        return self.fc_out(output)  # [seq_len, batch_size, vocab_size]


## **Training**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from transformers import BertTokenizer
from tqdm import tqdm
import pandas as pd
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load final_embeddings as per your code setup
output_aligned_features_dir = "./datasets/iu_xray/output/final_train_final_embeddings.pt"
final_embeddings = torch.load(output_aligned_features_dir, map_location=device)
print(final_embeddings.size())
# Parameters and model setup
vocab_size = 30522  # Example vocab size (based on BERT)
d_model = 3072
num_layers = 3
num_heads = 8
d_ff = 4096
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# decoder = ReportDecoder(d_model, vocab_size, num_layers, num_heads, d_ff).to(device)
sentence_encoder = SentenceEncoder().to(device)
text_encoder = TextEncoder()

# projection_layer = ProjectionLayer(input_dim=768, output_dim=d_model).to(device)  # Example dimensions
projection_layer = nn.Linear(768, 512).to(device)
avg_projection = nn.Linear(2048, 512).to(device)
vprime_layer = nn.Linear(13 * 3072, 3072).to(device)
text_projection_layer = nn.Linear(384, 3072)
hist_projection = nn.Identity().to(device)
itm_classifier = nn.Sequential(
    nn.Linear(512 * 2, 1),  # Example dimensions for a concatenated input
    # nn.ReLU(),
    # nn.Linear(512, 1)
).to(device)
MHA = nn.MultiheadAttention(512, num_heads=8).to(device)
mha = MultiHeadAttention(d_model, num_heads).to(device)
sentence_bert = SentenceBERT()
sentence_encoder = SentenceEncoder()
ffn = FeedForwardNetwork(d_model, d_ff).to(device)
piror_1 = Piror(d_model).to(device)
piror_2 = Piror(384).to(device)
# alignment_model = ImageTextAlignment().to(device)
# Initialize the decoder
# decoder = ReportDecoder(d_model, vocab_size, num_layers, num_heads, d_ff).to(device)
total_feature_dim = 3072 + 3072 + 25600  # 8704
# d_model = 512  # or whatever dimension you want to use

decoder = ReportDecoder(
    input_dim=total_feature_dim,  # 8704
    d_model=512,
    vocab_size=vocab_size,
    num_layers=3,
    num_heads=8,
    d_ff=2048,
    dropout=0.1
).to(device)

# Define a negative log-likelihood loss for the report generation
criterion_nll = nn.CrossEntropyLoss().to(device)

# Combine the optimizer for both encoder and decoder
params = (
    list(sentence_encoder.parameters()) + 
    list(text_encoder.parameters()) +
    list(projection_layer.parameters()) + 
    list(avg_projection.parameters()) + 
    list(itm_classifier.parameters()) +
    list(mha.parameters()) + 
    list(ffn.parameters()) + 
    # list(piror.parameters()) +
    list(decoder.parameters())+
    list(vprime_layer.parameters())
)
optimizer = Adam(params, lr=1e-4)

# Hyperparameters
batch_size = 64
num_epochs = 1
output_dir = './model_checkpoints'

average_embeddings = final_embeddings[:, :2048]
patch_embeddings = final_embeddings[:, 2048:2560]

# Updated Training Loop with Report Generation
def train_with_decoder(final_embeddings, batch_size, num_epochs, output_dir, target_reports_df):
    dataset_size = final_embeddings.size(0)
    training_log = []
    

    for epoch in range(num_epochs):
        total_loss = 0
        n_batches = dataset_size // batch_size + (0 if dataset_size % batch_size != 0 else -1)
        screened_historical_embedding= None
        historical_embeddings=[]

        with tqdm(total=n_batches, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
            for i in range(0, dataset_size, batch_size):
                if final_embeddings[i:i + batch_size].size(0) != batch_size:
                    print(f"Skipping batch {i} due to size mismatch (expected {batch_size}, got {final_embeddings[i:i + batch_size].size(0)})")
                    continue
                
                optimizer.zero_grad()

                
                #TARGET REPORTS
                batch = df.iloc[i:i+batch_size]  # Adjust slicing to include valid rows
                batch_reports_text = batch['findings'].tolist()
                # print(f"Batch {i} to {i+batch_size}: {batch_reports_text}")
                # Tokenize the target reports using the BERT tokenizer
                tokenized_reports = tokenizer(batch_reports_text, padding='longest', return_tensors='pt', truncation=True)
                
                # Extract input IDs and attention mask for the decoder
                batch_report_input_ids = tokenized_reports['input_ids'].to(device)
                batch_report_attention_mask = tokenized_reports['attention_mask'].to(device)

                # Shift the input IDs for the decoder
                report_seq = batch_report_input_ids[:, :-1]  # Input sequence for decoder
                tgt_seq = batch_report_input_ids[:, 1:]      # Target sequence for NLL loss

                
                #SENTENCE ENCODER
                # Get batch data
                batch_final_embeddings = final_embeddings[i:i + batch_size].to(device)
                current_batch_size = batch_final_embeddings.size(0)
                batch_reports = tokenized_reports[i:i + batch_size] #.to(device)
                
                # # Fine alignment (cosine similarity and KL divergence)
                # batch_embedding1 = sentence_encoder_matrix[i:i + current_batch_size].to(device)
                # batch_embedding2 = sbert_matrix[i:i + current_batch_size].to(device)
                

                batch_embedding2 = sentence_bert.encode_reports(batch_reports_text)
                batch_embedding1 = sentence_encoder.encode(batch_embedding2)
                batch_embedding1 = batch_embedding1.squeeze(1)

                batch_embedding2_projected = projection_layer(batch_embedding2)
                
                mean1 = batch_embedding1.mean(dim=0, keepdim=True)
                var1 = batch_embedding1.var(dim=0, keepdim=True, unbiased=False)
                mean2 = batch_embedding2_projected.mean(dim=0, keepdim=True)
                var2 = batch_embedding2_projected.var(dim=0, keepdim=True, unbiased=False)

                kl_loss_fine = kl_divergence_loss(mean1, var1, mean2, var2)
                sim_loss = 1 - F.cosine_similarity(batch_embedding1, batch_embedding2_projected, dim=1).mean()
                fine_alignment_loss = kl_loss_fine + sim_loss

                # Fprint(f"batch_embedding1: {batch_embedding1.size()}")
                for embedding in batch_embedding1:
                    historical_embeddings.append(embedding.detach())
                    # historical_embeddings_2.append(embedding.detach()) # Add embeddings without gradient tracking
                historical_embeddings_2 = batch_embedding1


                #SCREENED HISTORICAL KNOWLEDGE
                screened_knowledge_batch = []
                for embedding in batch_embedding1[:current_batch_size]:
                    screened_knowledge = screen_historical_knowledge(embedding, historical_embeddings_2, top_k=50)
                    screened_knowledge_batch.append(screened_knowledge)
                screened_knowledge_batch = torch.stack(screened_knowledge_batch).to(device)  # Shape: (batch_size, top_k, d_model)
                
                
                #BLIP ARCHITECTURE
                
                # Coarse alignment (ITC and ITM loss)
                proj_avg_embeddings = avg_projection(average_embeddings[i:i + current_batch_size].to(device))
                proj_hist_embeddings = hist_projection(batch_embedding1)
                itc_loss_total = 0
                itm_loss_total = 0
                for j in range(proj_avg_embeddings.size(0)):
                    img_embed = proj_avg_embeddings[j]
                    txt_embed = batch_embedding1[j]
                    txt_embed = txt_embed.squeeze(0)
                    pos_sim = F.cosine_similarity(img_embed, txt_embed, dim=0)
                    neg_index = torch.randint(0, historical_embeddings_2.size(0), (1,), device=device)
                    neg_txt_embed = proj_hist_embeddings[neg_index]
                    neg_sim = F.cosine_similarity(img_embed, neg_txt_embed, dim=0)

                    itc_loss = itc_loss_fn(pos_sim, neg_sim)
                    itc_loss_total += itc_loss
                    # print(f"Shape of combined tensor: {img_embed.shape}  {txt_embed.shape}")
                    combined = torch.cat((img_embed, txt_embed), dim=-1)
                    # print(f"Shape of combined tensor: {combined.shape}")
                    itm_pred = itm_classifier(combined)
                    itm_label = torch.tensor([1.0], dtype=torch.float, device=device)
                    itm_loss = F.binary_cross_entropy_with_logits(itm_pred, itm_label)
                    itm_loss_total += itm_loss

                itc_loss_avg = itc_loss_total / proj_avg_embeddings.size(0)
                itm_loss_avg = itm_loss_total / proj_avg_embeddings.size(0)
                coarse_alignment_loss = itc_loss_avg + itm_loss_avg


                
                
                #TEXT ENCODER WITH ALIGNMENT

                # dictionary_embeddings, label_embeddings = alignment_model.encode_dictionary_and_labels(medical_dict, batch_reports_text)
                # kl_loss_mha, attention_maps, V_prime = alignment_model(medical_dict, batch_reports_text, batch_final_embeddings, batch_size)
                # Multi-Head Attention and Feed-Forward Network
                dictionary_embeddings = text_encoder.encode_dictionary(medical_dict)

                # Load final embeddings and project dictionary embeddings
                dictionary_embeddings = dictionary_embeddings.to(torch.float32)  
                
                V_projected = text_projection_layer(dictionary_embeddings).to(device)
                # batch_patch_embeddings = average_embeddings[i:i + current_batch_size].to(device)
                # batch_V_projected = V_projected.unsqueeze(0).repeat(batch_final_embeddings.size(0), 1, 1).to(device)
                aligned_output, _ = mha(V_projected, batch_final_embeddings)
                aligned_output_ffn = ffn(aligned_output)
                V_prime = aligned_output_ffn
                V_label = text_encoder.encode_reports(batch_reports_text)
                
                # print(f"V_prime shape: {V_prime.shape}")
                # print(f"V_label shape: {V_prime.shape}")
                # print(f"V_prime reduced shape: {V_prime_reduced.shape}")
                # V_prime_reduced = V_prime.max(dim=1).values
                # attention_weights = F.softmax(vprime_layer(V_prime), dim=1)  
                # v_prime_reduced = (attention_weights * V_prime).sum(dim=1)  

                mu1, logvar1 = piror_1(V_prime)
                mu2, logvar2 = piror_2(V_label)
                kl_loss_mha = kl_divergence_loss(mu1, logvar1.exp(), mu2, logvar2.exp())
        
                # Total Loss
                total_encoder_loss = fine_alignment_loss + coarse_alignment_loss + kl_loss_mha
                # total_batch_loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(sentence_encoder.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(projection_layer.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(avg_projection.parameters(), max_norm=1.0) 
                torch.nn.utils.clip_grad_norm(itm_classifier.parameters(), max_norm=1.0) 
                torch.nn.utils.clip_grad_norm_(mha.parameters(), max_norm=1.0) 
                torch.nn.utils.clip_grad_norm_(ffn.parameters(), max_norm=1.0) 
                # torch.nn.utils.clip_grad_norm_(piror.parameters(), max_norm=1.0)
               
                
                #DECODER
                screened_knowledge_batch_flat = screened_knowledge_batch.view(current_batch_size, -1)  # Flatten along the second dimension
                # memory = torch.cat([batch_final_embeddings, V_prime], dim=1)
                
                # print(f"batch_final_embeddings shape: {batch_final_embeddings.shape}")
                # print(f"V_prime shape: {V_prime.shape}")
                # print(f"V_prime reduced shape: {V_prime_reduced.shape}")
                # print(f"screened_knowledge_batch shape: {screened_knowledge_batch_flat.shape}")
                # print(f"report_seq shape: {report_seq.shape}")
                # print(f"tgt_seq shape: {tgt_seq.shape}")
                V_prime_reduced = V_prime.squeeze(1)
                memory = torch.cat([batch_final_embeddings, V_prime_reduced, screened_knowledge_batch_flat], dim=1).to(device)
                report_seq = report_seq.to(device)
                tgt_seq = tgt_seq.to(device)
                
                output_seq = decoder(report_seq, memory)
                output_seq = output_seq.transpose(0, 1)

                output_flat = output_seq.reshape(-1, vocab_size)  # (batch_size * seq_len, vocab_size)
                target_flat = tgt_seq.reshape(-1)  # (batch_size * seq_len)
                # Compute report generation (NLL) loss
                nll_loss = criterion_nll(output_flat.reshape(-1, vocab_size), target_flat.reshape(-1))

                # Compute other losses (KL, similarity, etc.)
                # kl_loss = kl_divergence_loss(mu1, logvar1, mu2, logvar2)
                # sim_loss = 1 - F.cosine_similarity(batch_final_embeddings, V_prime, dim=-1).mean()

                # Combine all losses
                total_batch_loss = nll_loss + total_encoder_loss
                
                # Backward pass and optimization step
                total_batch_loss.backward(retain_graph=True)
                torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
                optimizer.step()
                total_loss += total_batch_loss.item()

                # Update progress
                pbar.update(1)
                pbar.set_postfix(loss=f"{total_loss / (i + batch_size):.4f}")

            if len(historical_embeddings) > 0:
                historical_embeddings_tensor = torch.stack(historical_embeddings)
                
                # Compute similarity for all historical embeddings
                similarity_scores = F.cosine_similarity(
                    historical_embeddings_tensor.unsqueeze(1),  # All embeddings (H x 1 x D)
                    historical_embeddings_tensor.unsqueeze(0),  # All embeddings (1 x H x D)
                    dim=2
                )  # Resulting tensor size: (H x H)
                
                # Mask self-similarity
                similarity_scores.fill_diagonal_(-float('inf'))
                
                # Flatten the similarity scores to find global top 50
                flat_scores = similarity_scores.view(-1)  # Flatten to 1D tensor
                top_50_values, top_50_indices = torch.topk(flat_scores, k=50, largest=True)
                
                # Convert flat indices back to 2D indices
                num_reports = similarity_scores.size(0)
                row_indices = top_50_indices // num_reports  # Source report indices
                col_indices = top_50_indices % num_reports   # Target report indices
                
                # Get the embeddings for the top 50 most similar reports
                top_n_embeddings = []
                top_n_similar_reports = []
                
                for score, col_idx in zip(top_50_values, col_indices):
                    top_n_embeddings.append(historical_embeddings_tensor[col_idx])
                    top_n_similar_reports.append((score.item(), col_idx.item()))
                
                # Stack the embeddings into a tensor
                top_n_embeddings_tensor = torch.stack(top_n_embeddings)  # Shape: [50, 512]

                # with open('top_n_similar_reports.pkl', 'wb') as pickle_file:
                #     pickle.dump(top_n_similar_reports, pickle_file)

            # Save model parameters and results per epoch
            torch.save({
                'sentence_encoder': sentence_encoder.state_dict(),
                'text_encoder': text_encoder.state_dict(),
                'projection_layer': projection_layer.state_dict(),
                'avg_projection': avg_projection.state_dict(),
                'itm_classifier': itm_classifier.state_dict(),
                'mha': mha.state_dict(),
                'ffn': ffn.state_dict(),
                # 'piror': piror.state_dict(),
                'decoder': decoder.state_dict(),
            }, os.path.join(output_dir, f"final_model_pafinal_filtered_train_datarameters_epoch_{epoch + 1}.pth"))
            torch.save(top_n_embeddings_tensor , "screened_historical_embedding.pt")
            print(f"\nEpoch {epoch + 1}/{num_epochs} completed. Average Loss: {total_loss / n_batches:.4f}")
            print(f"size of saved screened_historical_embedding: {top_n_embeddings_tensor .shape}")

    print("Training complete. Model parameters saved for each epoch.")

output_dir = "./datasets/iu_xray/output/"
filter_report_path = "./datasets/iu_xray/output/final_filtered_train_data.csv"
df = pd.read_csv(filter_report_path)
# Number of rows
num_rows = len(df)
print(f'Number of rows: {num_rows}')

# Number of columns
num_columns = len(df.columns)
print(f'Number of columns: {num_columns}')

df['findings'].isnull().sum()

# Training loop
train_with_decoder(final_embeddings, batch_size, num_epochs, output_dir, df)

torch.Size([2700, 3072])
Number of rows: 2700
Number of columns: 5


  torch.nn.utils.clip_grad_norm(itm_classifier.parameters(), max_norm=1.0)
Epoch 1/1:  31%|████▉           | 13/42 [07:13<19:42, 40.79s/batch, loss=0.0914]

### **Report Generator**

In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer
import heapq
def load_trained_model(checkpoint_path, device='cuda'):
    """
    Load the trained model components from a checkpoint
    """
    # Initialize model components
    sentence_encoder = SentenceEncoder().to(device)
    text_encoder = TextEncoder().to(device)
    projection_layer = nn.Linear(768, 512).to(device)
    avg_projection = nn.Linear(2048, 512).to(device)
    itm_classifier = nn.Sequential(
        nn.Linear(512 * 2, 1)
    ).to(device)
    mha = MultiHeadAttention(3072, 8).to(device)
    ffn = FeedForwardNetwork(3072, 4096).to(device)
    piror = Piror(3072).to(device)
    
    # Load the saved state dictionaries
    checkpoint = torch.load(checkpoint_path)
    sentence_encoder.load_state_dict(checkpoint['sentence_encoder'])
    text_encoder.load_state_dict(checkpoint['text_encoder'])
    projection_layer.load_state_dict(checkpoint['projection_layer'])
    avg_projection.load_state_dict(checkpoint['avg_projection'])
    itm_classifier.load_state_dict(checkpoint['itm_classifier'])
    mha.load_state_dict(checkpoint['mha'])
    ffn.load_state_dict(checkpoint['ffn'])
    # piror.load_state_dict(checkpoint['piror'])
    decoder.load_state_dict(checkpoint['decoder'])
    
    return {
        'sentence_encoder': sentence_encoder,
        'text_encoder': text_encoder,
        'projection_layer': projection_layer,
        'avg_projection': avg_projection,
        'itm_classifier': itm_classifier,
        'mha': mha,
        'ffn': ffn,
        # 'piror': piror,
        'decoder': decoder
    }
def beam_search_generate_report(models, image_embedding, historical_embeddings, tokenizer, device='cuda', beam_size=3, max_length=100):
    """
    Generate a medical report using beam search
    """
    with torch.no_grad():
        # Prepare image embedding
        image_embedding = image_embedding.unsqueeze(0).to(device)  # Add batch dimension
        
        # Process through sentence encoder and projections
        avg_embeddings = image_embedding[:, :2048]
        proj_avg_embeddings = models['avg_projection'](avg_embeddings)
        
        dictionary_embeddings = models['text_encoder'].encode_dictionary(medical_dict)
        dictionary_embeddings = dictionary_embeddings.to(torch.float32)  
                
        V_projected = text_projection_layer(dictionary_embeddings).to(device)
                
        # Multi-head attention and feed-forward processing
        aligned_output, _ = models['mha'](V_projected, image_embedding)
        aligned_output_ffn = models['ffn'](aligned_output)
        V_prime = aligned_output_ffn
        V_prime_reduced = V_prime.squeeze(1)
        V_prime_reduced = V_prime.max(dim=1).values
        # print(f"image embeddings {image_embedding.size()}")
        # print(f"V_prime_reduced,{V_prime_reduced.size()}")
        # print(f"historical_embeddings {historical_embeddings.size()}")
        # Prepare memory
        memory = torch.cat([image_embedding, V_prime_reduced, historical_embeddings], dim=1)
        
        # Initialize beam search
        start_token = tokenizer.cls_token_id
        end_token = tokenizer.sep_token_id
        
        # Beam search candidates: (score, sequence, last_token)
        initial_sequence = torch.tensor([[start_token]]).to(device)
        candidates = [(0.0, initial_sequence, start_token)]
        completed_sequences = []
        
        for step in range(max_length):
            next_candidates = []
            
            # Expand each candidate
            for score, sequence, prev_token in candidates:
                # print(prev_token)
                # print(end_token)
                # if prev_token.eq(end_token).any() or sequence.shape[1] >= max_length:
                #     heapq.heappush(completed_sequences, (score, sequence))
                #     continue
                # If prev_token is a tensor, use .eq() and .any()
                if isinstance(prev_token, torch.Tensor):
                    if prev_token.eq(end_token).any() or sequence.shape[1] >= max_length:
                        heapq.heappush(completed_sequences, (score, sequence))
                        continue
                # If prev_token is a single token (int), compare directly
                elif prev_token == end_token or sequence.shape[1] >= max_length:
                    heapq.heappush(completed_sequences, (score, sequence))
                    continue
                # Decoder step
                output = models['decoder'](sequence, memory)
                next_token_logits = output[:, -1:]
                
                # Get top K tokens
                topk_probs, topk_indices = torch.topk(
                    torch.softmax(next_token_logits, dim=-1), 
                    k=beam_size, 
                    dim=-1
                )
                
                for prob, token_id in zip(topk_probs[0], topk_indices[0]):
                    
                    token_id = token_id.squeeze(0)
                    # print(token_id.size())
                    # print(sequence.size())
                    # print(token_id.unsqueeze(1).size())
                    # token_id = token_id.view(1)
                    new_sequence = torch.cat([sequence, token_id.unsqueeze(0)], dim=1)
                    new_score = score - torch.log(prob) #.item()  # Negative log likelihood
                    
                    next_candidates.append((new_score, new_sequence, token_id))
            
            # Sort and select top beam_size candidates
            next_candidates.sort(key=lambda x: x[0])
            candidates = next_candidates[:beam_size]
            
            # Check if all candidates are end tokens
            # if all(candidate[2].item() == end_token for candidate in candidates):
            #     break
            if any((candidate[2] == end_token).any() for candidate in candidates):
                break
            # for candidate in candidates:
            #     print(candidate[2])
            #     if (candidate[2] == end_token):
            #         break
        # Select best sequence
        if completed_sequences:
            _, best_sequence = heapq.heappop(completed_sequences)
        else:
            _, best_sequence, _ = min(candidates, key=lambda x: x[0])
        
        # Decode the best sequence
        generated_ids = best_sequence[0].cpu().numpy().tolist()
        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        return generated_text

def test_model(test_image_embeddings, checkpoint_path, historical_embeddings, beam_size=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Load models
    models = load_trained_model(checkpoint_path, device)
    
    # Generate reports for all test images
    reports = []
    for embedding in test_image_embeddings:
        report = beam_search_generate_report(
            models,
            embedding,
            historical_embeddings,
            tokenizer,
            device,
            beam_size=beam_size
        )
        print(report)
        reports.append(report)
    
    return reports

# Example usage
checkpoint_path = "/kaggle/working/model_checkpoints/new_model_parameters_epoch_1.pth"
test_image_embeddings = torch.load("/kaggle/input/test-data/test_final_embeddings.pt")
historical_embeddings = torch.load("/kaggle/working/screened_historical_embedding.pt")
historical_flat = historical_embeddings.reshape(1, -1)

reports = test_model(
    test_image_embeddings,
    checkpoint_path,
    historical_flat,
    beam_size=3
)

### **Evaluation**