In [None]:
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

# RAG
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.chains import LLMChain, RetrievalQA
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOllama
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from adjustText import adjust_text

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import densenet121
import torch.nn as nn
import torchvision.models as models
import cv2
import numpy as np
from PIL import Image
import csv

## Classification Results

In [None]:
# Define the labels as used during training
labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
          'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 
          'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

class DenseNet121(nn.Module):
    def __init__(self):
        super(DenseNet121, self).__init__()
        self.densenet121 = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Identity()  # No classifier yet, only features

    def forward(self, x):
        return self.densenet121(x)

class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        num_ftrs = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Identity()  # No classifier yet, only features

    def forward(self, x):
        return self.resnet50(x)

class CombinedModel(nn.Module):
    def __init__(self, out_size=14):
        super(CombinedModel, self).__init__()
        # Instantiate DenseNet121 and ResNet50 for frontal and lateral views
        self.densenet_frontal = DenseNet121()
        self.resnet_frontal = ResNet50()
        
        self.densenet_lateral = DenseNet121()
        self.resnet_lateral = ResNet50()
        
        # The combined feature size
        frontal_feature_size = 1024 + 2048  # Assuming DenseNet121 outputs 1024 and ResNet50 outputs 2048
        lateral_feature_size = 1024 + 2048
        
        combined_feature_size = frontal_feature_size + lateral_feature_size
        
        # Final classifier layer
        self.classifier = nn.Sequential(
            nn.Linear(combined_feature_size, out_size),
            nn.Sigmoid()  # Assuming binary classification for multi-label
        )
        
    def forward(self, x_frontal, x_lateral):
        # Extract features from DenseNet and ResNet for both views
        frontal_features = torch.cat([self.densenet_frontal(x_frontal), self.resnet_frontal(x_frontal)], dim=1)
        lateral_features = torch.cat([self.densenet_lateral(x_lateral), self.resnet_lateral(x_lateral)], dim=1)
        
        # Combine frontal and lateral features
        combined_features = torch.cat([frontal_features, lateral_features], dim=1)
        
        # Final output through classifier
        out = self.classifier(combined_features)
        
        return out
        
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CombinedModel().to(device)
model.load_state_dict(torch.load('path_to_directory'))
model.eval()

# CLAHE transform class
class CLAHETransform:
    def __init__(self, clip_limit=0.10, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    def __call__(self, img):
        if isinstance(img, Image.Image):
            img = np.array(img)
        if img.ndim == 3:
            lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_img)
            l_channel = self.clahe.apply(l_channel)
            lab_img = cv2.merge((l_channel, a_channel, b_channel))
            img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        else:
            img = self.clahe.apply(img)
        return Image.fromarray(img.astype('uint8'))

