In [None]:
# 1. Setup and Imports

import os
import json
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import shutil
from datetime import datetime
import time
import logging
from collections import Counter, defaultdict
import gc
import random
import warnings
warnings.filterwarnings('ignore')

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("Imports completed")

In [None]:
# 2. Path and Configuration

BASE_PATH = r"D:/Dream_dataset"  # Specify the base path to your dataset here

# Save all the datasets to the BASE_PATH folder
# Path configuration based on your folder structure
datasets_config = {
    'APTOS2019': f'{BASE_PATH}/APTOS 2019',
    'Diabetic_Retinopathy_V03': f'{BASE_PATH}/Diabetic Retinopathy_V03',
    'IDRiD': f'{BASE_PATH}/IDRiD',
    'Messidor2': f'{BASE_PATH}/Messidor 2',
    'SUSTech_SYSU': f'{BASE_PATH}/SUSTech_SYSU',
    'DeepDRiD': f'{BASE_PATH}/DeepDRiD/DeepDRiD/DR/Original'
}

# Other configuration settings
OUTPUT_DIR = 'quality_review'
RANDOM_SEED = 42
N_SAMPLES_PER_DATASET = 300  # Number of images to sample for characterization

print("Configuration set")

In [None]:
# 3. Validate Dataset Paths

def validate_dataset_paths(datasets_config):
    valid_datasets = {}
    
    for name, path in datasets_config.items():
        print(f"\nChecking {name}:")
        print(f"  Path: {path}")
        
        if not os.path.exists(path):
            print(f"Path does not exist")
            continue
            
        # Check for DR class folders (0, 1, 2, 3, 4)
        dr_folders = []
        image_counts = {}
        
        try:
            for item in os.listdir(path):
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path):
                    if item.isdigit() and int(item) in [0, 1, 2, 3, 4]:
                        dr_class = int(item)
                        dr_folders.append(dr_class)
                        image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
                        image_count = 0
                        for root, dirs, files in os.walk(item_path):
                            for file in files:
                                if file.lower().endswith(image_extensions):
                                    image_count += 1
                        image_counts[dr_class] = image_count
            
            if dr_folders:
                dr_folders.sort()
                total_images = sum(image_counts.values())
                print(f" Found DR classes: {dr_folders}")
                print(f"Image counts per class:")
                dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
                for dr_class in dr_folders:
                    dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
                    print(f"     {dr_name} (Class {dr_class}): {image_counts[dr_class]:,} images")
                print(f"Total images: {total_images:,}")
                
                if total_images > 0:
                    valid_datasets[name] = path
                else:
                    print(f"No images found")
            else:
                print(f"No DR class folders (0,1,2,3,4) found")
                print(f"Available folders: {[item for item in os.listdir(path) if os.path.isdir(os.path.join(path, item))]}")

        except Exception as e:
            print(f"Error accessing path: {e}")

    return valid_datasets

# Run validation
valid_datasets = validate_dataset_paths(datasets_config)

if not valid_datasets:
    print("\nNo valid datasets found!")
    print("1. Update BASE_PATH in cell [2] to your actual dataset location")
    print("2. Ensure your datasets have folders named 0, 1, 2, 3, 4 containing images")
else:
    print(f"\n{len(valid_datasets)} Datasets loaded successfully")

In [None]:
# 4. Quality Identifier Class Definition

