In [None]:
import os
import numpy as np
import pydicom
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import requests
from tqdm import tqdm
import zipfile
import logging
from tensorflow.keras.models import load_model
from matplotlib.patches import Rectangle

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class DicomDownloader:
    """
    Handles downloading of sample DICOM images for testing
    """
    def __init__(self, output_dir="test_images"):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Add your specific DICOM sources here
        self.sources = {
            "mini_MIAS": "https://example.com/mini_mias.zip",
            "CBIS-DDSM": "https://example.com/sample_ddsm.zip"
        }
    
    def download_samples(self):
        """Downloads sample mammogram DICOM images"""
        for source_name, url in self.sources.items():
            try:
                logger.info(f"Downloading {source_name} dataset...")
                self._download_and_extract(source_name, url)
            except Exception as e:
                logger.error(f"Error downloading {source_name}: {e}")
    
    def _download_and_extract(self, source_name, url):
        """Helper method to download and extract files"""
        zip_path = f"{self.output_dir}/{source_name}.zip"
        
        # Download with progress bar
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(zip_path, 'wb') as file, tqdm(
            desc=source_name,
            total=total_size,
            unit='iB',
            unit_scale=True
        ) as progress_bar:
            for data in response.iter_content(chunk_size=1024):
                size = file.write(data)
                progress_bar.update(size)
        
        # Extract files
        logger.info(f"Extracting {source_name} files...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.output_dir)
        
        # Cleanup
        os.remove(zip_path)

class EnhancedVisualizer:
    """
    Handles advanced visualization of mammogram analysis results
    """
    def __init__(self, class_names):
        self.class_names = class_names
        self.colors = sns.color_palette("husl", len(class_names))
    
    def create_detailed_visualization(self, img, results, save_path=None):
        """
        Creates a comprehensive visualization of analysis results
        
        Parameters:
        - img: Processed mammogram image
        - results: Dictionary containing analysis results
        - save_path: Optional path to save the visualization
        """
        plt.style.use('seaborn')
        fig = plt.figure(figsize=(20, 10))
        
        # Original Image with confidence border
        self._plot_mammogram(plt.subplot2grid((2, 4), (0, 0), colspan=2), img, results)
        
        # Probability distribution
        self._plot_probabilities(plt.subplot2grid((2, 4), (0, 2), colspan=2), results)
        
        # Confidence gauge
        self._plot_confidence_gauge(plt.subplot2grid((2, 4), (1, 0)), results['confidence'])
        
        # Key metrics
        self._plot_metrics_summary(plt.subplot2grid((2, 4), (1, 1)), results)
        
        # DICOM metadata
        self._plot_metadata_summary(plt.subplot2grid((2, 4), (1, 2), colspan=2), 
                                  results['dicom_metadata'])
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.show()
    
    def _plot_mammogram(self, ax, img, results):
        """Plots mammogram with confidence-based border"""
        ax.imshow(img, cmap='gray')
        ax.set_title('Processed Mammogram')
        ax.axis('off')
        
        # Add confidence-based border
        confidence = results['confidence']
        color = 'green' if confidence > 0.8 else 'yellow' if confidence > 0.6 else 'red'
        rect = Rectangle((0, 0), 1, 1, transform=ax.transAxes,
                        facecolor='none', edgecolor=color, linewidth=3)
        ax.add_patch(rect)
    
    def _plot_probabilities(self, ax, results):
        """Plots probability distribution"""
        probabilities = list(results['probabilities'].values())
        bars = ax.bar(self.class_names, probabilities, color=self.colors)
        ax.set_title('Prediction Probabilities')
        ax.set_ylim([0, 1])
        plt.xticks(rotation=45)
        
        # Add value labels
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2%}',
                    ha='center', va='bottom')
    
    def _plot_confidence_gauge(self, ax, confidence):
        """Plots confidence gauge chart"""
        ax.set_title('Confidence Level')
        
        colors = ['red', 'yellow', 'green']
        n_colors = len(colors)
        
        # Background gauge
        for i in range(n_colors):
            ax.barh(0, 1/n_colors, left=i/n_colors, color=colors[i], alpha=0.3)
        
        # Confidence marker
        ax.barh(0, 0.02, left=confidence, color='black')
        ax.text(0.5, -0.5, f'{confidence:.1%}', ha='center', va='center')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(-1, 1)
        ax.axis('off')
    
    def _plot_metrics_summary(self, ax, results):
        """Plots key metrics summary"""
        ax.text(0.5, 0.8, f"Prediction:\n{results['predicted_class']}", 
                ha='center', va='center', fontsize=10)
        ax.text(0.5, 0.2, f"Confidence:\n{results['confidence']:.1%}", 
                ha='center', va='center', fontsize=10)
        ax.axis('off')
    
    def _plot_metadata_summary(self, ax, metadata):
        """Plots DICOM metadata summary"""
        ax.text(0.05, 0.95, "DICOM Metadata:", fontsize=10, va='top')
        y_pos = 0.8
        for key, value in metadata.items():
            ax.text(0.1, y_pos, f"{key}: {value}", fontsize=8)
            y_pos -= 0.15
        ax.axis('off')

