In [None]:
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

# RAG
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.chains import LLMChain, RetrievalQA
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOllama
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from adjustText import adjust_text

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import densenet121
import torch.nn as nn
import torchvision.models as models
import cv2
import numpy as np
from PIL import Image
import csv
import clip

## Classification Results

In [None]:
# Define the labels as used during training
labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
          'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 
          'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

class ResNet50(nn.Module):
    def __init__(self, out_size=14):
        super(ResNet50, self).__init__()
        # Use the latest ImageNet weights for ResNet50
        self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        num_ftrs = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()  # Assuming you're doing a binary classification, adjust as needed
        )

    def forward(self, x):
        return self.resnet50(x)
        
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet50().to(device)
model.load_state_dict(torch.load('best_save_models/lateral_best_model.pth'))
model.eval()

# CLAHE transform class
class CLAHETransform:
    def __init__(self, clip_limit=0.10, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    def __call__(self, img):
        if isinstance(img, Image.Image):
            img = np.array(img)
        if img.ndim == 3:
            lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_img)
            l_channel = self.clahe.apply(l_channel)
            lab_img = cv2.merge((l_channel, a_channel, b_channel))
            img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        else:
            img = self.clahe.apply(img)
        return Image.fromarray(img.astype('uint8'))

# Define the validation transform (same as during training)
transform = transforms.Compose([
    transforms.Resize(256),
    CLAHETransform(clip_limit=0.35, tile_grid_size=(8, 8)),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Function to predict an image
def predict_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Make prediction
    with torch.no_grad():
        output = model(image)
    
    # Convert predictions to numpy array and map to labels
    predictions = output.cpu().numpy().squeeze()  # Remove batch dimension
    pred_scores = {labels[i]: predictions[i] for i in range(len(predictions))}
    
    return pred_scores

# # Example usage
# image_path = 'text_reports/testing/IMAGES/p13952691_s54551451.jpg' # Real labels for this image {Consolidation, Pleural Effusion, Supporting Device}
# pred_scores = predict_image(image_path)
# # pred_scores
# # Print the scores for each label
# for label, score in pred_scores.items():
#     print(f'{label}: {score:.4f}')

def get_top_classifications(classifications: dict, top_n: int = 3, threshold: float = 0.50) -> dict:
    # Filter classifications that are above the threshold
    filtered_classifications = {k: v for k, v in classifications.items() if v > threshold}
    
    # Sort the filtered classifications by their score in descending order
    sorted_classifications = sorted(filtered_classifications.items(), key=lambda item: item[1], reverse=True)
    
    # Return the top N classifications
    top_classifications = {k: v for k, v in sorted_classifications[:top_n]}
    
    return top_classifications

## RAG 

In [None]:
# Set up the embedding model and vector database
oembed = OllamaEmbeddings(model="mxbai-embed-large")
vectordb = Chroma(
    persist_directory="text_reports/all_embed_db",
    embedding_function=oembed
)

retriever = vectordb.as_retriever()

ollama_llm = ChatOllama(base_url="http://localhost:11434", model="llava")

In [None]:
sample_report = """
EXAMINATION: CHEST (PA AND LAT)

INDICATION: ___F with new onset ascites

TECHNIQUE: Chest PA and lateral

COMPARISON: None.

FINDINGS: 
There is no focal consolidation, pleural effusion or pneumothorax. Bilateral nodular opacities that most likely represent nipple shadows. The cardiomediastinal silhouette is normal. Clips project over the left lung, potentially within the breast. The imaged upper abdomen is unremarkable. Chronic deformity of the posterior left sixth and seventh ribs are noted.

IMPRESSION: 
No acute cardiopulmonary process.
"""

# Define the report template
report_template = """
1. EXAMINATION: {examination}
2. INDICATION: {indication}
3. TECHNIQUE: {technique}
4. COMPARISON: {comparison}
5. FINDINGS: {findings}
6. IMPRESSION: {impression}
"""

# Set up the prompt template
prompt_template = ChatPromptTemplate.from_template("""
    You are an expert radiologist. Generate a concise and accurate radiology report based on the provided context and classification. Follow these strict guidelines:
    
    1. Use ONLY the given template structure. Do not add any sections or text outside this structure.
    2. Keep EXAMINATION, INDICATION, TECHNIQUE, and COMPARISON brief, as link sample report.
    3. FINDINGS should be detailed but focused, describing only relevant observations.
    4. IMPRESSION should summarize key findings and their clinical significance concisely.
    5. Do not include any introductory or concluding statements, notifications, or recommendations unless explicitly part of the findings or impression.
    6. Use appropriate medical terminology and maintain a professional tone.
    7. Base your report solely on the given context and classification.
    
    Classification: {classifications}
    Context: {context}
    Query: {query}

    Generate the report now, strictly following the template below:
    {report_template}

    Here's a sample report for reference:
    {sample_report}
""")

In [None]:
# Create LLMChain with the prompt template and Ollama LLM
llm_chain = LLMChain(prompt=prompt_template, llm=ollama_llm)

In [None]:
# Load the CLIP model and preprocessing function
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
# Function to pad embeddings to the required dimension (1024)
def embed_query_image_and_caption(query, image_path, image_caption, target_dim=1024):
    # Embed the query text using CLIP
    text = clip.tokenize([query]).to(device)
    with torch.no_grad():
        query_embedding = clip_model.encode_text(text).cpu().numpy()

    # Embed the image using CLIP
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_embedding = clip_model.encode_image(image).cpu().numpy()

    # Embed the image caption using Ollama
    caption_embedding = np.array(oembed.embed_query(image_caption))

    # Pad CLIP embeddings to match Ollama embedding size
    query_embedding_padded = pad_embedding(query_embedding, target_dim)
    image_embedding_padded = pad_embedding(image_embedding, target_dim)

    # Combine embeddings (you can experiment with different combination methods)
    combined_embedding = (query_embedding_padded + image_embedding_padded + caption_embedding.reshape(1, -1)) / 3

    return combined_embedding.flatten().tolist()  # Convert to list

In [None]:
# Update the pad_embedding function to handle 2D arrays
def pad_embedding(embedding, target_dim=1024):
    if embedding.shape[1] < target_dim:
        padding = np.zeros((embedding.shape[0], target_dim - embedding.shape[1]))
        return np.hstack((embedding, padding))
    elif embedding.shape[1] > target_dim:
        return embedding[:, :target_dim]
    return embedding

# Function to retrieve similar documents
def retrieve_similar_documents(combined_embedding, k=5):
    return vectordb.similarity_search_by_vector(combined_embedding, k=k)

# Function to calculate cosine similarity
def calculate_cosine_similarity(query_embedding, doc_embeddings):
    query_embedding = np.array(query_embedding).reshape(1, -1)
    doc_embeddings = np.array(doc_embeddings)
    cosine_similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
    return cosine_similarities

In [None]:
# Function to generate the medical report
def generate_medical_report(classifications, image_path, image_caption):
    # Construct a more detailed query
    query = (
        f"Examine the chest X-ray for the presence of {', '.join(classifications)}. "
        "Provide a comprehensive analysis of the observed radiographic features and their clinical implications."
    )

    # Combine query, image, and image caption embeddings
    combined_embedding = embed_query_image_and_caption(query, image_path, image_caption)
    
    # Retrieve similar documents using the combined embedding
    similar_docs = retrieve_similar_documents(combined_embedding)
    
    # Extract the content from similar documents
    context = "\n".join([doc.page_content for doc in similar_docs])
    
    # Show the retrieved documents
    print("Retrieved Documents:")
    for i, doc in enumerate(similar_docs, 1):
        print(f"Document {i}:\n{doc.page_content}\n")
    
    # Run the LLM chain with context, query, classifications, and image caption
    result = llm_chain.run({
        "context": context,
        "query": query,
        "classifications": ", ".join(classifications),
        "sample_report": sample_report,
        "image_caption": image_caption,
        "report_template": report_template
    })

    # Parse the output for clean presentation
    parsed_result = StrOutputParser().parse(result)
    
    print("Generated Medical Report:")
    print(parsed_result)
    
    # Visualize the query, image caption, and similarity of the retrieved documents
    # visualize_embeddings(query, image_caption, similar_docs, combined_embedding)
    
    return parsed_result

## LLAVA

In [None]:
def process_directory(image_dir, caption_csv, report_csv, report_dir):
    # Load the captions from maira-2-frontal-results.csv
    captions = {}
    with open(caption_csv, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            captions[row['Image']] = row['final_output']  # Image name as key, final_output as value

    # Prepare the output CSV
    with open(report_csv, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['image_name', 'classifications', 'generated_report', 'original_report'])

        # Process each image in the directory
        for image_name in os.listdir(image_dir):
            if image_name in captions:  # Ensure the image has a caption
                image_path = os.path.join(image_dir, image_name)
                image_caption = captions[image_name]

                # Get predictions and classifications
                pred_score = predict_image(image_path)
                for label, score in pred_score.items():
                    print(f'{label}: {score:.4f}')
                    
                classifications = get_top_classifications(pred_score)

                print(f"\nTop classification: {classifications}. Image Caption: {image_caption}\n")
                
                # Generate a medical report
                generated_report = generate_medical_report(classifications, image_path, image_caption)


                # Retrieve the original report
                report_filename = image_name.rsplit('.', 1)[0] + '.txt'
                original_report_path = os.path.join(report_dir, report_filename)
                original_report = ''
                try:
                    with open(original_report_path, 'r') as file:
                        original_report = file.read()
                except FileNotFoundError:
                    print(f"Warning: Original report not found for {report_filename}")
                
                # Save results to the CSV
                writer.writerow([
                    image_name,
                    classifications,
                    generated_report,
                    original_report
                ])

                print(f"\nProcessed: {image_name}")

    print(f"Processing complete. Results saved to {report_csv}")

# Paths
image_dir = 'path_to_directory'
caption_csv = 'path_to_directory'
report_dir = 'path_to_directory'
report_csv = 'path_to_directory'

# Process directory
process_directory(image_dir, caption_csv, report_csv, report_dir)