# Define the validation transform (same as during training)
transform = transforms.Compose([
    transforms.Resize(256),
    CLAHETransform(clip_limit=0.35, tile_grid_size=(8, 8)),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Function to predict an image
def predict_image(frontal_image_path, lateral_image_path):
    # Load and preprocess the frontal image
    frontal_image = Image.open(frontal_image_path).convert('RGB')
    frontal_image = transform(frontal_image)
    frontal_image = frontal_image.unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Load and preprocess the lateral image
    lateral_image = Image.open(lateral_image_path).convert('RGB')
    lateral_image = transform(lateral_image)
    lateral_image = lateral_image.unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Make prediction
    with torch.no_grad():
        output = model(frontal_image, lateral_image)
    
    # Convert predictions to numpy array and map to labels
    predictions = output.cpu().numpy().squeeze()  # Remove batch dimension
    pred_scores = {labels[i]: predictions[i] for i in range(len(predictions))}
    
    return pred_scores


# Example usage
# frontal_image_path = 'text_reports/testing/IMAGES/p13952691_s54551451.jpg'
# lateral_image_path = 'text_reports/testing/IMAGES/p13952691_s54551451.jpg'
# pred_scores = predict_image(frontal_image_path, lateral_image_path)

# pred_scores
# Print the scores for each label
# for label, score in pred_scores.items():
#     print(f'{label}: {score:.4f}')

def get_top_classifications(classifications: dict, top_n: int = 3, threshold: float = 0.50) -> dict:
    # Filter classifications that are above the threshold
    filtered_classifications = {k: v for k, v in classifications.items() if v > threshold}
    
    # Sort the filtered classifications by their score in descending order
    sorted_classifications = sorted(filtered_classifications.items(), key=lambda item: item[1], reverse=True)
    
    # Return the top N classifications
    top_classifications = {k: v for k, v in sorted_classifications[:top_n]}
    
    return top_classifications

## RAG

In [None]:
# Set up the embedding model and vector database
oembed = OllamaEmbeddings(model="mxbai-embed-large")
vectordb = Chroma(
    persist_directory="text_reports/all_embed_db",
    embedding_function=oembed
)

retriever = vectordb.as_retriever()

ollama_llm = ChatOllama(base_url="http://localhost:11434", model="llama3.2")

In [None]:
sample_report = """
EXAMINATION: CHEST (PA AND LAT)

INDICATION: ___F with new onset ascites

TECHNIQUE: Chest PA and lateral

COMPARISON: None.

FINDINGS: 
There is no focal consolidation, pleural effusion or pneumothorax. Bilateral nodular opacities that most likely represent nipple shadows. The cardiomediastinal silhouette is normal. Clips project over the left lung, potentially within the breast. The imaged upper abdomen is unremarkable. Chronic deformity of the posterior left sixth and seventh ribs are noted.

IMPRESSION: 
No acute cardiopulmonary process.
"""

# Define the report template
report_template = """
1. EXAMINATION: {examination}
2. INDICATION: {indication}
3. TECHNIQUE: {technique}
4. COMPARISON: {comparison}
5. FINDINGS: {findings}
6. IMPRESSION: {impression}
"""

# Set up the prompt template
prompt_template = ChatPromptTemplate.from_template("""
    You are an expert radiologist. Generate a concise and accurate radiology report based on the provided context and classification. Follow these strict guidelines:
    
    1. Use ONLY the given template structure. Do not add any sections or text outside this structure.
    2. Keep EXAMINATION, INDICATION, TECHNIQUE, and COMPARISON brief, as link sample report.
    3. FINDINGS should be detailed but focused, describing only relevant observations.
    4. IMPRESSION should summarize key findings and their clinical significance concisely.
    5. Do not include any introductory or concluding statements, notifications, or recommendations unless explicitly part of the findings or impression.
    6. Use appropriate medical terminology and maintain a professional tone.
    7. Base your report solely on the given context and classification.
    
    Classification: {classifications}
    Context: {context}
    Query: {query}

    Generate the report now, strictly following the template below:
    {report_template}

    Here's a sample report for reference:
    {sample_report}
""")

In [None]:
# Create LLMChain with the prompt template and Ollama LLM
llm_chain = LLMChain(llm=ollama_llm, prompt=prompt_template)

In [None]:
# Function to embed the query
def embed_query_and_caption(query, image_caption):
    query_embedding = oembed.embed_query(query)
    caption_embedding = oembed.embed_query(image_caption)
    # Combine embeddings (you can experiment with different combination methods)
    combined_embedding = (np.array(query_embedding) + np.array(caption_embedding)) / 2
    return combined_embedding.tolist()
    
# Function to retrieve similar documents
def retrieve_similar_documents(combined_embedding, k=5):
    return vectordb.similarity_search_by_vector(combined_embedding, k=k)

# Function to calculate cosine similarity
def calculate_cosine_similarity(vec1, vec2):
    # Compute the cosine similarity between two vectors
    dot_product = np.dot(vec1, vec2.T)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    return dot_product / (norm_vec1 * norm_vec2)

In [None]:
# Modify the generate_medical_report function
def generate_medical_report(classifications, image_caption):
    # Construct a more detailed query

    query = (
        f"Examine the chest X-ray for the presence of {', '.join(classifications)}. "
        "Provide a comprehensive analysis of the observed radiographic features and their clinical implications."
    )
    
    # Combine query and image caption embeddings
    combined_embedding = embed_query_and_caption(query, image_caption)
    
    # Retrieve similar documents using the combined embedding
    similar_docs = retrieve_similar_documents(combined_embedding)
    
    # Extract the content from similar documents
    context = "\n".join([doc.page_content for doc in similar_docs])
    
    # Show the retrieved documents
    print("Retrieved Documents:")
    for i, doc in enumerate(similar_docs, 1):
        print(f"Document {i}:\n{doc.page_content}\n")
    
    # Run the LLM chain with context, query, classifications, and image caption
    result = llm_chain.run({
        "context": context,
        "query": query,
        "classifications": ", ".join(classifications),
        "sample_report": sample_report,
        "image_caption": image_caption,
        "report_template": report_template
    })
    
    # Parse the output for clean presentation
    parsed_result = StrOutputParser().parse(result)
    
    print("Generated Medical Report:")
    print(parsed_result)

    return parsed_result
    
    # Visualize the query, image caption, and similarity of the retrieved documents
    # visualize_embeddings(query, image_caption, similar_docs, oembed)

## Llama3.2

In [None]:
def process_directory(frontal_image_dir, lateral_image_dir, caption_csv, report_csv, report_dir):
    # Load the captions from maira-2-frontal-results.csv
    captions = {}
    with open(caption_csv, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            captions[row['frontal_image']] = row['final_output']  # Image name as key, final_output as value

    # Prepare the output CSV

    with open(report_csv, 'w', newline='', encoding='utf-8') as csvfile:
       
        writer = csv.writer(csvfile)
        writer.writerow(['Image', 'Top Classifications', 'Generated Report', 'Original Report'])

        # Process each image in the directory
        for image_name in os.listdir(frontal_image_dir):
           
            if image_name in captions:  # Ensure the image has a caption
                print("herhe")
                frontal_image_path = os.path.join(frontal_image_dir, image_name)
                lateral_image_path = os.path.join(lateral_image_dir, image_name)
                
                image_caption = captions[image_name]

                # Get predictions and classifications
                pred_score = predict_image(frontal_image_path, lateral_image_path)
      
                for label, score in pred_score.items():
                    print(f'{label}: {score:.4f}')
                    
                classifications = get_top_classifications(pred_score)

                print(f"\nTop classification: {classifications}. Image Caption: {image_caption}\n")
                
                # Generate a medical report
                generated_report = generate_medical_report(classifications, image_caption)


                # Retrieve the original report
                report_filename = image_name.rsplit('.', 1)[0] + '.txt'
                original_report_path = os.path.join(report_dir, report_filename)
                original_report = ''
                try:
                    with open(original_report_path, 'r') as file:
                        original_report = file.read()
                except FileNotFoundError:
                    print(f"Warning: Original report not found for {report_filename}")
                
                # Save results to the CSV
                writer.writerow([
                    image_name,
                    ', '.join(classifications.keys()),
                    generated_report,
                    original_report
                ])

                print(f"\nProcessed: {image_name}")

    print(f"Processing complete. Results saved to {report_csv}")

# Paths
frontal_image_dir = 'path_to_directory'
lateral_image_dir = 'path_to_directory'
caption_csv = 'path_to_directory'
report_dir = 'path_to_directory'
report_csv = 'path_to_directory'

# Process directory
process_directory(frontal_image_dir, lateral_image_dir, caption_csv, report_csv, report_dir)


# # Example usage
# classifications = ["Atelectasis", "Cardiomegaly", "Pneumonia"]

# image_caption = "The patient is status post median sternotomy and cabg. The heart is moderately enlarged. The aorta is tortuous. There is mild pulmonary vascular congestion. Small bilateral pleural effusions are noted. Streaky opacities in the lung bases likely reflect atelectasis. No pneumothorax is identified. There are no acute osseous abnormalities."
# generate_medical_report(classifications, image_caption)