In [None]:
# /home/sysadm/Music/MedXpert/app.py


import streamlit as st
import os
import sys

# Configure the page
st.set_page_config(
    page_title="MedXpert",
    page_icon="🧠",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Create required directories if they don't exist
os.makedirs("data/reports", exist_ok=True)
os.makedirs("data/processed/texts", exist_ok=True)
os.makedirs("models/clip/fine_tuned", exist_ok=True)
os.makedirs("data/embeddings", exist_ok=True)

# Add current directory to path to ensure imports work
sys.path.insert(0, os.path.abspath("."))

# Define pages
pages = {
    "Home": "home.py",
    "Direct Diagnosis": "diagnosis.py",
    "Compare X-rays": "compare.py",
    "Visual Search": "search.py",
    "Reports": "reports.py"
}

# Define sidebar navigation
st.sidebar.title("🧠 MedXpert")
st.sidebar.caption("Medical Visual Question Answering & Diagnosis Assistant")

# Page selection
selected_page = st.sidebar.radio("Navigation", list(pages.keys()))

# Import and run the selected page
try:
    with st.spinner(f"Loading {selected_page}..."):
        page_path = f"src/ui/pages/{pages[selected_page]}"
        
        # Check if the page exists
        if os.path.exists(page_path):
            # Try to import as a module first
            try:
                # Import the module using importlib for more flexibility
                import importlib.util
                
                # Load the module specification
                spec = importlib.util.spec_from_file_location(
                    f"pages.{pages[selected_page][:-3]}", 
                    page_path
                )
                
                # Create the module
                module = importlib.util.module_from_spec(spec)
                
                # Execute the module
                spec.loader.exec_module(module)
                
                # Call show function if it exists
                if hasattr(module, 'show'):
                    module.show()
                    
            except Exception as module_error:
                # Fall back to direct execution
                with open(page_path) as f:
                    code = compile(f.read(), page_path, 'exec')
                    exec(code, globals())
        else:
            st.error(f"Page file not found: {page_path}")
except Exception as e:
    st.error(f"Error loading page: {str(e)}")
    import traceback
    st.exception(traceback.format_exc())

# Add footer
st.sidebar.divider()
st.sidebar.caption("© 2025 MedXpert - All Rights Reserved")



# MedXpert/src/ui/pages/compare.py



import streamlit as st
import os
import tempfile
import json
from datetime import datetime

from src.pipeline.blip_captioning import generate_blip_captions
from src.pipeline.llm_report_generation import generate_report
from src.llm_providers import llm_fn

# Create necessary directories
os.makedirs("data/reports", exist_ok=True)

st.header("🔍 Compare Two X-ray Images")

# Function to save report
def save_report(report_data):
    # Load existing reports
    if os.path.exists("data/reports/saved_reports.json"):
        try:
            with open("data/reports/saved_reports.json", "r") as f:
                reports = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            reports = []
    else:
        reports = []
    
    # Add new report
    reports.append(report_data)
    
    # Save updated reports
    with open("data/reports/saved_reports.json", "w") as f:
        json.dump(reports, f)

# Patient ID input
patient_id = st.text_input("Patient ID (Optional):", placeholder="e.g., P12345")

col1, col2 = st.columns(2)

with col1:
    st.markdown("### First X-ray")
    img1 = st.file_uploader("Upload first image", type=["png", "jpg", "jpeg"], key="img1")
    if img1:
        st.image(img1, caption="First X-ray", use_column_width=True)

with col2:
    st.markdown("### Second X-ray")
    img2 = st.file_uploader("Upload second image", type=["png", "jpg", "jpeg"], key="img2")
    if img2:
        st.image(img2, caption="Second X-ray", use_column_width=True)

def save_temp_image(uploaded_file):
    if uploaded_file is None:
        return None
        
    temp_dir = tempfile.gettempdir()
    temp_img_path = os.path.join(temp_dir, uploaded_file.name)
    with open(temp_img_path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    return temp_img_path

# Progress placeholder
progress_placeholder = st.empty()
results_container = st.container()

if st.button("Compare & Analyze"):
    if not img1 or not img2:
        st.warning("Please upload both images.")
    else:
        try:
            with progress_placeholder.container():
                progress_bar = st.progress(0)
                status_text = st.empty()
                
                # Save images
                status_text.text("Processing images...")
                progress_bar.progress(20)
                img1_path = save_temp_image(img1)
                img2_path = save_temp_image(img2)
                
                # Generate captions
                status_text.text("Generating captions...")
                progress_bar.progress(50)
                captions = generate_blip_captions([img1_path, img2_path])
                
                # Generate comparison report
                status_text.text("Creating comparison report...")
                progress_bar.progress(80)
                prompt = f"""
Compare the following radiology findings from two X-rays:

Image 1: {captions[0]}
Image 2: {captions[1]}

What are the differences or changes observed? Provide a detailed analysis of:
1. Changes in anatomical structures
2. Development or resolution of abnormalities
3. Progression or improvement of any condition
4. Technical differences between the images (if relevant)
"""
                comparison_report = llm_fn(prompt)
                
                progress_bar.progress(100)
                status_text.text("Complete!")
            
            # Clear progress indicators
            progress_placeholder.empty()
            
            # Display results
            with results_container:
                st.success("Comparison complete!")
                
                st.subheader("🖼️ BLIP Captions")
                col1, col2 = st.columns(2)
                with col1:
                    st.markdown(f"**Image 1:** {captions[0]}")
                with col2:
                    st.markdown(f"**Image 2:** {captions[1]}")
                
                st.subheader("📋 Comparative Diagnosis")
                comparison_text = st.text_area("Comparative Analysis", comparison_report, height=300)
                
                # Save report button
                if st.button("Save Report"):
                    report_data = {
                        "patient_id": patient_id if patient_id else f"Unknown-{datetime.now().strftime('%Y%m%d%H%M%S')}",
                        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        "report_type": "comparison",
                        "image_paths": [img1_path, img2_path],
                        "captions": captions,
                        "report_text": comparison_text
                    }
                    
                    save_report(report_data)
                    st.success("Comparison report saved successfully! View it in the Reports section.")
                    
        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
            import traceback
            st.exception(traceback.format_exc())



# /home/sysadm/Music/MedXpert/src/ui/pages/diagnosis.py

import streamlit as st
import tempfile
import os
import json
from datetime import datetime

from src.pipeline.blip_captioning import generate_blip_captions
from src.pipeline.llm_report_generation import generate_report
from src.llm_providers import llm_fn  # Real LLM interface

st.header("🩻 Direct Diagnosis from X-ray")

# Create necessary directories
os.makedirs("data/reports", exist_ok=True)

# Function to save report
def save_report(report_data):
    # Load existing reports
    if os.path.exists("data/reports/saved_reports.json"):
        try:
            with open("data/reports/saved_reports.json", "r") as f:
                reports = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            reports = []
    else:
        reports = []
    
    # Add new report
    reports.append(report_data)
    
    # Save updated reports
    with open("data/reports/saved_reports.json", "w") as f:
        json.dump(reports, f)

# Patient ID input
patient_id = st.text_input("Patient ID (Optional):", placeholder="e.g., P12345")

uploaded_file = st.file_uploader("Upload a chest X-ray image", type=["png", "jpg", "jpeg"])

def save_temp_image(file):
    temp_dir = tempfile.gettempdir()
    img_path = os.path.join(temp_dir, file.name)
    with open(img_path, "wb") as f:
        f.write(file.getbuffer())
    return img_path

if uploaded_file:
    # Display preview image
    st.image(uploaded_file, caption="Preview", width=300)

progress_placeholder = st.empty()
report_container = st.container()

if st.button("Generate Diagnosis"):
    if not uploaded_file:
        st.warning("Please upload an image.")
    else:
        try:
            with progress_placeholder.container():
                progress_bar = st.progress(0)
                status_text = st.empty()
                
                # Save image
                status_text.text("Processing image...")
                progress_bar.progress(10)
                image_path = save_temp_image(uploaded_file)
                
                # Generate BLIP caption
                status_text.text("Generating image caption...")
                progress_bar.progress(30)
                caption = generate_blip_captions([image_path])[0]
                
                # Generate report
                status_text.text("Creating diagnostic report...")
                progress_bar.progress(60)
                report = generate_report([caption], [], llm_fn)
                
                # Save the report
                status_text.text("Saving results...")
                progress_bar.progress(90)
                
                report_data = {
                    "patient_id": patient_id if patient_id else f"Unknown-{datetime.now().strftime('%Y%m%d%H%M%S')}",
                    "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "image_path": image_path,
                    "caption": caption,
                    "report_text": report
                }
                
                save_report(report_data)
                
                progress_bar.progress(100)
                status_text.text("Complete!")
            
            # Remove progress indicators
            progress_placeholder.empty()
            
            with report_container:
                st.success("Diagnosis complete!")
                st.subheader("🧠 Image Caption (via BLIP)")
                st.info(f"📝 Caption: {caption}")
                
                st.subheader("📋 Diagnostic Report (via LLM)")
                st.text_area("Generated Diagnosis", report, height=300)
                
                st.info("✅ Report saved! You can view all reports in the Reports section.")
        
        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
            import traceback
            st.exception(traceback.format_exc())


# /home/sysadm/Music/MedXpert/src/ui/pages/reports.py

import streamlit as st
import json
import os
import pandas as pd
from datetime import datetime

from src.pipeline.blip_captioning import generate_blip_captions
from src.pipeline.llm_report_generation import generate_report
from src.llm_providers import llm_fn

st.header("📋 Diagnostic Reports")

# Function to load saved reports
@st.cache_data
def load_reports():
    if os.path.exists("data/reports/saved_reports.json"):
        with open("data/reports/saved_reports.json", "r") as f:
            try:
                return json.load(f)
            except json.JSONDecodeError:
                return []
    return []

# Function to save reports
def save_report(reports):
    os.makedirs("data/reports", exist_ok=True)
    with open("data/reports/saved_reports.json", "w") as f:
        json.dump(reports, f)

# Initialize or load existing reports
reports = load_reports()

# View mode selection
view_mode = st.radio("Choose mode:", ["View Reports", "Search Reports", "Export Reports"], horizontal=True)

if view_mode == "View Reports":
    if not reports:
        st.info("No reports available. Generate reports from the Diagnosis or Search pages first.")
    else:
        # Display reports in reverse chronological order (newest first)
        for i, report in enumerate(reversed(reports)):
            with st.expander(f"Report #{len(reports)-i}: {report['timestamp']} - {report['patient_id']}"):
                st.markdown(f"**Patient ID:** {report['patient_id']}")
                st.markdown(f"**Timestamp:** {report['timestamp']}")
                
                if 'image_path' in report and report['image_path']:
                    st.image(report['image_path'], width=300)
                
                if 'caption' in report and report['caption']:
                    st.markdown(f"**BLIP Caption:** {report['caption']}")
                
                st.markdown("### Diagnostic Report")
                st.markdown(report['report_text'])
                
                # Delete button for each report
                if st.button(f"Delete Report", key=f"delete_{i}"):
                    reports.pop(len(reports)-i-1)
                    save_report(reports)
                    st.success("Report deleted successfully!")
                    st.rerun()

elif view_mode == "Search Reports":
    search_query = st.text_input("Search reports by keyword:")
    patient_id_filter = st.text_input("Filter by patient ID (optional):")
    
    if search_query or patient_id_filter:
        filtered_reports = []
        for report in reports:
            matches_keyword = not search_query or search_query.lower() in report['report_text'].lower()
            matches_patient = not patient_id_filter or patient_id_filter.lower() in report['patient_id'].lower()
            
            if matches_keyword and matches_patient:
                filtered_reports.append(report)
        
        if filtered_reports:
            st.success(f"Found {len(filtered_reports)} matching reports")
            for i, report in enumerate(filtered_reports):
                with st.expander(f"Report: {report['timestamp']} - {report['patient_id']}"):
                    st.markdown(f"**Patient ID:** {report['patient_id']}")
                    st.markdown(f"**Timestamp:** {report['timestamp']}")
                    
                    if 'image_path' in report and report['image_path']:
                        st.image(report['image_path'], width=300)
                    
                    st.markdown("### Diagnostic Report")
                    st.markdown(report['report_text'])
        else:
            st.info("No matching reports found.")

elif view_mode == "Export Reports":
    if not reports:
        st.info("No reports available to export.")
    else:
        # Convert to DataFrame for easy export
        df_data = []
        for report in reports:
            report_data = {
                "patient_id": report['patient_id'],
                "timestamp": report['timestamp'],
                "report_text": report['report_text'].replace("\n", " ")
            }
            if 'caption' in report:
                report_data["caption"] = report['caption']
            df_data.append(report_data)
        
        df = pd.DataFrame(df_data)
        
        # Create download button
        csv = df.to_csv(index=False)
        st.download_button(
            label="Download Reports as CSV",
            data=csv,
            file_name=f"medxpert_reports_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
            mime="text/csv"
        )
        
        # Also offer JSON export
        json_data = json.dumps(reports, indent=2)
        st.download_button(
            label="Download Reports as JSON",
            data=json_data,
            file_name=f"medxpert_reports_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
            mime="application/json"
        )

# Form for adding a report manually (useful for testing or adding external reports)
st.divider()
with st.expander("Add Report Manually (For Testing)"):
    with st.form("add_report_form"):
        patient_id = st.text_input("Patient ID:", placeholder="e.g., P12345")
        report_text = st.text_area("Report Text:", placeholder="Enter diagnostic report content...")
        
        submitted = st.form_submit_button("Save Report")
        
        if submitted and patient_id and report_text:
            new_report = {
                "patient_id": patient_id,
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "report_text": report_text
            }
            reports.append(new_report)
            save_report(reports)
            st.success("Report added successfully!")
            st.rerun()

# MedXpert/src/ui/pages/search.py

import streamlit as st
import os
import tempfile
import json
from datetime import datetime

from src.pipeline.clip_retrieval import retrieve_top_k
from src.pipeline.blip_captioning import generate_blip_captions
from src.pipeline.llm_report_generation import generate_report
from src.llm_providers import llm_fn  # your actual LLM API function

# Create necessary directories
os.makedirs("data/processed/texts", exist_ok=True)
os.makedirs("data/reports", exist_ok=True)

st.header("🔎 Visual + Text Search")

# Function to save report
def save_report(report_data):
    # Load existing reports
    if os.path.exists("data/reports/saved_reports.json"):
        try:
            with open("data/reports/saved_reports.json", "r") as f:
                reports = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            reports = []
    else:
        reports = []
    
    # Add new report
    reports.append(report_data)
    
    # Save updated reports
    with open("data/reports/saved_reports.json", "w") as f:
        json.dump(reports, f)

# Load or create test dataset
@st.cache_data
def load_dataset():
    dataset_path = "data/processed/texts/test.json"
    if os.path.exists(dataset_path):
        try:
            with open(dataset_path) as f:
                return json.load(f)
        except json.JSONDecodeError:
            # Return dummy dataset if file is corrupted
            return create_dummy_dataset()
    else:
        return create_dummy_dataset()

def create_dummy_dataset():
    # Create a minimal dummy dataset for testing when actual data isn't available
    dummy_data = [
        {"image_id": "examples/example1.jpg", "text": "Normal chest X-ray with no significant findings."},
        {"image_id": "examples/example2.jpg", "text": "Bilateral lung opacities consistent with pneumonia."},
        {"image_id": "examples/example3.jpg", "text": "Left lower lobe consolidation."},
    ]
    
    # Create examples directory if it doesn't exist
    os.makedirs("examples", exist_ok=True)
    
    # Save dummy dataset
    with open("data/processed/texts/test.json", "w") as f:
        json.dump(dummy_data, f)
    
    return dummy_data

# Patient ID input
patient_id = st.text_input("Patient ID (Optional):", placeholder="e.g., P12345")

mode = st.radio("Choose retrieval mode:", ["Text → Image/Text", "Image → Text"])
top_k = st.slider("How many results to retrieve?", 1, 10, 3)

try:
    dataset = load_dataset()
except Exception as e:
    st.error(f"Error loading dataset: {str(e)}")
    dataset = []

def save_temp_image(uploaded_file):
    temp_dir = tempfile.gettempdir()
    path = os.path.join(temp_dir, uploaded_file.name)
    with open(path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    return path

def display_results(indices, query_image=None):
    try:
        # Get samples from dataset
        samples = [dataset[i] for i in indices if i < len(dataset)]
        if not samples:
            st.warning("No matching results found. Try adjusting your query or upload a different image.")
            return
            
        image_paths = [s.get("image_id", "") for s in samples]
        texts = [s.get("text", "") for s in samples]
        
        # Check if image paths exist
        valid_image_paths = []
        for path in image_paths:
            if os.path.exists(path):
                valid_image_paths.append(path)
            else:
                st.warning(f"Image path not found: {path}")
        
        if not valid_image_paths:
            st.warning("No valid image paths found in results.")
            return
            
        st.subheader("📸 Retrieved X-ray Images + Captions")
        
        # Generate captions for valid images only
        with st.spinner("Generating captions..."):
            captions = generate_blip_captions(valid_image_paths)
        
        # Display images in columns
        cols = st.columns(len(valid_image_paths))
        for i, col in enumerate(cols):
            if i < len(valid_image_paths):
                col.image(valid_image_paths[i], caption=captions[i] if i < len(captions) else "", use_column_width=True)
        
        # Generate report
        st.subheader("📝 AI-Generated Diagnostic Report")
        
        with st.spinner("Generating diagnostic report..."):
            report = generate_report(captions, texts, llm_fn)
        
        report_text = st.text_area("Report Output", report, height=250)
        
        if st.button("Save Report"):
            # Prepare report data
            report_data = {
                "patient_id": patient_id if patient_id else f"Unknown-{datetime.now().strftime('%Y%m%d%H%M%S')}",
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "report_text": report_text,
                "query_mode": mode,
                "captions": captions,
                "texts": texts
            }
            
            # Add query image if available
            if query_image:
                report_data["image_path"] = query_image
            
            # Save report
            save_report(report_data)
            st.success("Report saved successfully! View it in the Reports section.")
            
    except Exception as e:
        st.error(f"Error displaying results: {str(e)}")
        import traceback
        st.exception(traceback.format_exc())

# Progress placeholder
progress_placeholder = st.empty()

# User input section
if mode == "Text → Image/Text":
    query = st.text_input("Enter medical query:", "What abnormality is present?")
    if st.button("Search & Generate Report"):
        if not query:
            st.warning("Please enter a valid query.")
        else:
            try:
                with progress_placeholder.container():
                    progress_bar = st.progress(0)
                    status_text = st.empty()
                    
                    status_text.text("Retrieving similar cases...")
                    progress_bar.progress(30)
                    
                    indices, _ = retrieve_top_k(query, mode="text", k=top_k)
                    
                    status_text.text("Processing results...")
                    progress_bar.progress(80)
                    
                    progress_bar.progress(100)
                    status_text.text("Complete!")
                
                # Clear progress indicators
                progress_placeholder.empty()
                
                # Display results
                display_results(indices)
                
            except Exception as e:
                st.error(f"Search error: {str(e)}")
                import traceback
                st.exception(traceback.format_exc())

elif mode == "Image → Text":
    uploaded_file = st.file_uploader("Upload chest X-ray:", type=["png", "jpg", "jpeg"])
    
    if uploaded_file:
        # Display preview image
        st.image(uploaded_file, caption="Preview", width=300)
        
    if st.button("Search & Generate Report"):
        if not uploaded_file:
            st.warning("Please upload a file.")
        else:
            try:
                with progress_placeholder.container():
                    progress_bar = st.progress(0)
                    status_text = st.empty()
                    
                    status_text.text("Processing uploaded image...")
                    progress_bar.progress(20)
                    
                    image_path = save_temp_image(uploaded_file)
                    
                    status_text.text("Retrieving similar cases...")
                    progress_bar.progress(50)
                    
                    indices, _ = retrieve_top_k(image_path, mode="image", k=top_k)
                    
                    status_text.text("Processing results...")
                    progress_bar.progress(80)
                    
                    progress_bar.progress(100)
                    status_text.text("Complete!")
                
                # Clear progress indicators
                progress_placeholder.empty()
                
                # Display results
                display_results(indices, query_image=image_path)
                
            except Exception as e:
                st.error(f"Search error: {str(e)}")
                import traceback
                st.exception(traceback.format_exc())

  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
# /home/sysadm/Music/MedXpert/src/pipeline/blip_captioning.py

def generate_blip_captions(image_paths):
    from transformers import BlipProcessor, BlipForConditionalGeneration
    from PIL import Image
    import torch
    import os
    
    # Load BLIP model and processor
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
    
    captions = []
    
    for img_path in image_paths:
        # Load and process each image
        image = Image.open(img_path).convert("RGB")
        inputs = processor(image, return_tensors="pt").to("cuda")
        
        # Generate caption
        output = model.generate(**inputs)
        caption = processor.decode(output[0], skip_special_tokens=True)
        captions.append(caption)
    
    return captions


# /home/sysadm/Music/MedXpert/src/pipeline/clip_retrieval.py
def retrieve_top_k(query, mode="text", k=5):
    from src.core.search_engine import search_image_by_text, search_text_by_image

    if mode == "text":
        return search_image_by_text(query, k=k)
    elif mode == "image":
        return search_text_by_image(query, k=k)
    else:
        raise ValueError("mode must be 'text' or 'image'")


# /home/sysadm/Music/MedXpert/src/pipeline/llm_report_generation.py
def generate_report(blip_captions, retrieved_texts, llm_fn):
    prompt = """
Below are findings extracted from multiple images and related radiology texts.

Image Findings (via BLIP):
"""
    for c in blip_captions:
        prompt += f"- {c}\n"
    
    prompt += "\nReport Texts (via CLIP):\n"
    for t in retrieved_texts:
        prompt += f"- {t}\n"

    prompt += "\nGenerate a summarized radiology report:"

    return llm_fn(prompt)
