In [None]:
import os
import pandas as pd
from PIL import Image
import skimage.io
import torchxrayvision as xrv
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import densenet121
import torchvision.models as models
import torchvision.models as models
import cv2

from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
from langchain.load import dumps, loads
from operator import itemgetter
from langchain.embeddings import OllamaEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
from langchain import hub
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain_core.runnables import RunnablePassthrough
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from adjustText import adjust_text

### Classification Best model Inferencace

In [None]:
# Define the labels as used during training
labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
          'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 
          'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

class DenseNetModel(nn.Module):
    def __init__(self, out_size=14):
        super(DenseNetModel, self).__init__()
        self.densenet121 = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.densenet121(x)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNetModel().to(device)
model.load_state_dict(torch.load('path_to_directory'))
model.eval()

# CLAHE transform class
class CLAHETransform:
    def __init__(self, clip_limit=0.10, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    def __call__(self, img):
        if isinstance(img, Image.Image):
            img = np.array(img)
        if img.ndim == 3:
            lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_img)
            l_channel = self.clahe.apply(l_channel)
            lab_img = cv2.merge((l_channel, a_channel, b_channel))
            img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        else:
            img = self.clahe.apply(img)
        return Image.fromarray(img.astype('uint8'))

# Define the validation transform (same as during training)
transform = transforms.Compose([
    transforms.Resize(256),
    CLAHETransform(clip_limit=0.35, tile_grid_size=(8, 8)),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Function to predict an image
def predict_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Make prediction
    with torch.no_grad():
        output = model(image)
    
    # Convert predictions to numpy array and map to labels
    predictions = output.cpu().numpy().squeeze()  # Remove batch dimension
    pred_scores = {labels[i]: predictions[i] for i in range(len(predictions))}
    
    return pred_scores

# Example usage
# image_path = r"'path_to_directory'" # Real labels for this image {Consolidation, Pleural Effusion, Supporting Device}
# pred_scores = predict_image(image_path)
# # pred_scores
# # Print the scores for each label
# for label, score in pred_scores.items():
#     print(f'{label}: {score:.4f}')

In [None]:
def get_top_classifications(classifications: dict, top_n: int = 3, threshold: float = 0.50) -> dict:
    # Filter classifications that are above the threshold
    filtered_classifications = {k: v for k, v in classifications.items() if v > threshold}
    
    # Sort the filtered classifications by their score in descending order
    sorted_classifications = sorted(filtered_classifications.items(), key=lambda item: item[1], reverse=True)
    
    # Return the top N classifications
    top_classifications = {k: v for k, v in sorted_classifications[:top_n]}
    
    return top_classifications

In [None]:
sample_report = """
EXAMINATION: CHEST (PA AND LAT)

INDICATION: ___ with new onset ascites

TECHNIQUE: Chest PA and lateral

COMPARISON: None.

FINDINGS: 
There is no focal consolidation, pleural effusion or pneumothorax. Bilateral nodular opacities that most likely represent nipple shadows. The cardiomediastinal silhouette is normal. Clips project over the left lung, potentially within the breast. The imaged upper abdomen is unremarkable. Chronic deformity of the posterior left sixth and seventh ribs are noted.

IMPRESSION: 
No acute cardiopulmonary process.
"""

report_template = """
1. EXAMINATION: {examination}
2. INDICATION: {indication}
3. TECHNIQUE: {technique}
4. COMPARISON: {comparison}
5. FINDINGS: {findings}
6. IMPRESSION: {impression}
"""

# def generate_prompt(classification_list):
#     classification = ", ".join([f"{k}: {v:.2f}" for k, v in classification_list.items()])
#     return f"""
#     You are an expert radiologist. Based on the following context, query, and sample report, generate a detailed radiology report.

#     classification: {classification}
    
#     Here's a sample report for reference:
#     {sample_report}

#     Please structure your report using the following template, maintaining a similar level of detail and professional tone as the sample report:
#     {report_template}
    
#     Important guidelines:
#     1. Ensure that each section of the report is filled with relevant and detailed information.
#     2. Use clear, concise medical terminology appropriate for radiology reports.
#     3. In the FINDINGS section, describe observations systematically, from most to least significant.
#     4. In the IMPRESSION section, summarize the key findings and their clinical significance.
#     5. Maintain a professional and objective tone throughout the report.
#     6. Adapt the level of detail to match the complexity of the examination and findings.
#     """

In [None]:
def reciprocal_rank_fusion(results, k=5):
    fused_scores = {}
    for docs in results:
        for rank, doc in enumerate(docs):
            doc_str = dumps(doc)
            fused_scores[doc_str] = fused_scores.get(doc_str, 0) + 1 / (rank + k)
    return [(loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)]

In [None]:
def calculate_cosine_similarity(query_embedding, doc_embeddings):
    query_embedding = np.array(query_embedding).reshape(1, -1)
    doc_embeddings = np.array(doc_embeddings)
    cosine_similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
    return cosine_similarities

def plot_tsne(query_embedding, doc_embeddings, cosine_similarities):
    # Combine query and document embeddings for t-SNE
    all_embeddings = np.vstack([query_embedding, doc_embeddings])
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, len(all_embeddings) - 1))
    tsne_results = tsne.fit_transform(all_embeddings)

    # Separate out query and document t-SNE results
    query_tsne = tsne_results[0]
    doc_tsne = tsne_results[1:]

    plt.figure(figsize=(12, 8))
    
    # Scatter plot for documents with color-coded cosine similarities
    scatter = plt.scatter(doc_tsne[:, 0], doc_tsne[:, 1], marker='o', c=cosine_similarities, cmap="coolwarm", edgecolor="k", s=100)
    plt.colorbar(scatter, label="Cosine Similarity")

    # Plot the query as a distinct marker
    plt.scatter(query_tsne[0], query_tsne[1], marker='*', color='green', s=300, label='Query', edgecolor="k")

    # Create annotations with adjust_text
    texts = []
    for i, (x, y) in enumerate(doc_tsne):
        texts.append(plt.text(x, y, f"Sim: {cosine_similarities[i]:.2f}", 
                              ha='center', va='center', fontsize=8,
                              bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.7, ec='none')))

    # Annotate the query separately
    query_text = plt.text(query_tsne[0], query_tsne[1], "Query", 
                          ha='center', va='center', fontsize=10, fontweight='bold',
                          bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.7, ec='none'))

    texts.append(query_text)

    # Adjust text to avoid overlaps
    adjust_text(texts, arrowprops=dict(arrowstyle='->', color='red', lw=0.5),
                expand_points=(1.2, 1.2), force_points=(0.1, 0.1))

    # Set plot labels and title
    plt.title('t-SNE Visualization of Query and Retrieved Document Embeddings')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.tight_layout()
    plt.autoscale()
    plt.margins(0.1)
    plt.show()

