In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [3]:
# Install required libraries
!pip install monai torch torchvision transformers gradio pydicom SimpleITK opencv-python-headless pylibjpeg pylibjpeg-libjpeg
!pip install -q datasets accelerate

import os
import logging
import re
import pydicom
import numpy as np
from datetime import datetime
from PIL import Image
import torch
import cv2
import SimpleITK as sitk
from monai.transforms import (
    Compose,
    EnsureChannelFirst,
    ScaleIntensity,
    Resize,
    ToTensor
)
from monai.networks.nets import DenseNet121
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    AutoTokenizer,
    BertLMHeadModel
)
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import json
from typing import Tuple, Dict, Optional

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Constants
CONFIG = {
    "image_size": [512, 512],
    "blip_model": "Salesforce/blip-image-captioning-large",
    "llm_model": "emilyalsentzer/Bio_ClinicalBERT",
    "monai_model": "DenseNet121",
    "num_classes": 14,
    "class_names": [
        "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly",
        "Lung Opacity", "Lung Lesion", "Edema", "Consolidation",
        "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion",
        "Pleural Other", "Fracture", "Support Devices"
    ]
}

# Initialize assets
def initialize_assets():
    """Create required JSON files if they don't exist"""
    assets = {
        "medical_terminology.json": {
            "opacity": "radiographic opacity",
            "shadow": "opacity",
            "heart": "cardiac silhouette",
            "bone": "osseous structures",
            "lung": "pulmonary parenchyma"
        },
        "report_templates.json": {
            "full_report": {
                "comparison": "Comparison is made with prior studies when available.\n\n{findings}",
                "technique": "Standard radiographic technique was employed.",
                "findings": "{findings}",
                "impression": "{impression}"
            }
        }
    }
    
    for filename, content in assets.items():
        if not os.path.exists(filename):
            with open(filename, "w") as f:
                json.dump(content, f)
            logger.info(f"Created {filename}")

initialize_assets()

class MedicalImageLoader:
    """Robust medical image loader with multiple fallback methods"""
    
    @staticmethod
    def load_image(image_path: str) -> Tuple[Optional[np.ndarray], Optional[str]]:
        """
        Load medical image using multiple methods with comprehensive error handling
        
        Args:
            image_path: Path to the image file
            
        Returns:
            tuple: (image_array, error_message)
        """
        try:
            if not os.path.exists(image_path):
                return None, f"File not found: {image_path}"
            
            if image_path.lower().endswith('.dcm'):
                return MedicalImageLoader._load_dicom(image_path)
            else:
                return MedicalImageLoader._load_standard_image(image_path)
                
        except Exception as e:
            error_msg = f"Image loading failed: {str(e)}"
            logger.error(error_msg)
            return None, error_msg
    
    @staticmethod
    def _load_dicom(path: str) -> Tuple[Optional[np.ndarray], Optional[str]]:
        """Load DICOM file with multiple fallback methods"""
        # Method 1: SimpleITK
        try:
            img_sitk = sitk.ReadImage(path)
            img_np = sitk.GetArrayFromImage(img_sitk)
            if len(img_np.shape) == 3:  # If multi-slice, take first slice
                img_np = img_np[0]
            return MedicalImageLoader._normalize_dicom(img_np), None
        except Exception as e:
            logger.warning(f"SimpleITK failed, trying pydicom: {e}")
        
        # Method 2: pydicom
        try:
            ds = pydicom.dcmread(path)
            if not hasattr(ds, 'pixel_array'):
                return None, "DICOM file does not contain image data"
            return MedicalImageLoader._normalize_dicom(ds.pixel_array), None
        except Exception as e:
            error_msg = f"DICOM loading failed with both methods: {str(e)}"
            logger.error(error_msg)
            return None, error_msg
    
    @staticmethod
    def _load_standard_image(path: str) -> Tuple[Optional[np.ndarray], Optional[str]]:
        """Load standard image with multiple fallback methods"""
        # Method 1: OpenCV
        try:
            img = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
            if img is None:
                raise ValueError("OpenCV returned None")
            return MedicalImageLoader._normalize_standard(img), None
        except Exception as e:
            logger.warning(f"OpenCV failed, trying PIL: {e}")
        
        # Method 2: PIL
        try:
            img = Image.open(path)
            if img.mode not in ['L', 'RGB', 'RGBA']:
                return None, f"Unsupported image mode: {img.mode}"
            return MedicalImageLoader._normalize_standard(np.array(img)), None
        except Exception as e:
            error_msg = f"Standard image loading failed with both methods: {str(e)}"
            logger.error(error_msg)
            return None, error_msg
    
    @staticmethod
    def _normalize_dicom(img: np.ndarray) -> np.ndarray:
        """Normalize DICOM image to 0-1 float32"""
        if img.dtype == np.uint16:
            img = img.astype(np.float32) / (2**16 - 1)
        elif img.dtype == np.uint8:
            img = img.astype(np.float32) / (2**8 - 1)
        else:
            img = img.astype(np.float32)
        
        # Convert to 3-channel if needed
        if len(img.shape) == 2:
            img = np.stack([img]*3, axis=-1)
        elif img.shape[-1] == 1:
            img = np.repeat(img, 3, axis=-1)
        
        return img
    
    @staticmethod
    def _normalize_standard(img: np.ndarray) -> np.ndarray:
        """Normalize standard image to 0-1 float32"""
        if img.dtype == np.uint16:
            img = img.astype(np.float32) / (2**16 - 1)
        elif img.dtype == np.uint8:
            img = img.astype(np.float32) / (2**8 - 1)
        else:
            img = img.astype(np.float32)
        
        # Handle different channel counts
        if len(img.shape) == 2:
            img = np.stack([img]*3, axis=-1)
        elif img.shape[-1] == 1:
            img = np.repeat(img, 3, axis=-1)
        elif img.shape[-1] == 4:  # RGBA to RGB
            img = img[..., :3]
        
        return img

