In [None]:
# 1. Setup and Imports

import os
import json
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import shutil
from datetime import datetime
import time
import logging
from collections import Counter, defaultdict
import gc
import random

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

logger.info("Required packages loaded successfully")

In [None]:
# 2. Path and Configuration

BASE_PATH = "Path/to/your/dataset"  # Specify the base path to your dataset here

# Save all the datasets to the BASE_PATH folder
# Path configuration based on your folder structure
datasets_config = {
    'APTOS2019': f'{BASE_PATH}/APTOS 2019',
    'Diabetic_Retinopathy_V03': f'{BASE_PATH}/Diabetic Retinopathy_V03',
    'IDRiD': f'{BASE_PATH}/IDRiD',
    'Messidor2': f'{BASE_PATH}/Messidor 2',
    'SUSTech_SYSU': f'{BASE_PATH}/SUSTech_SYSU',
    'DeepDRiD': f'{BASE_PATH}/DeepDRiD'
}

# Other configuration settings
OUTPUT_DIR = 'quality_review'
RANDOM_SEED = 42
N_SAMPLES_PER_DATASET = 300  # Number of images to sample for characterization

print("Configuration set")

In [None]:
# 3. Validate Dataset Paths

def validate_dataset_paths(datasets_config):
    valid_datasets = {}
    
    for name, path in datasets_config.items():
        logger.info(f"Checking {name}")
        logger.info(f"Path: {path}")
        
        if not os.path.exists(path):
            logger.error(f"Path does not exist: {path}")
            continue
            
        # Check for DR class folders (0, 1, 2, 3, 4)
        dr_folders = []
        image_counts = {}
        
        try:
            for item in os.listdir(path):
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path):
                    if item.isdigit() and int(item) in [0, 1, 2, 3, 4]:
                        dr_class = int(item)
                        dr_folders.append(dr_class)
                        image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
                        image_count = 0
                        for root, dirs, files in os.walk(item_path):
                            for file in files:
                                if file.lower().endswith(image_extensions):
                                    image_count += 1
                        image_counts[dr_class] = image_count
            
            if dr_folders:
                dr_folders.sort()
                total_images = sum(image_counts.values())
                logger.info(f"Found DR classes: {dr_folders}")
                logger.info(f"Image counts per class:")
                dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
                for dr_class in dr_folders:
                    dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
                    logger.info(f"  {dr_name} (Class {dr_class}): {image_counts[dr_class]:,} images")
                logger.info(f"Total images: {total_images:,}")
                
                if total_images > 0:
                    valid_datasets[name] = path
                else:
                    logger.warning(f"No images found in {name}")
            else:
                logger.warning(f"No DR class folders (0,1,2,3,4) found in {name}")
                available_folders = [item for item in os.listdir(path) if os.path.isdir(os.path.join(path, item))]
                logger.info(f"Available folders: {available_folders}")

        except Exception as e:
            logger.error(f"Error accessing path {path}: {e}")

    return valid_datasets

# Run validation
logger.info("Starting dataset validation")
valid_datasets = validate_dataset_paths(datasets_config)

if not valid_datasets:
    logger.error("No valid datasets found")
    logger.info("Setup instructions:")
    logger.info("1. Update BASE_PATH in cell [2] to your actual dataset location")
    logger.info("2. Ensure your datasets have folders named 0, 1, 2, 3, 4 containing images")
else:
    logger.info(f"Dataset validation completed - {len(valid_datasets)} datasets loaded successfully")

In [None]:
# 4. Quality Identifiers Class