In [None]:
oembed = OllamaEmbeddings(model="mxbai-embed-large")

vectordb = Chroma(
    persist_directory = "text_reports/all_embed_db",
    embedding_function = oembed
)

retriever = vectordb.as_retriever()

In [None]:
Ollamallm = ChatOllama(base_url="http://localhost:11434", model="gemma2:2b")

In [None]:
# Multi Query: Different Perspectives
query_template = """You are an AI language model assistant. Your task is to generate five 
different versions of the given user question to retrieve relevant documents from a vector 
database. By generating multiple perspectives on the user question, your goal is to help
the user overcome some of the limitations of the distance-based similarity search. 
Provide these alternative questions separated by newlines. Original question: {question}"""
prompt_rag_fusion = ChatPromptTemplate.from_template(query_template)

generate_queries = (
    prompt_rag_fusion  
    | Ollamallm
    | StrOutputParser() 
    | (lambda x: x.split("\n"))
)

In [None]:
def generate_report(question, classifications):
    classification = ", ".join([f"{k}: {v:.2f}" for k, v in classifications.items()])
    
    retrieval_chain = generate_queries | retriever.map() | reciprocal_rank_fusion
    retrieved_docs = retrieval_chain.invoke({"question": question})

    if not retrieved_docs:
        print("No documents retrieved. Please check the retrieval process.")
        return None, None

    # Get embeddings for the retrieved documents
    doc_contents = [doc[0].page_content for doc in retrieved_docs]
    doc_embeddings = oembed.embed_documents(doc_contents)

    if not doc_embeddings:
        print("Failed to generate document embeddings.")
        return None, None

    # Embed the query
    query_embedding = oembed.embed_query(question)

    if query_embedding is None:
        print("Failed to generate query embedding.")
        return None, None

    # Calculate cosine similarities
    cosine_similarities = calculate_cosine_similarity(query_embedding, doc_embeddings)

    if cosine_similarities is None or len(cosine_similarities) == 0:
        print("Failed to calculate cosine similarities.")
        return None, None

    # Plot t-SNE
    try:
        plot_tsne(query_embedding, doc_embeddings, cosine_similarities)
    except Exception as e:
        print(f"Error in plotting t-SNE: {e}")

    # Print cosine similarities
    print("Cosine Similarities between Query and Retrieved Documents:")
    for i, sim in enumerate(cosine_similarities):
        print(f"Document {i+1}: Cosine Similarity = {sim:.4f}")

    prompt = ChatPromptTemplate.from_template("""
    You are an expert radiologist. Generate a concise and accurate radiology report based on the provided context and classification. Follow these strict guidelines:
    
    1. Use ONLY the given template structure. Do not add any sections or text outside this structure.
    2. Keep EXAMINATION, INDICATION, TECHNIQUE, and COMPARISON brief, as link sample report.
    3. FINDINGS should be detailed but focused, describing only relevant observations.
    4. IMPRESSION should summarize key findings and their clinical significance concisely.
    5. Do not include any introductory or concluding statements, notifications, or recommendations unless explicitly part of the findings or impression.
    6. Use appropriate medical terminology and maintain a professional tone.
    7. Base your report solely on the given context and classification.
    
    Classification: {classification}
    
    Context: {context}
    
    Generate the report now, strictly following the template below:
    {report_template}

    Here's a sample report for reference:
    {sample_report}
    """)
    
    final_rag_chain = (
        {
            "context": itemgetter("context"),
            "question": itemgetter("question"),
            "classification": lambda x: classification,
            "sample_report": lambda x: sample_report,
            "report_template": lambda x: report_template
        }
        | prompt
        | Ollamallm
        | StrOutputParser()
    )

    generated_report = final_rag_chain.invoke({"context": retrieved_docs, "question": question})
    
    return generated_report, cosine_similarities