class MammogramAnalyzer:
    """
    Main class for mammogram analysis
    """
    def __init__(self, model_path, img_size=224):
        self.img_size = img_size
        self.model = self.load_model(model_path)
        self.class_names = ['BENIGN_WITHOUT_CALLBACK', 'BENIGN', 'MALIGNANT']
        self.visualizer = EnhancedVisualizer(self.class_names)
        
    def load_model(self, model_path):
        """Loads the trained model"""
        try:
            logger.info(f"Loading model from {model_path}")
            return load_model(model_path)
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise

    def enhance_mammogram(self, img):
        """Enhances mammogram image quality"""
        try:
            p1, p99 = np.percentile(img, (1, 99))
            img = np.clip(img, p1, p99)
            img = ((img - p1) / (p99 - p1) * 255).astype(np.uint8)
            
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            img = clahe.apply(img)
            
            return img
        except Exception as e:
            logger.error(f"Error enhancing image: {e}")
            return None

    def preprocess_dicom(self, file_path):
        """Preprocesses DICOM image for analysis"""
        try:
            dicom = pydicom.dcmread(file_path, force=True)
            img = dicom.pixel_array.astype(np.float32)
            
            if img.shape[0] < 100 or img.shape[1] < 100:
                raise ValueError("Image dimensions too small")
            
            img = cv2.resize(img, (self.img_size, self.img_size))
            img = self.enhance_mammogram(img)
            img = (img.astype(np.float32) - 127.5) / 127.5
            img = np.stack([img] * 3, axis=-1)
            
            return img, dicom
        except Exception as e:
            logger.error(f"Error preprocessing DICOM {file_path}: {e}")
            return None, None

    def analyze_image(self, file_path, visualization=True, save_path=None):
        """
        Analyzes a single mammogram image
        
        Parameters:
        - file_path: Path to DICOM file
        - visualization: Whether to show visualization
        - save_path: Path to save visualization
        """
        try:
            img, dicom = self.preprocess_dicom(file_path)
            if img is None:
                return None
            
            prediction = self.model.predict(np.expand_dims(img, axis=0))[0]
            predicted_class = self.class_names[np.argmax(prediction)]
            confidence = float(np.max(prediction))
            
            results = {
                'predicted_class': predicted_class,
                'confidence': confidence,
                'probabilities': {
                    class_name: float(prob)
                    for class_name, prob in zip(self.class_names, prediction)
                },
                'dicom_metadata': {
                    'PatientID': getattr(dicom, 'PatientID', 'Unknown'),
                    'Modality': getattr(dicom, 'Modality', 'Unknown'),
                    'StudyDate': getattr(dicom, 'StudyDate', 'Unknown'),
                    'ImageLaterality': getattr(dicom, 'ImageLaterality', 'Unknown'),
                    'ViewPosition': getattr(dicom, 'ViewPosition', 'Unknown'),
                }
            }
            
            if visualization:
                self.visualizer.create_detailed_visualization(img, results, save_path)
            
            return results
            
        except Exception as e:
            logger.error(f"Error analyzing image: {e}")
            return None

def test_multiple_images(analyzer, test_directory):
    """
    Tests the model on multiple images in a directory
    """
    results = []
    
    for filename in os.listdir(test_directory):
        if filename.endswith('.dcm'):
            file_path = os.path.join(test_directory, filename)
            logger.info(f"Analyzing {filename}")
            
            result = analyzer.analyze_image(file_path)
            if result:
                result['filename'] = filename
                results.append(result)
                
                logger.info(f"Results for {filename}:")
                logger.info(f"Prediction: {result['predicted_class']}")
                logger.info(f"Confidence: {result['confidence']:.2%}")
                logger.info("-" * 50)
    
    return results

# Example usage scenarios
def main():
    """
    Example usage of the mammogram analysis system
    """
    # Configuration
    MODEL_PATH = 'final_model.keras'
    TEST_DIR = 'test_images'
    
    try:
        # 1. Download sample images (if needed)
        downloader = DicomDownloader(TEST_DIR)
        downloader.download_samples()
        
        # 2. Initialize analyzer
        analyzer = MammogramAnalyzer(MODEL_PATH)
        
        # 3. Example 1: Analyze single image
        logger.info("\nExample 1: Single Image Analysis")
        single_image_path = os.path.join(TEST_DIR, 'sample.dcm')
        if os.path.exists(single_image_path):
            result = analyzer.analyze_image(
                single_image_path,
                visualization=True,
                save_path='analysis_result.png'
            )
            if result:
                logger.info(f"Prediction: {result['predicted_class']}")
                logger.info(f"Confidence: {result['confidence']:.2%}")
        
        # 4. Example 2: Batch Analysis
        logger.info("\nExample 2: Batch Analysis")
        if os.path.exists(TEST_DIR):
            results = test_multiple_images(analyzer, TEST_DIR)
            
            # Summary statistics
            if results:
                confidence_levels = [r['confidence'] for r in results]
                logger.info("\nBatch Analysis Summary:")
                logger.info(f"Total images processed: {len(results)}")
                logger.info(f"Average confidence: {np.mean(confidence_levels):.2%}")
                logger.info("Predictions by class:")
                for class_name in analyzer.class_names:
                    count = sum(1 for r in results if r['predicted_class'] == class_name)
                    logger.info(f"  {class_name}: {count}")
        
    except Exception as e:
        logger.error(f"Error in main execution: {e}")

if __name__ == "__main__":
    main()