class MedicalImageTransformer:
    """Medical image transformer with robust error handling"""
    
    def __init__(self):
        self.transform = Compose([
            EnsureChannelFirst(),
            ScaleIntensity(),
            Resize(CONFIG["image_size"]),
            ToTensor()
        ])
    
    def transform_image(self, image_array: np.ndarray) -> Tuple[Optional[torch.Tensor], Optional[str]]:
        """Apply transforms with comprehensive error handling"""
        try:
            # Ensure array is float32 and in correct range
            if image_array.dtype != np.float32:
                image_array = image_array.astype(np.float32)
            
            # Apply MONAI transforms
            return self.transform(image_array), None
        except Exception as e:
            error_msg = f"Image transformation failed: {str(e)}"
            logger.error(error_msg)
            return None, error_msg

class MedicalAnalysisPipeline:
    """Complete medical image analysis pipeline"""
    
    def __init__(self):
        self.image_loader = MedicalImageLoader()
        self.image_transformer = MedicalImageTransformer()
        self._init_models()
        self._load_assets()
    
    def _init_models(self):
        """Initialize all models with proper error handling"""
        try:
            # BLIP model
            self.blip_processor = BlipProcessor.from_pretrained(CONFIG["blip_model"])
            self.blip_model = BlipForConditionalGeneration.from_pretrained(
                CONFIG["blip_model"]
            ).to(device)
            
            # Clinical LLM
            self.llm_tokenizer = AutoTokenizer.from_pretrained(CONFIG["llm_model"])
            self.llm_model = BertLMHeadModel.from_pretrained(
                CONFIG["llm_model"],
                is_decoder=True
            ).to(device)
            
            # MONAI model
            self.monai_model = DenseNet121(
                spatial_dims=2,
                in_channels=3,
                out_channels=CONFIG["num_classes"]
            ).to(device)
            
            # Try to load pretrained weights if available
            if os.path.exists("pretrained_chexpert.pth"):
                self.monai_model.load_state_dict(
                    torch.load("pretrained_chexpert.pth", map_location=device)
                )
                logger.info("Loaded pretrained weights")
            
            self.monai_model.eval()
            
        except Exception as e:
            logger.error(f"Model initialization failed: {e}")
            raise RuntimeError(f"Could not initialize models: {e}")
    
    def _load_assets(self):
        """Load medical terminology and templates"""
        try:
            with open("medical_terminology.json", "r") as f:
                self.med_terms = json.load(f)
            
            with open("report_templates.json", "r") as f:
                self.report_templates = json.load(f)
                
        except Exception as e:
            logger.error(f"Failed to load assets: {e}")
            raise RuntimeError(f"Could not load required assets: {e}")
    
    def analyze(self, image_path: str, clinical_context: str = "") -> Dict:
        """Complete analysis pipeline with comprehensive error handling"""
        try:
            # Step 1: Load image
            img_array, error = self.image_loader.load_image(image_path)
            if error:
                return {"error": error}
            
            # Step 2: Transform image
            img_tensor, error = self.image_transformer.transform_image(img_array)
            if error:
                return {"error": error}
            
            # Step 3: Generate findings
            findings = self._generate_findings(img_tensor, clinical_context)
            
            # Step 4: Detect pathologies
            pathologies = self._detect_pathologies(img_tensor)
            
            # Step 5: Generate report
            report = self._generate_report(findings, pathologies)
            
            # Step 6: Create visualization
            visualization = self._create_visualization(pathologies)
            
            return {
                "success": True,
                "findings": findings,
                "pathologies": pathologies,
                "report": report,
                "visualization": visualization,
                "clinical_context": clinical_context,
                "timestamp": datetime.now().isoformat()
            }
            
        except Exception as e:
            error_msg = f"Analysis pipeline failed: {str(e)}"
            logger.error(error_msg)
            return {"error": error_msg}
    
    def _generate_findings(self, image_tensor: torch.Tensor, context: str) -> str:
        """Generate initial findings with BLIP"""
        try:
            img_pil = self._tensor_to_pil(image_tensor)
            prompt = f"Analyze this medical image. Clinical context: {context or 'None'}"
            
            inputs = self.blip_processor(
                img_pil, 
                text=prompt, 
                return_tensors="pt"
            ).to(device)
            
            outputs = self.blip_model.generate(
                **inputs,
                max_new_tokens=300,
                num_beams=4,
                temperature=0.7
            )
            
            findings = self.blip_processor.decode(outputs[0], skip_special_tokens=True)
            return self._enhance_terminology(findings)
            
        except Exception as e:
            logger.error(f"Findings generation failed: {e}")
            return "Could not generate detailed findings"
    
    def _detect_pathologies(self, image_tensor: torch.Tensor) -> Dict:
        """Detect pathologies using MONAI model"""
        try:
            with torch.no_grad():
                outputs = self.monai_model(image_tensor.unsqueeze(0).to(device))
                probs = torch.sigmoid(outputs).cpu().numpy()[0]
            
            return {
                "labels": CONFIG["class_names"],
                "probabilities": probs.tolist(),
                "detected": [
                    CONFIG["class_names"][i] 
                    for i, p in enumerate(probs) 
                    if p > 0.5
                ]
            }
            
        except Exception as e:
            logger.error(f"Pathology detection failed: {e}")
            return {
                "labels": CONFIG["class_names"],
                "probabilities": [0.0] * len(CONFIG["class_names"]),
                "detected": [],
                "error": "Pathology detection failed"
            }
    
    def _generate_report(self, findings: str, pathologies: Dict) -> Dict:
        """Generate structured radiology report"""
        try:
            impression = self._get_impression(pathologies)
            
            return {
                "comparison": self.report_templates["full_report"]["comparison"].format(findings=findings),
                "technique": self.report_templates["full_report"]["technique"],
                "findings": self._format_findings(findings, pathologies),
                "impression": impression
            }
            
        except Exception as e:
            logger.error(f"Report generation failed: {e}")
            return {
                "comparison": "Could not generate comparison",
                "technique": "Standard technique",
                "findings": findings,
                "impression": "Could not generate full impression"
            }
    
    def _create_visualization(self, pathologies: Dict) -> str:
        """Create pathology probability plot"""
        try:
            plt.figure(figsize=(10, 6))
            sns.barplot(
                x=pathologies["probabilities"],
                y=pathologies["labels"],
                palette="viridis"
            )
            plt.title("Pathology Probability Scores")
            plt.xlim(0, 1)
            plt.tight_layout()
            
            plot_path = "pathology_plot.png"
            plt.savefig(plot_path)
            plt.close()
            return plot_path
            
        except Exception as e:
            logger.error(f"Visualization failed: {e}")
            return ""
    
    def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
        """Convert tensor to PIL Image"""
        img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
        img = (img * 255).astype(np.uint8)
        return Image.fromarray(img)
    
    def _get_impression(self, pathologies: Dict) -> str:
        """Generate impression based on pathologies"""
        if pathologies.get("error"):
            return "Pathology analysis unavailable"
            
        if pathologies["detected"]:
            return f"Findings consistent with: {', '.join(pathologies['detected'])}. Clinical correlation recommended."
        return "No acute cardiopulmonary abnormality detected."
    
    def _enhance_terminology(self, text: str) -> str:
        """Replace lay terms with medical terminology"""
        for term, replacement in self.med_terms.items():
            text = re.sub(rf'\b{term}\b', replacement, text, flags=re.IGNORECASE)
        return text
    
    def _format_findings(self, findings: str, pathologies: Dict) -> str:
        """Highlight detected pathologies in findings"""
        if pathologies.get("error"):
            return findings
            
        for path in pathologies["detected"]:
            findings = findings.replace(
                path.lower(),
                f"**{path.upper()}**"
            )
        return findings