class QualityIdentifier:
    def __init__(self, output_dir='quality_review', random_seed=42):
        self.output_dir = output_dir
        self.dataset_profiles = {}
        self.identification_results = []
        self.random_seed = random_seed
        
        np.random.seed(random_seed)
        
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f'{output_dir}/sample_images', exist_ok=True)
        os.makedirs(f'{output_dir}/flagged_samples', exist_ok=True)
        
        logger.info(f"Quality identification output directory: {output_dir}")
    
    def extract_dr_label_from_path(self, image_path):
        parts = image_path.split(os.sep)
        
        for part in parts:
            if part.isdigit() and int(part) in [0, 1, 2, 3, 4]:
                return int(part)
        
        dr_patterns = {
            'no_dr': 0, 'normal': 0, 'grade_0': 0, 'class_0': 0,
            'mild': 1, 'grade_1': 1, 'class_1': 1,
            'moderate': 2, 'grade_2': 2, 'class_2': 2,
            'severe': 3, 'grade_3': 3, 'class_3': 3,
            'proliferative': 4, 'grade_4': 4, 'class_4': 4
        }
        
        for part in parts:
            part_lower = part.lower()
            if part_lower in dr_patterns:
                return dr_patterns[part_lower]
        return None
    
    def sample_images_strategically(self, dataset_path, n_samples):
        all_images = []
        class_images = {0: [], 1: [], 2: [], 3: [], 4: []}
        
        image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
        
        for root, dirs, files in os.walk(dataset_path):
            for file in files:
                if file.lower().endswith(image_extensions):
                    image_path = os.path.join(root, file)
                    dr_class = self.extract_dr_label_from_path(image_path)
                    if dr_class is not None:
                        class_images[dr_class].append(image_path)
        
        for dr_class, images in class_images.items():
            dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            if images:
                print(f"    {dr_name}: {len(images)} images")
        
        sampled_images = []
        total_available = sum(len(images) for images in class_images.values())
        
        if total_available == 0:
            print("No images found with valid DR labels")
            return []
        
        samples_per_class = max(1, n_samples // 5)
        
        for dr_class, images in class_images.items():
            if images:
                n_class_samples = min(samples_per_class, len(images))
                sampled = np.random.choice(images, n_class_samples, replace=False)
                sampled_images.extend(sampled)

        print(f"Selected {len(sampled_images)} stratified samples")
        return sampled_images
    
    def safe_calculate_brightness(self, image):
        try:
            brightness = float(np.mean(image))
            return max(0.0, min(255.0, brightness))
        except:
            return 127.5
    
    def safe_calculate_contrast(self, image):
        try:
            contrast = float(np.std(image))
            return max(0.0, contrast)
        except:
            return 50.0
    
    def safe_calculate_sharpness(self, gray):
        try:
            laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
            return max(0.0, float(laplacian_var))
        except:
            return 500.0
    
    def safe_calculate_entropy(self, gray):
        try:
            hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
            hist = hist.flatten()
            hist = hist[hist > 0]
            if len(hist) == 0:
                return 4.0
            prob = hist / hist.sum()
            entropy = float(-np.sum(prob * np.log2(prob)))
            return max(0.0, min(8.0, entropy))
        except:
            return 4.0

    def assess_vessel_visibility_improved(self, image):
        try:
            green = image[:, :, 1]
            
            kernels = {
                'horizontal': np.array([[-1, -1, -1], [2, 2, 2], [-1, -1, -1]], dtype=np.float32),
                'vertical': np.array([[-1, 2, -1], [-1, 2, -1], [-1, 2, -1]], dtype=np.float32),
                'diagonal1': np.array([[2, -1, -1], [-1, 2, -1], [-1, -1, 2]], dtype=np.float32),
                'diagonal2': np.array([[-1, -1, 2], [-1, 2, -1], [2, -1, -1]], dtype=np.float32)
            }
            
            vessel_responses = []
            for kernel in kernels.values():
                filtered = cv2.filter2D(green, cv2.CV_32F, kernel)
                vessel_responses.append(filtered)
            
            max_response = np.maximum.reduce(vessel_responses)
            threshold = np.percentile(max_response, 95)
            vessel_pixels = np.sum(max_response > threshold)
            total_pixels = max_response.shape[0] * max_response.shape[1]
            
            visibility_score = vessel_pixels / max(total_pixels, 1)
            return float(min(0.1, visibility_score))
            
        except Exception:
            return 0.0

    def assess_optic_disc_visibility(self, image):
        try:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            bright_pixels = np.sum(gray > np.percentile(gray, 90))
            total_pixels = gray.shape[0] * gray.shape[1]
            return float(bright_pixels / max(total_pixels, 1))
        except Exception:
            return 0.0

    def assess_illumination_uniformity(self, gray):
        try:
            h, w = gray.shape
            regions = []
            
            for i in range(3):
                for j in range(3):
                    start_y, end_y = i * h // 3, (i + 1) * h // 3
                    start_x, end_x = j * w // 3, (j + 1) * w // 3
                    region = gray[start_y:end_y, start_x:end_x]
                    regions.append(np.mean(region))
            
            mean_intensity = np.mean(regions)
            if mean_intensity > 0:
                cv_score = np.std(regions) / mean_intensity
                return float(max(0, min(1, 1 - cv_score)))
            return 0.0
        except Exception:
            return 0.0

    def detect_extreme_pixels(self, gray):
        try:
            very_dark = np.sum(gray < 10)
            very_bright = np.sum(gray > 245)
            total_pixels = gray.shape[0] * gray.shape[1]
            return float((very_dark + very_bright) / max(total_pixels, 1))
        except Exception:
            return 0.0

    def assess_motion_blur(self, gray):
        try:
            grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
            grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
            magnitude = np.sqrt(grad_x**2 + grad_y**2)
            return float(np.mean(magnitude))
        except Exception:
            return 0.0

    def analyze_single_image(self, image_path):
        try:
            if not os.path.exists(image_path):
                logger.warning(f"File not found: {image_path}")
                return None
                
            image = cv2.imread(image_path)
            if image is None:
                logger.warning(f"Could not read image: {image_path}")
                return None
            
            if len(image.shape) != 3:
                logger.warning(f"Invalid image format: {image_path}")
                return None
                
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            h, w = image.shape[:2]

            if h < 100 or w < 100:
                logger.warning(f"Image too small ({w}x{h}): {image_path}")
                return None
            
            characteristics = {
                'image_path': image_path,
                'filename': os.path.basename(image_path),
                'resolution': (w, h),
                'file_size_mb': max(0.001, os.path.getsize(image_path) / (1024*1024)),
                
                'brightness': self.safe_calculate_brightness(image),
                'contrast': self.safe_calculate_contrast(image),
                'sharpness': self.safe_calculate_sharpness(gray),
                'entropy': self.safe_calculate_entropy(gray),
                
                'color_balance': float(np.std([np.mean(image[:,:,0]), np.mean(image[:,:,1]), np.mean(image[:,:,2])])),
                
                'vessel_visibility': self.assess_vessel_visibility_improved(image),
                'optic_disc_visibility': self.assess_optic_disc_visibility(image),
                'illumination_uniformity': self.assess_illumination_uniformity(gray),
                
                'extreme_brightness_pixels': self.detect_extreme_pixels(gray),
                'motion_blur_score': self.assess_motion_blur(gray)
            }
            return characteristics
            
        except Exception as e:
            logger.error(f"Error analyzing {image_path}: {e}")
            return None

print("QualityIdentifier class defined")

In [None]:
# 4.5. Parameter Validation

class QualityParameters:
    def __init__(self):
        # Validated thresholds
        self.BRIGHTNESS_BOUNDS = (10, 245)
        self.BLACK_THRESHOLD = 30
        self.VESSEL_PERCENTILE = 95
        self.OPTIC_DISC_PERCENTILE = 90
        
        # Severity-specific adaptive thresholds
        self.SEVERITY_PERCENTILES = {
            0: 15,  # No DR - strictest
            1: 12,  # Mild DR
            2: 10,  # Moderate DR
            3: 8,   # Severe DR
            4: 5    # PDR - most relaxed
        }
        
        # Quality composition weights
        self.BASIC_WEIGHT = 0.3
        self.MEDICAL_WEIGHT = 0.7
        
        # Minimum image requirements
        self.MIN_RESOLUTION = (100, 100)
        self.MIN_FILE_SIZE_MB = 0.001
    
    def validate_bounds(self):
        assert 0 <= self.BRIGHTNESS_BOUNDS[0] < self.BRIGHTNESS_BOUNDS[1] <= 255
        assert 0 < self.BLACK_THRESHOLD < 255
        assert 0 < self.VESSEL_PERCENTILE <= 100
        assert 0 < self.OPTIC_DISC_PERCENTILE <= 100
        assert abs(self.BASIC_WEIGHT + self.MEDICAL_WEIGHT - 1.0) < 0.001
        return True

# Initialize and validate parameters
quality_params = QualityParameters()
quality_params.validate_bounds()
print("Quality parameters validated")

In [None]:
# 5. Initialize Quality Identifier

# Proceed if we have datasets are validated in above cells
if valid_datasets:
    identifier = QualityIdentifier(
        output_dir=OUTPUT_DIR, 
        random_seed=RANDOM_SEED
    )
    print("QualityIdentifier initialized!")
    print(f"Output directory: {OUTPUT_DIR}")
else:
    print("Cannot initialize - no valid datasets found")

In [None]:
# 6. Dataset Characterization & Adaptive Thresholding

# Characterization function
def characterize_dataset(identifier, dataset_path, dataset_name, n_samples=300):
    print(f"Characterizing {dataset_name}...")
    
    sample_images = identifier.sample_images_strategically(dataset_path, n_samples)
    
    if not sample_images:
        print(f"No images found in {dataset_path}")
        return None
    
    print(f"Analyzing {len(sample_images)} sample images...")
    
    # Analyze characteristics with progress updates
    characteristics = []
    for i, img_path in enumerate(sample_images):
        if i % 50 == 0 and i > 0:
            print(f"    Progress: {i}/{len(sample_images)} ({i/len(sample_images)*100:.1f}%)")
        char = identifier.analyze_single_image(img_path)
        if char:
            characteristics.append(char)  

        # Memory management
        if i % 100 == 0:
            gc.collect()

    if not characteristics:
        print(f"No valid characteristics extracted from {dataset_name}")
        return None

    print(f"Analyzed {len(characteristics)} valid images")

    # Calculate dataset profile
    profile = calculate_dataset_profile(dataset_name, characteristics)
    
    # Save sample images for review
    save_sample_images(identifier, sample_images[:20], dataset_name)
    return profile

# Calculate dataset profile
def calculate_dataset_profile(dataset_name, characteristics):
    profile = {
        'dataset_name': dataset_name,
        'analysis_date': datetime.now().isoformat(),
        'n_samples_analyzed': len(characteristics),
        'characteristics_stats': {},
        'adaptive_thresholds': {}
    }
    
    numeric_keys = [key for key in characteristics[0].keys() 
                   if key not in ['image_path', 'filename', 'resolution']]
    
    for key in numeric_keys:
        # Get raw values, filtering out None
        raw_values = [char[key] for char in characteristics if char[key] is not None]
        
        # Convert values, handling booleans and ensuring numeric types
        values = []
        for val in raw_values:
            if isinstance(val, bool):
                values.append(float(val))  # True -> 1.0, False -> 0.0
            elif isinstance(val, (int, float)):
                values.append(float(val))
            else:
                # Skip non-numeric, non-boolean values
                continue
        
        if values:
            profile['characteristics_stats'][key] = {
                'mean': float(np.mean(values)),
                'std': float(np.std(values)),
                'min': float(np.min(values)),
                'max': float(np.max(values)),
                'percentiles': {
                    '5': float(np.percentile(values, 5)),
                    '10': float(np.percentile(values, 10)),
                    '25': float(np.percentile(values, 25)),
                    '50': float(np.percentile(values, 50)),
                    '75': float(np.percentile(values, 75)),
                    '90': float(np.percentile(values, 90)),
                    '95': float(np.percentile(values, 95))
                }
            }

    # Adaptive removal thresholds calculation with 3-component scoring
    removal_percentiles = {
        0: 15,
        1: 12,
        2: 10,
        3: 8,
        4: 5
    }
    
    # Calculate combined quality scores for threshold calculation using 3-component formula
    quality_scores = []
    for char in characteristics:
        # Safely get and normalize metrics with error handling
        try:
            # Basic Quality metrics
            brightness = float(char['brightness']) if char['brightness'] is not None else 127.5
            contrast = float(char['contrast']) if char['contrast'] is not None else 50.0
            sharpness = float(char['sharpness']) if char['sharpness'] is not None else 500.0
            entropy = float(char['entropy']) if char['entropy'] is not None else 4.0
            
            # Medical Quality metrics
            illumination_uniformity = float(char['illumination_uniformity']) if char['illumination_uniformity'] is not None else 0.5
            vessel_visibility = float(char['vessel_visibility']) if char['vessel_visibility'] is not None else 0.1
            optic_disc_visibility = float(char['optic_disc_visibility']) if char['optic_disc_visibility'] is not None else 0.1
            
            # Technical Quality metrics
            extreme_pixels = float(char['extreme_brightness_pixels']) if char['extreme_brightness_pixels'] is not None else 0.1
            motion_blur = float(char['motion_blur_score']) if char['motion_blur_score'] is not None else 20.0
            color_balance = float(char['color_balance']) if char['color_balance'] is not None else 15.0
            
            # Normalize Basic Quality metrics
            brightness_norm = min(1.0, max(0.0, brightness / 255.0))
            contrast_norm = min(1.0, max(0.0, contrast / 100.0))
            sharpness_norm = min(1.0, max(0.0, sharpness / 1000.0))
            entropy_norm = min(1.0, max(0.0, entropy / 8.0))
            
            basic_quality = np.mean([brightness_norm, contrast_norm, sharpness_norm, entropy_norm])
            
            # Normalize Medical Quality metrics
            medical_quality = np.mean([
                illumination_uniformity,
                min(1.0, vessel_visibility * 10),
                min(1.0, optic_disc_visibility * 10)
            ])
            
            # Normalize Technical Quality metrics
            extreme_pixels_norm = max(0, min(1, 1 - (extreme_pixels * 2)))
            motion_blur_norm = min(1.0, max(0.0, motion_blur / 50.0))
            color_balance_norm = max(0, min(1, 1 - (color_balance / 50.0)))
            
            technical_quality = np.mean([extreme_pixels_norm, motion_blur_norm, color_balance_norm])
            
            # COMPOSITE FORMULA: CQ = 0.25 × BasicQuality + 0.55 × MedicalQuality + 0.20 × TechnicalQuality
            combined_quality = 0.25 * basic_quality + 0.55 * medical_quality + 0.20 * technical_quality
            quality_scores.append(combined_quality)
            
        except (TypeError, ValueError) as e:
            print(f"Warning: Error calculating quality score for image, using default: {e}")
            quality_scores.append(0.3)
    
    # Set percentile-based thresholds
    for dr_severity, percentile in removal_percentiles.items():
        try:
            threshold = np.percentile(quality_scores, percentile) if quality_scores else 0.3
            profile['adaptive_thresholds'][dr_severity] = float(threshold)
        except Exception as e:
            print(f"Warning: Error calculating threshold for DR severity {dr_severity}: {e}")
            profile['adaptive_thresholds'][dr_severity] = 0.3

    return profile  

# Save sample images for visual review
def save_sample_images(identifier, sample_paths, dataset_name):
    sample_dir = f'{identifier.output_dir}/sample_images/{dataset_name}'
    os.makedirs(sample_dir, exist_ok=True)
    
    copied_count = 0
    for i, img_path in enumerate(sample_paths):
        try:
            dst_path = f'{sample_dir}/sample_{i:02d}_{os.path.basename(img_path)}'
            shutil.copy2(img_path, dst_path)
            copied_count += 1
        except Exception as e:
            print(f"Warning: Error copying sample {img_path}: {e}")

    print(f"Saved {copied_count} sample images to {sample_dir}")

# Run characterization for all valid datasets
if valid_datasets:
    dataset_profiles = {}
    
    for dataset_name, dataset_path in valid_datasets.items():
        print(f"\n{'='*60}")
        profile = characterize_dataset(identifier, dataset_path, dataset_name, N_SAMPLES_PER_DATASET)
        if profile:
            dataset_profiles[dataset_name] = profile
            identifier.dataset_profiles[dataset_name] = profile
            print(f"{dataset_name} characterized")

            # Show key metrics
            stats = profile['characteristics_stats']
            print(f"Key metrics (mean ± std):")
            if 'brightness' in stats:
                print(f"      Brightness: {stats['brightness']['mean']:.1f} ± {stats['brightness']['std']:.1f}")
            if 'sharpness' in stats:
                print(f"      Sharpness: {stats['sharpness']['mean']:.1f} ± {stats['sharpness']['std']:.1f}")
            if 'illumination_uniformity' in stats:
                print(f"      Illumination uniformity: {stats['illumination_uniformity']['mean']:.3f} ± {stats['illumination_uniformity']['std']:.3f}")
        else:
            print(f"Failed characterization {dataset_name}")

    print(f"\nCharacterization complete! Profiles created for {len(dataset_profiles)} datasets.")
else:
    print("Skipping characterization - no valid datasets")

In [None]:
# 7. Quality Issue Identification - Updated with 3-Component Scoring

def assess_image_quality_corrected(characteristics, profile, dr_severity):
    """Updated quality assessment with 3-component scoring system"""
    
    char_stats = profile['characteristics_stats']
    normalized_scores = {}
    
    # Dataset-relative normalization for technical metrics (BasicQuality)
    technical_metrics = ['brightness', 'contrast', 'sharpness', 'entropy']
    for metric in technical_metrics:
        if metric in char_stats and metric in characteristics:
            stats = char_stats[metric]
            value = characteristics[metric]
            
            if stats['std'] > 0:
                z_score = (value - stats['mean']) / stats['std']
                normalized_scores[metric] = max(0, min(1, (z_score + 3) / 6))
            else:
                normalized_scores[metric] = 0.5
    
    # Direct normalization for medical metrics (MedicalQuality) with validated bounds
    medical_metrics = {
        'illumination_uniformity': lambda x: min(1.0, max(0.0, x)),
        'vessel_visibility': lambda x: min(1.0, max(0.0, x * 100)),
        'optic_disc_visibility': lambda x: min(1.0, max(0.0, x * 50))
    }
    
    for metric, normalizer in medical_metrics.items():
        if metric in characteristics:
            normalized_scores[metric] = normalizer(characteristics[metric])
    
    # NEW: Normalization for technical quality metrics (TechnicalQuality)
    # Extreme pixels normalization (lower is better, so invert)
    if 'extreme_brightness_pixels' in characteristics:
        ep_value = characteristics['extreme_brightness_pixels']
        # Invert so that fewer extreme pixels = higher score
        normalized_scores['extreme_brightness_pixels'] = max(0, min(1, 1 - (ep_value * 2)))
    
    # Motion blur normalization (higher motion blur score is better)
    if 'motion_blur_score' in characteristics:
        mb_value = characteristics['motion_blur_score']
        # Normalize to reasonable range (0-100 typical for gradient magnitude)
        normalized_scores['motion_blur_score'] = min(1.0, max(0.0, mb_value / 50.0))
    
    # Color balance normalization (lower standard deviation is better, so invert)
    if 'color_balance' in characteristics:
        cb_value = characteristics['color_balance']
        # Invert so that better color balance (lower std) = higher score
        # Typical color balance std ranges from 0-50
        normalized_scores['color_balance'] = max(0, min(1, 1 - (cb_value / 50.0)))
    
    # Calculate composite scores with NEW 3-component system
    basic_metrics = ['brightness', 'contrast', 'sharpness', 'entropy']
    medical_metrics_list = ['illumination_uniformity', 'vessel_visibility', 'optic_disc_visibility']
    technical_metrics_list = ['extreme_brightness_pixels', 'motion_blur_score', 'color_balance']
    
    basic_score = np.mean([normalized_scores.get(m, 0.5) for m in basic_metrics])
    medical_score = np.mean([normalized_scores.get(m, 0.5) for m in medical_metrics_list])
    technical_score = np.mean([normalized_scores.get(m, 0.5) for m in technical_metrics_list])
    
    # NEW COMPOSITE FORMULA: CQ = 0.25 × BasicQuality + 0.55 × MedicalQuality + 0.20 × TechnicalQuality
    overall_score = 0.25 * basic_score + 0.55 * medical_score + 0.20 * technical_score
    
    threshold = profile['adaptive_thresholds'].get(dr_severity, 0.3)
    
    # Enhanced removal criteria (keeping existing logic)
    removal_reasons = []
    
    # Critical quality checks
    if characteristics['sharpness'] < char_stats.get('sharpness', {}).get('percentiles', {}).get('5', 0):
        removal_reasons.append('extremely_blurry')
    
    if characteristics['brightness'] < 20 or characteristics['brightness'] > 240:
        removal_reasons.append('extreme_brightness')
    
    if characteristics['illumination_uniformity'] < 0.1:
        removal_reasons.append('poor_illumination')
    
    if characteristics['vessel_visibility'] < char_stats.get('vessel_visibility', {}).get('percentiles', {}).get('10', 0):
        removal_reasons.append('poor_vessel_visibility')
    
    if characteristics['extreme_brightness_pixels'] > 0.3:
        removal_reasons.append('too_many_extreme_pixels')
    
    if characteristics['file_size_mb'] < 0.1:
        removal_reasons.append('file_too_small')
    
    w, h = characteristics['resolution']
    if w < 224 or h < 224:
        removal_reasons.append('resolution_too_low')
    
    # NEW: Additional technical quality checks
    if characteristics.get('motion_blur_score', 0) < 5:  # Very low motion blur score indicates severe blur
        removal_reasons.append('severe_motion_blur')
    
    if characteristics.get('color_balance', 0) > 40:  # High color balance std indicates severe color cast
        removal_reasons.append('severe_color_imbalance')
    
    # Decision logic (keeping existing logic)
    if len(removal_reasons) >= 2:
        action = 'REMOVE'
        confidence = 'HIGH'
    elif overall_score < threshold:
        action = 'REMOVE'
        confidence = 'MEDIUM'
    else:
        action = 'KEEP'
        confidence = 'HIGH' if overall_score > threshold + 0.1 else 'MEDIUM'
    
    return {
        'overall_score': overall_score,
        'basic_score': basic_score,      # NEW: Return individual component scores
        'medical_score': medical_score,  # NEW: Return individual component scores  
        'technical_score': technical_score,  # NEW: Return individual component scores
        'threshold': threshold,
        'action': action,
        'reasons': removal_reasons,
        'confidence': confidence,
        'normalized_scores': normalized_scores
    }

def identify_quality_issues(identifier, dataset_path, dataset_name, profile):
    print(f"Identifying quality issues in {dataset_name}")
    
    results = []
    processed_count = 0
    error_count = 0
    
    image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
    
    total_images = 0
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(image_extensions):
                image_path = os.path.join(root, file)
                dr_severity = identifier.extract_dr_label_from_path(image_path)
                if dr_severity is not None:
                    total_images += 1
    
    print(f"  Found {total_images:,} images to analyze")
    
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(image_extensions):
                image_path = os.path.join(root, file)
                dr_severity = identifier.extract_dr_label_from_path(image_path)
                
                if dr_severity is None:
                    continue
                
                characteristics = identifier.analyze_single_image(image_path)
                if not characteristics:
                    error_count += 1
                    continue
                
                quality_assessment = assess_image_quality_corrected(characteristics, profile, dr_severity)
                
                result = {
                    'dataset': dataset_name,
                    'image_path': image_path,
                    'filename': file,
                    'dr_severity': dr_severity,
                    'overall_quality_score': quality_assessment['overall_score'],
                    'basic_quality_score': quality_assessment['basic_score'],      # NEW
                    'medical_quality_score': quality_assessment['medical_score'],  # NEW
                    'technical_quality_score': quality_assessment['technical_score'],  # NEW
                    'threshold_used': quality_assessment['threshold'],
                    'recommended_action': quality_assessment['action'],
                    'removal_reasons': quality_assessment['reasons'],
                    'confidence': quality_assessment['confidence']
                }
                
                result.update(characteristics)
                result.update({f'normalized_{k}': v for k, v in quality_assessment['normalized_scores'].items()})
                
                results.append(result)
                processed_count += 1
                
                if processed_count % 1000 == 0:
                    progress = processed_count / total_images * 100 if total_images > 0 else 0
                    print(f"    Progress: {processed_count:,}/{total_images:,} ({progress:.1f}%)")
                    gc.collect()
    
    print(f"  Analysis completed: {processed_count:,} images processed")
    if error_count > 0:
        print(f"  Images with errors: {error_count}")
    
    return results

# Run quality issue identification
if valid_datasets and dataset_profiles:
    print("Starting quality issue identification with 3-component scoring")
    print("NEW FORMULA: CQ = 0.25 × BasicQuality + 0.55 × MedicalQuality + 0.20 × TechnicalQuality")
    print("=" * 80)
    print("BasicQuality: brightness, contrast, sharpness, entropy")
    print("MedicalQuality: illumination uniformity, vessel visibility, optic disc visibility")
    print("TechnicalQuality: extreme pixels, motion blur, color balance")
    print("=" * 80)
    
    all_results = []
    
    for dataset_name, profile in dataset_profiles.items():
        print(f"\n{'='*60}")
        dataset_path = valid_datasets[dataset_name]
        results = identify_quality_issues(identifier, dataset_path, dataset_name, profile)
        all_results.extend(results)
        
        flagged = [r for r in results if r['recommended_action'] == 'REMOVE']
        flagged_count = len(flagged)
        total_count = len(results)
        removal_rate = flagged_count / total_count * 100 if total_count > 0 else 0
        
        print(f"  Results for {dataset_name}:")
        print(f"     Total analyzed: {total_count:,}")
        print(f"     Flagged for removal: {flagged_count:,} ({removal_rate:.1f}%)")
        
        # Show average component scores
        if results:
            avg_basic = np.mean([r['basic_quality_score'] for r in results])
            avg_medical = np.mean([r['medical_quality_score'] for r in results])
            avg_technical = np.mean([r['technical_quality_score'] for r in results])
            avg_overall = np.mean([r['overall_quality_score'] for r in results])
            
            print(f"     Average quality scores:")
            print(f"       Basic (25%): {avg_basic:.3f}")
            print(f"       Medical (55%): {avg_medical:.3f}")
            print(f"       Technical (20%): {avg_technical:.3f}")
            print(f"       Overall: {avg_overall:.3f}")
        
        dr_stats = defaultdict(lambda: {'total': 0, 'flagged': 0})
        for result in results:
            dr_class = result['dr_severity']
            dr_stats[dr_class]['total'] += 1
            if result['recommended_action'] == 'REMOVE':
                dr_stats[dr_class]['flagged'] += 1
        
        print(f"     DR class breakdown:")
        dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
        for dr_class in sorted(dr_stats.keys()):
            stats = dr_stats[dr_class]
            class_removal_rate = stats['flagged'] / stats['total'] * 100 if stats['total'] > 0 else 0
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            print(f"       {dr_name}: {stats['flagged']:,}/{stats['total']:,} ({class_removal_rate:.1f}%)")
    
    print(f"\n{'='*80}")
    print(f"QUALITY IDENTIFICATION COMPLETED")
    print(f"{'='*80}")
    print(f"Total images analyzed: {len(all_results):,}")
    
    # Overall statistics with component breakdown
    if all_results:
        total_flagged = len([r for r in all_results if r['recommended_action'] == 'REMOVE'])
        overall_removal_rate = total_flagged / len(all_results) * 100
        
        avg_basic_all = np.mean([r['basic_quality_score'] for r in all_results])
        avg_medical_all = np.mean([r['medical_quality_score'] for r in all_results])
        avg_technical_all = np.mean([r['technical_quality_score'] for r in all_results])
        avg_overall_all = np.mean([r['overall_quality_score'] for r in all_results])
        
        print(f"Overall removal rate: {overall_removal_rate:.1f}%")
        print(f"Average quality scores across all datasets:")
        print(f"  Basic Quality (25%): {avg_basic_all:.3f}")
        print(f"  Medical Quality (55%): {avg_medical_all:.3f}")
        print(f"  Technical Quality (20%): {avg_technical_all:.3f}")
        print(f"  Overall Quality: {avg_overall_all:.3f}")
        
        print(f"\n3-component scoring system applied successfully!")
        print(f"All 14 quality metrics now properly utilized in assessment.")
        
else:
    print("Quality identification skipped - no valid datasets or profiles available")

In [None]:
# 8. Create Flagged Samples

# Create visual samples of flagged images for manual review
def create_flagged_samples(identifier, results, n_samples_per_dataset=20):
    print("Creating flagged image samples for review")
   
    # Group by dataset
    by_dataset = {}
    for result in results:
        dataset = result['dataset']
        if dataset not in by_dataset:
            by_dataset[dataset] = []
        by_dataset[dataset].append(result)
   
    total_samples_created = 0
   
    for dataset_name, dataset_results in by_dataset.items():
        print(f"\n  Processing {dataset_name}")
       
        # Get flagged images
        flagged = [r for r in dataset_results if r['recommended_action'] == 'REMOVE']
       
        if not flagged:
            print(f"    No flagged images found")
            continue
       
        print(f"    Found {len(flagged)} flagged images")
       
        # Sample different types of issues
        sample_dir = f'{identifier.output_dir}/flagged_samples/{dataset_name}'
        os.makedirs(sample_dir, exist_ok=True)
       
        # Group by removal reasons
        by_reason = {}
        for result in flagged:
            for reason in result['removal_reasons']:
                if reason not in by_reason:
                    by_reason[reason] = []
                by_reason[reason].append(result)
       
        print(f"    Issue types found: {list(by_reason.keys())}")
       
        # Sample from each reason category
        samples_copied = 0
        for reason, reason_results in by_reason.items():
            reason_samples = min(5, len(reason_results), n_samples_per_dataset - samples_copied)
            if reason_samples <= 0:
                continue
           
            # Sort by confidence and take most confident removals
            reason_results.sort(key=lambda x: x['overall_quality_score'])
           
            for i, result in enumerate(reason_results[:reason_samples]):
                try:
                    src_path = result['image_path']
                    dst_filename = f'{reason}_{i:02d}_{result["filename"]}'
                    dst_path = os.path.join(sample_dir, dst_filename)
                    shutil.copy2(src_path, dst_path)
                    samples_copied += 1
                except Exception as e:
                    print(f"      Error copying flagged sample {src_path}: {e}")
       
        print(f"    Created {samples_copied} flagged samples")
        total_samples_created += samples_copied
   
    print(f"\nTotal flagged samples created: {total_samples_created}")

# Run flagged sample creation
if 'all_results' in locals() and all_results:
    create_flagged_samples(identifier, all_results)
else:
    print("Flagged sample creation skipped - no results available")

In [None]:
# 9. Generate Comprehensive Report

# Generate comprehensive identification report with statistics and recommendations
def generate_identification_report(all_results):
    print("Generating identification report")
    
    # Convert to DataFrame for analysis
    df = pd.DataFrame(all_results)
    
    # Overall statistics
    total_images = len(df)
    flagged_for_removal = len(df[df['recommended_action'] == 'REMOVE'])
    removal_rate = flagged_for_removal / total_images if total_images > 0 else 0
    
    report = {
        'analysis_summary': {
            'total_images_analyzed': total_images,
            'images_flagged_for_removal': flagged_for_removal,
            'overall_removal_rate': removal_rate,
            'analysis_date': datetime.now().isoformat()
        },
        'dataset_breakdown': {},
        'removal_reasons_summary': {},
        'quality_score_statistics': {},
        'recommendations': []
    }
    
    # Per-dataset breakdown
    for dataset in df['dataset'].unique():
        dataset_data = df[df['dataset'] == dataset]
        dataset_flagged = len(dataset_data[dataset_data['recommended_action'] == 'REMOVE'])
        dataset_total = len(dataset_data)
        
        # Per-class breakdown
        class_breakdown = {}
        for dr_class in range(5):
            class_data = dataset_data[dataset_data['dr_severity'] == dr_class]
            if len(class_data) > 0:
                class_flagged = len(class_data[class_data['recommended_action'] == 'REMOVE'])
                class_breakdown[dr_class] = {
                    'total': len(class_data),
                    'flagged': class_flagged,
                    'removal_rate': class_flagged / len(class_data)
                }
        
        report['dataset_breakdown'][dataset] = {
            'total_images': dataset_total,
            'flagged_images': dataset_flagged,
            'removal_rate': dataset_flagged / dataset_total if dataset_total > 0 else 0,
            'class_breakdown': class_breakdown,
            'avg_quality_score': float(dataset_data['overall_quality_score'].mean()),
            'quality_score_std': float(dataset_data['overall_quality_score'].std())
        }
    
    # Removal reasons summary
    all_reasons = []
    for _, row in df.iterrows():
        if row['recommended_action'] == 'REMOVE':
            all_reasons.extend(row['removal_reasons'])
    
    reason_counts = Counter(all_reasons)
    report['removal_reasons_summary'] = dict(reason_counts)
    
    # Quality score statistics
    report['quality_score_statistics'] = {
        'mean': float(df['overall_quality_score'].mean()),
        'std': float(df['overall_quality_score'].std()),
        'min': float(df['overall_quality_score'].min()),
        'max': float(df['overall_quality_score'].max()),
        'percentiles': {
            '10': float(df['overall_quality_score'].quantile(0.1)),
            '25': float(df['overall_quality_score'].quantile(0.25)),
            '50': float(df['overall_quality_score'].quantile(0.5)),
            '75': float(df['overall_quality_score'].quantile(0.75)),
            '90': float(df['overall_quality_score'].quantile(0.9))
        }
    }
    
    # Generate recommendations
    if removal_rate > 0.4:
        report['recommendations'].append("HIGH removal rate detected. Consider relaxing quality thresholds.")
    
    if removal_rate < 0.05:
        report['recommendations'].append("LOW removal rate detected. Consider tightening quality thresholds.")
    
    for dataset, stats in report['dataset_breakdown'].items():
        if stats['removal_rate'] > 0.5:
            report['recommendations'].append(f"Very high removal rate for {dataset}. Review dataset-specific thresholds.")
        
        # Check for class imbalance in removal
        class_rates = [info['removal_rate'] for info in stats['class_breakdown'].values()]
        if class_rates and max(class_rates) - min(class_rates) > 0.3:
            report['recommendations'].append(f"Uneven removal rates across DR classes in {dataset}. Consider class-specific adjustments.")
    
    return report

# Generate comprehensive report
if 'all_results' in locals() and all_results:
    report = generate_identification_report(all_results)
    print("Report generated successfully")
else:
    print("Report generation skipped - no results available")

In [None]:
# 10. Save All Results

# Save all identification results to files
def save_identification_results(identifier, results, report):
    print("Saving identification results")
   
    # Save detailed results
    df = pd.DataFrame(results)
    results_file = f'{identifier.output_dir}/quality_identification_results.csv'
    df.to_csv(results_file, index=False)
    print(f"  Detailed results saved: {results_file}")
   
    # Save dataset profiles
    profiles_file = f'{identifier.output_dir}/dataset_profiles.json'
    with open(profiles_file, 'w') as f:
        json.dump(identifier.dataset_profiles, f, indent=2)
    print(f"  Dataset profiles saved: {profiles_file}")
   
    # Save identification report
    report_file = f'{identifier.output_dir}/identification_report.json'
    with open(report_file, 'w') as f:
        json.dump(report, f, indent=2, default=str)
    print(f"  Analysis report saved: {report_file}")
   
    # Create summary CSV for easy review
    summary_data = []
    for _, row in df.iterrows():
        if row['recommended_action'] == 'REMOVE':
            summary_data.append({
                'dataset': row['dataset'],
                'filename': row['filename'],
                'dr_severity': row['dr_severity'],
                'quality_score': row['overall_quality_score'],
                'confidence': row['confidence'],
                'reasons': ', '.join(row['removal_reasons']),
                'image_path': row['image_path']
            })
   
    summary_df = pd.DataFrame(summary_data)
    summary_file = f'{identifier.output_dir}/flagged_images_summary.csv'
    summary_df.to_csv(summary_file, index=False)
    print(f"  Flagged summary saved: {summary_file}")
   
    return results_file, summary_file, profiles_file, report_file

# Save all results
if 'all_results' in locals() and 'report' in locals() and all_results:
    results_file, summary_file, profiles_file, report_file = save_identification_results(identifier, all_results, report)
    print("\nAll results saved successfully")
else:
    print("Results saving skipped - no data available")

In [None]:
# 11. Final Summary and Visualization

# Display comprehensive final summary
if 'all_results' in locals() and 'report' in locals() and all_results:
    print("="*80)
    print("QUALITY IDENTIFICATION COMPLETE")
    print("="*80)
   
    # Overall statistics
    total_images = len(all_results)
    flagged_images = len([r for r in all_results if r['recommended_action'] == 'REMOVE'])
    removal_percentage = flagged_images / total_images * 100 if total_images > 0 else 0
   
    print(f"\nOVERALL STATISTICS:")
    print(f"   Total images analyzed: {total_images:,}")
    print(f"   Images flagged for removal: {flagged_images:,}")
    print(f"   Overall removal rate: {removal_percentage:.1f}%")
   
    # Per-dataset breakdown
    print(f"\nPER-DATASET BREAKDOWN:")
    dataset_stats = defaultdict(lambda: {'total': 0, 'flagged': 0})
   
    for result in all_results:
        dataset = result['dataset']
        dataset_stats[dataset]['total'] += 1
        if result['recommended_action'] == 'REMOVE':
            dataset_stats[dataset]['flagged'] += 1
   
    for dataset, stats in dataset_stats.items():
        removal_rate = stats['flagged'] / stats['total'] * 100 if stats['total'] > 0 else 0
        print(f"   {dataset}: {stats['flagged']:,}/{stats['total']:,} ({removal_rate:.1f}%)")
   
    # Per-DR class breakdown
    print(f"\nPER-DR CLASS BREAKDOWN:")
    dr_stats = defaultdict(lambda: {'total': 0, 'flagged': 0})
   
    for result in all_results:
        dr_class = result['dr_severity']
        dr_stats[dr_class]['total'] += 1
        if result['recommended_action'] == 'REMOVE':
            dr_stats[dr_class]['flagged'] += 1
   
    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
    for dr_class in sorted(dr_stats.keys()):
        stats = dr_stats[dr_class]
        removal_rate = stats['flagged'] / stats['total'] * 100 if stats['total'] > 0 else 0
        dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
        print(f"   {dr_name}: {stats['flagged']:,}/{stats['total']:,} ({removal_rate:.1f}%)")
   
    # Top removal reasons
    print(f"\nTOP REMOVAL REASONS:")
    if 'removal_reasons_summary' in report:
        sorted_reasons = sorted(report['removal_reasons_summary'].items(), key=lambda x: x[1], reverse=True)
        for reason, count in sorted_reasons[:10]:  # Top 10 reasons
            percentage = count / flagged_images * 100 if flagged_images > 0 else 0
            print(f"   {reason}: {count:,} ({percentage:.1f}% of flagged images)")
   
    # Recommendations
    if 'recommendations' in report and report['recommendations']:
        print(f"\nRECOMMENDATIONS:")
        for i, recommendation in enumerate(report['recommendations'], 1):
            print(f"   {i}. {recommendation}")
   
    # File locations
    print(f"\nRESULTS SAVED TO:")
    print(f"   Flagged images summary: {OUTPUT_DIR}/flagged_images_summary.csv")
    print(f"   Detailed results: {OUTPUT_DIR}/quality_identification_results.csv")
    print(f"   Analysis report: {OUTPUT_DIR}/identification_report.json")
    print(f"   Sample images: {OUTPUT_DIR}/sample_images/")
    print(f"   Flagged samples: {OUTPUT_DIR}/flagged_samples/")
   
    print(f"\nNEXT STEPS:")
    print(f"   1. Review flagged samples in: {OUTPUT_DIR}/flagged_samples/")
    print(f"   2. Check the summary CSV for a list of all flagged images")
    print(f"   3. Adjust thresholds if needed (modify dataset profiles)")
    print(f"   4. Use the detailed results for actual image removal when ready")
   
    print(f"\nQuality identification process completed successfully")
else:
    print("No results to summarize - please run all previous cells first")

In [None]:
# 12. Quick Visualization

# Create basic visualizations for quality analysis results
if 'all_results' in locals() and all_results:
    print("Creating basic visualizations")
    
    # Set up matplotlib
    plt.style.use('default')
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('DR Dataset Quality Analysis Results', fontsize=16, fontweight='bold')
    
    df = pd.DataFrame(all_results)
    
    # 1. Removal rates by dataset
    ax1 = axes[0, 0]
    dataset_removal_rates = []
    dataset_names = []
    
    for dataset in df['dataset'].unique():
        dataset_data = df[df['dataset'] == dataset]
        removal_rate = len(dataset_data[dataset_data['recommended_action'] == 'REMOVE']) / len(dataset_data) * 100
        dataset_removal_rates.append(removal_rate)
        dataset_names.append(dataset)
    
    bars1 = ax1.bar(range(len(dataset_names)), dataset_removal_rates, color='lightcoral', alpha=0.7)
    ax1.set_xlabel('Dataset')
    ax1.set_ylabel('Removal Rate (%)')
    ax1.set_title('Removal Rates by Dataset')
    ax1.set_xticks(range(len(dataset_names)))
    ax1.set_xticklabels(dataset_names, rotation=45, ha='right')
    
    # Add value labels on bars
    for bar, rate in zip(bars1, dataset_removal_rates):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{rate:.1f}%', ha='center', va='bottom')
    
    # 2. Removal rates by DR class
    ax2 = axes[0, 1]
    dr_removal_rates = []
    dr_labels = []
    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
    
    for dr_class in sorted(df['dr_severity'].unique()):
        dr_data = df[df['dr_severity'] == dr_class]
        removal_rate = len(dr_data[dr_data['recommended_action'] == 'REMOVE']) / len(dr_data) * 100
        dr_removal_rates.append(removal_rate)
        dr_label = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
        dr_labels.append(dr_label)
    
    bars2 = ax2.bar(range(len(dr_labels)), dr_removal_rates, color='lightblue', alpha=0.7)
    ax2.set_xlabel('DR Severity')
    ax2.set_ylabel('Removal Rate (%)')
    ax2.set_title('Removal Rates by DR Severity')
    ax2.set_xticks(range(len(dr_labels)))
    ax2.set_xticklabels(dr_labels, rotation=45, ha='right')
    
    # Add value labels on bars
    for bar, rate in zip(bars2, dr_removal_rates):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{rate:.1f}%', ha='center', va='bottom')
    
    # 3. Quality score distribution
    ax3 = axes[1, 0]
    keep_scores = df[df['recommended_action'] == 'KEEP']['overall_quality_score']
    remove_scores = df[df['recommended_action'] == 'REMOVE']['overall_quality_score']
    
    ax3.hist(keep_scores, bins=30, alpha=0.7, label='Keep', color='lightgreen', density=True)
    ax3.hist(remove_scores, bins=30, alpha=0.7, label='Remove', color='lightcoral', density=True)
    ax3.set_xlabel('Quality Score')
    ax3.set_ylabel('Density')
    ax3.set_title('Quality Score Distribution')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Top removal reasons
    ax4 = axes[1, 1]
    if 'removal_reasons_summary' in report:
        reasons = list(report['removal_reasons_summary'].keys())[:10]  # Top 10
        counts = [report['removal_reasons_summary'][reason] for reason in reasons]
        
        bars4 = ax4.barh(range(len(reasons)), counts, color='orange', alpha=0.7)
        ax4.set_ylabel('Removal Reason')
        ax4.set_xlabel('Count')
        ax4.set_title('Top Removal Reasons')
        ax4.set_yticks(range(len(reasons)))
        ax4.set_yticklabels([reason.replace('_', ' ').title() for reason in reasons])
        
        # Add value labels on bars
        for bar, count in zip(bars4, counts):
            width = bar.get_width()
            ax4.text(width + max(counts)*0.01, bar.get_y() + bar.get_height()/2.,
                    f'{count}', ha='left', va='center')
    
    plt.tight_layout()
    
    # Save the plot
    plot_file = f'{OUTPUT_DIR}/analysis_summary_plots.png'
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Visualization saved to: {plot_file}")

else:
    print("Visualization creation skipped - no results available")

In [None]:
# 13. Final Checkpoint - Verification

# Final verification that all components completed successfully
print("FINAL VERIFICATION:")
print("=" * 50)

# Check if all major components completed
checks = {
    'Valid datasets found': 'valid_datasets' in locals() and bool(valid_datasets),
    'Dataset profiles created': 'dataset_profiles' in locals() and bool(dataset_profiles),
    'Quality analysis completed': 'all_results' in locals() and bool(all_results),
    'Report generated': 'report' in locals() and bool(report),
    'Results saved': os.path.exists(f'{OUTPUT_DIR}/flagged_images_summary.csv'),
    'Sample images created': os.path.exists(f'{OUTPUT_DIR}/sample_images'),
    'Flagged samples created': os.path.exists(f'{OUTPUT_DIR}/flagged_samples')
}

all_good = True
for check_name, check_result in checks.items():
    status = "PASS" if check_result else "FAIL"
    print(f"{status}: {check_name}")
    if not check_result:
        all_good = False

if all_good:
    print(f"\nSUCCESS: All components completed successfully")
    print(f"Check the '{OUTPUT_DIR}' directory for all results")
   
    if 'all_results' in locals():
        total = len(all_results)
        flagged = len([r for r in all_results if r['recommended_action'] == 'REMOVE'])
        print(f"Final count: {flagged:,} of {total:,} images flagged for removal ({flagged/total*100:.1f}%)")
else:
    print(f"\nWARNING: Some components did not complete successfully")
    print(f"Please review the error messages above and re-run the failed sections")

print(f"\nIMPORTANT: This analysis only IDENTIFIED potential issues")
print(f"NO IMAGES WERE ACTUALLY REMOVED from your datasets")
print(f"Review the flagged samples before deciding on actual removal")

In [None]:
# 14. Check what sample images were saved
import os
sample_base_dir = f'{identifier.output_dir}/sample_images'
print("Sample images saved:")
for dataset_name in os.listdir(sample_base_dir):
    dataset_sample_dir = os.path.join(sample_base_dir, dataset_name)
    if os.path.isdir(dataset_sample_dir):
        num_samples = len([f for f in os.listdir(dataset_sample_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
        print(f"   {dataset_name}: {num_samples} sample images in {dataset_sample_dir}")

In [None]:
# 15. Problematic Images Review

# Find and organize problematic images for manual review
def find_and_review_problematic_images(identifier, dataset_profiles,
                                     review_dir='problematic_images_review_all'):
    
    # Create review directory
    os.makedirs(review_dir, exist_ok=True)
    
    print("Finding problematic images based on quality thresholds")
    
    for dataset_name, profile in dataset_profiles.items():
        print(f"\nProcessing {dataset_name}")
        
        # Get adaptive thresholds for this dataset
        thresholds = profile['adaptive_thresholds']
        
        # Create subdirectories for this dataset
        dataset_review_dir = os.path.join(review_dir, dataset_name)
        os.makedirs(dataset_review_dir, exist_ok=True)
        
        # Analyze all images in the dataset
        dataset_path = valid_datasets[dataset_name]
        all_images = []
        
        # Collect all images from ALL subdirectories (all DR severity levels)
        for root, dirs, files in os.walk(dataset_path):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    all_images.append(os.path.join(root, file))
        
        print(f"   Analyzing {len(all_images)} images")
        
        problematic_by_severity = {0: [], 1: [], 2: [], 3: [], 4: []}
        
        # Analyze each image
        for i, img_path in enumerate(all_images):
            if i % 500 == 0 and i > 0:
                print(f"      Progress: {i}/{len(all_images)} ({i/len(all_images)*100:.1f}%)")
            
            try:
                # Get image characteristics
                char = identifier.analyze_single_image(img_path)
                if not char:
                    continue
                
                # Calculate quality score (same logic as in profiling)
                brightness = float(char['brightness']) if char['brightness'] is not None else 127.5
                contrast = float(char['contrast']) if char['contrast'] is not None else 50.0
                sharpness = float(char['sharpness']) if char['sharpness'] is not None else 500.0
                entropy = float(char['entropy']) if char['entropy'] is not None else 4.0
                illumination_uniformity = float(char['illumination_uniformity']) if char['illumination_uniformity'] is not None else 0.5
                vessel_visibility = float(char['vessel_visibility']) if char['vessel_visibility'] is not None else 0.1
                optic_disc_visibility = float(char['optic_disc_visibility']) if char['optic_disc_visibility'] is not None else 0.1
                
                # Normalize metrics
                brightness_norm = min(1.0, max(0.0, brightness / 255.0))
                contrast_norm = min(1.0, max(0.0, contrast / 100.0))
                sharpness_norm = min(1.0, max(0.0, sharpness / 1000.0))
                entropy_norm = min(1.0, max(0.0, entropy / 8.0))
                
                basic_quality = np.mean([brightness_norm, contrast_norm, sharpness_norm, entropy_norm])
                medical_quality = np.mean([
                    illumination_uniformity,
                    min(1.0, vessel_visibility * 10),
                    min(1.0, optic_disc_visibility * 10)
                ])
                
                combined_quality = 0.3 * basic_quality + 0.7 * medical_quality
                
                # Determine DR severity from path or filename
                dr_severity = get_dr_severity_from_path(img_path)
                
                # Check if image is below threshold for its DR severity
                if dr_severity in thresholds and combined_quality < thresholds[dr_severity]:
                    problematic_by_severity[dr_severity].append({
                        'path': img_path,
                        'quality_score': combined_quality,
                        'threshold': thresholds[dr_severity],
                        'characteristics': char
                    })
            
            except Exception as e:
                print(f"      Error analyzing {img_path}: {e}")
                continue
        
        # Save problematic images by severity
        total_problematic = 0
        for severity, images in problematic_by_severity.items():
            if images:
                severity_dir = os.path.join(dataset_review_dir, f'DR_severity_{severity}')
                os.makedirs(severity_dir, exist_ok=True)
                
                # Sort by quality score (worst first)
                images.sort(key=lambda x: x['quality_score'])
                
                print(f"   DR Severity {severity}: {len(images)} problematic images")
                
                # Copy worst 50 images for manual review
                for i, img_info in enumerate(images[:50]):
                    try:
                        src_path = img_info['path']
                        dst_name = f"quality_{img_info['quality_score']:.3f}_{os.path.basename(src_path)}"
                        dst_path = os.path.join(severity_dir, dst_name)
                        shutil.copy2(src_path, dst_path)
                        
                        # Save characteristics as text file
                        txt_path = dst_path.replace('.jpg', '.txt').replace('.jpeg', '.txt').replace('.png', '.txt')
                        with open(txt_path, 'w') as f:
                            f.write(f"Quality Score: {img_info['quality_score']:.3f}\n")
                            f.write(f"Threshold: {img_info['threshold']:.3f}\n")
                            f.write(f"Original Path: {src_path}\n\n")
                            f.write("Characteristics:\n")
                            for key, value in img_info['characteristics'].items():
                                f.write(f"  {key}: {value}\n")
                    
                    except Exception as e:
                        print(f"      Error copying {src_path}: {e}")
                
                total_problematic += len(images)
        
        print(f"   Found {total_problematic} total problematic images in {dataset_name}")
        print(f"   Review images saved to: {dataset_review_dir}")
    
    print(f"\nReview complete. Check the '{review_dir}' directory")

# Extract DR severity from image path or filename
def get_dr_severity_from_path(img_path):
    import re
    path_lower = img_path.lower()
    
    # First, check for direct folder patterns like /0/, /1/, /2/, /3/, /4/
    folder_match = re.search(r'[/\\]([0-4])[/\\]', img_path)
    if folder_match:
        return int(folder_match.group(1))
    
    # Check for keyword-based patterns
    if 'no_dr' in path_lower or 'grade_0' in path_lower or '_0_' in path_lower:
        return 0
    elif 'mild' in path_lower or 'grade_1' in path_lower or '_1_' in path_lower:
        return 1
    elif 'moderate' in path_lower or 'grade_2' in path_lower or '_2_' in path_lower:
        return 2
    elif 'severe' in path_lower or 'grade_3' in path_lower or '_3_' in path_lower:
        return 3
    elif 'proliferative' in path_lower or 'grade_4' in path_lower or '_4_' in path_lower:
        return 4
    
    # Look for patterns like "grade_2", "severity_1", etc.
    grade_match = re.search(r'grade[_\-]?([0-4])', path_lower)
    if grade_match:
        return int(grade_match.group(1))
    
    severity_match = re.search(r'severity[_\-]?([0-4])', path_lower)
    if severity_match:
        return int(severity_match.group(1))
    
    # Look for numeric patterns in filename
    filename_match = re.search(r'[_\-]([0-4])[_\-\.]', path_lower)
    if filename_match:
        return int(filename_match.group(1))
    
    # Default to 0 if can't determine
    print(f"   Could not determine DR severity for: {img_path}")
    return 0

# Create an HTML file for easy image review
def create_review_html(review_dir='problematic_images_review_all'):
    html_content = """
    <!DOCTYPE html>
    <html>
    <head>
        <title>Problematic Images Review</title>
        <style>
            body { font-family: Arial, sans-serif; margin: 20px; }
            .dataset { margin-bottom: 30px; border: 1px solid #ccc; padding: 15px; }
            .severity { margin-bottom: 20px; }
            .image-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 10px; }
            .image-item { border: 1px solid #ddd; padding: 10px; text-align: center; }
            .image-item img { max-width: 100%; height: 150px; object-fit: cover; }
            .quality-score { font-weight: bold; color: red; }
        </style>
    </head>
    <body>
        <h1>Problematic Images Review</h1>
    """
    
    for dataset_name in os.listdir(review_dir):
        dataset_path = os.path.join(review_dir, dataset_name)
        if not os.path.isdir(dataset_path):
            continue
        
        html_content += f'<div class="dataset"><h2>{dataset_name}</h2>'
        
        for severity_dir in sorted(os.listdir(dataset_path)):
            severity_path = os.path.join(dataset_path, severity_dir)
            if not os.path.isdir(severity_path):
                continue
            
            html_content += f'<div class="severity"><h3>{severity_dir}</h3><div class="image-grid">'
            
            images = [f for f in os.listdir(severity_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
            for img_file in sorted(images)[:20]:  # Show first 20
                quality_score = img_file.split('_')[1] if '_' in img_file else 'unknown'
                img_rel_path = os.path.join(dataset_name, severity_dir, img_file).replace('\\', '/')
                
                html_content += f'''
                <div class="image-item">
                    <img src="{img_rel_path}" alt="{img_file}">
                    <div class="quality-score">Quality: {quality_score}</div>
                    <div>{img_file}</div>
                </div>
                '''
            
            html_content += '</div></div>'
        
        html_content += '</div>'
    
    html_content += '</body></html>'
    
    html_path = os.path.join(review_dir, 'review.html')
    with open(html_path, 'w') as f:
        f.write(html_content)
    
    print(f"HTML review file created: {html_path}")
    print("   Open this file in your web browser to easily review problematic images")

# Run the analysis
print("Starting problematic image identification")
find_and_review_problematic_images(identifier, dataset_profiles)

# Create HTML review file
create_review_html()

print("\nReview setup complete")
print("\nTo review the images:")
print("1. Check the 'problematic_images_review_all' directory")
print("2. Open 'problematic_images_review_all/review.html' in your web browser")
print("3. Each image has a quality score and characteristics file (.txt)")

In [None]:
# 16. Manual Threshold Adjustment (Optional)

print("MANUAL THRESHOLD ADJUSTMENT")
print("=" * 50)
print("Use this cell to fine-tune removal thresholds for each dataset")
print("Lower thresholds = more images flagged for removal")
print("Higher thresholds = fewer images flagged for removal")
print("Current thresholds are based on your dataset analysis")

# Check if dataset_profiles exists
if 'dataset_profiles' not in locals() or not dataset_profiles:
    print("\nERROR: Dataset profiles not found")
    print("Please run the previous cells first to create dataset profiles.")
    print("Required steps:")
    print("1. Run dataset validation (Cell 3)")
    print("2. Run dataset characterization (Cell 6)")
    print("3. Then come back to this cell")
else:
    # Display current thresholds
    print("\nCURRENT THRESHOLDS:")
    for dataset_name, profile in dataset_profiles.items():
        print(f"\n{dataset_name}:")
        for dr_class, threshold in profile['adaptive_thresholds'].items():
            dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            print(f"  {dr_name} (Class {dr_class}): {threshold:.3f}")

    print("\n" + "=" * 50)
    print("MANUAL ADJUSTMENT SECTION")
    print("Edit the values below to adjust thresholds")
    print("Set to None to keep current value")
    print("=" * 50)

    # Manual threshold override dictionary
    # Users can edit these values
    manual_thresholds = {
        'APTOS2019': {
            0: None,  # No DR - set to None to keep current, or set value like 0.250
            1: None,  # Mild DR
            2: None,  # Moderate DR  
            3: None,  # Severe DR
            4: None   # Proliferative DR
        },
        'Diabetic_Retinopathy_V03': {
            0: None,
            1: None,
            2: None,
            3: None,
            4: None
        },
        'IDRiD': {
            0: None,
            1: None,
            2: None,
            3: None,
            4: None
        },
        'Messidor2': {
            0: None,
            1: None,
            2: None,
            3: None,
            4: None
        },
        'SUSTech_SYSU': {
            0: None,
            1: None,
            2: None,
            3: None,
            4: None
        },
        'DeepDRiD': {
            0: None,
            1: None,
            2: None,
            3: None,
            4: None
        }
    }

    # Apply manual overrides
    updated_profiles = {}
    changes_made = False

    for dataset_name, profile in dataset_profiles.items():
        updated_profiles[dataset_name] = profile.copy()
        updated_profiles[dataset_name]['adaptive_thresholds'] = profile['adaptive_thresholds'].copy()
        
        if dataset_name in manual_thresholds:
            for dr_class, manual_threshold in manual_thresholds[dataset_name].items():
                if manual_threshold is not None:
                    old_threshold = profile['adaptive_thresholds'][dr_class]
                    updated_profiles[dataset_name]['adaptive_thresholds'][dr_class] = manual_threshold
                    changes_made = True
                    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
                    dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
                    print(f"Updated {dataset_name} - {dr_name}: {old_threshold:.3f} → {manual_threshold:.3f}")

    if changes_made:
        print(f"\nThreshold adjustments applied")
        print("Updated profiles will be used for final dataset creation.")
        # Update the global profiles
        dataset_profiles = updated_profiles
    else:
        print(f"\nNo manual adjustments made. Using original thresholds")

    print(f"\nFINAL THRESHOLDS AFTER ADJUSTMENT:")
    for dataset_name, profile in dataset_profiles.items():
        print(f"\n{dataset_name}:")
        for dr_class, threshold in profile['adaptive_thresholds'].items():
            dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            print(f"  {dr_name} (Class {dr_class}): {threshold:.3f}")

In [None]:
# 17. Dataset Cleaning with 3-Component Quality Scoring

def create_cleaned_dataset(source_datasets, profiles, copy_output_dir='DREAM_dataset_cleaned'):
    """
    Create cleaned dataset using 3-component quality scoring system.
    
    Args:
        source_datasets: Dictionary mapping dataset names to source paths
        profiles: Dictionary containing dataset quality profiles and thresholds
        copy_output_dir: Output directory for cleaned dataset
        
    Returns:
        tuple: (statistics_dict, processing_time)
    """
    
    print(f"Creating cleaned dataset: {copy_output_dir}")
    print("Quality scoring: CQ = 0.25×Basic + 0.55×Medical + 0.20×Technical")
    
    os.makedirs(copy_output_dir, exist_ok=True)
    
    # Initialize statistics tracking
    copy_stats = {
        'total_processed': 0, 'total_kept': 0, 'total_removed': 0,
        'by_dataset': {}, 'by_dr_class': {i: {'kept': 0, 'removed': 0} for i in range(5)}
    }
    
    start_time = time.time()
    
    # Process each dataset
    for dataset_name, dataset_path in source_datasets.items():
        print(f"\nProcessing {dataset_name}")
        
        if dataset_name not in profiles:
            print(f"  Profile not found for {dataset_name}, skipping")
            continue
        
        profile = profiles[dataset_name]
        thresholds = profile['adaptive_thresholds']
        
        # Create dataset output directory
        copy_dataset_dir = os.path.join(copy_output_dir, dataset_name)
        os.makedirs(copy_dataset_dir, exist_ok=True)
        
        # Initialize dataset-level statistics
        copy_dataset_stats = {
            'total_processed': 0, 'total_kept': 0, 'total_removed': 0,
            'by_dr_class': {i: {'kept': 0, 'removed': 0} for i in range(5)}
        }
        
        # Process each DR severity class
        for dr_class in range(5):
            dr_class_path = os.path.join(dataset_path, str(dr_class))
            
            if not os.path.exists(dr_class_path):
                continue
            
            # Create output directory for this DR class
            copy_dr_dir = os.path.join(copy_dataset_dir, str(dr_class))
            os.makedirs(copy_dr_dir, exist_ok=True)
            
            # Get image files
            image_files = [f for f in os.listdir(dr_class_path) 
                          if f.lower().endswith(('.jpg', '.jpeg', '.png', '.tiff', '.bmp'))]
            
            if not image_files:
                continue
                
            print(f"  DR Class {dr_class}: Processing {len(image_files)} images")
            
            threshold = thresholds.get(dr_class, 0.3)
            kept_count = 0
            removed_count = 0
            
            # Process images in this class
            for i, image_file in enumerate(image_files):
                if i % 500 == 0 and i > 0:
                    progress = (i / len(image_files)) * 100
                    print(f"    Progress: {i}/{len(image_files)} ({progress:.1f}%)")
                
                try:
                    source_image_path = os.path.join(dr_class_path, image_file)
                    
                    # Analyze image quality
                    char = identifier.analyze_single_image(source_image_path)
                    if not char:
                        continue
                    
                    # Calculate 3-component quality score
                    combined_quality = calculate_composite_quality(char)
                    
                    # Apply quality threshold
                    if combined_quality >= threshold:
                        # Keep image
                        copy_dest_path = os.path.join(copy_dr_dir, image_file)
                        shutil.copy2(source_image_path, copy_dest_path)
                        kept_count += 1
                        copy_dataset_stats['by_dr_class'][dr_class]['kept'] += 1
                        copy_stats['by_dr_class'][dr_class]['kept'] += 1
                    else:
                        # Remove image
                        removed_count += 1
                        copy_dataset_stats['by_dr_class'][dr_class]['removed'] += 1
                        copy_stats['by_dr_class'][dr_class]['removed'] += 1
                    
                    copy_dataset_stats['total_processed'] += 1
                    copy_stats['total_processed'] += 1
                    
                except Exception as e:
                    print(f"    Error processing {image_file}: {e}")
                    continue
            
            # Update statistics
            copy_dataset_stats['total_kept'] += kept_count
            copy_dataset_stats['total_removed'] += removed_count
            
            # Report class results
            if kept_count + removed_count > 0:
                removal_rate = (removed_count / (kept_count + removed_count)) * 100
                dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
                dr_name = dr_names[dr_class]
                print(f"    {dr_name}: Kept {kept_count}, Removed {removed_count} ({removal_rate:.1f}% removed)")
        
        # Update total statistics
        copy_stats['total_kept'] += copy_dataset_stats['total_kept']
        copy_stats['total_removed'] += copy_dataset_stats['total_removed']
        copy_stats['by_dataset'][dataset_name] = copy_dataset_stats
        
        # Dataset summary
        if copy_dataset_stats['total_processed'] > 0:
            dataset_removal_rate = (copy_dataset_stats['total_removed'] / copy_dataset_stats['total_processed']) * 100
            print(f"  {dataset_name} Summary: {copy_dataset_stats['total_processed']:,} processed, "
                  f"{copy_dataset_stats['total_kept']:,} kept, "
                  f"{copy_dataset_stats['total_removed']:,} removed ({dataset_removal_rate:.1f}%)")
    
    processing_time = time.time() - start_time
    
    # Save statistics
    copy_stats_file = os.path.join(copy_output_dir, 'cleaning_statistics.json')
    with open(copy_stats_file, 'w') as f:
        json.dump(copy_stats, f, indent=2)
    
    return copy_stats, processing_time

def calculate_composite_quality(characteristics):
    """
    Calculate composite quality score using 3-component system.
    
    Args:
        characteristics: Dictionary of image quality characteristics
        
    Returns:
        float: Composite quality score (0-1)
    """
    
    # Extract and validate metrics
    brightness = float(characteristics.get('brightness', 127.5))
    contrast = float(characteristics.get('contrast', 50.0))
    sharpness = float(characteristics.get('sharpness', 500.0))
    entropy = float(characteristics.get('entropy', 4.0))
    
    illumination_uniformity = float(characteristics.get('illumination_uniformity', 0.5))
    vessel_visibility = float(characteristics.get('vessel_visibility', 0.1))
    optic_disc_visibility = float(characteristics.get('optic_disc_visibility', 0.1))
    
    extreme_pixels = float(characteristics.get('extreme_brightness_pixels', 0.1))
    motion_blur = float(characteristics.get('motion_blur_score', 20.0))
    color_balance = float(characteristics.get('color_balance', 15.0))
    
    # Normalize Basic Quality metrics
    brightness_norm = min(1.0, max(0.0, brightness / 255.0))
    contrast_norm = min(1.0, max(0.0, contrast / 100.0))
    sharpness_norm = min(1.0, max(0.0, sharpness / 1000.0))
    entropy_norm = min(1.0, max(0.0, entropy / 8.0))
    
    basic_quality = np.mean([brightness_norm, contrast_norm, sharpness_norm, entropy_norm])
    
    # Normalize Medical Quality metrics
    medical_quality = np.mean([
        min(1.0, max(0.0, illumination_uniformity)),
        min(1.0, max(0.0, vessel_visibility * 10)),
        min(1.0, max(0.0, optic_disc_visibility * 10))
    ])
    
    # Normalize Technical Quality metrics
    extreme_pixels_norm = max(0, min(1, 1 - (extreme_pixels * 2)))  # Fewer is better
    motion_blur_norm = min(1.0, max(0.0, motion_blur / 50.0))        # Higher is better
    color_balance_norm = max(0, min(1, 1 - (color_balance / 50.0)))  # Lower std is better
    
    technical_quality = np.mean([extreme_pixels_norm, motion_blur_norm, color_balance_norm])
    
    # Composite formula: CQ = 0.25×Basic + 0.55×Medical + 0.20×Technical
    return 0.25 * basic_quality + 0.55 * medical_quality + 0.20 * technical_quality

def generate_cleaning_report(stats, processing_time, output_dir):
    """
    Generate concise cleaning report.
    
    Args:
        stats: Cleaning statistics dictionary
        processing_time: Total processing time in seconds
        output_dir: Output directory path
    """
    
    total_processed = stats['total_processed']
    total_kept = stats['total_kept']
    total_removed = stats['total_removed']
    
    if total_processed == 0:
        print("No images processed")
        return
    
    overall_removal_rate = (total_removed / total_processed) * 100
    
    print(f"\nCleaning Summary:")
    print(f"Processing time: {processing_time/60:.1f} minutes")
    print(f"Total processed: {total_processed:,}")
    print(f"Images kept: {total_kept:,}")
    print(f"Images removed: {total_removed:,} ({overall_removal_rate:.1f}%)")
    
    # Per-dataset summary
    print(f"\nDataset breakdown:")
    for dataset_name, dataset_stats in stats['by_dataset'].items():
        if dataset_stats['total_processed'] > 0:
            removal_rate = (dataset_stats['total_removed'] / dataset_stats['total_processed']) * 100
            print(f"  {dataset_name}: {dataset_stats['total_kept']:,} kept, "
                  f"{dataset_stats['total_removed']:,} removed ({removal_rate:.1f}%)")
    
    # Per-DR class summary
    print(f"\nDR class breakdown:")
    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
    for dr_class in range(5):
        dr_stats = stats['by_dr_class'][dr_class]
        total_dr = dr_stats['kept'] + dr_stats['removed']
        if total_dr > 0:
            removal_rate = (dr_stats['removed'] / total_dr) * 100
            print(f"  {dr_names[dr_class]}: {dr_stats['kept']:,} kept, "
                  f"{dr_stats['removed']:,} removed ({removal_rate:.1f}%)")
    
    print(f"\nOutput location: {output_dir}")
    print(f"Statistics saved: {output_dir}/cleaning_statistics.json")

# Execute dataset cleaning
if __name__ == "__main__" or 'valid_datasets' in locals():
    print("Dataset Cleaning with 3-Component Quality Scoring")
    
    # Check for required variables
    if 'valid_datasets' not in locals() or 'dataset_profiles' not in locals():
        print("Error: Required variables not found (valid_datasets, dataset_profiles)")
        print("Ensure previous analysis steps are completed")
    else:
        print("Processing may take time depending on dataset size")
        
        # Execute cleaning
        final_stats, total_processing_time = create_cleaned_dataset(
            valid_datasets, dataset_profiles
        )
        
        # Generate report
        generate_cleaning_report(final_stats, total_processing_time, 'DREAM_dataset_cleaned')
        
        print("Dataset cleaning completed")

In [None]:
# 18. Quality Analysis CSV Generation

import pandas as pd
import numpy as np

def create_quality_csv(all_results, output_file='quality_analysis.csv'):
    """
    Generate comprehensive quality analysis CSV with all metrics and assessments.
    
    Args:
        all_results: List of quality analysis results from image processing
        output_file: Output CSV file path
        
    Returns:
        pandas.DataFrame: Complete analysis dataset
    """
    
    dr_names = {0: 'No DR', 1: 'Mild DR', 2: 'Moderate DR', 3: 'Severe DR', 4: 'Proliferative DR'}
    analysis_data = []
    
    for result in all_results:
        width, height = result.get('resolution', (0, 0))
        
        # Format removal reasons
        removal_reasons = result.get('removal_reasons', [])
        if isinstance(removal_reasons, list):
            removal_reasons_str = ';'.join(removal_reasons) if removal_reasons else ''
        else:
            removal_reasons_str = str(removal_reasons)
        
        analysis_record = {
            # Image identification
            'dataset_name': result.get('dataset', ''),
            'filename': result.get('filename', ''),
            'dr_severity': result.get('dr_severity', 0),
            'dr_class_name': dr_names.get(result.get('dr_severity', 0), 'Unknown'),
            
            # Quality scores
            'overall_quality_score': round(result.get('overall_quality_score', 0), 4),
            'basic_quality_score': round(result.get('basic_quality_score', 0), 4),
            'medical_quality_score': round(result.get('medical_quality_score', 0), 4),
            'technical_quality_score': round(result.get('technical_quality_score', 0), 4),
            
            # Core quality metrics
            'brightness': round(result.get('brightness', 0), 2),
            'contrast': round(result.get('contrast', 0), 2),
            'sharpness': round(result.get('sharpness', 0), 1),
            'entropy': round(result.get('entropy', 0), 3),
            'illumination_uniformity': round(result.get('illumination_uniformity', 0), 4),
            
            # Medical-specific metrics
            'vessel_visibility': round(result.get('vessel_visibility', 0), 4),
            'optic_disc_visibility': round(result.get('optic_disc_visibility', 0), 4),
            'color_balance': round(result.get('color_balance', 0), 2),
            
            # Assessment results
            'recommended_action': result.get('recommended_action', 'UNKNOWN'),
            'confidence': result.get('confidence', 'UNKNOWN'),
            'removal_reasons': removal_reasons_str,
            'threshold_used': round(result.get('threshold_used', 0), 4),
            
            # Technical metadata
            'image_width': width,
            'image_height': height,
            'file_size_mb': round(result.get('file_size_mb', 0), 2),
            'extreme_brightness_pixels': round(result.get('extreme_brightness_pixels', 0), 4),
            'motion_blur_score': round(result.get('motion_blur_score', 0), 2)
        }
        
        analysis_data.append(analysis_record)
    
    df = pd.DataFrame(analysis_data)
    df = df.sort_values(['dataset_name', 'dr_severity', 'overall_quality_score'], 
                       ascending=[True, True, False])
    
    df.to_csv(output_file, index=False)
    
    print(f"Quality analysis saved: {output_file} ({len(df):,} records, {len(df.columns)} columns)")
    return df

def create_summary_statistics(df, output_file='summary_statistics.csv'):
    """
    Generate comprehensive summary statistics from quality analysis data.
    
    Args:
        df: Quality analysis DataFrame
        output_file: Base filename for output files
        
    Returns:
        tuple: (dataset_stats, dr_stats, metric_stats)
    """
    
    # Overall statistics
    total_images = len(df)
    images_kept = len(df[df['recommended_action'] == 'KEEP'])
    images_removed = len(df[df['recommended_action'] == 'REMOVE'])
    removal_rate = round(images_removed / total_images * 100, 1)
    
    overall_stats = {
        'Metric': ['Total Images', 'Images Kept', 'Images Removed', 'Removal Rate (%)'],
        'Value': [total_images, images_kept, images_removed, removal_rate]
    }
    
    # Dataset-level statistics
    dataset_stats = []
    for dataset in df['dataset_name'].unique():
        dataset_data = df[df['dataset_name'] == dataset]
        removed = len(dataset_data[dataset_data['recommended_action'] == 'REMOVE'])
        total = len(dataset_data)
        
        dataset_stats.append({
            'Dataset': dataset,
            'Total_Images': total,
            'Images_Kept': total - removed,
            'Images_Removed': removed,
            'Removal_Rate_Percent': round(removed / total * 100, 1),
            'Mean_Quality_Score': round(dataset_data['overall_quality_score'].mean(), 3),
            'Std_Quality_Score': round(dataset_data['overall_quality_score'].std(), 3)
        })
    
    # DR severity statistics
    dr_stats = []
    for dr_class in sorted(df['dr_severity'].unique()):
        dr_data = df[df['dr_severity'] == dr_class]
        removed = len(dr_data[dr_data['recommended_action'] == 'REMOVE'])
        total = len(dr_data)
        
        dr_stats.append({
            'DR_Severity': dr_class,
            'DR_Class_Name': dr_data['dr_class_name'].iloc[0],
            'Total_Images': total,
            'Images_Kept': total - removed,
            'Images_Removed': removed,
            'Removal_Rate_Percent': round(removed / total * 100, 1),
            'Mean_Quality_Score': round(dr_data['overall_quality_score'].mean(), 3),
            'Std_Quality_Score': round(dr_data['overall_quality_score'].std(), 3)
        })
    
    # Quality metric descriptive statistics
    quality_metrics = ['brightness', 'contrast', 'sharpness', 'entropy', 
                      'illumination_uniformity', 'vessel_visibility', 'optic_disc_visibility']
    
    metric_stats = []
    for metric in quality_metrics:
        if metric in df.columns:
            metric_stats.append({
                'Metric': metric,
                'Mean': round(df[metric].mean(), 4),
                'Std': round(df[metric].std(), 4),
                'Min': round(df[metric].min(), 4),
                'Max': round(df[metric].max(), 4),
                'Q25': round(df[metric].quantile(0.25), 4),
                'Q50': round(df[metric].quantile(0.50), 4),
                'Q75': round(df[metric].quantile(0.75), 4)
            })
    
    # Save statistics
    output_excel = output_file.replace('.csv', '.xlsx')
    
    try:
        with pd.ExcelWriter(output_excel, engine='openpyxl') as writer:
            pd.DataFrame(overall_stats).to_excel(writer, sheet_name='Overall_Stats', index=False)
            pd.DataFrame(dataset_stats).to_excel(writer, sheet_name='Dataset_Stats', index=False)
            pd.DataFrame(dr_stats).to_excel(writer, sheet_name='DR_Severity_Stats', index=False)
            pd.DataFrame(metric_stats).to_excel(writer, sheet_name='Quality_Metrics_Stats', index=False)
        
        print(f"Summary statistics saved: {output_excel}")
        
    except ImportError:
        # Fallback to separate CSV files
        base_name = output_file.replace('.csv', '')
        pd.DataFrame(overall_stats).to_csv(f'{base_name}_overall.csv', index=False)
        pd.DataFrame(dataset_stats).to_csv(f'{base_name}_datasets.csv', index=False)
        pd.DataFrame(dr_stats).to_csv(f'{base_name}_dr_severity.csv', index=False)
        pd.DataFrame(metric_stats).to_csv(f'{base_name}_quality_metrics.csv', index=False)
        print(f"Summary statistics saved as separate CSV files (openpyxl not available)")
    
    return dataset_stats, dr_stats, metric_stats

def generate_analysis_files():
    """Main function to generate quality analysis files."""
    
    if 'all_results' not in locals() and 'all_results' not in globals():
        print("Error: Quality analysis results not found")
        print("Required steps:")
        print("1. Run dataset validation")
        print("2. Run dataset characterization") 
        print("3. Run quality issue identification")
        return None
    
    # Get results from global scope
    results = globals().get('all_results', locals().get('all_results', []))
    
    if not results:
        print("Error: No analysis results available")
        return None
    
    print(f"Processing {len(results):,} quality analysis results")
    
    # Generate main analysis CSV
    analysis_df = create_quality_csv(results, 'quality_analysis.csv')
    
    # Generate summary statistics
    dataset_stats, dr_stats, metric_stats = create_summary_statistics(
        analysis_df, 'summary_statistics.csv'
    )
    
    # Display summary information
    print(f"\nDataset Summary:")
    for stat in dataset_stats:
        print(f"  {stat['Dataset']}: {stat['Total_Images']:,} images, "
              f"{stat['Removal_Rate_Percent']}% removed")
    
    print(f"\nDR Severity Summary:")
    for stat in dr_stats:
        print(f"  {stat['DR_Class_Name']}: {stat['Total_Images']:,} images, "
              f"{stat['Removal_Rate_Percent']}% removed")
    
    print(f"\nFiles generated:")
    print(f"- quality_analysis.csv: Complete analysis dataset")
    print(f"- summary_statistics.xlsx: Statistical summaries")
    
    return analysis_df, dataset_stats, dr_stats, metric_stats

# Execute analysis file generation
if __name__ == "__main__":
    print("Quality Analysis File Generation")
    
    # Check for results availability
    if 'all_results' in locals() and all_results:
        generate_analysis_files()
    else:
        print("Quality analysis results not found. Ensure previous analysis steps are completed.")
else:
    # When run as part of notebook
    generate_analysis_files()