class QualityIdentifier:
    def __init__(self, output_dir='quality_review', random_seed=42):
        self.output_dir = output_dir
        self.dataset_profiles = {}
        self.identification_results = []
        self.random_seed = random_seed
        
        np.random.seed(random_seed)
        
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f'{output_dir}/sample_images', exist_ok=True)
        os.makedirs(f'{output_dir}/flagged_samples', exist_ok=True)
        
        logger.info(f"Quality identification initialized - output directory: {output_dir}")
    
    def extract_dr_label_from_path(self, image_path):
        parts = image_path.split(os.sep)
        
        for part in parts:
            if part.isdigit() and int(part) in [0, 1, 2, 3, 4]:
                return int(part)
        
        dr_patterns = {
            'no_dr': 0, 'normal': 0, 'grade_0': 0, 'class_0': 0,
            'mild': 1, 'grade_1': 1, 'class_1': 1,
            'moderate': 2, 'grade_2': 2, 'class_2': 2,
            'severe': 3, 'grade_3': 3, 'class_3': 3,
            'proliferative': 4, 'grade_4': 4, 'class_4': 4
        }
        
        for part in parts:
            part_lower = part.lower()
            if part_lower in dr_patterns:
                return dr_patterns[part_lower]
        return None
    
    def sample_images_strategically(self, dataset_path, n_samples):
        all_images = []
        class_images = {0: [], 1: [], 2: [], 3: [], 4: []}
        
        image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
        
        for root, dirs, files in os.walk(dataset_path):
            for file in files:
                if file.lower().endswith(image_extensions):
                    image_path = os.path.join(root, file)
                    dr_class = self.extract_dr_label_from_path(image_path)
                    if dr_class is not None:
                        class_images[dr_class].append(image_path)
        
        for dr_class, images in class_images.items():
            dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            if images:
                logger.info(f"    {dr_name}: {len(images)} images")
        
        sampled_images = []
        total_available = sum(len(images) for images in class_images.values())
        
        if total_available == 0:
            logger.warning("No images found with valid DR labels")
            return []
        
        samples_per_class = max(1, n_samples // 5)
        
        for dr_class, images in class_images.items():
            if images:
                n_class_samples = min(samples_per_class, len(images))
                sampled = np.random.choice(images, n_class_samples, replace=False)
                sampled_images.extend(sampled)

        logger.info(f"Selected {len(sampled_images)} stratified samples")
        return sampled_images
    
    def safe_calculate_brightness(self, image):
        try:
            brightness = float(np.mean(image))
            return max(0.0, min(255.0, brightness))
        except:
            return 127.5
    
    def safe_calculate_contrast(self, image):
        try:
            contrast = float(np.std(image))
            return max(0.0, contrast)
        except:
            return 50.0
    
    def safe_calculate_sharpness(self, gray):
        try:
            laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
            return max(0.0, float(laplacian_var))
        except:
            return 500.0
    
    def safe_calculate_entropy(self, gray):
        try:
            hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
            hist = hist.flatten()
            hist = hist[hist > 0]
            if len(hist) == 0:
                return 4.0
            prob = hist / hist.sum()
            entropy = float(-np.sum(prob * np.log2(prob)))
            return max(0.0, min(8.0, entropy))
        except:
            return 4.0

    def assess_vessel_visibility_improved(self, image):
        try:
            green = image[:, :, 1]
            
            kernels = {
                'horizontal': np.array([[-1, -1, -1], [2, 2, 2], [-1, -1, -1]], dtype=np.float32),
                'vertical': np.array([[-1, 2, -1], [-1, 2, -1], [-1, 2, -1]], dtype=np.float32),
                'diagonal1': np.array([[2, -1, -1], [-1, 2, -1], [-1, -1, 2]], dtype=np.float32),
                'diagonal2': np.array([[-1, -1, 2], [-1, 2, -1], [2, -1, -1]], dtype=np.float32)
            }
            
            vessel_responses = []
            for kernel in kernels.values():
                filtered = cv2.filter2D(green, cv2.CV_32F, kernel)
                vessel_responses.append(filtered)
            
            max_response = np.maximum.reduce(vessel_responses)
            threshold = np.percentile(max_response, 95)
            vessel_pixels = np.sum(max_response > threshold)
            total_pixels = max_response.shape[0] * max_response.shape[1]
            
            visibility_score = vessel_pixels / max(total_pixels, 1)
            return float(min(0.1, visibility_score))
            
        except Exception:
            return 0.0

    def assess_optic_disc_visibility(self, image):
        try:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            bright_pixels = np.sum(gray > np.percentile(gray, 90))
            total_pixels = gray.shape[0] * gray.shape[1]
            return float(bright_pixels / max(total_pixels, 1))
        except Exception:
            return 0.0

    def assess_illumination_uniformity(self, gray):
        try:
            h, w = gray.shape
            regions = []
            
            for i in range(3):
                for j in range(3):
                    start_y, end_y = i * h // 3, (i + 1) * h // 3
                    start_x, end_x = j * w // 3, (j + 1) * w // 3
                    region = gray[start_y:end_y, start_x:end_x]
                    regions.append(np.mean(region))
            
            mean_intensity = np.mean(regions)
            if mean_intensity > 0:
                cv_score = np.std(regions) / mean_intensity
                return float(max(0, min(1, 1 - cv_score)))
            return 0.0
        except Exception:
            return 0.0

    def detect_extreme_pixels(self, gray):
        try:
            very_dark = np.sum(gray < 10)
            very_bright = np.sum(gray > 245)
            total_pixels = gray.shape[0] * gray.shape[1]
            return float((very_dark + very_bright) / max(total_pixels, 1))
        except Exception:
            return 0.0

    def assess_motion_blur(self, gray):
        try:
            grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
            grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
            magnitude = np.sqrt(grad_x**2 + grad_y**2)
            return float(np.mean(magnitude))
        except Exception:
            return 0.0

    def analyze_single_image(self, image_path):
        try:
            if not os.path.exists(image_path):
                logger.error(f"File not found: {image_path}")
                return None
                
            image = cv2.imread(image_path)
            if image is None:
                logger.error(f"Could not read image: {image_path}")
                return None
            
            if len(image.shape) != 3:
                logger.error(f"Invalid image format: {image_path}")
                return None
                
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            h, w = image.shape[:2]

            if h < 100 or w < 100:
                logger.error(f"Image too small ({w}x{h}): {image_path}")
                return None
            
            characteristics = {
                'image_path': image_path,
                'filename': os.path.basename(image_path),
                'resolution': (w, h),
                'file_size_mb': max(0.001, os.path.getsize(image_path) / (1024*1024)),
                
                'brightness': self.safe_calculate_brightness(image),
                'contrast': self.safe_calculate_contrast(image),
                'sharpness': self.safe_calculate_sharpness(gray),
                'entropy': self.safe_calculate_entropy(gray),
                
                'color_balance': float(np.std([np.mean(image[:,:,0]), np.mean(image[:,:,1]), np.mean(image[:,:,2])])),
                
                'vessel_visibility': self.assess_vessel_visibility_improved(image),
                'optic_disc_visibility': self.assess_optic_disc_visibility(image),
                'illumination_uniformity': self.assess_illumination_uniformity(gray),
                
                'extreme_brightness_pixels': self.detect_extreme_pixels(gray),
                'motion_blur_score': self.assess_motion_blur(gray)
            }
            return characteristics
            
        except Exception as e:
            logger.error(f"Error analyzing {image_path}: {e}")
            return None

logger.info("QualityIdentifier class defined")

In [None]:
# 5. Initialize Quality Identifier

# Proceed if we have datasets are validated in above cells
if valid_datasets:
    identifier = QualityIdentifier(
        output_dir=OUTPUT_DIR, 
        random_seed=RANDOM_SEED
    )
    print("QualityIdentifier initialized!")
    print(f"Output directory: {OUTPUT_DIR}")
else:
    print("Cannot initialize - no valid datasets found")

In [None]:
# 6. Dataset Characterization & Adaptive Thresholding

def characterize_dataset(identifier, dataset_path, dataset_name, n_samples=300):
    logger.info(f"Characterizing {dataset_name}")
    
    sample_images = identifier.sample_images_strategically(dataset_path, n_samples)
    
    if not sample_images:
        logger.warning(f"No images found in {dataset_path}")
        return None
    
    logger.info(f"Analyzing {len(sample_images)} sample images")
    
    characteristics = []
    for i, img_path in enumerate(sample_images):
        if i % 50 == 0 and i > 0:
            progress = (i / len(sample_images)) * 100
            logger.info(f"Analysis progress: {progress:.1f}% ({i}/{len(sample_images)})")
        
        char = identifier.analyze_single_image(img_path)
        if char:
            characteristics.append(char)  

        if i % 100 == 0:
            gc.collect()

    if not characteristics:
        logger.error(f"No valid characteristics extracted from {dataset_name}")
        return None

    logger.info(f"Successfully analyzed {len(characteristics)} images")

    profile = calculate_dataset_profile(dataset_name, characteristics)
    save_sample_images(identifier, sample_images[:20], dataset_name)
    return profile

def calculate_dataset_profile(dataset_name, characteristics):
    """Calculate dataset profile using standardized 3-component quality scoring"""
    profile = {
        'dataset_name': dataset_name,
        'analysis_date': datetime.now().isoformat(),
        'n_samples_analyzed': len(characteristics),
        'characteristics_stats': {},
        'adaptive_thresholds': {}
    }
    
    numeric_keys = [key for key in characteristics[0].keys() 
                   if key not in ['image_path', 'filename', 'resolution']]
    
    for key in numeric_keys:
        raw_values = [char[key] for char in characteristics if char[key] is not None]
        
        values = []
        for val in raw_values:
            if isinstance(val, bool):
                values.append(float(val))
            elif isinstance(val, (int, float)):
                values.append(float(val))
        
        if values:
            profile['characteristics_stats'][key] = {
                'mean': float(np.mean(values)),
                'std': float(np.std(values)),
                'min': float(np.min(values)),
                'max': float(np.max(values)),
                'percentiles': {
                    '5': float(np.percentile(values, 5)),
                    '10': float(np.percentile(values, 10)),
                    '25': float(np.percentile(values, 25)),
                    '50': float(np.percentile(values, 50)),
                    '75': float(np.percentile(values, 75)),
                    '90': float(np.percentile(values, 90)),
                    '95': float(np.percentile(values, 95))
                }
            }

    # Standardized adaptive thresholds using 3-component scoring
    removal_percentiles = {0: 15, 1: 12, 2: 10, 3: 8, 4: 5}
    
    quality_scores = []
    for char in characteristics:
        try:
            # Basic Quality Component
            brightness = float(char['brightness']) if char['brightness'] is not None else 127.5
            contrast = float(char['contrast']) if char['contrast'] is not None else 50.0
            sharpness = float(char['sharpness']) if char['sharpness'] is not None else 500.0
            entropy = float(char['entropy']) if char['entropy'] is not None else 4.0
            
            brightness_norm = min(1.0, max(0.0, brightness / 255.0))
            contrast_norm = min(1.0, max(0.0, contrast / 100.0))
            sharpness_norm = min(1.0, max(0.0, sharpness / 1000.0))
            entropy_norm = min(1.0, max(0.0, entropy / 8.0))
            basic_quality = np.mean([brightness_norm, contrast_norm, sharpness_norm, entropy_norm])
            
            # Medical Quality Component
            illumination_uniformity = float(char['illumination_uniformity']) if char['illumination_uniformity'] is not None else 0.5
            vessel_visibility = float(char['vessel_visibility']) if char['vessel_visibility'] is not None else 0.1
            optic_disc_visibility = float(char['optic_disc_visibility']) if char['optic_disc_visibility'] is not None else 0.1
            
            medical_quality = np.mean([
                illumination_uniformity,
                min(1.0, vessel_visibility * 10),
                min(1.0, optic_disc_visibility * 10)
            ])
            
            # Technical Quality Component
            extreme_pixels = float(char['extreme_brightness_pixels']) if char['extreme_brightness_pixels'] is not None else 0.1
            motion_blur = float(char['motion_blur_score']) if char['motion_blur_score'] is not None else 20.0
            color_balance = float(char['color_balance']) if char['color_balance'] is not None else 15.0
            
            extreme_pixels_norm = max(0, min(1, 1 - (extreme_pixels * 2)))
            motion_blur_norm = min(1.0, max(0.0, motion_blur / 50.0))
            color_balance_norm = max(0, min(1, 1 - (color_balance / 50.0)))
            technical_quality = np.mean([extreme_pixels_norm, motion_blur_norm, color_balance_norm])
            
            # Composite Quality Score: 25% Basic + 55% Medical + 20% Technical
            combined_quality = 0.25 * basic_quality + 0.55 * medical_quality + 0.20 * technical_quality
            quality_scores.append(combined_quality)
            
        except (TypeError, ValueError) as e:
            logger.warning(f"Error calculating quality score: {e}")
            quality_scores.append(0.3)
    
    for dr_severity, percentile in removal_percentiles.items():
        try:
            threshold = np.percentile(quality_scores, percentile) if quality_scores else 0.3
            profile['adaptive_thresholds'][dr_severity] = float(threshold)
        except Exception as e:
            logger.error(f"Error calculating threshold for DR severity {dr_severity}: {e}")
            profile['adaptive_thresholds'][dr_severity] = 0.3

    return profile  

def save_sample_images(identifier, sample_paths, dataset_name):
    sample_dir = f'{identifier.output_dir}/sample_images/{dataset_name}'
    os.makedirs(sample_dir, exist_ok=True)
    
    copied_count = 0
    for i, img_path in enumerate(sample_paths):
        try:
            dst_path = f'{sample_dir}/sample_{i:02d}_{os.path.basename(img_path)}'
            shutil.copy2(img_path, dst_path)
            copied_count += 1
        except Exception as e:
            logger.error(f"Error copying sample {img_path}: {e}")

    logger.info(f"Saved {copied_count} sample images to {sample_dir}")

# Run characterization for all datasets
if valid_datasets:
    dataset_profiles = {}
    
    for dataset_name, dataset_path in valid_datasets.items():
        profile = characterize_dataset(identifier, dataset_path, dataset_name, N_SAMPLES_PER_DATASET)
        if profile:
            dataset_profiles[dataset_name] = profile
            identifier.dataset_profiles[dataset_name] = profile
            logger.info(f"{dataset_name} characterization completed")

            stats = profile['characteristics_stats']
            if 'brightness' in stats:
                logger.info(f"Brightness: {stats['brightness']['mean']:.1f} ± {stats['brightness']['std']:.1f}")
            if 'sharpness' in stats:
                logger.info(f"Sharpness: {stats['sharpness']['mean']:.1f} ± {stats['sharpness']['std']:.1f}")
            if 'illumination_uniformity' in stats:
                logger.info(f"Illumination uniformity: {stats['illumination_uniformity']['mean']:.3f} ± {stats['illumination_uniformity']['std']:.3f}")
        else:
            logger.error(f"Failed characterization for {dataset_name}")

    logger.info(f"Characterization completed for {len(dataset_profiles)} datasets")
else:
    logger.error("Cannot proceed - no valid datasets found")

In [None]:
# 7. Quality Issue Identification

def assess_image_quality_corrected(characteristics, profile, dr_severity):
    """
    Assess image quality using standardized 3-component scoring system.
    Components: Basic Quality (25%), Medical Quality (55%), Technical Quality (20%)
    """
    char_stats = profile['characteristics_stats']
    normalized_scores = {}
    
    # Normalize technical metrics using z-score normalization
    technical_metrics = ['brightness', 'contrast', 'sharpness', 'entropy']
    for metric in technical_metrics:
        if metric in char_stats and metric in characteristics:
            stats = char_stats[metric]
            value = characteristics[metric]
            
            if stats['std'] > 0:
                z_score = (value - stats['mean']) / stats['std']
                normalized_scores[metric] = max(0, min(1, (z_score + 3) / 6))
            else:
                normalized_scores[metric] = 0.5
    
    # Direct normalization for medical metrics
    medical_metrics = {
        'illumination_uniformity': lambda x: min(1.0, max(0.0, x)),
        'vessel_visibility': lambda x: min(1.0, max(0.0, x * 100)),
        'optic_disc_visibility': lambda x: min(1.0, max(0.0, x * 50))
    }
    
    for metric, normalizer in medical_metrics.items():
        if metric in characteristics:
            normalized_scores[metric] = normalizer(characteristics[metric])
    
    # Technical quality metrics normalization
    if 'extreme_brightness_pixels' in characteristics:
        ep_value = characteristics['extreme_brightness_pixels']
        normalized_scores['extreme_brightness_pixels'] = max(0, min(1, 1 - (ep_value * 2)))
    
    if 'motion_blur_score' in characteristics:
        mb_value = characteristics['motion_blur_score']
        normalized_scores['motion_blur_score'] = min(1.0, max(0.0, mb_value / 50.0))
    
    if 'color_balance' in characteristics:
        cb_value = characteristics['color_balance']
        normalized_scores['color_balance'] = max(0, min(1, 1 - (cb_value / 50.0)))
    
    # Calculate component scores
    basic_metrics = ['brightness', 'contrast', 'sharpness', 'entropy']
    medical_metrics_list = ['illumination_uniformity', 'vessel_visibility', 'optic_disc_visibility']
    technical_metrics_list = ['extreme_brightness_pixels', 'motion_blur_score', 'color_balance']
    
    basic_score = np.mean([normalized_scores.get(m, 0.5) for m in basic_metrics])
    medical_score = np.mean([normalized_scores.get(m, 0.5) for m in medical_metrics_list])
    technical_score = np.mean([normalized_scores.get(m, 0.5) for m in technical_metrics_list])
    
    # Composite Quality Score: 25% Basic + 55% Medical + 20% Technical
    overall_score = 0.25 * basic_score + 0.55 * medical_score + 0.20 * technical_score
    
    threshold = profile['adaptive_thresholds'].get(dr_severity, 0.3)
    
    # Critical quality assessment
    removal_reasons = []
    
    if characteristics['sharpness'] < char_stats.get('sharpness', {}).get('percentiles', {}).get('5', 0):
        removal_reasons.append('extremely_blurry')
    
    if characteristics['brightness'] < 20 or characteristics['brightness'] > 240:
        removal_reasons.append('extreme_brightness')
    
    if characteristics['illumination_uniformity'] < 0.1:
        removal_reasons.append('poor_illumination')
    
    if characteristics['vessel_visibility'] < char_stats.get('vessel_visibility', {}).get('percentiles', {}).get('10', 0):
        removal_reasons.append('poor_vessel_visibility')
    
    if characteristics['extreme_brightness_pixels'] > 0.3:
        removal_reasons.append('too_many_extreme_pixels')
    
    if characteristics['file_size_mb'] < 0.1:
        removal_reasons.append('file_too_small')
    
    w, h = characteristics['resolution']
    if w < 224 or h < 224:
        removal_reasons.append('resolution_too_low')
    
    if characteristics.get('motion_blur_score', 0) < 5:
        removal_reasons.append('severe_motion_blur')
    
    if characteristics.get('color_balance', 0) > 40:
        removal_reasons.append('severe_color_imbalance')
    
    # Final decision logic
    if len(removal_reasons) >= 2:
        action = 'REMOVE'
        confidence = 'HIGH'
    elif overall_score < threshold:
        action = 'REMOVE'
        confidence = 'MEDIUM'
    else:
        action = 'KEEP'
        confidence = 'HIGH' if overall_score > threshold + 0.1 else 'MEDIUM'
    
    return {
        'overall_score': overall_score,
        'basic_score': basic_score,
        'medical_score': medical_score, 
        'technical_score': technical_score,
        'threshold': threshold,
        'action': action,
        'reasons': removal_reasons,
        'confidence': confidence,
        'normalized_scores': normalized_scores
    }

def identify_quality_issues(identifier, dataset_path, dataset_name, profile):
    logger.info(f"Identifying quality issues in {dataset_name}")
    
    results = []
    processed_count = 0
    error_count = 0
    
    image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
    
    # Count total images for progress tracking
    total_images = 0
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(image_extensions):
                image_path = os.path.join(root, file)
                dr_severity = identifier.extract_dr_label_from_path(image_path)
                if dr_severity is not None:
                    total_images += 1
    
    logger.info(f"Processing {total_images:,} images")
    
    # Process images
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(image_extensions):
                image_path = os.path.join(root, file)
                dr_severity = identifier.extract_dr_label_from_path(image_path)
                
                if dr_severity is None:
                    continue
                
                characteristics = identifier.analyze_single_image(image_path)
                if not characteristics:
                    error_count += 1
                    continue
                
                quality_assessment = assess_image_quality_corrected(characteristics, profile, dr_severity)
                
                result = {
                    'dataset': dataset_name,
                    'image_path': image_path,
                    'filename': file,
                    'dr_severity': dr_severity,
                    'overall_quality_score': quality_assessment['overall_score'],
                    'basic_quality_score': quality_assessment['basic_score'],
                    'medical_quality_score': quality_assessment['medical_score'],
                    'technical_quality_score': quality_assessment['technical_score'],
                    'threshold_used': quality_assessment['threshold'],
                    'recommended_action': quality_assessment['action'],
                    'removal_reasons': quality_assessment['reasons'],
                    'confidence': quality_assessment['confidence']
                }
                
                result.update(characteristics)
                result.update({f'normalized_{k}': v for k, v in quality_assessment['normalized_scores'].items()})
                
                results.append(result)
                processed_count += 1
                
                if processed_count % 1000 == 0:
                    progress = (processed_count / total_images) * 100 if total_images > 0 else 0
                    logger.info(f"Processing progress: {progress:.1f}% ({processed_count:,}/{total_images:,})")
                    gc.collect()
    
    logger.info(f"Analysis completed: {processed_count:,} images processed")
    if error_count > 0:
        logger.warning(f"Images with errors: {error_count}")
    
    return results

# Execute quality issue identification
if valid_datasets and dataset_profiles:
    all_results = []
    
    for dataset_name, profile in dataset_profiles.items():
        logger.info(f"Processing quality assessment for {dataset_name}")
        dataset_path = valid_datasets[dataset_name]
        results = identify_quality_issues(identifier, dataset_path, dataset_name, profile)
        all_results.extend(results)
        
        flagged = [r for r in results if r['recommended_action'] == 'REMOVE']
        flagged_count = len(flagged)
        total_count = len(results)
        removal_rate = flagged_count / total_count * 100 if total_count > 0 else 0
        
        logger.info(f"{dataset_name} - Total: {total_count:,}, Flagged: {flagged_count:,} ({removal_rate:.1f}%)")
        
        if results:
            avg_basic = np.mean([r['basic_quality_score'] for r in results])
            avg_medical = np.mean([r['medical_quality_score'] for r in results])
            avg_technical = np.mean([r['technical_quality_score'] for r in results])
            avg_overall = np.mean([r['overall_quality_score'] for r in results])
            
            logger.info(f"Average scores - Basic: {avg_basic:.3f}, Medical: {avg_medical:.3f}, Technical: {avg_technical:.3f}, Overall: {avg_overall:.3f}")
        
        # DR class breakdown
        dr_stats = defaultdict(lambda: {'total': 0, 'flagged': 0})
        for result in results:
            dr_class = result['dr_severity']
            dr_stats[dr_class]['total'] += 1
            if result['recommended_action'] == 'REMOVE':
                dr_stats[dr_class]['flagged'] += 1
        
        dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
        for dr_class in sorted(dr_stats.keys()):
            stats = dr_stats[dr_class]
            class_removal_rate = stats['flagged'] / stats['total'] * 100 if stats['total'] > 0 else 0
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            logger.info(f"{dr_name}: {stats['flagged']:,}/{stats['total']:,} ({class_removal_rate:.1f}%)")
    
    logger.info(f"Quality identification completed for {len(all_results):,} images")
    
    if all_results:
        total_flagged = len([r for r in all_results if r['recommended_action'] == 'REMOVE'])
        overall_removal_rate = total_flagged / len(all_results) * 100
        
        avg_basic_all = np.mean([r['basic_quality_score'] for r in all_results])
        avg_medical_all = np.mean([r['medical_quality_score'] for r in all_results])
        avg_technical_all = np.mean([r['technical_quality_score'] for r in all_results])
        avg_overall_all = np.mean([r['overall_quality_score'] for r in all_results])
        
        logger.info(f"Overall removal rate: {overall_removal_rate:.1f}%")
        logger.info(f"Average quality scores - Basic: {avg_basic_all:.3f}, Medical: {avg_medical_all:.3f}, Technical: {avg_technical_all:.3f}, Overall: {avg_overall_all:.3f}")
        
else:
    logger.error("Quality identification skipped - missing datasets or profiles")

In [None]:
# 8. Create Flagged Samples

def create_flagged_samples(identifier, results, n_samples_per_dataset=20):
    by_dataset = {}
    for result in results:
        dataset = result['dataset']
        if dataset not in by_dataset:
            by_dataset[dataset] = []
        by_dataset[dataset].append(result)
   
    total_samples_created = 0
   
    for dataset_name, dataset_results in by_dataset.items():
        logger.info(f"Processing flagged samples for {dataset_name}")
       
        flagged = [r for r in dataset_results if r['recommended_action'] == 'REMOVE']
       
        if not flagged:
            logger.info(f"No flagged images found for {dataset_name}")
            continue
       
        logger.info(f"Found {len(flagged)} flagged images")
       
        sample_dir = f'{identifier.output_dir}/flagged_samples/{dataset_name}'
        os.makedirs(sample_dir, exist_ok=True)
       
        # Group by removal reasons
        by_reason = {}
        for result in flagged:
            for reason in result['removal_reasons']:
                if reason not in by_reason:
                    by_reason[reason] = []
                by_reason[reason].append(result)
       
        logger.info(f"Issue types found: {list(by_reason.keys())}")
       
        # Sample for each category
        samples_copied = 0
        for reason, reason_results in by_reason.items():
            reason_samples = min(5, len(reason_results), n_samples_per_dataset - samples_copied)
            if reason_samples <= 0:
                continue
           
            # Sort by confidence and take most confident removals
            reason_results.sort(key=lambda x: x['overall_quality_score'])
           
            for i, result in enumerate(reason_results[:reason_samples]):
                try:
                    src_path = result['image_path']
                    dst_filename = f'{reason}_{i:02d}_{result["filename"]}'
                    dst_path = os.path.join(sample_dir, dst_filename)
                    shutil.copy2(src_path, dst_path)
                    samples_copied += 1
                except Exception as e:
                    logger.error(f"Error copying flagged sample {src_path}: {e}")
       
        logger.info(f"Created {samples_copied} flagged samples for {dataset_name}")
        total_samples_created += samples_copied
   
    logger.info(f"Total flagged samples created: {total_samples_created}")

# Run flagged sample creation
if 'all_results' in locals() and all_results:
    create_flagged_samples(identifier, all_results)
else:
    logger.warning("Flagged sample creation skipped - no results available")

In [None]:
# 9. Flagged Images Report Generation

def generate_identification_report(all_results):
    df = pd.DataFrame(all_results)
    
    # Overall statistics
    total_images = len(df)
    flagged_for_removal = len(df[df['recommended_action'] == 'REMOVE'])
    removal_rate = flagged_for_removal / total_images if total_images > 0 else 0
    
    report = {
        'analysis_summary': {
            'total_images_analyzed': total_images,
            'images_flagged_for_removal': flagged_for_removal,
            'overall_removal_rate': removal_rate,
            'analysis_date': datetime.now().isoformat()
        },
        'dataset_breakdown': {},
        'removal_reasons_summary': {},
        'quality_score_statistics': {},
        'recommendations': []
    }
    
    # Per dataset and class breakdown
    for dataset in df['dataset'].unique():
        dataset_data = df[df['dataset'] == dataset]
        dataset_flagged = len(dataset_data[dataset_data['recommended_action'] == 'REMOVE'])
        dataset_total = len(dataset_data)
        
        class_breakdown = {}
        for dr_class in range(5):
            class_data = dataset_data[dataset_data['dr_severity'] == dr_class]
            if len(class_data) > 0:
                class_flagged = len(class_data[class_data['recommended_action'] == 'REMOVE'])
                class_breakdown[dr_class] = {
                    'total': len(class_data),
                    'flagged': class_flagged,
                    'removal_rate': class_flagged / len(class_data)
                }
        
        report['dataset_breakdown'][dataset] = {
            'total_images': dataset_total,
            'flagged_images': dataset_flagged,
            'removal_rate': dataset_flagged / dataset_total if dataset_total > 0 else 0,
            'class_breakdown': class_breakdown,
            'avg_quality_score': float(dataset_data['overall_quality_score'].mean()),
            'quality_score_std': float(dataset_data['overall_quality_score'].std())
        }
    
    # Reasons summary
    all_reasons = []
    for _, row in df.iterrows():
        if row['recommended_action'] == 'REMOVE':
            all_reasons.extend(row['removal_reasons'])
    reason_counts = Counter(all_reasons)
    report['removal_reasons_summary'] = dict(reason_counts)
    
    # Statistics
    report['quality_score_statistics'] = {
        'mean': float(df['overall_quality_score'].mean()),
        'std': float(df['overall_quality_score'].std()),
        'min': float(df['overall_quality_score'].min()),
        'max': float(df['overall_quality_score'].max()),
        'percentiles': {
            '10': float(df['overall_quality_score'].quantile(0.1)),
            '25': float(df['overall_quality_score'].quantile(0.25)),
            '50': float(df['overall_quality_score'].quantile(0.5)),
            '75': float(df['overall_quality_score'].quantile(0.75)),
            '90': float(df['overall_quality_score'].quantile(0.9))
        }
    }
    
    # Recommendations
    if removal_rate > 0.4:
        report['recommendations'].append("High removal rate detected - consider reducing quality thresholds")
    
    if removal_rate < 0.05:
        report['recommendations'].append("Low removal rate detected - consider tightening quality thresholds")
    
    for dataset, stats in report['dataset_breakdown'].items():
        if stats['removal_rate'] > 0.5:
            report['recommendations'].append(f"Very high removal rate for {dataset} - review dataset-specific thresholds")
        
        # Check for class imbalance in removal
        class_rates = [info['removal_rate'] for info in stats['class_breakdown'].values()]
        if class_rates and max(class_rates) - min(class_rates) > 0.3:
            report['recommendations'].append(f"Uneven removal rates across DR classes in {dataset} - consider class-specific adjustments")
    
    return report

# Generate report
if 'all_results' in locals() and all_results:
    report = generate_identification_report(all_results)
    logger.info("Analysis report generated successfully")
else:
    logger.warning("Report generation skipped - no results available")

In [None]:
# 10. Save Results

def save_identification_results(identifier, results, report):
    # Result report
    df = pd.DataFrame(results)
    results_file = f'{identifier.output_dir}/quality_identification_results.csv'
    df.to_csv(results_file, index=False)
   
    # Profile report
    profiles_file = f'{identifier.output_dir}/dataset_profiles.json'
    with open(profiles_file, 'w') as f:
        json.dump(identifier.dataset_profiles, f, indent=2)
   
    # Analysis report
    report_file = f'{identifier.output_dir}/identification_report.json'
    with open(report_file, 'w') as f:
        json.dump(report, f, indent=2, default=str)
   
    # Create summary CSV for flagged images
    summary_data = []
    for _, row in df.iterrows():
        if row['recommended_action'] == 'REMOVE':
            summary_data.append({
                'dataset': row['dataset'],
                'filename': row['filename'],
                'dr_severity': row['dr_severity'],
                'quality_score': row['overall_quality_score'],
                'confidence': row['confidence'],
                'reasons': ', '.join(row['removal_reasons']),
                'image_path': row['image_path']
            })
   
    summary_df = pd.DataFrame(summary_data)
    summary_file = f'{identifier.output_dir}/flagged_images_summary.csv'
    summary_df.to_csv(summary_file, index=False)
    
    logger.info(f"Results saved successfully")
    logger.info(f"Flagged summary: {summary_file}")
    logger.info(f"Complete results: {results_file}")
    logger.info(f"Dataset profiles: {profiles_file}")
    logger.info(f"Analysis report: {report_file}")
    
    return results_file, summary_file, profiles_file, report_file

# Save results
if 'all_results' in locals() and 'report' in locals() and all_results:
    results_file, summary_file, profiles_file, report_file = save_identification_results(identifier, all_results, report)
    logger.info("All results saved successfully")
else:
    logger.warning("Saving skipped - no data available")

In [None]:
# 11. Summary

if 'all_results' in locals() and 'report' in locals() and all_results:
    logger.info("Quality identification analysis completed")
   
    # Overall statistics
    total_images = len(all_results)
    flagged_images = len([r for r in all_results if r['recommended_action'] == 'REMOVE'])
    removal_percentage = flagged_images / total_images * 100 if total_images > 0 else 0
   
    logger.info(f"Overall Statistics:")
    logger.info(f"   Total images analyzed: {total_images:,}")
    logger.info(f"   Images flagged: {flagged_images:,}")
    logger.info(f"   Removal rate: {removal_percentage:.1f}%")
   
    # Per-dataset breakdown
    logger.info(f"Dataset breakdown:")
    dataset_stats = defaultdict(lambda: {'total': 0, 'flagged': 0})
   
    for result in all_results:
        dataset = result['dataset']
        dataset_stats[dataset]['total'] += 1
        if result['recommended_action'] == 'REMOVE':
            dataset_stats[dataset]['flagged'] += 1
   
    for dataset, stats in dataset_stats.items():
        removal_rate = stats['flagged'] / stats['total'] * 100 if stats['total'] > 0 else 0
        logger.info(f"   {dataset}: {stats['flagged']:,}/{stats['total']:,} ({removal_rate:.1f}%)")
   
    # Per-DR class breakdown
    logger.info(f"DR class breakdown:")
    dr_stats = defaultdict(lambda: {'total': 0, 'flagged': 0})
   
    for result in all_results:
        dr_class = result['dr_severity']
        dr_stats[dr_class]['total'] += 1
        if result['recommended_action'] == 'REMOVE':
            dr_stats[dr_class]['flagged'] += 1
   
    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
    for dr_class in sorted(dr_stats.keys()):
        stats = dr_stats[dr_class]
        removal_rate = stats['flagged'] / stats['total'] * 100 if stats['total'] > 0 else 0
        dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
        logger.info(f"   {dr_name}: {stats['flagged']:,}/{stats['total']:,} ({removal_rate:.1f}%)")
   
    logger.info(f"Top removal reasons:")
    if 'removal_reasons_summary' in report:
        sorted_reasons = sorted(report['removal_reasons_summary'].items(), key=lambda x: x[1], reverse=True)
        for reason, count in sorted_reasons[:10]:
            percentage = count / flagged_images * 100 if flagged_images > 0 else 0
            logger.info(f"   {reason}: {count:,} ({percentage:.1f}% of flagged images)")
   
    # Recommendations
    if 'recommendations' in report and report['recommendations']:
        logger.info(f"Recommendations:")
        for i, recommendation in enumerate(report['recommendations'], 1):
            logger.info(f"   {i}. {recommendation}")
   
    # File locations
    logger.info(f"Results saved to:")
    logger.info(f"   Flagged images summary: {OUTPUT_DIR}/flagged_images_summary.csv")
    logger.info(f"   Detailed results: {OUTPUT_DIR}/quality_identification_results.csv")
    logger.info(f"   Analysis report: {OUTPUT_DIR}/identification_report.json")
    logger.info(f"   Sample images: {OUTPUT_DIR}/sample_images/")
    logger.info(f"   Flagged samples: {OUTPUT_DIR}/flagged_samples/")
   
    logger.info(f"Quality identification process completed successfully")
else:
    logger.error("No results available for summary - ensure previous analysis steps completed")

In [None]:
# 12. Visualization

if 'all_results' in locals() and all_results:
    logger.info("Creating analysis visualizations")
    
    plt.style.use('default')
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('DR Dataset Quality Analysis Results', fontsize=16, fontweight='bold')
    
    df = pd.DataFrame(all_results)
    
    # 1. Removal rates by dataset
    ax1 = axes[0, 0]
    dataset_removal_rates = []
    dataset_names = []
    
    for dataset in df['dataset'].unique():
        dataset_data = df[df['dataset'] == dataset]
        removal_rate = len(dataset_data[dataset_data['recommended_action'] == 'REMOVE']) / len(dataset_data) * 100
        dataset_removal_rates.append(removal_rate)
        dataset_names.append(dataset)
    
    bars1 = ax1.bar(range(len(dataset_names)), dataset_removal_rates, color='lightcoral', alpha=0.7)
    ax1.set_xlabel('Dataset')
    ax1.set_ylabel('Removal Rate (%)')
    ax1.set_title('Removal Rates by Dataset')
    ax1.set_xticks(range(len(dataset_names)))
    ax1.set_xticklabels(dataset_names, rotation=45, ha='right')
    
    for bar, rate in zip(bars1, dataset_removal_rates):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{rate:.1f}%', ha='center', va='bottom')
    
    # 2. Removal rates by DR class
    ax2 = axes[0, 1]
    dr_removal_rates = []
    dr_labels = []
    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
    
    for dr_class in sorted(df['dr_severity'].unique()):
        dr_data = df[df['dr_severity'] == dr_class]
        removal_rate = len(dr_data[dr_data['recommended_action'] == 'REMOVE']) / len(dr_data) * 100
        dr_removal_rates.append(removal_rate)
        dr_label = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
        dr_labels.append(dr_label)
    
    bars2 = ax2.bar(range(len(dr_labels)), dr_removal_rates, color='lightblue', alpha=0.7)
    ax2.set_xlabel('DR Severity')
    ax2.set_ylabel('Removal Rate (%)')
    ax2.set_title('Removal Rates by DR Severity')
    ax2.set_xticks(range(len(dr_labels)))
    ax2.set_xticklabels(dr_labels, rotation=45, ha='right')
    
    for bar, rate in zip(bars2, dr_removal_rates):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{rate:.1f}%', ha='center', va='bottom')
    
    # 3. Quality score distribution
    ax3 = axes[1, 0]
    keep_scores = df[df['recommended_action'] == 'KEEP']['overall_quality_score']
    remove_scores = df[df['recommended_action'] == 'REMOVE']['overall_quality_score']
    ax3.hist(keep_scores, bins=30, alpha=0.7, label='Keep', color='lightgreen', density=True)
    ax3.hist(remove_scores, bins=30, alpha=0.7, label='Remove', color='lightcoral', density=True)
    ax3.set_xlabel('Quality Score')
    ax3.set_ylabel('Density')
    ax3.set_title('Quality Score Distribution')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Top removal reasons
    ax4 = axes[1, 1]
    if 'removal_reasons_summary' in report:
        reasons = list(report['removal_reasons_summary'].keys())[:10]
        counts = [report['removal_reasons_summary'][reason] for reason in reasons]
        
        bars4 = ax4.barh(range(len(reasons)), counts, color='orange', alpha=0.7)
        ax4.set_ylabel('Removal Reason')
        ax4.set_xlabel('Count')
        ax4.set_title('Top Removal Reasons')
        ax4.set_yticks(range(len(reasons)))
        ax4.set_yticklabels([reason.replace('_', ' ').title() for reason in reasons])
        
        for bar, count in zip(bars4, counts):
            width = bar.get_width()
            ax4.text(width + max(counts)*0.01, bar.get_y() + bar.get_height()/2.,
                    f'{count}', ha='left', va='center')
    
    plt.tight_layout()
    
    plot_file = f'{OUTPUT_DIR}/analysis_summary_plots.png'
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    logger.info(f"Visualization saved to: {plot_file}")

else:
    logger.warning("Visualization creation skipped - no results available")

In [None]:
# 13. Final Checkpoint - Verification

# Final verification that all components completed successfully
print("FINAL VERIFICATION:")
print("=" * 50)

# Check if all major components completed
checks = {
    'Valid datasets found': 'valid_datasets' in locals() and bool(valid_datasets),
    'Dataset profiles created': 'dataset_profiles' in locals() and bool(dataset_profiles),
    'Quality analysis completed': 'all_results' in locals() and bool(all_results),
    'Report generated': 'report' in locals() and bool(report),
    'Results saved': os.path.exists(f'{OUTPUT_DIR}/flagged_images_summary.csv'),
    'Sample images created': os.path.exists(f'{OUTPUT_DIR}/sample_images'),
    'Flagged samples created': os.path.exists(f'{OUTPUT_DIR}/flagged_samples')
}

all_good = True
for check_name, check_result in checks.items():
    status = "PASS" if check_result else "FAIL"
    print(f"{status}: {check_name}")
    if not check_result:
        all_good = False

if all_good:
    print(f"\nSUCCESS: All components completed successfully")
    print(f"Check the '{OUTPUT_DIR}' directory for all results")
   
    if 'all_results' in locals():
        total = len(all_results)
        flagged = len([r for r in all_results if r['recommended_action'] == 'REMOVE'])
        print(f"Final count: {flagged:,} of {total:,} images flagged for removal ({flagged/total*100:.1f}%)")
else:
    print(f"\nWARNING: Some components did not complete successfully")
    print(f"Please review the error messages above and re-run the failed sections")

print(f"\nIMPORTANT: This analysis only IDENTIFIED potential issues")
print(f"NO IMAGES WERE ACTUALLY REMOVED from your datasets")
print(f"Review the flagged samples before deciding on actual removal")

In [None]:
# 14. Check what sample images were saved
import os
sample_base_dir = f'{identifier.output_dir}/sample_images'
print("Sample images saved:")
for dataset_name in os.listdir(sample_base_dir):
    dataset_sample_dir = os.path.join(sample_base_dir, dataset_name)
    if os.path.isdir(dataset_sample_dir):
        num_samples = len([f for f in os.listdir(dataset_sample_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
        print(f"   {dataset_name}: {num_samples} sample images in {dataset_sample_dir}")

In [None]:
# 15. Manual Threshold Adjustment (Optional)

logger.info("Manual threshold adjustment interface")

if 'dataset_profiles' not in locals() or not dataset_profiles:
    logger.error("Dataset profiles not found - run characterization first")
else:
    logger.info("Current thresholds:")
    for dataset_name, profile in dataset_profiles.items():
        logger.info(f"{dataset_name}:")
        for dr_class, threshold in profile['adaptive_thresholds'].items():
            dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            logger.info(f"  {dr_name} (Class {dr_class}): {threshold:.3f}")

    logger.info("Manual threshold override section")
    logger.info("Edit values below to adjust thresholds")
    logger.info("Set to None to keep current value")

    manual_thresholds = {}
    for dataset_name in dataset_profiles.keys():
        manual_thresholds[dataset_name] = {
            0: None,  # Set to desired value like 0.250 or None to keep current
            1: None,
            2: None,
            3: None,
            4: None
        }

    updated_profiles = {}
    changes_made = False

    for dataset_name, profile in dataset_profiles.items():
        updated_profiles[dataset_name] = profile.copy()
        updated_profiles[dataset_name]['adaptive_thresholds'] = profile['adaptive_thresholds'].copy()
        
        if dataset_name in manual_thresholds:
            for dr_class, manual_threshold in manual_thresholds[dataset_name].items():
                if manual_threshold is not None:
                    old_threshold = profile['adaptive_thresholds'][dr_class]
                    updated_profiles[dataset_name]['adaptive_thresholds'][dr_class] = manual_threshold
                    changes_made = True
                    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
                    dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
                    logger.info(f"Updated {dataset_name} - {dr_name}: {old_threshold:.3f} → {manual_threshold:.3f}")

    if changes_made:
        logger.info("Threshold adjustments applied")
        dataset_profiles = updated_profiles
    else:
        logger.info("No manual adjustments made - using original thresholds")

    logger.info("Final thresholds after adjustment:")
    for dataset_name, profile in dataset_profiles.items():
        logger.info(f"{dataset_name}:")
        for dr_class, threshold in profile['adaptive_thresholds'].items():
            dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            logger.info(f"  {dr_name} (Class {dr_class}): {threshold:.3f}")

In [None]:
# 16. Dataset Cleaning with Standardized Quality Scoring

def create_cleaned_dataset(source_datasets, profiles, copy_output_dir='DREAM_dataset_cleaned'):
    """
    Create cleaned dataset using standardized 3-component quality scoring.
    Components: Basic Quality (25%), Medical Quality (55%), Technical Quality (20%)
    """
    
    logger.info(f"Creating cleaned dataset: {copy_output_dir}")
    logger.info("Quality scoring: CQ = 0.25*Basic + 0.55*Medical + 0.20*Technical")
    
    os.makedirs(copy_output_dir, exist_ok=True)
    
    copy_stats = {
        'total_processed': 0, 'total_kept': 0, 'total_removed': 0,
        'by_dataset': {}, 'by_dr_class': {i: {'kept': 0, 'removed': 0} for i in range(5)}
    }
    
    start_time = time.time()
    
    for dataset_name, dataset_path in source_datasets.items():
        logger.info(f"Processing {dataset_name}")
        
        if dataset_name not in profiles:
            logger.error(f"Profile not found for {dataset_name}")
            continue
        
        profile = profiles[dataset_name]
        thresholds = profile['adaptive_thresholds']
        
        copy_dataset_dir = os.path.join(copy_output_dir, dataset_name)
        os.makedirs(copy_dataset_dir, exist_ok=True)
        
        copy_dataset_stats = {
            'total_processed': 0, 'total_kept': 0, 'total_removed': 0,
            'by_dr_class': {i: {'kept': 0, 'removed': 0} for i in range(5)}
        }
        
        for dr_class in range(5):
            dr_class_path = os.path.join(dataset_path, str(dr_class))
            
            if not os.path.exists(dr_class_path):
                continue
            
            copy_dr_dir = os.path.join(copy_dataset_dir, str(dr_class))
            os.makedirs(copy_dr_dir, exist_ok=True)
            
            image_files = [f for f in os.listdir(dr_class_path) 
                          if f.lower().endswith(('.jpg', '.jpeg', '.png', '.tiff', '.bmp'))]
            
            if not image_files:
                continue
                
            logger.info(f"DR Class {dr_class}: Processing {len(image_files)} images")
            
            threshold = thresholds.get(dr_class, 0.3)
            kept_count = 0
            removed_count = 0
            
            for i, image_file in enumerate(image_files):
                if i % 500 == 0 and i > 0:
                    progress = (i / len(image_files)) * 100
                    logger.info(f"Progress: {progress:.1f}% ({i}/{len(image_files)})")
                
                if i % 1000 == 0:
                    gc.collect()
                
                try:
                    source_image_path = os.path.join(dr_class_path, image_file)
                    
                    char = identifier.analyze_single_image(source_image_path)
                    if not char:
                        continue
                    
                    combined_quality = calculate_standardized_quality_score(char)
                    
                    if combined_quality >= threshold:
                        copy_dest_path = os.path.join(copy_dr_dir, image_file)
                        shutil.copy2(source_image_path, copy_dest_path)
                        kept_count += 1
                        copy_dataset_stats['by_dr_class'][dr_class]['kept'] += 1
                        copy_stats['by_dr_class'][dr_class]['kept'] += 1
                    else:
                        removed_count += 1
                        copy_dataset_stats['by_dr_class'][dr_class]['removed'] += 1
                        copy_stats['by_dr_class'][dr_class]['removed'] += 1
                    
                    copy_dataset_stats['total_processed'] += 1
                    copy_stats['total_processed'] += 1
                    
                except Exception as e:
                    logger.error(f"Error processing {image_file}: {e}")
                    continue
            
            copy_dataset_stats['total_kept'] += kept_count
            copy_dataset_stats['total_removed'] += removed_count
            
            if kept_count + removed_count > 0:
                removal_rate = (removed_count / (kept_count + removed_count)) * 100
                dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
                dr_name = dr_names[dr_class]
                logger.info(f"{dr_name}: Kept {kept_count}, Removed {removed_count} ({removal_rate:.1f}% removed)")
            
            gc.collect()
        
        copy_stats['total_kept'] += copy_dataset_stats['total_kept']
        copy_stats['total_removed'] += copy_dataset_stats['total_removed']
        copy_stats['by_dataset'][dataset_name] = copy_dataset_stats
        
        if copy_dataset_stats['total_processed'] > 0:
            dataset_removal_rate = (copy_dataset_stats['total_removed'] / copy_dataset_stats['total_processed']) * 100
            logger.info(f"{dataset_name} Summary: {copy_dataset_stats['total_processed']:,} processed, "
                       f"{copy_dataset_stats['total_kept']:,} kept, "
                       f"{copy_dataset_stats['total_removed']:,} removed ({dataset_removal_rate:.1f}%)")
    
    processing_time = time.time() - start_time
    
    copy_stats_file = os.path.join(copy_output_dir, 'cleaning_statistics.json')
    with open(copy_stats_file, 'w') as f:
        json.dump(copy_stats, f, indent=2)
    
    return copy_stats, processing_time

def calculate_standardized_quality_score(characteristics):
    """
    Calculate standardized composite quality score using 3-component system.
    Consistent with Cell 7 methodology.
    """
    
    brightness = float(characteristics.get('brightness', 127.5))
    contrast = float(characteristics.get('contrast', 50.0))
    sharpness = float(characteristics.get('sharpness', 500.0))
    entropy = float(characteristics.get('entropy', 4.0))
    
    illumination_uniformity = float(characteristics.get('illumination_uniformity', 0.5))
    vessel_visibility = float(characteristics.get('vessel_visibility', 0.1))
    optic_disc_visibility = float(characteristics.get('optic_disc_visibility', 0.1))
    
    extreme_pixels = float(characteristics.get('extreme_brightness_pixels', 0.1))
    motion_blur = float(characteristics.get('motion_blur_score', 20.0))
    color_balance = float(characteristics.get('color_balance', 15.0))
    
    # Basic Quality Component normalization
    brightness_norm = min(1.0, max(0.0, brightness / 255.0))
    contrast_norm = min(1.0, max(0.0, contrast / 100.0))
    sharpness_norm = min(1.0, max(0.0, sharpness / 1000.0))
    entropy_norm = min(1.0, max(0.0, entropy / 8.0))
    
    basic_quality = np.mean([brightness_norm, contrast_norm, sharpness_norm, entropy_norm])
    
    # Medical Quality Component normalization
    medical_quality = np.mean([
        min(1.0, max(0.0, illumination_uniformity)),
        min(1.0, max(0.0, vessel_visibility * 10)),
        min(1.0, max(0.0, optic_disc_visibility * 10))
    ])
    
    # Technical Quality Component normalization
    extreme_pixels_norm = max(0, min(1, 1 - (extreme_pixels * 2)))
    motion_blur_norm = min(1.0, max(0.0, motion_blur / 50.0))
    color_balance_norm = max(0, min(1, 1 - (color_balance / 50.0)))
    
    technical_quality = np.mean([extreme_pixels_norm, motion_blur_norm, color_balance_norm])
    
    # Composite Quality Score: 25% Basic + 55% Medical + 20% Technical
    return 0.25 * basic_quality + 0.55 * medical_quality + 0.20 * technical_quality

def generate_cleaning_report(stats, processing_time, output_dir):
    """Generate cleaning summary report."""
    
    total_processed = stats['total_processed']
    total_kept = stats['total_kept']
    total_removed = stats['total_removed']
    
    if total_processed == 0:
        logger.warning("No images processed")
        return
    
    overall_removal_rate = (total_removed / total_processed) * 100
    
    logger.info(f"Cleaning Summary:")
    logger.info(f"Processing time: {processing_time/60:.1f} minutes")
    logger.info(f"Total processed: {total_processed:,}")
    logger.info(f"Images kept: {total_kept:,}")
    logger.info(f"Images removed: {total_removed:,} ({overall_removal_rate:.1f}%)")
    
    logger.info(f"Dataset breakdown:")
    for dataset_name, dataset_stats in stats['by_dataset'].items():
        if dataset_stats['total_processed'] > 0:
            removal_rate = (dataset_stats['total_removed'] / dataset_stats['total_processed']) * 100
            logger.info(f"  {dataset_name}: {dataset_stats['total_kept']:,} kept, "
                       f"{dataset_stats['total_removed']:,} removed ({removal_rate:.1f}%)")
    
    logger.info(f"DR class breakdown:")
    dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
    for dr_class in range(5):
        dr_stats = stats['by_dr_class'][dr_class]
        total_dr = dr_stats['kept'] + dr_stats['removed']
        if total_dr > 0:
            removal_rate = (dr_stats['removed'] / total_dr) * 100
            logger.info(f"  {dr_names[dr_class]}: {dr_stats['kept']:,} kept, "
                       f"{dr_stats['removed']:,} removed ({removal_rate:.1f}%)")
    
    logger.info(f"Output location: {output_dir}")
    logger.info(f"Statistics saved: {output_dir}/cleaning_statistics.json")

# Execute dataset cleaning
if 'valid_datasets' in locals() and 'dataset_profiles' in locals():
    logger.info("Starting dataset cleaning with standardized quality scoring")
    
    final_stats, total_processing_time = create_cleaned_dataset(
        valid_datasets, dataset_profiles
    )
    
    generate_cleaning_report(final_stats, total_processing_time, 'DREAM_dataset_cleaned')
    
    logger.info("Dataset cleaning completed successfully")
else:
    logger.error("Cannot proceed - missing required variables (valid_datasets, dataset_profiles)")

In [None]:
# 17. Quality Analysis CSV Generation

def create_quality_csv(all_results, output_file='quality_analysis.csv'):
    """Generate comprehensive quality analysis CSV with all metrics and assessments."""
    
    dr_names = {0: 'No DR', 1: 'Mild DR', 2: 'Moderate DR', 3: 'Severe DR', 4: 'Proliferative DR'}
    analysis_data = []
    
    for result in all_results:
        width, height = result.get('resolution', (0, 0))
        
        removal_reasons = result.get('removal_reasons', [])
        if isinstance(removal_reasons, list):
            removal_reasons_str = ';'.join(removal_reasons) if removal_reasons else ''
        else:
            removal_reasons_str = str(removal_reasons)
        
        analysis_record = {
            'dataset_name': result.get('dataset', ''),
            'filename': result.get('filename', ''),
            'dr_severity': result.get('dr_severity', 0),
            'dr_class_name': dr_names.get(result.get('dr_severity', 0), 'Unknown'),
            
            'overall_quality_score': round(result.get('overall_quality_score', 0), 4),
            'basic_quality_score': round(result.get('basic_quality_score', 0), 4),
            'medical_quality_score': round(result.get('medical_quality_score', 0), 4),
            'technical_quality_score': round(result.get('technical_quality_score', 0), 4),
            
            'brightness': round(result.get('brightness', 0), 2),
            'contrast': round(result.get('contrast', 0), 2),
            'sharpness': round(result.get('sharpness', 0), 1),
            'entropy': round(result.get('entropy', 0), 3),
            'illumination_uniformity': round(result.get('illumination_uniformity', 0), 4),
            
            'vessel_visibility': round(result.get('vessel_visibility', 0), 4),
            'optic_disc_visibility': round(result.get('optic_disc_visibility', 0), 4),
            'color_balance': round(result.get('color_balance', 0), 2),
            
            'recommended_action': result.get('recommended_action', 'UNKNOWN'),
            'confidence': result.get('confidence', 'UNKNOWN'),
            'removal_reasons': removal_reasons_str,
            'threshold_used': round(result.get('threshold_used', 0), 4),
            
            'image_width': width,
            'image_height': height,
            'file_size_mb': round(result.get('file_size_mb', 0), 2),
            'extreme_brightness_pixels': round(result.get('extreme_brightness_pixels', 0), 4),
            'motion_blur_score': round(result.get('motion_blur_score', 0), 2)
        }
        
        analysis_data.append(analysis_record)
    
    df = pd.DataFrame(analysis_data)
    df = df.sort_values(['dataset_name', 'dr_severity', 'overall_quality_score'], 
                       ascending=[True, True, False])
    
    df.to_csv(output_file, index=False)
    
    logger.info(f"Quality analysis CSV saved: {output_file} ({len(df):,} records)")
    return df

def create_summary_statistics(df, output_file='summary_statistics.csv'):
    """Generate comprehensive summary statistics from quality analysis data."""
    
    total_images = len(df)
    images_kept = len(df[df['recommended_action'] == 'KEEP'])
    images_removed = len(df[df['recommended_action'] == 'REMOVE'])
    removal_rate = round(images_removed / total_images * 100, 1)
    
    overall_stats = {
        'Metric': ['Total Images', 'Images Kept', 'Images Removed', 'Removal Rate (%)'],
        'Value': [total_images, images_kept, images_removed, removal_rate]
    }
    
    dataset_stats = []
    for dataset in df['dataset_name'].unique():
        dataset_data = df[df['dataset_name'] == dataset]
        removed = len(dataset_data[dataset_data['recommended_action'] == 'REMOVE'])
        total = len(dataset_data)
        
        dataset_stats.append({
            'Dataset': dataset,
            'Total_Images': total,
            'Images_Kept': total - removed,
            'Images_Removed': removed,
            'Removal_Rate_Percent': round(removed / total * 100, 1),
            'Mean_Quality_Score': round(dataset_data['overall_quality_score'].mean(), 3),
            'Std_Quality_Score': round(dataset_data['overall_quality_score'].std(), 3)
        })
    
    dr_stats = []
    for dr_class in sorted(df['dr_severity'].unique()):
        dr_data = df[df['dr_severity'] == dr_class]
        removed = len(dr_data[dr_data['recommended_action'] == 'REMOVE'])
        total = len(dr_data)
        
        dr_stats.append({
            'DR_Severity': dr_class,
            'DR_Class_Name': dr_data['dr_class_name'].iloc[0],
            'Total_Images': total,
            'Images_Kept': total - removed,
            'Images_Removed': removed,
            'Removal_Rate_Percent': round(removed / total * 100, 1),
            'Mean_Quality_Score': round(dr_data['overall_quality_score'].mean(), 3),
            'Std_Quality_Score': round(dr_data['overall_quality_score'].std(), 3)
        })
    
    quality_metrics = ['brightness', 'contrast', 'sharpness', 'entropy', 
                      'illumination_uniformity', 'vessel_visibility', 'optic_disc_visibility']
    
    metric_stats = []
    for metric in quality_metrics:
        if metric in df.columns:
            metric_stats.append({
                'Metric': metric,
                'Mean': round(df[metric].mean(), 4),
                'Std': round(df[metric].std(), 4),
                'Min': round(df[metric].min(), 4),
                'Max': round(df[metric].max(), 4),
                'Q25': round(df[metric].quantile(0.25), 4),
                'Q50': round(df[metric].quantile(0.50), 4),
                'Q75': round(df[metric].quantile(0.75), 4)
            })
    
    output_excel = output_file.replace('.csv', '.xlsx')
    
    try:
        with pd.ExcelWriter(output_excel, engine='openpyxl') as writer:
            pd.DataFrame(overall_stats).to_excel(writer, sheet_name='Overall_Stats', index=False)
            pd.DataFrame(dataset_stats).to_excel(writer, sheet_name='Dataset_Stats', index=False)
            pd.DataFrame(dr_stats).to_excel(writer, sheet_name='DR_Severity_Stats', index=False)
            pd.DataFrame(metric_stats).to_excel(writer, sheet_name='Quality_Metrics_Stats', index=False)
        
        logger.info(f"Summary statistics saved: {output_excel}")
        
    except ImportError:
        base_name = output_file.replace('.csv', '')
        pd.DataFrame(overall_stats).to_csv(f'{base_name}_overall.csv', index=False)
        pd.DataFrame(dataset_stats).to_csv(f'{base_name}_datasets.csv', index=False)
        pd.DataFrame(dr_stats).to_csv(f'{base_name}_dr_severity.csv', index=False)
        pd.DataFrame(metric_stats).to_csv(f'{base_name}_quality_metrics.csv', index=False)
        logger.info(f"Summary statistics saved as separate CSV files")
    
    return dataset_stats, dr_stats, metric_stats

def generate_analysis_files():
    """Main function to generate quality analysis files."""
    
    if 'all_results' not in locals() and 'all_results' not in globals():
        logger.error("Quality analysis results not found")
        return None
    
    results = globals().get('all_results', locals().get('all_results', []))
    
    if not results:
        logger.error("No analysis results available")
        return None
    
    logger.info(f"Processing {len(results):,} quality analysis results")
    
    analysis_df = create_quality_csv(results, 'quality_analysis.csv')
    
    dataset_stats, dr_stats, metric_stats = create_summary_statistics(
        analysis_df, 'summary_statistics.csv'
    )
    
    logger.info(f"Dataset Summary:")
    for stat in dataset_stats:
        logger.info(f"  {stat['Dataset']}: {stat['Total_Images']:,} images, "
                   f"{stat['Removal_Rate_Percent']}% removed")
    
    logger.info(f"DR Severity Summary:")
    for stat in dr_stats:
        logger.info(f"  {stat['DR_Class_Name']}: {stat['Total_Images']:,} images, "
                   f"{stat['Removal_Rate_Percent']}% removed")
    
    logger.info(f"Files generated:")
    logger.info(f"- quality_analysis.csv: Complete analysis dataset")
    logger.info(f"- summary_statistics.xlsx: Statistical summaries")
    
    return analysis_df, dataset_stats, dr_stats, metric_stats

# Execute analysis file generation
if 'all_results' in locals() and all_results:
    generate_analysis_files()
else:
    logger.warning("Analysis file generation skipped - no results available")