def create_gradio_interface():
    """Create Gradio interface for the medical analysis system"""
    analyzer = MedicalAnalysisPipeline()
    
    def analyze_wrapper(image_path: str, clinical_context: str) -> str:
        """Wrapper function for Gradio interface"""
        result = analyzer.analyze(image_path, clinical_context)
        
        if result.get("error"):
            return f"<div style='color:red;padding:20px;'><h3>Error</h3><p>{result['error']}</p></div>"
        
        # Format HTML report
        html_report = f"""
        <div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto;">
            <h1 style="color: #2c3e50;">Medical Imaging Analysis Report</h1>
            <p><strong>Clinical Context:</strong> {result['clinical_context'] or 'None provided'}</p>
            <p><em>Generated at: {result['timestamp']}</em></p>
            
            <h2 style="color: #2c3e50; border-bottom: 1px solid #3498db;">Findings</h2>
            <div style="background: #f8f9fa; padding: 15px; border-radius: 5px;">
                {result['report']['findings']}
            </div>
            
            <h2 style="color: #2c3e50; border-bottom: 1px solid #3498db;">Pathology Analysis</h2>
            <div style="margin: 20px 0;">
                <img src="file/{result['visualization']}" style="max-width: 100%;">
            </div>
            
            <h2 style="color: #2c3e50; border-bottom: 1px solid #3498db;">Impression</h2>
            <div style="background: #f8f9fa; padding: 15px; border-radius: 5px;">
                {result['report']['impression']}
            </div>
            
            <div style="margin-top: 30px; font-size: 0.9em; color: #666;">
                <em>This AI-generated report requires verification by a qualified radiologist.</em>
            </div>
        </div>
        """
        
        return html_report
    
    with gr.Blocks() as app:
        gr.Markdown("# 🏥 Medical Imaging Analysis System")
        
        with gr.Row():
            with gr.Column():
                image_input = gr.File(
                    label="Upload Medical Image",
                    file_types=[".dcm", ".png", ".jpg", ".jpeg"]
                )
                context_input = gr.Textbox(
                    label="Clinical Context (optional)",
                    placeholder="Patient symptoms or history"
                )
                analyze_btn = gr.Button("Analyze", variant="primary")
            
            with gr.Column():
                report_output = gr.HTML(
                    label="Analysis Report",
                    value="<div style='text-align:center;padding:20px;'>Results will appear here</div>"
                )
        
        analyze_btn.click(
            fn=analyze_wrapper,
            inputs=[image_input, context_input],
            outputs=[report_output]
        )
    
    return app

if __name__ == "__main__":
    app = create_gradio_interface()
    app.launch(share=True)

* Running on local URL:  http://127.0.0.1:7861
* Running on public URL: https://42d52c42971eca57fa.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
