In [None]:
# 1. Setup and Imports

import os
import json
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import shutil
from datetime import datetime
import time
import logging
from collections import Counter, defaultdict
import gc
import random
import warnings
warnings.filterwarnings('ignore')

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("Imports completed")

In [None]:
# 2. Path and Configuration

BASE_PATH = "Path\\to\\your\\dataset"  # Specify the base path to your dataset here

# Save all the datasets to the BASE_PATH folder
# Path configuration based on your folder structure
datasets_config = {
    'APTOS2019': f'{BASE_PATH}/APTOS 2019',
    'Diabetic_Retinopathy_V03': f'{BASE_PATH}/Diabetic Retinopathy_V03',
    'IDRiD': f'{BASE_PATH}/IDRiD',
    'Messidor2': f'{BASE_PATH}/Messidor 2',
    'SUSTech_SYSU': f'{BASE_PATH}/SUSTech_SYSU',
    'DeepDRiD': f'{BASE_PATH}/DeepDRiD'
}

# Other configuration settings
OUTPUT_DIR = 'quality_review'
RANDOM_SEED = 42
N_SAMPLES_PER_DATASET = 300  # Number of images to sample for characterization

print("Configuration set")

In [None]:
# 3. Validate Dataset Paths

def validate_dataset_paths(datasets_config):
    valid_datasets = {}
    
    for name, path in datasets_config.items():
        print(f"\nChecking {name}:")
        print(f"  Path: {path}")
        
        if not os.path.exists(path):
            print(f"Path does not exist")
            continue
            
        # Check for DR class folders (0, 1, 2, 3, 4)
        dr_folders = []
        image_counts = {}
        
        try:
            for item in os.listdir(path):
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path):
                    if item.isdigit() and int(item) in [0, 1, 2, 3, 4]:
                        dr_class = int(item)
                        dr_folders.append(dr_class)
                        image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
                        image_count = 0
                        for root, dirs, files in os.walk(item_path):
                            for file in files:
                                if file.lower().endswith(image_extensions):
                                    image_count += 1
                        image_counts[dr_class] = image_count
            
            if dr_folders:
                dr_folders.sort()
                total_images = sum(image_counts.values())
                print(f" Found DR classes: {dr_folders}")
                print(f"Image counts per class:")
                dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
                for dr_class in dr_folders:
                    dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
                    print(f"     {dr_name} (Class {dr_class}): {image_counts[dr_class]:,} images")
                print(f"Total images: {total_images:,}")
                
                if total_images > 0:
                    valid_datasets[name] = path
                else:
                    print(f"No images found")
            else:
                print(f"No DR class folders (0,1,2,3,4) found")
                print(f"Available folders: {[item for item in os.listdir(path) if os.path.isdir(os.path.join(path, item))]}")

        except Exception as e:
            print(f"Error accessing path: {e}")

    return valid_datasets

# Run validation
valid_datasets = validate_dataset_paths(datasets_config)

if not valid_datasets:
    print("\nNo valid datasets found!")
    print("1. Update BASE_PATH in cell [2] to your actual dataset location")
    print("2. Ensure your datasets have folders named 0, 1, 2, 3, 4 containing images")
else:
    print(f"\n{len(valid_datasets)} Datasets loaded successfully")

In [None]:
# 4. Quality Identifier Class Definition

