In [None]:
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOllama
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.chains import LLMChain

from langchain.retrievers.document_compressors import EmbeddingsFilter

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import densenet121
import torch.nn as nn
import torchvision.models as models
import cv2
import numpy as np
from PIL import Image
import csv

## Classification Best model Inferencace

In [None]:
# Define the labels as used during training
labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
          'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 
          'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

class ResNet50(nn.Module):
    def __init__(self, out_size=14):
        super(ResNet50, self).__init__()
        # Use the latest ImageNet weights for ResNet50
        self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        num_ftrs = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()  # Assuming you're doing a binary classification, adjust as needed
        )

    def forward(self, x):
        return self.resnet50(x)
        
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet50().to(device)
model.load_state_dict(torch.load('../best_save_models/lateral_best_model.pth'))
model.eval()

# CLAHE transform class
class CLAHETransform:
    def __init__(self, clip_limit=0.10, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    def __call__(self, img):
        if isinstance(img, Image.Image):
            img = np.array(img)
        if img.ndim == 3:
            lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_img)
            l_channel = self.clahe.apply(l_channel)
            lab_img = cv2.merge((l_channel, a_channel, b_channel))
            img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        else:
            img = self.clahe.apply(img)
        return Image.fromarray(img.astype('uint8'))

# Define the validation transform (same as during training)
transform = transforms.Compose([
    transforms.Resize(256),
    CLAHETransform(clip_limit=0.35, tile_grid_size=(8, 8)),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Function to predict an image
def predict_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Make prediction
    with torch.no_grad():
        output = model(image)
    
    # Convert predictions to numpy array and map to labels
    predictions = output.cpu().numpy().squeeze()  # Remove batch dimension
    pred_scores = {labels[i]: predictions[i] for i in range(len(predictions))}
    
    return pred_scores

# # Example usage
# image_path = r"'path_to_directory'" # Real labels for this image {Consolidation, Pleural Effusion, Supporting Device}
# pred_scores = predict_image(image_path)
# # pred_scores
# # Print the scores for each label
# for label, score in pred_scores.items():
#     print(f'{label}: {score:.4f}')

def get_top_classifications(classifications: dict, top_n: int = 3, threshold: float = 0.50) -> dict:
    # Filter classifications that are above the threshold
    filtered_classifications = {k: v for k, v in classifications.items() if v > threshold}
    
    # Sort the filtered classifications by their score in descending order
    sorted_classifications = sorted(filtered_classifications.items(), key=lambda item: item[1], reverse=True)
    
    # Return the top N classifications
    top_classifications = {k: v for k, v in sorted_classifications[:top_n]}
    
    return top_classifications

## RAG

In [None]:
# Set up the embedding model and vector database
oembed = OllamaEmbeddings(model="mxbai-embed-large")
vectordb = Chroma(
    persist_directory="text_reports/all_embed_db",
    embedding_function=oembed
)

retriever = vectordb.as_retriever()

In [None]:
# Set up the Ollama LLM
ollama_llm = ChatOllama(base_url="http://localhost:11434", model="llama3.2")

In [None]:
sample_report = """
EXAMINATION: CHEST (PA AND LAT)

INDICATION: ___ with new onset ascites

TECHNIQUE: Chest PA and lateral

COMPARISON: None.

FINDINGS: 
There is no focal consolidation, pleural effusion or pneumothorax. Bilateral nodular opacities that most likely represent nipple shadows. The cardiomediastinal silhouette is normal. Clips project over the left lung, potentially within the breast. The imaged upper abdomen is unremarkable. Chronic deformity of the posterior left sixth and seventh ribs are noted.

IMPRESSION: 
No acute cardiopulmonary process.
"""

report_template = """
1. EXAMINATION: {examination}
2. INDICATION: {indication}
3. TECHNIQUE: {technique}
4. COMPARISON: {comparison}
5. FINDINGS: {findings}
6. IMPRESSION: {impression}
"""

prompt_template = PromptTemplate(
    input_variables=["classification", "context", "report_template", "sample_report"],
    template="""
    You are an expert radiologist. Generate a concise and accurate radiology report based on the provided context and classification. Follow these strict guidelines:

    1. Use ONLY the given template structure. Do not add any sections or text outside this structure.
    2. Keep EXAMINATION, INDICATION, TECHNIQUE, and COMPARISON brief, as link sample report.
    3. FINDINGS should be detailed but focused, describing only relevant observations.
    4. IMPRESSION should summarize key findings and their clinical significance concisely.
    5. Do not include any introductory or concluding statements, notifications, or recommendations unless explicitly part of the findings or impression.
    6. Use appropriate medical terminology and maintain a professional tone.
    7. Base your report solely on the given context and classification.
    
    Classification: {classification}
    
    Context: {context}
    
    Generate the report now, strictly following the template below:
    {report_template}

    Here's a sample report for reference:
    {sample_report}
    """
)

In [None]:
def pretty_print_docs(docs):
    print(
        f"\n{'-' * 100}\n".join(
            [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
        )
    )

# embeddings_filter = EmbeddingsFilter(embeddings=oembed, similarity_threshold=0.60)
# compression_retriever = ContextualCompressionRetriever(
#             base_compressor=embeddings_filter, base_retriever=retriever
#         )
        
# query = f"""Retrieve relevant information for a chest radiograph report focusing on Support Devices, Cardiomegaly. 
#     Include details about:
#     1. Radiological appearances of Support Devices, Cardiomegaly
#     2. Associated findings and complications
#     3. Typical locations and distributions
#     4. Differentiation between Support Devices, Cardiomegaly
#     5. Severity indicators and extent assessment
#     6. Any other relevant chest radiograph findings that might co-occur
    
#     Provide specific medical terminology and descriptions used in radiology reports for these conditions."""
    
# compressed_docs = compression_retriever.invoke(query)
# retrieved_docs = pretty_print_docs(compressed_docs)

In [None]:
def rag_compression_retrieval(vectordb, llm, classifications):
    # Extract only classification names
    classification_names = ", ".join(classifications.keys())
    
    print(f"\nClassifications name: {classification_names}")

    embeddings_filter = EmbeddingsFilter(embeddings=oembed, similarity_threshold=0.50)
    compression_retriever = ContextualCompressionRetriever(
            base_compressor=embeddings_filter, base_retriever=retriever
        )
        
    query = f"""Retrieve relevant information for a chest radiograph report focusing on {classification_names}. 
    Include details about:
    1. Radiological appearances of {classification_names}
    2. Associated findings and complications
    3. Typical locations and distributions
    4. Differentiation between {classification_names}
    5. Severity indicators and extent assessment
    6. Any other relevant chest radiograph findings that might co-occur
    
    Provide specific medical terminology and descriptions used in radiology reports for these conditions."""
    
    compressed_docs = compression_retriever.invoke(query)
    retrieved_docs = pretty_print_docs(compressed_docs)

    # Combine the retrieved documents into a single context string
    context = "\n".join([doc.page_content for doc in compressed_docs])

    print(f"context: {context}")
    # Create an LLMChain
    llm_chain = LLMChain(llm=llm, prompt=prompt_template)
    
    # Generate the report
    result = llm_chain.run(
        classification=classification_names,
        context=context,
        report_template=report_template,
        sample_report=sample_report
    )

    print(f"\n Final Report \n {result}")
    return result, retrieved_docs

## llama3-2

In [None]:
def process_image_directory(image_dir, report_dir, vectordb, llm, output_csv_path):
    with open(output_csv_path, mode='w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['image_name', 'Classifications', 'generated_report', 'original_report'])

        # Process each image
        for image_filename in os.listdir(image_dir):
            # Ensure the file is an image
            if not image_filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                continue

            image_path = os.path.join(image_dir, image_filename)
            
            # Predict the classifications
            pred_scores = predict_image(image_path)
            for label, score in pred_scores.items():
                print(f'{label}: {score:.4f}')
                
            classifications = get_top_classifications(pred_scores)

            # Retrieve the original report
            report_filename = image_filename.rsplit('.', 1)[0] + '.txt'
            original_report_path = os.path.join(report_dir, report_filename)
            original_report = ''
            try:
                with open(original_report_path, 'r') as file:
                    original_report = file.read()
            except FileNotFoundError:
                print(f"Warning: Original report not found for {report_filename}")

            # Generate the RAG-based report
            generated_report, _ = rag_compression_retrieval(vectordb, llm, classifications)

            # Save the results to the CSV file
            writer.writerow([
                image_filename,
                ', '.join(classifications.keys()),
                generated_report,
                original_report
            ])

            print(f"Processed: {image_filename}")

# Example usage
image_dir = 'path_to_directory'
report_dir = 'path_to_directory'
output_csv_path = 'path_to_directory'

process_image_directory(image_dir, report_dir, vectordb, ollama_llm, output_csv_path)


# Example usage
# image_path = r'D:\CODES\classification_and_llm\testing\frontal_classification\images\p10773491_s54418703.jpg'
# pred_scores = predict_image(image_path)
# # pred_scores
# # Print the scores for each label
# for label, score in pred_scores.items():
#     print(f'{label}: {score:.4f}')

# classifications = get_top_classifications(pred_scores)
# print(f"\nTop classifications: {classifications}")

# final_report, retrieved_docs = rag_compression_retrieval(vectordb, ollama_llm, classifications)

# print("Final Report:")
# print(final_report)