In [None]:
def process_all_images_in_directory(image_dir, report_dir, output_csv=None):
    results = []
    similarities_data = []

    for image_filename in os.listdir(image_dir):
        image_path = os.path.join(image_dir, image_filename)
        
        if not image_filename.lower().endswith(('.jpg', '.png')):
            continue
        
        classifications = predict_image(image_path)
        top_classifications = get_top_classifications(classifications)

        question = f"Retrieve radiology reports that diagnose or mention {', '.join([f'{k}: {v:.2f}' for k, v in top_classifications.items()])}. Based on the provided chest X-ray images, detail the relevant radiographic findings and their clinical significance."

        generated_report, cosine_similarities = generate_report(question, top_classifications)
        
        if generated_report is None or cosine_similarities is None:
            print(f"Skipping image {image_filename} due to error in report generation.")
            continue
        
        # Save cosine similarities for each document retrieved
        for i, sim in enumerate(cosine_similarities):
            similarities_data.append({
                'image_name': image_filename,
                'document_number': i + 1,
                'cosine_similarity': sim
            })

        # Find and read the original report (if available)
        report_filename = image_filename.rsplit('.', 1)[0] + '.txt'
        original_report_path = os.path.join(report_dir, report_filename)
        
        original_report = ''
        try:
            with open(original_report_path, 'r') as file:
                original_report = file.read()
        except FileNotFoundError:
            print(f"Warning: Original report not found for {report_filename}")

        # Append results to the list
        results.append({
            'image_name': image_filename,
            'Classification': top_classifications,
            'generated_report': generated_report,
            'original_report': original_report
        })

    # Convert results to a DataFrame and save as CSV
    results_df = pd.DataFrame(results)
    similarities_df = pd.DataFrame(similarities_data)

    if output_csv:
        results_csv_path = output_csv if output_csv.endswith('.csv') else output_csv + 'ggegemgemagema2_reports.csv'
        similarities_csv_path = output_csv if output_csv.endswith('.csv') else output_csv + 'gemma2_similarities.csv'
        
        results_df.to_csv(results_csv_path, index=False)
        similarities_df.to_csv(similarities_csv_path, index=False)
        print(f"Reports saved to CSV file: {results_csv_path}")
        print(f"Similarities saved to CSV file: {similarities_csv_path}")
    else:
        print("Output CSV path is not provided.")
    
    return results_df, similarities_df

# Example usage
image_directory ='path_to_directory'
report_directory = 'path_to_directory'
output_csv_base = 'path_to_directory'

process_all_images_in_directory(image_directory, report_directory, output_csv_base)