class QualityIdentifier:
    def __init__(self, output_dir='quality_review', random_seed=42):
        self.output_dir = output_dir
        self.dataset_profiles = {}
        self.identification_results = []
        self.random_seed = random_seed
        
        np.random.seed(random_seed)
        
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f'{output_dir}/sample_images', exist_ok=True)
        os.makedirs(f'{output_dir}/flagged_samples', exist_ok=True)
        
        logger.info(f"Quality identification output directory: {output_dir}")
    
    def extract_dr_label_from_path(self, image_path):
        parts = image_path.split(os.sep)
        
        for part in parts:
            if part.isdigit() and int(part) in [0, 1, 2, 3, 4]:
                return int(part)
        
        dr_patterns = {
            'no_dr': 0, 'normal': 0, 'grade_0': 0, 'class_0': 0,
            'mild': 1, 'grade_1': 1, 'class_1': 1,
            'moderate': 2, 'grade_2': 2, 'class_2': 2,
            'severe': 3, 'grade_3': 3, 'class_3': 3,
            'proliferative': 4, 'grade_4': 4, 'class_4': 4
        }
        
        for part in parts:
            part_lower = part.lower()
            if part_lower in dr_patterns:
                return dr_patterns[part_lower]
        return None
    
    def sample_images_strategically(self, dataset_path, n_samples):
        all_images = []
        class_images = {0: [], 1: [], 2: [], 3: [], 4: []}
        
        image_extensions = ('.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.JPG', '.JPEG', '.PNG')
        
        for root, dirs, files in os.walk(dataset_path):
            for file in files:
                if file.lower().endswith(image_extensions):
                    image_path = os.path.join(root, file)
                    dr_class = self.extract_dr_label_from_path(image_path)
                    if dr_class is not None:
                        class_images[dr_class].append(image_path)
        
        for dr_class, images in class_images.items():
            dr_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
            dr_name = dr_names[dr_class] if 0 <= dr_class < 5 else f'Class {dr_class}'
            if images:
                print(f"    {dr_name}: {len(images)} images")
        
        sampled_images = []
        total_available = sum(len(images) for images in class_images.values())
        
        if total_available == 0:
            print("No images found with valid DR labels")
            return []
        
        samples_per_class = max(1, n_samples // 5)
        
        for dr_class, images in class_images.items():
            if images:
                n_class_samples = min(samples_per_class, len(images))
                sampled = np.random.choice(images, n_class_samples, replace=False)
                sampled_images.extend(sampled)

        print(f"Selected {len(sampled_images)} stratified samples")
        return sampled_images
    
    def safe_calculate_brightness(self, image):
        try:
            brightness = float(np.mean(image))
            return max(0.0, min(255.0, brightness))
        except:
            return 127.5
    
    def safe_calculate_contrast(self, image):
        try:
            contrast = float(np.std(image))
            return max(0.0, contrast)
        except:
            return 50.0
    
    def safe_calculate_sharpness(self, gray):
        try:
            laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
            return max(0.0, float(laplacian_var))
        except:
            return 500.0
    
    def safe_calculate_entropy(self, gray):
        try:
            hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
            hist = hist.flatten()
            hist = hist[hist > 0]
            if len(hist) == 0:
                return 4.0
            prob = hist / hist.sum()
            entropy = float(-np.sum(prob * np.log2(prob)))
            return max(0.0, min(8.0, entropy))
        except:
            return 4.0

    def assess_vessel_visibility_improved(self, image):
        try:
            green = image[:, :, 1]
            
            kernels = {
                'horizontal': np.array([[-1, -1, -1], [2, 2, 2], [-1, -1, -1]], dtype=np.float32),
                'vertical': np.array([[-1, 2, -1], [-1, 2, -1], [-1, 2, -1]], dtype=np.float32),
                'diagonal1': np.array([[2, -1, -1], [-1, 2, -1], [-1, -1, 2]], dtype=np.float32),
                'diagonal2': np.array([[-1, -1, 2], [-1, 2, -1], [2, -1, -1]], dtype=np.float32)
            }
            
            vessel_responses = []
            for kernel in kernels.values():
                filtered = cv2.filter2D(green, cv2.CV_32F, kernel)
                vessel_responses.append(filtered)
            
            max_response = np.maximum.reduce(vessel_responses)
            threshold = np.percentile(max_response, 95)
            vessel_pixels = np.sum(max_response > threshold)
            total_pixels = max_response.shape[0] * max_response.shape[1]
            
            visibility_score = vessel_pixels / max(total_pixels, 1)
            return float(min(0.1, visibility_score))
            
        except Exception:
            return 0.0

    def assess_optic_disc_visibility(self, image):
        try:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            bright_pixels = np.sum(gray > np.percentile(gray, 90))
            total_pixels = gray.shape[0] * gray.shape[1]
            return float(bright_pixels / max(total_pixels, 1))
        except Exception:
            return 0.0

    def assess_illumination_uniformity(self, gray):
        try:
            h, w = gray.shape
            regions = []
            
            for i in range(3):
                for j in range(3):
                    start_y, end_y = i * h // 3, (i + 1) * h // 3
                    start_x, end_x = j * w // 3, (j + 1) * w // 3
                    region = gray[start_y:end_y, start_x:end_x]
                    regions.append(np.mean(region))
            
            mean_intensity = np.mean(regions)
            if mean_intensity > 0:
                cv_score = np.std(regions) / mean_intensity
                return float(max(0, min(1, 1 - cv_score)))
            return 0.0
        except Exception:
            return 0.0

    def detect_extreme_pixels(self, gray):
        try:
            very_dark = np.sum(gray < 10)
            very_bright = np.sum(gray > 245)
            total_pixels = gray.shape[0] * gray.shape[1]
            return float((very_dark + very_bright) / max(total_pixels, 1))
        except Exception:
            return 0.0

    def assess_motion_blur(self, gray):
        try:
            grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
            grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
            magnitude = np.sqrt(grad_x**2 + grad_y**2)
            return float(np.mean(magnitude))
        except Exception:
            return 0.0

    def analyze_single_image(self, image_path):
        try:
            if not os.path.exists(image_path):
                logger.warning(f"File not found: {image_path}")
                return None
                
            image = cv2.imread(image_path)
            if image is None:
                logger.warning(f"Could not read image: {image_path}")
                return None
            
            if len(image.shape) != 3:
                logger.warning(f"Invalid image format: {image_path}")
                return None
                
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            h, w = image.shape[:2]

            if h < 100 or w < 100:
                logger.warning(f"Image too small ({w}x{h}): {image_path}")
                return None
            
            characteristics = {
                'image_path': image_path,
                'filename': os.path.basename(image_path),
                'resolution': (w, h),
                'file_size_mb': max(0.001, os.path.getsize(image_path) / (1024*1024)),
                
                'brightness': self.safe_calculate_brightness(image),
                'contrast': self.safe_calculate_contrast(image),
                'sharpness': self.safe_calculate_sharpness(gray),
                'entropy': self.safe_calculate_entropy(gray),
                
                'color_balance': float(np.std([np.mean(image[:,:,0]), np.mean(image[:,:,1]), np.mean(image[:,:,2])])),
                
                'vessel_visibility': self.assess_vessel_visibility_improved(image),
                'optic_disc_visibility': self.assess_optic_disc_visibility(image),
                'illumination_uniformity': self.assess_illumination_uniformity(gray),
                
                'extreme_brightness_pixels': self.detect_extreme_pixels(gray),
                'motion_blur_score': self.assess_motion_blur(gray)
            }
            return characteristics
            
        except Exception as e:
            logger.error(f"Error analyzing {image_path}: {e}")
            return None

print("QualityIdentifier class defined")

In [None]:
# 4.5. Parameter Validation

class QualityParameters:
    def __init__(self):
        # Validated thresholds
        self.BRIGHTNESS_BOUNDS = (10, 245)
        self.BLACK_THRESHOLD = 30
        self.VESSEL_PERCENTILE = 95
        self.OPTIC_DISC_PERCENTILE = 90
        
        # Severity-specific adaptive thresholds
        self.SEVERITY_PERCENTILES = {
            0: 15,  # No DR - strictest
            1: 12,  # Mild DR
            2: 10,  # Moderate DR
            3: 8,   # Severe DR
            4: 5    # PDR - most relaxed
        }
        
        # Quality composition weights
        self.BASIC_WEIGHT = 0.3
        self.MEDICAL_WEIGHT = 0.7
        
        # Minimum image requirements
        self.MIN_RESOLUTION = (100, 100)
        self.MIN_FILE_SIZE_MB = 0.001
    
    def validate_bounds(self):
        assert 0 <= self.BRIGHTNESS_BOUNDS[0] < self.BRIGHTNESS_BOUNDS[1] <= 255
        assert 0 < self.BLACK_THRESHOLD < 255
        assert 0 < self.VESSEL_PERCENTILE <= 100
        assert 0 < self.OPTIC_DISC_PERCENTILE <= 100
        assert abs(self.BASIC_WEIGHT + self.MEDICAL_WEIGHT - 1.0) < 0.001
        return True

# Initialize and validate parameters
quality_params = QualityParameters()
quality_params.validate_bounds()
print("Quality parameters validated")