# Phase 2: Data Preprocessing and Training Preparation
This notebook implements preprocessing methods, parasitemia scoring, and data preparation for MTTL training

## 2.1 Setup and Imports
*We start being importing all necessary libraries and set up paths for preprocessing pipeline*

In [None]:
import sys
import os
sys.path.append('..')
import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context('notebook')
from PIL import Image
import cv2
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
from tqdm.notebook import tqdm
import json
import pickle
from sklearn.model_selection import train_test_split
from scipy.ndimage import gaussian_filter
import warnings
warnings.filterwarnings('ignore')

# If staintools not available, we'll fall back to our custom implementations
try:
    import staintools
    STAINTOOLS_AVAILABLE = True
    print("staintools library available")
except ImportError:
    STAINTOOLS_AVAILABLE = False
    print("staintools library not available - using alternative implementations")
    
from src.utils.DataUtils import set_seeds, MalariaPreprocessor, MalariaDataPreparator
set_seeds(12)

# Target size
TARGET_SIZE = (512,512)

In [None]:
NLM_ROOT = os.path.join('..', 'data', 'NIH-NLM-ThinBloodSmearsPf')
POINT_SET_DIR = os.path.join(NLM_ROOT, 'Point Set')
OUTPUT_DIR = os.path.join('..', 'data', 'preprocessed_NLM')
os.makedirs(os.path.join(OUTPUT_DIR, 'metadata'), exist_ok=True)

## 2.2 Data Loading and Path Setup
*We then load the parsed annotations from Phase 1*

In [None]:
# Load annotations from Phase 1 
def parse_point_set_annotations(point_set_dir):
    annotations = []
    point_set_folders = sorted([os.path.join(point_set_dir, d) for d in os.listdir(point_set_dir) 
                               if os.path.isdir(os.path.join(point_set_dir, d))])
    
    for folder in tqdm(point_set_folders, desc="Parsing annotations"):
        gt_dir = os.path.join(folder, 'GT')
        img_dir = os.path.join(folder, 'Img')
        if not os.path.isdir(gt_dir) or not os.path.isdir(img_dir):
            continue
            
        for ann_file in os.listdir(gt_dir):
            if not ann_file.lower().endswith('.txt'):
                continue
            ann_path = os.path.join(gt_dir, ann_file)
            img_name = ann_file.replace('.txt', '.jpg')
            img_path = os.path.join(img_dir, img_name)
            
            if not os.path.exists(img_path):
                continue
                
            with open(ann_path, 'r', encoding='utf-8') as f:
                lines = f.read().strip().split('\n')
                if len(lines) < 2:
                    continue
                    
                for line in lines[1:]:
                    parts = line.split(',')
                    if len(parts) < 7:
                        continue
                    cell_type = parts[1]
                    shape = parts[3]
                    
                    if shape == 'Point':
                        x = float(parts[5])
                        y = float(parts[6])
                        annotations.append({
                            'image_path': img_path,
                            'cell_type': cell_type,
                            'shape': shape,
                            'x': x,
                            'y': y
                        })
                    elif shape == 'Polygon':
                        n_points = int(parts[4])
                        coords = [float(v) for v in parts[5:5+2*n_points]]
                        xy = list(zip(coords[::2], coords[1::2]))
                        xs, ys = zip(*xy)
                        bbox = [min(xs), min(ys), max(xs), max(ys)]
                        annotations.append({
                            'image_path': img_path,
                            'cell_type': cell_type,
                            'shape': shape,
                            'polygon': xy,
                            'bbox': bbox
                        })
    return annotations

print("Loading annotations...")
annotations = parse_point_set_annotations(POINT_SET_DIR)
print(f"Loaded {len(annotations)} annotations")

# Group annotations by image
image_to_anns = defaultdict(list)
for ann in annotations:
    image_to_anns[ann['image_path']].append(ann)

print(f"Found {len(image_to_anns)} unique images")


## 2.3 Test Image Selection
*Next, We select representative test images for preprocessing method comparison*

In [None]:
print("Selecting test images for preprocessing comparison...")
available_images = list(image_to_anns.keys())
test_images = random.sample(available_images, min(3, len(available_images)))
print(f"Selected {len(test_images)} test images:")
for img in test_images:
    print(f"  - {os.path.basename(img)}")

## 2.4 Preprocessing Methods Implementation and Testing

### 2.4.1 Preprocessing Methods Comparison
*Test different preprocessing methods on sample images to determine the best approach*

**Methods Tested:**
- **Original**: Raw microscopy images
- **Resized**: Standard resize to 224x224
- **CLAHE**: Contrast Limited Adaptive Histogram Equalization 
- **Macenko**: Macenko stain normalization for consistent color
- **Reinhard**: Reinhard color normalization
- **Color Deconvolution**: Enhanced structure separation

In [None]:
# We'll compare preprocessing methods and their effects on original images
def compare_preprocessing_methods(image_path, preprocessor):
    # Load original image
    original = np.array(Image.open(image_path).convert('RGB'))
    
    print(f"\nProcessing image: {os.path.basename(image_path)}")
    print(f"Original image shape: {original.shape}, dtype: {original.dtype}")
    print(f"Original intensity range: [{original.min()}, {original.max()}]")
    print("-" * 60)
    
    # Apply all methods
    #resized = preprocessor.resize_image(original, maintain_aspect=False)
    resized = preprocessor.resize_image(original, maintain_aspect=False)
    clahe = preprocessor.clahe_normalization(resized.copy())
    macenko = preprocessor.macenko_normalization(clahe.copy())
    reinhard = preprocessor.reinhard_normalization(clahe.copy())
    deconv = preprocessor.enhanced_color_deconvolution(clahe.copy())
    
    methods = {
        'Original': original,
        'Resized': resized,
        'CLAHE': clahe,
        'Macenko': macenko,
        'Reinhard': reinhard,
        'Color Deconv': deconv
    }
    
    print("-" * 60)
    
    return methods

# Test the preprocessing
print("Testing preprocessing methods...")
preprocessor = MalariaPreprocessor(target_size=TARGET_SIZE)
preprocessing_results = {}

for img_path in test_images[:2]: 
    results = compare_preprocessing_methods(img_path, preprocessor)
    preprocessing_results[os.path.basename(img_path)] = results

### 2.4.2 Visual Preprocessing Comparison
*Generate side-by-side visual comparisons of all preprocessing methods*

In [None]:
# comparison plot of preprocessing methods
def plot_preprocessing_comparison(results_dict):
    
    for img_name, results in results_dict.items():
        # Order of methods to display
        methods_order = ['Original', 'Resized', 'CLAHE', 'Macenko', 'Reinhard', 'Color Deconv']
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle(f'Preprocessing Methods Comparison - {img_name}', fontsize=20, fontweight='bold')
        
        for i, method in enumerate(methods_order):
            row = i // 3
            col = i % 3
            
            if method in results:
                axes[row, col].imshow(results[method])
                axes[row, col].set_title(method, fontsize=16, fontweight='bold', pad=15)
                axes[row, col].axis('off')
                
                # Subtle border around each image
                for spine in axes[row, col].spines.values():
                    spine.set_visible(True)
                    spine.set_linewidth(2)
                    spine.set_color('gray')
        
        # Spacing between subplots
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()
        print(f"Completed visualization for: {img_name}\n")

# Visualization
print("Visualizing preprocessing methods ...")
plot_preprocessing_comparison(preprocessing_results)

### 2.4.3 RGB Histogram Analysis
*Analyze color distribution changes across different preprocessing methods*

In [None]:
# RGB histograms for preprocessing comparison
def plot_histogram_comparison(results_dict):
    colors = ['red', 'green', 'blue']
    
    for img_name, results in results_dict.items():
        methods_order = ['Original', 'Resized', 'CLAHE', 'Macenko', 'Reinhard', 'Color Deconv']
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle(f'RGB Histograms Comparison - {img_name}', fontsize=20, fontweight='bold')
        
        for i, method in enumerate(methods_order):
            row = i // 3
            col = i % 3
            
            if method in results:
                img = results[method]
                
                # histogram for each RGB channel
                for c, color in enumerate(colors):
                    hist = cv2.calcHist([img], [c], None, [256], [0, 256])
                    axes[row, col].plot(hist, color=color, alpha=0.8, linewidth=2, 
                                      label=f'{color.upper()} channel')
                
                axes[row, col].set_title(method, fontsize=14, fontweight='bold')
                axes[row, col].set_xlabel('Pixel Intensity', fontsize=12)
                axes[row, col].set_ylabel('Frequency', fontsize=12)
                axes[row, col].legend(loc='upper right', fontsize=10)
                axes[row, col].grid(True, alpha=0.3)
                
                # consistent y-axis limits
                axes[row, col].set_ylim(0, None)
        
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()
        print(f"Completed histogram analysis for: {img_name}\n")

print("\nGenerating histogram comparisons...")
plot_histogram_comparison(preprocessing_results)

### 2.4.4 Quantitative Preprocessing Metrics
*Calculate objective metrics to compare preprocessing effectiveness*

**Metrics Calculated:**
- **Contrast Ratio**: Standard deviation / mean intensity
- **RMS Contrast**: Root mean square contrast measure
- **Information Entropy**: Measures information content preservation
- **Edge Strength**: Sobel gradient magnitude for cell boundary definition
- **Dynamic Range**: Full intensity range utilization
- **Signal-to-Noise Ratio**: Quality estimation using Laplacian variance
- **Color Consistency**: Coefficient of variation across channels



In [None]:
# Calculate quantitative metrics for preprocessing comparison
def calculate_preprocessing_metrics(results_dict):
    from scipy import stats
    from skimage import measure, filters
    
    metrics_data = []
    
    print("Calculating preprocessing metrics...")
    print("=" * 70)
    
    for img_name, results in results_dict.items():
        print(f"\nAnalyzing {img_name}:")
        
        for method_name, img in results.items():
            if method_name == 'Original':
                continue
            
            # Convert to grayscale for some metrics
            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            
            # 1. Basic intensity statistics
            mean_intensity = np.mean(img)
            std_intensity = np.std(img)
            
            # 2. Contrast metrics
            contrast_ratio = std_intensity / mean_intensity if mean_intensity > 0 else 0
            rms_contrast = np.sqrt(np.mean((gray - np.mean(gray))**2))
            
            # 3. Entropy (information content)
            hist, _ = np.histogram(gray, bins=256, range=(0,256))
            hist = hist + 1e-10  # Avoid log(0)
            hist_norm = hist / np.sum(hist)
            entropy = -np.sum(hist_norm * np.log2(hist_norm))
            
            # 4. Color distribution metrics
            color_std = [np.std(img[:,:,c]) for c in range(3)]
            avg_color_std = np.mean(color_std)
            
            # 5. Edge content (Sobel gradient magnitude)
            sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
            sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
            edge_magnitude = np.sqrt(sobelx**2 + sobely**2)
            avg_edge_strength = np.mean(edge_magnitude)
            
            # 6. Dynamic range
            dynamic_range = np.max(img) - np.min(img)
            
            # 7. Signal-to-noise ratio estimation
            # Using Laplacian variance as noise estimate
            laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
            snr_estimate = np.mean(gray)**2 / (laplacian_var + 1e-10)
            
            # 8. Color consistency (coefficient of variation for each channel)
            color_consistency = [std_intensity / (mean_intensity + 1e-10) for _ in range(3)]
            avg_color_consistency = np.mean(color_consistency)
            
            metrics_data.append({
                'Image': img_name,
                'Method': method_name,
                'Mean_Intensity': mean_intensity,
                'Std_Intensity': std_intensity,
                'Contrast_Ratio': contrast_ratio,
                'RMS_Contrast': rms_contrast,
                'Entropy': entropy,
                'Avg_Color_Std': avg_color_std,
                'Edge_Strength': avg_edge_strength,
                'Dynamic_Range': dynamic_range,
                'SNR_Estimate': snr_estimate,
                'Color_Consistency': avg_color_consistency
            })
            
            print(f"  {method_name:12} | Contrast: {contrast_ratio:.3f} | Entropy: {entropy:.2f} | Edge: {avg_edge_strength:.1f}")
    
    return pd.DataFrame(metrics_data)

# Method selection and recommendation system
def recommend_best_preprocessing_method(metrics_df, weights=None):
    if weights is None:
        # Default weights for our malaria microscopy based on what is important
        weights = {
            'contrast_weight': 0.25,      # Good contrast is important
            'entropy_weight': 0.20,       # Information preservation
            'edge_weight': 0.20,          # Cell boundary definition
            'consistency_weight': 0.15,   # Color stability
            'snr_weight': 0.10,          # Noise reduction
            'dynamic_range_weight': 0.10  # Full intensity utilization
        }
    
    print("\n" + "="*80)
    print("PREPROCESSING METHOD EVALUATION")
    print("="*80)
    
    # Summary statistics by method
    method_stats = metrics_df.groupby('Method').agg({
        'Contrast_Ratio': ['mean', 'std'],
        'Entropy': ['mean', 'std'], 
        'Edge_Strength': ['mean', 'std'],
        'Color_Consistency': ['mean', 'std'],
        'SNR_Estimate': ['mean', 'std'],
        'Dynamic_Range': ['mean', 'std']
    }).round(4)
    
    print("\nSUMMARY STATISTICS BY METHOD:")
    print("-" * 50)
    print(method_stats)
    
    # Scoring system
    methods = metrics_df['Method'].unique()
    method_scores = {}
    detailed_scores = {}
    
    print(f"\nDETAILED SCORING ANALYSIS:")
    print("-" * 50)
    
    for method in methods:
        method_data = metrics_df[metrics_df['Method'] == method]
        
        # Calculate normalized scores (0-1 scale)
        avg_contrast = method_data['Contrast_Ratio'].mean()
        avg_entropy = method_data['Entropy'].mean()
        avg_edge = method_data['Edge_Strength'].mean()
        avg_consistency = 1 / (method_data['Color_Consistency'].mean() + 1e-6)  # Lower is better
        avg_snr = method_data['SNR_Estimate'].mean()
        avg_dynamic_range = method_data['Dynamic_Range'].mean()
        
        # Normalize scores to 0-1 range based on optimal values for microscopy
        contrast_score = min(avg_contrast / 0.4, 1.0)  # Optimal ~0.3-0.4
        entropy_score = min(avg_entropy / 8.0, 1.0)    # Higher entropy is better
        edge_score = min(avg_edge / 50.0, 1.0)         # Good edge definition
        consistency_score = min(avg_consistency / 10.0, 1.0)  # Consistency important
        snr_score = min(avg_snr / 100.0, 1.0)          # Good SNR
        range_score = min(avg_dynamic_range / 255.0, 1.0)  # Full range utilization
        
        # weighted total score
        total_score = (
            contrast_score * weights['contrast_weight'] +
            entropy_score * weights['entropy_weight'] +
            edge_score * weights['edge_weight'] +
            consistency_score * weights['consistency_weight'] +
            snr_score * weights['snr_weight'] +
            range_score * weights['dynamic_range_weight']
        )
        
        method_scores[method] = total_score
        detailed_scores[method] = {
            'contrast': contrast_score,
            'entropy': entropy_score,
            'edge': edge_score,
            'consistency': consistency_score,
            'snr': snr_score,
            'range': range_score,
            'total': total_score
        }
        
        print(f"\n{method.upper()}:")
        print(f"Contrast Score:    {contrast_score:.3f} (avg ratio: {avg_contrast:.3f})")
        print(f"Entropy Score:     {entropy_score:.3f} (avg entropy: {avg_entropy:.2f})")
        print(f"Edge Score:        {edge_score:.3f} (avg strength: {avg_edge:.1f})")
        print(f"Consistency Score: {consistency_score:.3f} (consistency: {1/avg_consistency:.3f})")
        print(f"SNR Score:         {snr_score:.3f} (avg SNR: {avg_snr:.1f})")
        print(f"Range Score:       {range_score:.3f} (avg range: {avg_dynamic_range:.0f})")
        print(f"TOTAL SCORE:       {total_score:.3f}")
    
    # Find best method
    best_method_preprocessing = max(method_scores, key=method_scores.get)
    best_score = method_scores[best_method_preprocessing]
    
    # Sort methods by score
    ranked_methods = sorted(method_scores.items(), key=lambda x: x[1], reverse=True)
    
    print(f"\n" + "="*50)
    print("FINAL RANKING:")
    print("="*50)
    
    for i, (method, score) in enumerate(ranked_methods, 1):
        status = "RECOMMENDED" if method == best_method_preprocessing else f"#{i}"
        print(f"{status:15} {method:15} Score: {score:.3f}")
    
    print(f"\nBEST METHOD FOR MALARIA DETECTION: {best_method_preprocessing}")
    print(f"   Final Score: {best_score:.3f}")
    
    # Method-specific recommendations
    print(f"\nANALYSIS SUMMARY:")
    print("-" * 30)
    
    if best_method_preprocessing == 'CLAHE':
        print("CLAHE provides excellent contrast enhancement while preserving detail, Ideal for enhancing cell structures in microscopy images")
    elif best_method_preprocessing == 'Macenko':
        print("Macenko normalization provides consistent stain appearance, Excellent for standardizing images across different preparations")
    elif best_method_preprocessing == 'Reinhard':
        print("Reinhard normalization provides stable color distribution, Good for maintaining consistent color profiles")
    elif best_method_preprocessing == 'Color Deconv':
        print("Color deconvolution enhances cellular structures, Excellent for separating different staining components")
    
    return best_method_preprocessing, method_scores, detailed_scores

# Visualization of metrics comparison
def plot_metrics_comparison(metrics_df):
    # plotting style
    plt.style.use('seaborn-v0_8-whitegrid')

    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    fig.suptitle('Preprocessing Methods - Quantitative Metrics Comparison', 
                 fontsize=16, fontweight='bold', y=0.98)
    
    methods = metrics_df['Method'].unique()
    colors = plt.cm.Set3(np.linspace(0, 1, len(methods)))
    
    # 1. Contrast comparison
    ax = axes[0, 0]
    metrics_df.boxplot(column='Contrast_Ratio', by='Method', ax=ax)
    ax.set_title('Contrast Ratio Distribution')
    ax.set_xlabel('Method')
    ax.set_ylabel('Contrast Ratio')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
    
    # 2. Entropy comparison
    ax = axes[0, 1]
    metrics_df.boxplot(column='Entropy', by='Method', ax=ax)
    ax.set_title('Information Content (Entropy)')
    ax.set_xlabel('Method')
    ax.set_ylabel('Entropy')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
    
    # 3. Edge strength comparison
    ax = axes[0, 2]
    metrics_df.boxplot(column='Edge_Strength', by='Method', ax=ax)
    ax.set_title('Edge Definition Strength')
    ax.set_xlabel('Method')
    ax.set_ylabel('Edge Strength')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
    
    # 4. SNR comparison
    ax = axes[1, 0]
    metrics_df.boxplot(column='SNR_Estimate', by='Method', ax=ax)
    ax.set_title('Signal-to-Noise Ratio')
    ax.set_xlabel('Method')
    ax.set_ylabel('SNR Estimate')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
    
    # 5. Dynamic range comparison
    ax = axes[1, 1]
    metrics_df.boxplot(column='Dynamic_Range', by='Method', ax=ax)
    ax.set_title('Dynamic Range Utilization')
    ax.set_xlabel('Method')
    ax.set_ylabel('Dynamic Range')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
    
    # 6. Overall score visualization
    ax = axes[1, 2]
    method_means = metrics_df.groupby('Method')[['Contrast_Ratio', 'Entropy', 'Edge_Strength']].mean()
    method_means.plot(kind='bar', ax=ax)
    ax.set_title('Key Metrics Overview')
    ax.set_xlabel('Method')
    ax.set_ylabel('Normalized Values')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
    
    plt.tight_layout()
    plt.show()

### 2.4.5 Preprocessing Method Recommendation
*Systematic evaluation and selection of optimal preprocessing method for malaria microscopy*

**Evaluation Criteria:**
- **Contrast Enhancement (25%)**: Good contrast crucial for cell differentiation
- **Information Preservation (20%)**: Maintain diagnostic details
- **Edge Definition (20%)**: Clear cell boundaries for detection
- **Color Consistency (15%)**: Stable color profiles across images
- **Noise Reduction (10%)**: Improved signal quality
- **Dynamic Range (10%)**: Full intensity utilization

In [None]:
# Execute the analysis
def run_complete_preprocessing_analysis():    
    print("MALARIA PREPROCESSING ANALYSIS...")
    print("="*60)
    
    # 1. Calculate metrics
    print("\n1. Calculating metrics...")
    metrics_df = calculate_preprocessing_metrics(preprocessing_results)
    
    # 2. Display detailed metrics table
    print(f"\n2. Detailed Metrics Table:")
    print("-" * 60)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', None)
    print(metrics_df.round(3))
    
    # 3. Create metrics visualization
    print(f"\n3. Generating metrics visualization...")
    plot_metrics_comparison(metrics_df)
    
    # 4. Recommend best method
    print(f"\n4. Method recommendation analysis...")
    best_method_preprocessing, method_scores, detailed_scores = recommend_best_preprocessing_method(metrics_df)
    
    # 5. Save results
    results_summary = {
        'best_method_preprocessing': best_method_preprocessing,
        'method_scores': method_scores,
        'detailed_scores': detailed_scores,
        'metrics_table': metrics_df.to_dict('records')
    }
    
    # Save to file
    results_path = os.path.join(OUTPUT_DIR, 'preprocessing_analysis_results.json')
    with open(results_path, 'w') as f:
        json.dump(results_summary, f, indent=2, default=str)
    
    print(f"\nResults saved to: {results_path}")
    print(f"\nPREPROCESSING ANALYSIS COMPLETE!")
    print(f"Recommended method: {best_method_preprocessing}")
    print(f"Analysis based on {len(metrics_df)} measurements")
    
    return best_method_preprocessing, metrics_df, method_scores

# Run the analysis
if 'preprocessing_results' in locals() and preprocessing_results:
    best_method_preprocessing, metrics_df, scores = run_complete_preprocessing_analysis()
else:
    print("Error: preprocessing_results not found. Run the preprocessing comparison first.")

## 2.5 Parasitemia Scoring Implementation

### 2.5.1 ParasitemiaScorer Initialization
*Initialize the parasitemia scoring system with target image size*


In [None]:
from src.utils.DataUtils import ParasitemiaScorer

# Initialize parasitemia scorer
print("\nInitializing ParasitemiaScorer...")
parasitemia_scorer = ParasitemiaScorer(target_image_size=TARGET_SIZE)

### 2.5.2 Parasitemia Scoring Methods Testing
*Test all three parasitemia scoring approaches on sample images*

**Three Scoring Approaches Implemented:**

1. **Count-based Method (Clinical Standard)**:
   ```
   Parasitemia = (Infected_Cells / Total_RBCs) × 100%
   ```
   - Most clinically relevant approach
   - Direct percentage of infected red blood cells
   - Standard medical practice metric from WHO

2. **Area-based Method (Severity Weighting)**:
   ```
   Parasitemia = (Infected_Area / Total_Cell_Area) × 100%
   ```
   - Accounts for infection severity (larger parasites = higher score)
   - Good for detecting advanced infections
   - Considers parasite size variations

3. **Density-based Method (Spatial Distribution)**:
   ```
   Parasitemia = (Infected_Objects / Image_Area) × 1000 per 1000px²
   ```
   - Normalizes by image area for consistency
   - Spatial distribution aware
   - Good for different cell density images

In [None]:
# STEP 2: Test parasitemia scoring on sample data 
print("\n" + "="*50)
print("TESTING PARASITEMIA SCORING ON SAMPLE DATA")
print("="*50)

def test_parasitemia_scoring_on_samples():
    
    print(f"Testing on {len(test_images)} sample images...")
    
    all_parasitemia_results = []
    
    for i, img_path in enumerate(test_images[:3], 1):  # Test on up to 3 images
        print(f"\nProcessing Image {i}: {os.path.basename(img_path)}")
        
        # Get annotations 
        if img_path in image_to_anns:
            annotations = image_to_anns[img_path]
            print(f"Found {len(annotations)} annotations")
            
            # Annotation types
            cell_types = [ann.get('cell_type', 'unknown') for ann in annotations]
            type_counts = Counter(cell_types)
            print(f"Cell types found: {dict(type_counts)}")
            
            # Calculate parasitemia scores
            results = parasitemia_scorer.calculate_all_scores(annotations, img_path)
            all_parasitemia_results.append(results)
            
            # Display results
            print(f"\nPARASITEMIA SCORES:")
            print(f"Count-based:   {results['count_based']['score']:.2f}%")
            print(f"Area-based:    {results['area_based']['score']:.2f}%")
            print(f"Density-based: {results['density_based']['score']:.2f} per 1000px²")
            
            print(f"\nDETAILED BREAKDOWN:")
            print(f"Infected cells: {results['summary']['infected_cells']}")
            print(f"Healthy cells:  {results['summary']['healthy_cells']}")
            print(f"Unknown cells:  {results['summary']['unknown_cells']}")
            print(f"Total annotations: {results['summary']['total_annotations']}")
            
        else:
            print(f"No annotations found for {img_path}")
    
    return all_parasitemia_results

# Running the test
sample_results = test_parasitemia_scoring_on_samples()

### 2.5.3 Comprehensive Parasitemia Analysis
*Calculate parasitemia scores for all images using all three methods*

In [None]:
# STEP 3: Calculate parasitemia for ALL images
print("\n" + "="*50)
print("CALCULATING PARASITEMIA FOR ALL IMAGES")
print("="*50)

def calculate_parasitemia_for_all_images():
    """Calculate parasitemia scores for all available images"""
    
    print(f"Processing {len(image_to_anns)} total images...")
    
    all_results = []
    failed_images = []
    
    for i, (img_path, annotations) in enumerate(tqdm(image_to_anns.items(), 
                                                   desc="Calculating parasitemia")):
        try:
            results = parasitemia_scorer.calculate_all_scores(annotations, img_path)
            all_results.append(results)
                
        except Exception as e:
            print(f"Failed to process {os.path.basename(img_path)}: {str(e)}")
            failed_images.append((img_path, str(e)))
    
    print(f"\nSuccessfully processed {len(all_results)} images")
    if failed_images:
        print(f"Failed to process {len(failed_images)} images")
        for img_path, error in failed_images[:5]: 
            print(f"  - {os.path.basename(img_path)}: {error}")
    
    return all_results, failed_images

# Calculate for all images
print("Starting comprehensive parasitemia calculation...")
all_parasitemia_results, failed_images = calculate_parasitemia_for_all_images()

### 2.5.4 Parasitemia Methods Comparison and Visualization
*Generate comprehensive analysis plots comparing all scoring methods*

**Analysis Includes:**
- Score distribution histograms
- Box plot comparisons
- Correlation analysis between methods
- Statistical summary tables

In [None]:
# STEP 4: Analyze and compare all scoring methods 
print("\n" + "="*50)
print("SCORING METHODS ANALYSIS")
print("="*50)

def create_parasitemia_analysis_plots(analysis_results):
    if not analysis_results:
        print("No analysis results to plot!")
        return
    
    stats = analysis_results['statistics']
    all_results = analysis_results['all_results']
    
    # Data for plotting
    count_scores = [r['count_based']['score'] for r in all_results]
    area_scores = [r['area_based']['score'] for r in all_results]
    density_scores = [r['density_based']['score'] for r in all_results]
    
    # Create plot
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    fig.suptitle('Parasitemia Scoring Methods - Comprehensive Analysis', 
                 fontsize=16, fontweight='bold')
    
    # 1. Score distributions
    ax = axes[0, 0]
    ax.hist(count_scores, bins=20, alpha=0.7, label='Count-based', color='blue')
    ax.hist(area_scores, bins=20, alpha=0.7, label='Area-based', color='red')
    ax.hist(density_scores, bins=20, alpha=0.7, label='Density-based', color='green')
    ax.set_title('Score Distributions')
    ax.set_xlabel('Parasitemia Score')
    ax.set_ylabel('Frequency')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. Box plots comparison
    ax = axes[0, 1]
    data_to_plot = [count_scores, area_scores, density_scores]
    labels = ['Count-based', 'Area-based', 'Density-based']
    box_plot = ax.boxplot(data_to_plot, labels=labels, patch_artist=True)
    colors = ['lightblue', 'lightcoral', 'lightgreen']
    for patch, color in zip(box_plot['boxes'], colors):
        patch.set_facecolor(color)
    ax.set_title('Score Distributions Comparison')
    ax.set_ylabel('Parasitemia Score')
    ax.grid(True, alpha=0.3)
    
    # 3. Count vs Area correlation
    ax = axes[0, 2]
    ax.scatter(count_scores, area_scores, alpha=0.6, color='purple')
    ax.set_title('Count-based vs Area-based Scores')
    ax.set_xlabel('Count-based Score (%)')
    ax.set_ylabel('Area-based Score (%)')
    ax.grid(True, alpha=0.3)
    
    # correlation coefficient
    corr = analysis_results['correlations']['count_vs_area']['correlation']
    ax.text(0.05, 0.95, f'r = {corr:.3f}', transform=ax.transAxes, 
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 4. Count vs Density correlation
    ax = axes[1, 0]
    ax.scatter(count_scores, density_scores, alpha=0.6, color='orange')
    ax.set_title('Count-based vs Density-based Scores')
    ax.set_xlabel('Count-based Score (%)')
    ax.set_ylabel('Density-based Score (per 1000px²)')
    ax.grid(True, alpha=0.3)
    
    corr = analysis_results['correlations']['count_vs_density']['correlation']
    ax.text(0.05, 0.95, f'r = {corr:.3f}', transform=ax.transAxes,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 5. Area vs Density correlation
    ax = axes[1, 1]
    ax.scatter(area_scores, density_scores, alpha=0.6, color='brown')
    ax.set_title('Area-based vs Density-based Scores')
    ax.set_xlabel('Area-based Score (%)')
    ax.set_ylabel('Density-based Score (per 1000px²)')
    ax.grid(True, alpha=0.3)
    
    corr = analysis_results['correlations']['area_vs_density']['correlation']
    ax.text(0.05, 0.95, f'r = {corr:.3f}', transform=ax.transAxes,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 6. Summary stats table
    ax = axes[1, 2]
    ax.axis('off')
    
    table_data = []
    for method, stat in stats.items():
        table_data.append([
            method,
            f"{stat['mean']:.2f}",
            f"{stat['std']:.2f}",
            f"{stat['min']:.2f}",
            f"{stat['max']:.2f}"
        ])
    
    table = ax.table(cellText=table_data,
                    colLabels=['Method', 'Mean', 'Std', 'Min', 'Max'],
                    cellLoc='center',
                    loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)
    ax.set_title('Summary Statistics')
    
    plt.tight_layout()
    plt.show()

# Run analysis
print("Analysis of all scoring methods...")
analysis_results = parasitemia_scorer.analyze_scoring_methods(all_parasitemia_results)

if analysis_results:
    print("\nGenerating analysis plots...")
    create_parasitemia_analysis_plots(analysis_results)
else:
    print("Analysis failed - no results to analyze")

### 2.5.5 Parasitemia Method Selection
*Systematic evaluation and recommendation of best parasitemia scoring method*

**Selection Criteria:**
- **Dynamic Range (25%)**: Ability to discriminate infection levels
- **Coefficient of Variation (25%)**: Optimal balance (not too uniform/variable)
- **Detection Rate (25%)**: Ability to detect parasitemia presence
- **Clinical Relevance (25%)**: Medical standard alignment


In [None]:
# STEP 5: Choose the best method and create final dataset 
print("\n" + "="*50)
print("METHOD SELECTION AND FINAL DATASET CREATION")
print("="*50)

def recommend_best_parasitemia_method(analysis_results):
    
    if not analysis_results:
        print("No analysis results available!")
        return None
    
    stats = analysis_results['statistics']
    correlations = analysis_results['correlations']
    
    print("METHOD RECOMMENDATION ANALYSIS:")
    print("="*40)
    
    # Scoring criteria for method selection
    method_scores = {}
    
    for method, stat in stats.items():
        print(f"\n{method.upper()} EVALUATION:")
        
        # 1. Dynamic Range (higher is better for discrimination)
        range_score = (stat['max'] - stat['min']) / stat['max'] if stat['max'] > 0 else 0
        print(f"Dynamic Range: {range_score:.3f} (range: {stat['min']:.2f}-{stat['max']:.2f})")
        
        # 2. Coefficient of Variation (moderate is best, not too uniform, not too variable)
        cv = stat['std'] / stat['mean'] if stat['mean'] > 0 else 0
        cv_score = max(0, 1 - abs(cv - 0.5))  # Optimal CV around 0.5
        print(f"Coefficient of Variation: {cv:.3f} (score: {cv_score:.3f})")
        
        # 3. Non-zero values (methods that can detect parasitemia)
        non_zero_count = sum(1 for r in analysis_results['all_results'] 
                           if r[method.lower().replace('-', '_')]['score'] > 0)
        non_zero_ratio = non_zero_count / len(analysis_results['all_results'])
        print(f"Detection Rate: {non_zero_ratio:.3f} ({non_zero_count}/{len(analysis_results['all_results'])})")
        
        # 4. Clinical relevance
        clinical_score = 1.0 if 'count' in method.lower() else 0.8 if 'area' in method.lower() else 0.6
        print(f"Clinical Relevance: {clinical_score:.3f}")
        
        # 5. Calculate total score
        total_score = (range_score * 0.25 + cv_score * 0.25 + 
                      non_zero_ratio * 0.25 + clinical_score * 0.25)
        
        method_scores[method] = {
            'total_score': total_score,
            'range_score': range_score,
            'cv_score': cv_score,
            'detection_rate': non_zero_ratio,
            'clinical_score': clinical_score
        }
        
        print(f"TOTAL SCORE: {total_score:.3f}")
    
    # Find best method
    best_method_parasitemia = max(method_scores, key=lambda x: method_scores[x]['total_score'])
    best_score = method_scores[best_method_parasitemia]['total_score']
    
    print(f"\n{'='*40}")
    print("FINAL RECOMMENDATION:")
    print(f"{'='*40}")
    print(f"BEST METHOD: {best_method_parasitemia}")
    print(f"SCORE: {best_score:.3f}")
    
    # Methods interpretation
    if 'count' in best_method_parasitemia.lower():
        print("\nREASONING:")
        print("Count-based method is most clinically relevant, Directly represents percentage of infected cells, Standard approach in medical practice, Easy to interpret and validate")
    elif 'area' in best_method_parasitemia.lower():
        print("\nREASONING:")
        print("Area-based method accounts for infection severity, Larger parasites contribute more to score, Good for detecting advanced infections")
    else:
        print("\nREASONING:")
        print("Density-based method normalizes by image area, Good for comparing images with different cell densities, Spatial distribution aware")
    
    return best_method_parasitemia, method_scores

# Get recommendation
best_method_parasitemia, method_scores = recommend_best_parasitemia_method(analysis_results)

In [None]:
#best_method_preprocessing = 'Color Deconv'
#best_method_preprocessing = 'CLAHE'
#best_method_preprocessing = 'Resized'

## 2.6 Final Dataset Creation Pipeline

### 2.6.1 Final Dataset Creation with Chosen Methods
*Create the complete dataset using the selected preprocessing and scoring methods*

**Pipeline Steps:**
1. **Image Processing**: Apply chosen preprocessing method
2. **Annotation Resizing**: Scale annotations to target size (224×224)
3. **Parasitemia Calculation**: Apply chosen scoring method
4. **Infection Level Categorization**:
   - **Negative**: 0% parasitemia
   - **Low**: 0-2% parasitemia  
   - **Moderate**: 2-10% parasitemia
   - **High**: >10% parasitemia

In [None]:
# STEP 6: We will create final dataset with chosen method 
print("\n" + "="*60)
print("CREATING FINAL DATASET WITH CHOSEN METHOD")
print("="*60)

def create_final_dataset(best_method, all_results, best_preprocessing_method):
    
    print(f"Using {best_method} parasitemia scoring method")
    print(f"Using {best_preprocessing_method} preprocessing method")
    
    # Initialize preprocessor
    preprocessor = MalariaPreprocessor(target_size=TARGET_SIZE)
    
    # Dataset structure
    dataset_records = []
    processed_images_count = 0
    failed_processing = []
    
    # Create dir for resized annotations
    resized_annotations_dir = os.path.join(OUTPUT_DIR, 'resized_annotations')
    os.makedirs(resized_annotations_dir, exist_ok=True)
    
    print(f"\nProcessing {len(all_results)} images for final dataset...")
    
    for i, result in enumerate(tqdm(all_results, desc="Creating dataset")):
        try:
            img_path = result['image_path']
            
            # Load and preprocess image
            original_img = np.array(Image.open(img_path).convert('RGB'))
            original_size = original_img.shape
            
            # Apply preprocessing method
            resized_img = preprocessor.resize_image(original_img, maintain_aspect=False)
            
            if best_preprocessing_method == 'CLAHE':
                processed_img = preprocessor.clahe_normalization(resized_img.copy())
            elif best_preprocessing_method == 'Macenko':
                clahe_img = preprocessor.clahe_normalization(resized_img.copy())
                processed_img = preprocessor.macenko_normalization(clahe_img.copy())
            elif best_preprocessing_method == 'Reinhard':
                clahe_img = preprocessor.clahe_normalization(resized_img.copy())
                processed_img = preprocessor.reinhard_normalization(clahe_img.copy())
            elif best_preprocessing_method == 'Color Deconv':
                clahe_img = preprocessor.clahe_normalization(resized_img.copy())
                processed_img = preprocessor.enhanced_color_deconvolution(clahe_img.copy())
            else:
                processed_img = resized_img  
            
            # Resize and save annotations 
            resized_annotations = []
            resized_annotations_path = None
            
            if img_path in image_to_anns:
                original_annotations = image_to_anns[img_path]
                resized_annotations = preprocessor.resize_annotations_to_target_size(
                    original_annotations, original_size, TARGET_SIZE
                )
                
                # Save as JSON
                ann_filename = f"nlm_{i:04d}_resized_annotations.json"
                resized_annotations_path = os.path.join(resized_annotations_dir, ann_filename)
                
                with open(resized_annotations_path, 'w') as f:
                    json.dump(resized_annotations, f, indent=2, default=str)
            
            # Extract parasitemia score based on chosen method
            method_key = best_method.lower().replace('-', '_')
            parasitemia_score = result[method_key]['score']
            
            # Determine infection level categories
            if parasitemia_score == 0:
                infection_level = 'negative'
                infection_category = 0
            elif parasitemia_score <= 2:
                infection_level = 'low'
                infection_category = 1
            elif parasitemia_score <= 10:
                infection_level = 'moderate'
                infection_category = 2
            else:
                infection_level = 'high'
                infection_category = 3
            
            # Create record
            record = {
                'image_id': f"nlm_{i:04d}",
                'original_path': img_path,
                'image_name': os.path.basename(img_path),
                'processed_image': processed_img,
                'parasitemia_score': parasitemia_score,
                'infection_level': infection_level,
                'infection_category': infection_category,
                'cell_counts': {
                    'infected': result['summary']['infected_cells'],
                    'healthy': result['summary']['healthy_cells'],
                    'wbc': result['summary']['wbc_cells'],
                    'total_annotations': result['summary']['total_annotations']
                },
                'preprocessing_method': best_preprocessing_method,
                'scoring_method': best_method,
                'all_scores': {
                    'count_based': result['count_based']['score'],
                    'area_based': result['area_based']['score'],
                    'density_based': result['density_based']['score']
                },
                'resized_annotations': resized_annotations,
                'resized_annotations_path': resized_annotations_path
            }
            
            dataset_records.append(record)
            processed_images_count += 1
            
        except Exception as e:
            failed_processing.append((img_path, str(e)))
            continue
    
    print(f"\nDataset creation complete!")
    print(f"Successfully processed: {processed_images_count} images")
    print(f"Failed to process: {len(failed_processing)} images")
    print(f"Resized annotations saved to: {resized_annotations_dir}")
    
    if failed_processing:
        print("\nFirst 3 processing failures:")
        for img_path, error in failed_processing[:3]:
            print(f"  - {os.path.basename(img_path)}: {error}")
    
    return dataset_records


def create_resized_annotation_lookup(final_dataset):
    
    print("\n" + "="*50)
    print("CREATING RESIZED ANNOTATION LOOKUP")
    print("="*50)
    
    resized_image_to_anns = {}
    
    for record in final_dataset:
        if record.get('resized_annotations') and len(record['resized_annotations']) > 0:
            # Use processed image path or image_id as key
            processed_img_key = record['image_id']
            resized_image_to_anns[processed_img_key] = record['resized_annotations']
    
    print(f"Created resized annotation lookup for {len(resized_image_to_anns)} images")
    
    # Lookup for future use
    lookup_path = os.path.join(OUTPUT_DIR, 'resized_image_to_annotations_lookup.json')
    with open(lookup_path, 'w') as f:
        json.dump(resized_image_to_anns, f, indent=2, default=str)
    
    print(f"Saved resized annotation lookup: {lookup_path}")
    
    return resized_image_to_anns


def analyze_final_dataset(dataset_records):
    
    print("\n" + "="*50)
    print("FINAL DATASET ANALYSIS")
    print("="*50)
    
    # Extract data 
    parasitemia_scores = [r['parasitemia_score'] for r in dataset_records]
    infection_levels = [r['infection_level'] for r in dataset_records]
    infection_categories = [r['infection_category'] for r in dataset_records]
    
    # Distribution analysis
    level_counts = Counter(infection_levels)
    category_counts = Counter(infection_categories)
    
    print(f"DATASET SIZE: {len(dataset_records)} images")
    print(f"\nINFECTION LEVEL DISTRIBUTION:")
    for level, count in level_counts.items():
        percentage = (count / len(dataset_records)) * 100
        print(f"  {level:10}: {count:4d} images ({percentage:5.1f}%)")
    
    print(f"\nPARASITEMIA SCORE STATISTICS:")
    print(f"  Mean:     {np.mean(parasitemia_scores):.3f}")
    print(f"  Std:      {np.std(parasitemia_scores):.3f}")
    print(f"  Min:      {np.min(parasitemia_scores):.3f}")
    print(f"  Max:      {np.max(parasitemia_scores):.3f}")
    print(f"  Median:   {np.median(parasitemia_scores):.3f}")
    
    print(f"\nCELL COUNT STATISTICS:")
    total_infected = sum(r['cell_counts']['infected'] for r in dataset_records)
    total_healthy = sum(r['cell_counts']['healthy'] for r in dataset_records)
    total_wbc = sum(r['cell_counts']['wbc'] for r in dataset_records)
    total_annotations = sum(r['cell_counts']['total_annotations'] for r in dataset_records)
    
    print(f"  Total infected cells:  {total_infected:,}")
    print(f"  Total healthy cells:   {total_healthy:,}")
    print(f"  Total WBC cells:       {total_wbc:,}")
    print(f"  Total annotations:     {total_annotations:,}")
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Final Dataset Analysis', fontsize=16, fontweight='bold')
    
    # 1. Infection level distribution
    ax = axes[0, 0]
    levels = list(level_counts.keys())
    counts = list(level_counts.values())
    colors = ['green', 'yellow', 'orange', 'red'][:len(levels)]
    ax.pie(counts, labels=levels, autopct='%1.1f%%', colors=colors, startangle=90)
    ax.set_title('Infection Level Distribution')
    
    # 2. Parasitemia score histogram
    ax = axes[0, 1]
    ax.hist(parasitemia_scores, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    ax.set_title('Parasitemia Score Distribution')
    ax.set_xlabel('Parasitemia Score (%)')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)
    
    # 3. Score vs infection level boxplot
    ax = axes[1, 0]
    level_data = []
    level_labels = []
    for level in ['negative', 'low', 'moderate', 'high']:
        if level in level_counts:
            scores = [r['parasitemia_score'] for r in dataset_records 
                     if r['infection_level'] == level]
            if scores:
                level_data.append(scores)
                level_labels.append(level)
    
    box_plot = ax.boxplot(level_data, labels=level_labels, patch_artist=True)
    colors = ['lightgreen', 'lightyellow', 'orange', 'lightcoral']
    for patch, color in zip(box_plot['boxes'], colors[:len(box_plot['boxes'])]):
        patch.set_facecolor(color)
    ax.set_title('Parasitemia Scores by Infection Level')
    ax.set_ylabel('Parasitemia Score (%)')
    ax.grid(True, alpha=0.3)
    
    # 4. Cell count distribution
    ax = axes[1, 1]
    cell_types = ['Infected', 'Healthy', 'WBC']
    cell_counts = [total_infected, total_healthy, total_wbc]
    colors = ['red', 'blue', 'purple']
    bars = ax.bar(cell_types, cell_counts, color=colors, alpha=0.7)
    ax.set_title('Total Cell Counts by Type')
    ax.set_ylabel('Cell Count')
    ax.grid(True, alpha=0.3)
    for bar, count in zip(bars, cell_counts):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{count:,}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    return {
        'total_images': len(dataset_records),
        'level_distribution': level_counts,
        'score_stats': {
            'mean': np.mean(parasitemia_scores),
            'std': np.std(parasitemia_scores),
            'min': np.min(parasitemia_scores),
            'max': np.max(parasitemia_scores),
            'median': np.median(parasitemia_scores)
        },
        'cell_totals': {
            'infected': total_infected,
            'healthy': total_healthy,
            'wbc': total_wbc,
            'total': total_annotations
        }
    }


### 2.6.2 Resized Annotation Lookup Creation
*Create lookup tables for resized annotations to preserve spatial relationships*

### 2.6.3 Final Dataset Analysis and Visualization
*Comprehensive analysis of the created dataset*

**Analysis Components:**
- **Size Distribution**: Images per infection level
- **Parasitemia Statistics**: Mean, std, min, max, median
- **Cell Count Analysis**: Total infected, healthy, WBC counts
- **Visual Distribution**: Pie charts, histograms, box plots

In [None]:
def save_final_dataset(dataset_records, dataset_stats):
    
    print("\n" + "="*50)
    print("SAVING FINAL DATASET")
    print("="*50)
    
    # dirs
    images_dir = os.path.join(OUTPUT_DIR, 'processed_images')
    metadata_dir = os.path.join(OUTPUT_DIR, 'metadata')
    os.makedirs(images_dir, exist_ok=True)
    
    # Save processed images and create metadata
    metadata_records = []
    
    print("Saving processed images...")
    for i, record in enumerate(tqdm(dataset_records, desc="Saving images")):
        
        # Save processed image
        img_filename = f"{record['image_id']}.png"
        img_path = os.path.join(images_dir, img_filename)
        img_pil = Image.fromarray(record['processed_image'])
        img_pil.save(img_path, 'PNG')
        
        # Create metadata record (without the actual image data)
        metadata_record = {k: v for k, v in record.items() if k != 'processed_image'}
        metadata_record['processed_image_path'] = img_path
        metadata_records.append(metadata_record)
    
    # Save metadata as JSON
    metadata_path = os.path.join(metadata_dir, 'dataset_metadata.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata_records, f, indent=2, default=str)
    
    # Save dataset statistics
    stats_path = os.path.join(metadata_dir, 'dataset_statistics.json')
    with open(stats_path, 'w') as f:
        json.dump(dataset_stats, f, indent=2, default=str)
    
    # Create CSV for easy analysis
    csv_data = []
    for record in metadata_records:
        csv_row = {
            'image_id': record['image_id'],
            'image_name': record['image_name'],
            'parasitemia_score': record['parasitemia_score'],
            'infection_level': record['infection_level'],
            'infection_category': record['infection_category'],
            'infected_cells': record['cell_counts']['infected'],
            'healthy_cells': record['cell_counts']['healthy'],
            'wbc_cells': record['cell_counts']['wbc'],
            'total_annotations': record['cell_counts']['total_annotations'],
            'preprocessing_method': record['preprocessing_method'],
            'scoring_method': record['scoring_method']
        }
        csv_data.append(csv_row)
    
    df = pd.DataFrame(csv_data)
    csv_path = os.path.join(metadata_dir, 'dataset_summary.csv')
    df.to_csv(csv_path, index=False)
    
    print(f"\nDataset saved successfully!")
    print(f"  Images: {images_dir}")
    print(f"  Metadata: {metadata_path}")
    print(f"  Statistics: {stats_path}")
    print(f"  CSV Summary: {csv_path}")
    print(f"  Total files: {len(dataset_records)} images + metadata")
    
    return {
        'images_dir': images_dir,
        'metadata_path': metadata_path,
        'csv_path': csv_path,
        'stats_path': stats_path
    }

# Dataset creation
print("Starting final dataset creation pipeline...")

# Get the best methods
if 'best_method_parasitemia' not in locals():
    best_method_parasitemia = 'Count-based'  
if 'best_method_preprocessing' not in locals():
    best_method_preprocessing = 'CLAHE'  

print(f"Using parasitemia method: {best_method_parasitemia}")
print(f"Using preprocessing method: {best_method_preprocessing}")

### 2.6.4 Dataset Saving and Metadata Creation
*Save processed dataset with comprehensive metadata*

**Saved Components:**
- **Processed Images**: Enhanced microscopy images (PNG format)
- **Metadata JSON**: Complete dataset information
- **Statistics JSON**: Dataset statistics and distributions  
- **CSV Summary**: Tabular data for easy analysis

In [None]:
# Create final dataset
final_dataset = create_final_dataset(best_method_parasitemia, all_parasitemia_results, best_method_preprocessing)

# Analyze dataset
dataset_stats = analyze_final_dataset(final_dataset)

# Save everything
save_paths = save_final_dataset(final_dataset, dataset_stats)

print("\n" + "="*60)
print("DATASET CREATION COMPLETE!")
print("="*60)
print(f"Final dataset contains {len(final_dataset)} processed images")
print(f"Ready for MTTL training pipeline")
print(f"All files saved to: {os.path.abspath(OUTPUT_DIR)}")

## 2.7 Training Data Preparation (Multi-Task & Single-Task Compatible)

### 2.7.1 Training Pipeline Setup
*Set up the training data preparation pipeline with PyTorch integration*

### 2.7.2 Multi-Task Data Preparation
*Prepare data for all three MTTL tasks simultaneously*

**Task Preparation:**
1. **Detection Task**: Bounding box annotation processing
2. **Regression Task**: Parasitemia score normalization  
3. **Localization Task**: Infection-only heatmap generation

### 2.7.3 Train/Validation/Test Split Creation
*Create stratified splits maintaining infection level distribution*

**Split Strategy:**
- **Training**: 70% (stratified by infection level)
- **Validation**: 15% (stratified by infection level)
- **Testing**: 15% (stratified by infection level)

### 2.7.4 Task Mode Testing
*Verify all MTTL task modes work correctly with the dataset*

**Task Modes Tested:**
- **Detection Mode**: Object detection with bounding boxes
- **Regression Mode**: Parasitemia score prediction
- **Localization Mode**: Infection heatmap generation
- **Multi-Task Mode**: All tasks combined

### 2.7.5 Infection Localization Heatmap Visualization
*Visualize and analyze the infection-only heatmap generation*

**Heatmap Features:**
- **Infection-Only Focus**: Only infected regions highlighted (red colormap)
- **Spatial Accuracy**: Precise infection localization
- **Intensity Mapping**: Variable intensity based on infection density
- **Clinical Relevance**: Direct visual feedback for diagnosis

### 2.7.6 Heatmap Analysis by Infection Level
*Comprehensive analysis of heatmap quality across different parasitemia levels*

**Analysis Categories:**
- **Negative (0%)**: No infection heatmap (validation)
- **Low (0-2%)**: Sparse infection patterns
- **Moderate (2-10%)**: Moderate infection coverage
- **High (>10%)**: Dense infection patterns

### 2.7.7 Comprehensive Heatmap Statistical Analysis
*Quantitative analysis of infection heatmap characteristics*

**Statistical Metrics:**
- **Coverage Percentage**: Infection pixel ratio
- **Intensity Statistics**: Max, mean, standard deviation
- **Precision Analysis**: Pixels per infected cell ratio
- **Spatial Distribution**: Infection pattern analysis

In [None]:
# STEP 7: Prepare Data for Training (Multi-Task & Single-Task Compatible) ===
print("\n" + "="*70)
print("PREPARING DATA FOR TRAINING - MULTI-TASK & SINGLE-TASK COMPATIBLE")
print("="*70)

#import sys
#import os
#sys.path.append('..')
import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context('notebook')
from PIL import Image
import cv2
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
from tqdm.notebook import tqdm
import json
import pickle
from sklearn.model_selection import train_test_split
from scipy.ndimage import gaussian_filter
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random
import numpy as np
from tqdm.notebook import tqdm
from src.utils.DataUtils import MalariaDataPreparator, FlexibleMalariaDataset, set_seeds
SEED = 12
set_seeds(SEED)

In [None]:

def draw_annotations_on_image(ax, image, annotations):
    if image.dtype != np.uint8:
        if image.max() <= 1.0:
            image = (image * 255).astype(np.uint8)
        else:
            image = np.clip(image, 0, 255).astype(np.uint8)
    
    ax.imshow(image)
    ax.set_xlim(0, image.shape[1])
    ax.set_ylim(image.shape[0], 0)
    
    CELL_TYPE_COLORS = {
        'parasitized': 'red',
        'Parasitized': 'red',
        'uninfected': 'lime', 
        'Uninfected': 'lime',
        'white_blood_cell': 'blue',
        'White_Blood_Cell': 'blue'
    }
    
    CELL_TYPE_SHORT = {
        'parasitized': 'P',
        'Parasitized': 'P', 
        'uninfected': 'U',
        'Uninfected': 'U',
        'white_blood_cell': 'W',
        'White_Blood_Cell': 'W'
    }
    
    for ann in annotations:
        cell_type = ann.get('cell_type', 'unknown')
        color = CELL_TYPE_COLORS.get(cell_type, 'yellow')
        short_label = CELL_TYPE_SHORT.get(cell_type, '?')
        
        if ann.get('shape') == 'Point':
            x, y = ann.get('x', 0), ann.get('y', 0)
            x = max(0, min(x, image.shape[1]-1))
            y = max(0, min(y, image.shape[0]-1))
            
            ax.plot(x, y, 'o', color=color, markersize=6, markeredgewidth=1.5, markeredgecolor='black')
            
            text_x = min(x + 5, image.shape[1] - 15)
            text_y = max(y - 5, 15)
            
            ax.text(text_x, text_y, short_label, color=color, fontsize=8, weight='bold',
                   bbox=dict(facecolor='white', alpha=0.8, edgecolor=color, boxstyle='round,pad=0.2'))
        
        elif ann.get('shape') == 'Polygon' and 'polygon' in ann:
            coords = ann['polygon']
            clipped_coords = []
            for x, y in coords:
                x = max(0, min(x, image.shape[1]-1))
                y = max(0, min(y, image.shape[0]-1))
                clipped_coords.append((x, y))
            
            if clipped_coords:
                poly = np.array(clipped_coords)
                patch = patches.Polygon(poly, closed=True, fill=False, edgecolor=color, linewidth=1.5)
                ax.add_patch(patch)
                
                centroid_x = np.mean([p[0] for p in clipped_coords])
                centroid_y = np.mean([p[1] for p in clipped_coords])
                
                ax.text(centroid_x, centroid_y, short_label, color=color, fontsize=8, weight='bold',
                       bbox=dict(facecolor='white', alpha=0.8, edgecolor=color, boxstyle='round,pad=0.2'),
                       ha='center', va='center')
    
    ax.axis('off')

def create_train_val_test_split(training_samples, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, random_state=SEED):
    print(f"Creating train/val/test split...")
    print(f"Split ratios: train={train_ratio}, val={val_ratio}, test={test_ratio}")
    
    score_groups = defaultdict(list)
    for i, sample in enumerate(training_samples):
        score = sample['regression']['parasitemia_score']
        if score == 0:
            group = 'negative'
        elif score <= 2:
            group = 'low'
        elif score <= 10:
            group = 'moderate'
        else:
            group = 'high'
        score_groups[group].append(i)
    
    train_indices, val_indices, test_indices = [], [], []
    
    for group, indices in score_groups.items():
        n_samples = len(indices)
        n_train = int(n_samples * train_ratio)
        n_val = int(n_samples * val_ratio)
        n_test = n_samples - n_train - n_val
        
        np.random.seed(random_state)
        np.random.shuffle(indices)
        
        train_indices.extend(indices[:n_train])
        val_indices.extend(indices[n_train:n_train+n_val])
        test_indices.extend(indices[n_train+n_val:])
        
        print(f"  {group:10}: {n_samples:3d} total -> train: {n_train:3d}, val: {n_val:3d}, test: {n_test:3d}")
    
    train_samples = [training_samples[i] for i in train_indices]
    val_samples = [training_samples[i] for i in val_indices]
    test_samples = [training_samples[i] for i in test_indices]
    
    print(f"Final split sizes:")
    print(f"Train: {len(train_samples):3d} samples ({len(train_samples)/len(training_samples)*100:.1f}%)")
    print(f"Val:   {len(val_samples):3d} samples ({len(val_samples)/len(training_samples)*100:.1f}%)")
    print(f"Test:  {len(test_samples):3d} samples ({len(test_samples)/len(training_samples)*100:.1f}%)")
    
    return train_samples, val_samples, test_samples

def test_task_modes(train_samples):
    print("TESTING MTTL TASK MODES")
    print("="*60)
    
    task_modes = ['detection', 'regression', 'localization', 'multi_task', 'segmentation', 'severity']
    
    for task_mode in task_modes:
        print(f"\nTesting {task_mode.upper()} mode...")
        
        try:
            dataset = FlexibleMalariaDataset(train_samples[:5], task_mode=task_mode, augment=False)
            
            if task_mode in ['detection', 'multi_task']:
                def collate_fn(batch):
                    return batch
                dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)
            else:
                dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0)
            
            sample_batch = next(iter(dataloader))
            
            print(f"Success: {task_mode} mode working!")
            
            # Samples inspection
            if isinstance(sample_batch, list):
                print(f"Batch type: List of {len(sample_batch)} samples")
                sample = sample_batch[0]  
                print(f"Sample keys: {list(sample.keys())}")
                
                # shapes and types
                for key, value in sample.items():
                    if isinstance(value, torch.Tensor):
                        print(f"{key:15}: Tensor {tuple(value.shape)} | dtype: {value.dtype} | range: [{value.min():.3f}, {value.max():.3f}]")
                    elif isinstance(value, (int, float)):
                        print(f"{key:15}: {type(value).__name__} | value: {value}")
                    else:
                        print(f"{key:15}: {type(value).__name__} | value: {value}")
                        
            else:
                print(f"Batch type: Single tensor batch")
                print(f"Batch keys: {list(sample_batch.keys())}")
                
                # Show detailed shapes and types
                for key, value in sample_batch.items():
                    if isinstance(value, torch.Tensor):
                        if len(value.shape) > 1:  
                            print(f"{key:15}: Tensor {tuple(value.shape)} | dtype: {value.dtype} | range: [{value.min():.3f}, {value.max():.3f}]")
                        else:  
                            print(f"{key:15}: Tensor {tuple(value.shape)} | dtype: {value.dtype} | values: {value.tolist()}")
                    elif isinstance(value, list):
                        print(f"{key:15}: List of {len(value)} items | sample: {value[:3] if len(value) > 3 else value}")
                    else:
                        print(f"{key:15}: {type(value).__name__} | value: {value}")
            
            # Task-specific info
            if task_mode == 'detection':
                sample = sample_batch[0] if isinstance(sample_batch, list) else sample_batch
                num_boxes = len(sample['bboxes']) if len(sample['bboxes'].shape) > 1 else 1 if sample['bboxes'].numel() > 0 else 0
                print(f"Detection info: {num_boxes} bounding boxes detected")
                
            elif task_mode == 'multi_task':
                sample = sample_batch[0] if isinstance(sample_batch, list) else sample_batch
                num_boxes = len(sample['bboxes']) if len(sample['bboxes'].shape) > 1 else 1 if sample['bboxes'].numel() > 0 else 0
                heatmap_coverage = (sample['heatmap'] > 0).sum().item() / sample['heatmap'].numel() * 100
                print(f"Multi-task info: {num_boxes} boxes, {heatmap_coverage:.1f}% heatmap coverage")
                
            elif task_mode == 'localization':
                heatmap = sample_batch['heatmap'][0] if len(sample_batch['heatmap'].shape) > 2 else sample_batch['heatmap']
                coverage = (heatmap > 0).sum().item() / heatmap.numel() * 100
                print(f"Localization info: {coverage:.1f}% infection coverage")
                
            elif task_mode == 'regression':
                scores = sample_batch['parasitemia_score'] if len(sample_batch['parasitemia_score'].shape) > 0 else [sample_batch['parasitemia_score']]
                print(f"Regression info: parasitemia scores range [{scores.min():.2f}, {scores.max():.2f}]")
                
            elif task_mode == 'segmentation':
                mask = sample_batch['mask'] if isinstance(sample_batch, dict) else sample_batch[0]['mask']
                print(f"Segmentation info: mask shape {mask.shape}, unique values: {np.unique(mask)}")
                
            elif task_mode == 'severity':
                severity_class = sample_batch['severity_class'] if isinstance(sample_batch, dict) else sample_batch[0]['severity_class']
                severity_label = sample_batch['severity_label'] if isinstance(sample_batch, dict) else sample_batch[0]['severity_label']
                print(f"Severity info: class={severity_class}, label={severity_label}")
            
        except Exception as e:
            print(f"Error: {task_mode} mode failed: {str(e)}")
            import traceback
            print(f"Full error: {traceback.format_exc()}")
            
def visualize_heatmaps(samples, num_samples=4):
    print("VISUALIZING...")
    print("="*60)
    
    fig, axes = plt.subplots(4, num_samples, figsize=(num_samples*5, 18), gridspec_kw={'hspace': 0.3, 'wspace': 0.15})
    fig.suptitle('Infection-Only Localization Heatmaps', fontsize=18, fontweight='bold', y=0.98)
    
    selected_samples = np.random.choice(len(samples), min(num_samples, len(samples)), replace=False)
    severity_labels = []
    
    for i, idx in enumerate(selected_samples):
        sample = samples[idx]
        image = sample['image'].copy()
        if image.dtype != np.uint8:
            if image.max() <= 1.0:
                image = (image * 255).astype(np.uint8)
            else:
                image = np.clip(image, 0, 255).astype(np.uint8)
        
        record_id = sample['image_id']
        resized_annotations = []
        for record in final_dataset:
            if record['image_id'] == record_id:
                resized_annotations = record.get('resized_annotations', [])
                break
        
        # Row 0: Original + Annotations
        ax = axes[0, i]
        draw_annotations_on_image(ax, image, resized_annotations)
        infected_count = sample['cell_counts']['infected']
        healthy_count = sample['cell_counts']['healthy']
        ax.set_title(f"Original + Annotations\nInfected: {infected_count}, Healthy: {healthy_count}", fontsize=12, fontweight='bold')
        
        # Row 1: Infection Heatmap
        ax = axes[1, i]
        heatmap = sample['localization']['heatmap']
        im = ax.imshow(heatmap, cmap='Reds', interpolation='bilinear', aspect='equal', vmin=0, vmax=1)
        ax.set_title(f"Infection Heatmap\nRange: {heatmap.min():.3f} - {heatmap.max():.3f}", fontsize=12, fontweight='bold')
        ax.axis('off')
        # Only add colorbar to the last heatmap
        if i == num_samples - 1:
            cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, shrink=0.8)
            cbar.set_label('Infection Intensity', rotation=270, labelpad=20, fontsize=11)
            tick_positions = [0, 0.3, 0.5, 0.7, 1.0]
            tick_labels = ['None', 'Low', 'Medium', 'High', 'Critical']
            cbar.set_ticks(tick_positions)
            cbar.set_ticklabels(tick_labels, fontsize=9)
        
        # Row 2: Combined View
        ax = axes[2, i]
        ax.imshow(image, alpha=0.7)
        ax.imshow(heatmap, cmap='Reds', alpha=0.6, interpolation='bilinear')
        parasitemia = sample['regression']['parasitemia_score']
        ax.set_title(f"Combined View\nParasitemia: {parasitemia:.1f}%", fontsize=12, fontweight='bold')
        ax.axis('off')
        
        # Row 3: Binary Segmentation Mask
        ax = axes[3, i]
        mask = sample['segmentation']['mask'] if 'segmentation' in sample else sample['multi_task']['mask']
        ax.imshow(mask, cmap='gray', interpolation='nearest', vmin=0, vmax=1)
        ax.set_title("Binary Segmentation Mask\n(1=cell, 0=background)", fontsize=12, fontweight='bold')
        ax.axis('off')
        
        # Collect severity for bar plot
        severity = sample['severity']['severity_label'] if 'severity' in sample else sample['multi_task']['severity_label']
        severity_labels.append(severity)
    
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    # Add severity bar plot below
    plt.figure(figsize=(num_samples*5, 2.5))
    severity_map = {'negative': 0, 'low': 1, 'moderate': 2, 'high': 3}
    severity_numeric = [severity_map.get(s, -1) for s in severity_labels]
    severity_colors = ['green', 'yellow', 'orange', 'red']
    plt.bar(range(num_samples), severity_numeric, color=[severity_colors[s] if s >= 0 else 'gray' for s in severity_numeric])
    plt.xticks(range(num_samples), [f"Sample {i+1}" for i in range(num_samples)], fontsize=12)
    plt.yticks([0,1,2,3], ['Negative', 'Low', 'Moderate', 'High'], fontsize=12)
    plt.title('Severity Level per Sample', fontsize=15, fontweight='bold')
    plt.xlabel('Sample')
    plt.ylabel('Severity')
    plt.tight_layout()
    plt.show()
    
def plot_heatmaps_by_infection_level(training_samples, samples_per_level=3):
    print("INFECTION-ONLY HEATMAP BY PARASITEMIA LEVELS")
    print("="*70)
    
    infection_categories = {
        'Negative (0%)': [],
        'Low (0-2%)': [],
        'Moderate (2-10%)': [],
        'High (>10%)': []
    }
    
    for sample in training_samples:
        score = sample['regression']['parasitemia_score']
        if score == 0:
            infection_categories['Negative (0%)'].append(sample)
        elif score <= 2:
            infection_categories['Low (0-2%)'].append(sample)
        elif score <= 10:
            infection_categories['Moderate (2-10%)'].append(sample)
        else:
            infection_categories['High (>10%)'].append(sample)
    
    print("INFECTION LEVEL DISTRIBUTION:")
    for level, samples in infection_categories.items():
        print(f"  {level:15}: {len(samples):3d} samples")
    
    fig, axes = plt.subplots(4, samples_per_level, figsize=(samples_per_level*5, 20))
    fig.suptitle('Infection-Only Localization Heatmaps by Parasitemia Level', 
                 fontsize=16, fontweight='bold', y=0.98)
    
    row_titles = ['Negative (0%)', 'Low (0-2%)', 'Moderate (2-10%)', 'High (>10%)']
    
    for row, (level, samples) in enumerate(infection_categories.items()):
        if len(samples) == 0:
            for col in range(samples_per_level):
                axes[row, col].text(0.5, 0.5, f'No {level} samples', 
                                  transform=axes[row, col].transAxes,
                                  ha='center', va='center', fontsize=12)
                axes[row, col].axis('off')
            continue
        
        selected_indices = np.random.choice(len(samples), 
                                          min(samples_per_level, len(samples)), 
                                          replace=False)
        
        for col, idx in enumerate(selected_indices):
            sample = samples[idx]
            ax = axes[row, col]
            
            image = sample['image'].copy()
            if image.dtype != np.uint8:
                if image.max() <= 1.0:
                    image = (image * 255).astype(np.uint8)
                else:
                    image = np.clip(image, 0, 255).astype(np.uint8)
            
            heatmap = sample['localization']['heatmap']
            parasitemia = sample['regression']['parasitemia_score']
            infected_count = sample['cell_counts']['infected']
            healthy_count = sample['cell_counts']['healthy']
            
            # Show original image
            ax.imshow(image, alpha=0.6)
            
            # Only show heatmap where there are infected regions
            if np.any(heatmap > 0):
                im = ax.imshow(heatmap, cmap='Reds', alpha=0.8, 
                             interpolation='bilinear', vmin=0, vmax=1)
            
            if col == samples_per_level - 1:
                if np.any(heatmap > 0):  #add colorbar if there are infections
                    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, shrink=0.8)
                    cbar.set_label('Infection', rotation=270, 
                                 labelpad=15, fontsize=10)
                    
                    tick_positions = [0, 0.5, 1.0]
                    tick_labels = ['None', 'Medium', 'High']
                    cbar.set_ticks(tick_positions)
                    cbar.set_ticklabels(tick_labels, fontsize=8)
            
            coverage = np.sum(heatmap > 0) / heatmap.size * 100
            ax.set_title(f'Score: {parasitemia:.1f}%\nInfected: {infected_count}\nCoverage: {coverage:.1f}%', 
                        fontsize=9, pad=8)
            ax.axis('off')
        
        for col in range(len(selected_indices), samples_per_level):
            axes[row, col].axis('off')
        
        axes[row, 0].text(-0.15, 0.5, row_titles[row], 
                         transform=axes[row, 0].transAxes,
                         rotation=90, ha='center', va='center',
                         fontsize=14, weight='bold')
    
    plt.tight_layout()
    plt.subplots_adjust(left=0.1, top=0.95)
    plt.show()

def create_comprehensive_analysis(training_samples):
    print("COMPREHENSIVE INFECTION-ONLY HEATMAP ANALYSIS")
    print("="*70)
    
    analysis_data = []
    
    for sample in training_samples:
        score = sample['regression']['parasitemia_score']
        heatmap = sample['localization']['heatmap']
        
        if score == 0:
            level = 'Negative'
        elif score <= 2:
            level = 'Low'
        elif score <= 10:
            level = 'Moderate'
        else:
            level = 'High'
        
        # Only analyze infection pixels
        infection_pixels = np.sum(heatmap > 0)
        background_pixels = np.sum(heatmap == 0)
        precision = infection_pixels / max(sample['cell_counts']['infected'], 1)
        
        analysis_data.append({
            'level': level,
            'parasitemia_score': score,
            'max_intensity': heatmap.max(),
            'mean_intensity': heatmap.mean(),
            'std_intensity': heatmap.std(),
            'infection_pixels': infection_pixels,
            'background_pixels': background_pixels,
            'coverage_percent': infection_pixels / heatmap.size * 100,
            'precision_pixels_per_cell': precision,
            'infected_count': sample['cell_counts']['infected'],
            'healthy_count': sample['cell_counts']['healthy']
        })
    
    df = pd.DataFrame(analysis_data)
    
    summary = df.groupby('level').agg({
        'parasitemia_score': ['mean', 'std', 'count'],
        'max_intensity': ['mean', 'std'],
        'mean_intensity': ['mean', 'std'],
        'coverage_percent': ['mean', 'std'],
        'precision_pixels_per_cell': ['mean', 'std'],
        'infected_count': ['mean', 'std'],
        'healthy_count': ['mean', 'std']
    }).round(3)
    
    print("STATISTICAL SUMMARY BY INFECTION LEVEL:")
    print("=" * 80)
    print(summary)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Infection-Only Heatmap Statistics by Infection Level', fontsize=16, fontweight='bold')
    
    ax = axes[0, 0]
    df.boxplot(column='max_intensity', by='level', ax=ax)
    ax.set_title('Maximum Infection Intensity')
    ax.set_xlabel('Infection Level')
    ax.set_ylabel('Max Intensity')
    
    ax = axes[0, 1]
    df.boxplot(column='precision_pixels_per_cell', by='level', ax=ax)
    ax.set_title('Precision (Pixels per Infected Cell)')
    ax.set_xlabel('Infection Level')
    ax.set_ylabel('Precision (px/cell)')
    
    ax = axes[0, 2]
    df.boxplot(column='coverage_percent', by='level', ax=ax)
    ax.set_title('Infection Coverage Percentage')
    ax.set_xlabel('Infection Level')
    ax.set_ylabel('Coverage (%)')
    
    ax = axes[1, 0]
    level_means = df.groupby('level')[['infection_pixels', 'background_pixels']].mean()
    level_means.plot(kind='bar', ax=ax, stacked=True, color=["#F32323", "#8C17F2"])
    ax.set_title('Average Pixel Distribution')
    ax.set_xlabel('Infection Level')
    ax.set_ylabel('Average Pixel Count')
    ax.legend(['Infection (Red)', 'Background (Black)'])
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
    
    ax = axes[1, 1]
    colors = {'Negative': 'green', 'Low': 'yellow', 'Moderate': 'orange', 'High': 'red'}
    for level in df['level'].unique():
        level_data = df[df['level'] == level]
        ax.scatter(level_data['parasitemia_score'], level_data['precision_pixels_per_cell'], 
                  c=colors.get(level, 'blue'), label=level, alpha=0.7)
    
    ax.set_xlabel('Parasitemia Score (%)')
    ax.set_ylabel('Precision (px/cell)')
    ax.set_title('Parasitemia vs Heatmap Precision')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    ax = axes[1, 2]
    for level in df['level'].unique():
        level_data = df[df['level'] == level]
        ax.scatter(level_data['healthy_count'], level_data['infected_count'], 
                  c=colors.get(level, 'blue'), label=level, alpha=0.7)
    
    ax.set_xlabel('Healthy Cell Count')
    ax.set_ylabel('Infected Cell Count')
    ax.set_title('Healthy vs Infected Cell Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return df

### 2.7.8 Final Training Dataset Creation and Saving
*Save the complete training-ready dataset with all task components*

**Dataset Structure:**
```
mttl_training_data/
├── train/
│   ├── images/          # Processed images (PNG)
│   ├── heatmaps/        # Infection heatmaps (NPY)
│   ├── train_metadata.json
│   └── train_samples.pkl
├── val/
│   ├── images/
│   ├── heatmaps/
│   ├── val_metadata.json
│   └── val_samples.pkl
├── test/
│   ├── images/
│   ├── heatmaps/
│   ├── test_metadata.json
│   └── test_samples.pkl
└── dataset_info.json   # Complete dataset information
```

In [None]:

def save_training_dataset(train_samples, val_samples, test_samples, save_dir):
    print(f"Saving training dataset to {save_dir}...")
    
    train_dir = os.path.join(save_dir, 'train')
    val_dir = os.path.join(save_dir, 'val')
    test_dir = os.path.join(save_dir, 'test')
    
    for dir_path in [train_dir, val_dir, test_dir]:
        os.makedirs(dir_path, exist_ok=True)
        os.makedirs(os.path.join(dir_path, 'images'), exist_ok=True)
        os.makedirs(os.path.join(dir_path, 'heatmaps'), exist_ok=True)
        os.makedirs(os.path.join(dir_path, 'masks'), exist_ok=True)

    def save_split(samples, split_dir, split_name):
        print(f"  Saving {split_name} split ({len(samples)} samples)...")
        
        for sample in tqdm(samples, desc=f"Saving {split_name}"):
            img_path = os.path.join(split_dir, 'images', f"{sample['image_id']}.png")
            img_pil = Image.fromarray(sample['image'])
            img_pil.save(img_path)
            
            heatmap_path = os.path.join(split_dir, 'heatmaps', f"{sample['image_id']}.npy")
            np.save(heatmap_path, sample['localization']['heatmap'])
            
            mask_path = os.path.join(split_dir, 'masks', f"{sample['image_id']}.npy")
            np.save(mask_path, sample['segmentation']['mask'])
        
        metadata = []
        for sample in samples:
            meta = {
                'image_id': sample['image_id'],
                'detection': {
                    'bboxes': sample['detection']['bboxes'].tolist() if len(sample['detection']['bboxes']) > 0 else [],
                    'labels': sample['detection']['labels'].tolist() if len(sample['detection']['labels']) > 0 else [],
                    'num_objects': sample['detection']['num_objects']
                },
                'regression': {
                    'parasitemia_score': float(sample['regression']['parasitemia_score'])
                },
                'localization': {
                    'heatmap_path': f"heatmaps/{sample['image_id']}.npy",
                    'max_intensity': float(sample['localization']['heatmap'].max())
                },
                'segmentation': {
                    'mask_path': f"masks/{sample['image_id']}.npy",
                    'mask_shape': list(sample['segmentation']['mask'].shape)
                },
                'severity': {
                    'severity_class': sample['severity']['severity_class'],
                    'severity_label': sample['severity']['severity_label']
                },
                'multi_task': {
                    'parasitemia_score': float(sample['multi_task']['parasitemia_score']),
                    'bboxes': sample['multi_task']['bboxes'].tolist() if len(sample['multi_task']['bboxes']) > 0 else [],
                    'bbox_labels': sample['multi_task']['bbox_labels'].tolist() if len(sample['multi_task']['bbox_labels']) > 0 else [],
                    'mask_path': f"masks/{sample['image_id']}.npy",  
                    'mask_shape': list(sample['segmentation']['mask'].shape),  
                    'severity_label': sample['severity']['severity_label']  
                },
                'cell_counts': sample['cell_counts'],
                'metadata': sample['metadata']
            }
            metadata.append(meta)
        
        with open(os.path.join(split_dir, f'{split_name}_metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=2, default=str)
        
        with open(os.path.join(split_dir, f'{split_name}_samples.pkl'), 'wb') as f:
            pickle.dump(samples, f)
    
    save_split(train_samples, train_dir, 'train')
    save_split(val_samples, val_dir, 'val')
    save_split(test_samples, test_dir, 'test')
    
    # Get actual info
    all_samples = train_samples + val_samples + test_samples
    
    # Get actual image size from first sample
    actual_image_size = list(all_samples[0]['image'].shape)
    
    # Get actual heatmap size from first sample  
    actual_heatmap_size = list(all_samples[0]['localization']['heatmap'].shape)
    
    # Get actual class info from detection labels
    all_labels = []
    for sample in all_samples:
        labels = sample['detection']['labels']
        all_labels.extend(labels.tolist() if hasattr(labels, 'tolist') else labels)
    
    unique_labels = sorted(set(all_labels))
    actual_num_classes = len(unique_labels)
    
    # Create actual class names mapping
    label_to_name = {0: 'Infected', 1: 'Healthy', 2: 'WBC', 3: 'Unknown'}
    actual_class_names = [label_to_name.get(label, f'Class_{label}') for label in unique_labels]
    
    # Get actual parasitemia score range
    all_scores = [sample['regression']['parasitemia_score'] for sample in all_samples]
    
    # Count actual objects per class
    class_counts = {label: 0 for label in unique_labels}
    for sample in all_samples:
        labels = sample['detection']['labels']
        for label in (labels.tolist() if hasattr(labels, 'tolist') else labels):
            class_counts[label] += 1
    
    # Create class distribution with actual names
    class_distribution = {actual_class_names[i]: class_counts.get(unique_labels[i], 0) 
                         for i in range(len(unique_labels))}
    
    dataset_info = {
        'total_samples': len(all_samples),
        'splits': {
            'train_samples': len(train_samples),
            'val_samples': len(val_samples),
            'test_samples': len(test_samples)
        },
        'image_size': actual_image_size,                    
        'heatmap_size': actual_heatmap_size,               
        'num_classes': actual_num_classes,                 
        'class_names': actual_class_names,                 
        'class_labels': {name: unique_labels[i] for i, name in enumerate(actual_class_names)},  
        'class_distribution': class_distribution,           
        'parasitemia_range': [min(all_scores), max(all_scores)],  
        'mask_shapes': [list(sample['segmentation']['mask'].shape) for sample in all_samples],  
        'mask_paths': [f"masks/{sample['image_id']}.npy" for sample in all_samples],            
        'severity_classes': sorted(set(sample['severity']['severity_class'] for sample in all_samples)), 
        'severity_labels': sorted(set(sample['severity']['severity_label'] for sample in all_samples)),  
        'heatmap_type': 'infection_only',
        'supported_tasks': {
            'detection': 'Object detection with bounding boxes',
            'regression': 'Parasitemia score prediction',
            'localization': 'Infection-only localization heatmaps',
            'segmentation': 'Pixel-wise cell segmentation',
            'severity': 'Infection severity classification',
            'multi_task': 'All tasks combined'
        },
        'created_at': pd.Timestamp.now().isoformat()
    }
    
    with open(os.path.join(save_dir, 'dataset_info.json'), 'w') as f:
        json.dump(dataset_info, f, indent=2, default=str)
    
    print(f"Dataset saved successfully!")
    print(f"  Location: {os.path.abspath(save_dir)}")
    print(f"  Train: {len(train_samples)} | Val: {len(val_samples)} | Test: {len(test_samples)}")
    print(f"  Actual Classes: {actual_num_classes} - {actual_class_names}")
    print(f"  Class Distribution: {class_distribution}")
    
    return dataset_info

def execute_complete_pipeline():
    print("RUNNING FULL PIPELINE FOR MTTL DATASET PREPARATION...")
    print("="*60)
    
    data_preparator = MalariaDataPreparator(image_size=TARGET_SIZE)
    
    print("Step 1: Preparing training data...")
    training_samples = data_preparator.prepare_training_data(final_dataset)
    
    print("Step 2: Creating train/val/test splits...")
    train_samples, val_samples, test_samples = create_train_val_test_split(training_samples)
    
    print("Step 3: Testing task modes...")
    test_task_modes(train_samples)
    
    print("Step 4: Visualizing infection-only heatmaps...")
    visualize_heatmaps(training_samples, num_samples=4)
    
    print("Step 5: Plotting heatmaps by infection level...")
    plot_heatmaps_by_infection_level(training_samples, samples_per_level=3)
    
    print("Step 6: Creating comprehensive analysis...")
    analysis_df = create_comprehensive_analysis(training_samples)
    
    print("Step 7: Saving dataset...")
    training_data_dir = os.path.join(OUTPUT_DIR, 'mttl_training_data')
    dataset_info = save_training_dataset(train_samples, val_samples, test_samples, training_data_dir)
    
    print("PIPELINE COMPLETE!")
    print(f"Training data ready for MTTL at: {training_data_dir}")
    
    return {
        'training_samples': training_samples,
        'train_samples': train_samples,
        'val_samples': val_samples,
        'test_samples': test_samples,
        'dataset_info': dataset_info,
        'analysis_df': analysis_df
    }

if 'final_dataset' in locals():
    results = execute_complete_pipeline()
else:
    print("Error: final_dataset not found. Run preprocessing first.")

### 2.7.9 Dataset Information Extraction and Validation
*Extract and validate actual dataset characteristics (no hard-coding)*

**Validated Information:**
- **Image Dimensions**: Dynamic extraction from data
- **Class Count**: Based on real detection labels
- **Class Names**: Proper mapping (Infected, Healthy, WBC)
- **Split Sizes**: Real train/val/test counts
- **Parasitemia Range**: Min/max from calculated scores
- **Class Distribution**: Real object counts per class

In [None]:
def get_actual_dataset_info(results):
    
    if not results or 'train_samples' not in results:
        print("No results available!")
        return None
    
    train_samples = results['train_samples']
    val_samples = results['val_samples']
    test_samples = results['test_samples']
    all_samples = train_samples + val_samples + test_samples
    
    # Get actual image size from first sample
    sample_image = train_samples[0]['image']
    actual_image_size = list(sample_image.shape)
    
    # Get actual class names from detection labels
    all_labels = []
    for sample in train_samples + val_samples + test_samples:
        labels = sample['detection']['labels']
        all_labels.extend(labels.tolist() if hasattr(labels, 'tolist') else labels)
    
    unique_labels = sorted(set(all_labels))
    actual_num_classes = len(unique_labels)
    
    # Map labels to actual class names from your cell type mapping
    cell_type_mapping = {
        'parasitized': 0, 'Parasitized': 0,    # Infected
        'uninfected': 1, 'Uninfected': 1,     # Healthy  
        'white_blood_cell': 2, 'White_Blood_Cell': 2  # WBC
    }
    
    # Reverse mapping to get class names
    label_to_name = {0: 'Infected', 1: 'Healthy', 2: 'WBC'}
    actual_class_names = [label_to_name.get(label, f'Class_{label}') for label in unique_labels]
    
    # Get actual heatmap type
    sample_heatmap = train_samples[0]['localization']['heatmap']
    actual_heatmap_size = list(sample_heatmap.shape)
    
    # Get actual parasitemia score range
    all_scores = [sample['regression']['parasitemia_score'] for sample in train_samples + val_samples + test_samples]
    
    # Comprehensive heatmap analysis
    print("Analyzing heatmap characteristics across all samples...")
    
    heatmap_stats = {
        'min_values': [],
        'max_values': [],
        'mean_values': [],
        'std_values': [],
        'coverage_percentages': [],
        'non_zero_pixels': [],
        'intensity_ranges': []
    }
    
    # Analyze by infection level with enhanced metrics
    infection_level_heatmaps = {
        'negative': [],
        'low': [],
        'moderate': [],
        'high': []
    }
    
    for sample in train_samples + val_samples + test_samples:
        heatmap = sample['localization']['heatmap']
        parasitemia = sample['regression']['parasitemia_score']
        
        # Overall statistics
        heatmap_stats['min_values'].append(float(heatmap.min()))
        heatmap_stats['max_values'].append(float(heatmap.max()))
        heatmap_stats['mean_values'].append(float(heatmap.mean()))
        heatmap_stats['std_values'].append(float(heatmap.std()))
        
        # Coverage analysis
        non_zero_pixels = np.sum(heatmap > 0.01)  # Threshold for actual infection
        total_pixels = heatmap.size
        coverage = (non_zero_pixels / total_pixels) * 100
        
        heatmap_stats['coverage_percentages'].append(coverage)
        heatmap_stats['non_zero_pixels'].append(int(non_zero_pixels))
        heatmap_stats['intensity_ranges'].append(float(heatmap.max() - heatmap.min()))
        
        # Categorize by infection level with enhanced metrics
        heatmap_data = {
            'min_intensity': float(heatmap.min()),
            'max_intensity': float(heatmap.max()),
            'mean_intensity': float(heatmap.mean()),
            'std_intensity': float(heatmap.std()),
            'coverage': coverage,
            'non_zero_pixels': int(non_zero_pixels),
            'intensity_range': float(heatmap.max() - heatmap.min())
        }
        
        if parasitemia == 0:
            infection_level_heatmaps['negative'].append(heatmap_data)
        elif parasitemia <= 2:
            infection_level_heatmaps['low'].append(heatmap_data)
        elif parasitemia <= 10:
            infection_level_heatmaps['moderate'].append(heatmap_data)
        else:
            infection_level_heatmaps['high'].append(heatmap_data)
    
    # Calculate comprehensive summary statistics
    heatmap_summary = {
        'overall': {
            'min_range': [min(heatmap_stats['min_values']), max(heatmap_stats['min_values'])],
            'max_range': [min(heatmap_stats['max_values']), max(heatmap_stats['max_values'])],
            'avg_min': np.mean(heatmap_stats['min_values']),
            'avg_max': np.mean(heatmap_stats['max_values']),
            'avg_mean': np.mean(heatmap_stats['mean_values']),
            'avg_std': np.mean(heatmap_stats['std_values']),
            'avg_coverage': np.mean(heatmap_stats['coverage_percentages']),
            'avg_non_zero_pixels': np.mean(heatmap_stats['non_zero_pixels']),
            'intensity_range': [min(heatmap_stats['intensity_ranges']), max(heatmap_stats['intensity_ranges'])],
            'avg_intensity_range': np.mean(heatmap_stats['intensity_ranges'])
        },
        'by_infection_level': {}
    }
    
    # Calculate detailed stats by infection level
    for level, heatmaps in infection_level_heatmaps.items():
        if heatmaps:
            min_intensities = [h['min_intensity'] for h in heatmaps]
            max_intensities = [h['max_intensity'] for h in heatmaps]
            mean_intensities = [h['mean_intensity'] for h in heatmaps]
            std_intensities = [h['std_intensity'] for h in heatmaps]
            coverages = [h['coverage'] for h in heatmaps]
            non_zero_pixels = [h['non_zero_pixels'] for h in heatmaps]
            intensity_ranges = [h['intensity_range'] for h in heatmaps]
            
            heatmap_summary['by_infection_level'][level] = {
                'count': len(heatmaps),
                'avg_min_intensity': np.mean(min_intensities),
                'min_intensity_range': [min(min_intensities), max(min_intensities)],
                'avg_max_intensity': np.mean(max_intensities),
                'max_intensity_range': [min(max_intensities), max(max_intensities)],
                'avg_mean_intensity': np.mean(mean_intensities),
                'mean_intensity_range': [min(mean_intensities), max(mean_intensities)],
                'avg_std_intensity': np.mean(std_intensities),
                'std_intensity_range': [min(std_intensities), max(std_intensities)],
                'avg_coverage': np.mean(coverages),
                'coverage_range': [min(coverages), max(coverages)],
                'avg_non_zero_pixels': np.mean(non_zero_pixels),
                'non_zero_pixels_range': [min(non_zero_pixels), max(non_zero_pixels)],
                'avg_intensity_range': np.mean(intensity_ranges),
                'intensity_range_span': [min(intensity_ranges), max(intensity_ranges)]
            }
        else:
            heatmap_summary['by_infection_level'][level] = {
                'count': 0,
                'avg_min_intensity': 0.0,
                'min_intensity_range': [0.0, 0.0],
                'avg_max_intensity': 0.0,
                'max_intensity_range': [0.0, 0.0],
                'avg_mean_intensity': 0.0,
                'mean_intensity_range': [0.0, 0.0],
                'avg_std_intensity': 0.0,
                'std_intensity_range': [0.0, 0.0],
                'avg_coverage': 0.0,
                'coverage_range': [0.0, 0.0],
                'avg_non_zero_pixels': 0.0,
                'non_zero_pixels_range': [0.0, 0.0],
                'avg_intensity_range': 0.0,
                'intensity_range_span': [0.0, 0.0]
            }
    
    # Get actual split sizes
    actual_splits = {
        'train_samples': len(train_samples),
        'val_samples': len(val_samples), 
        'test_samples': len(test_samples),
        'total_samples': len(train_samples) + len(val_samples) + len(test_samples)
    }
    
    # Count actual objects per class
    class_counts = {label: 0 for label in unique_labels}
    for sample in train_samples + val_samples + test_samples:
        labels = sample['detection']['labels']
        for label in (labels.tolist() if hasattr(labels, 'tolist') else labels):
            class_counts[label] += 1
            
    # Mask details (dynamic, no hard code)
    mask_shapes = [list(sample['segmentation']['mask'].shape) for sample in all_samples if 'segmentation' in sample]
    mask_paths = [f"masks/{sample['image_id']}.npy" for sample in all_samples if 'segmentation' in sample]

    # Severity details (dynamic, no hard code)
    severity_classes = sorted(set(sample['severity']['severity_class'] for sample in all_samples if 'severity' in sample))
    severity_labels = sorted(set(sample['severity']['severity_label'] for sample in all_samples if 'severity' in sample))
    
    actual_info = {
        'total_samples': actual_splits['total_samples'],
        'image_size': actual_image_size,
        'num_classes': actual_num_classes,
        'class_names': actual_class_names,
        'class_labels': {name: idx for idx, name in enumerate(actual_class_names)},
        'heatmap_size': actual_heatmap_size,
        'heatmap_type': 'infection_only',
        'heatmap_analysis': heatmap_summary,
        'splits': actual_splits,
        'parasitemia_range': [min(all_scores), max(all_scores)],
        'class_distribution': {actual_class_names[i]: class_counts.get(i, 0) for i in unique_labels},
        'mask_shapes': mask_shapes,
        'mask_paths': mask_paths,
        'severity_classes': severity_classes,
        'severity_labels': severity_labels,
        'supported_tasks': ['detection', 'regression', 'localization', 'segmentation', 'severity', 'multi_task']
    }
    
    return actual_info

def print_actual_info(info):
    """Print the actual dataset info with pandas table presentation"""
    if not info:
        return
        
    print("ACTUAL DATASET INFO:")
    print("="*70)
    print(f"Total Samples: {info['total_samples']}")
    print(f"Image Size: {info['image_size']}")
    print(f"Number of Classes: {info['num_classes']}")
    print(f"Class Names: {info['class_names']}")
    print(f"Class Labels: {info['class_labels']}")
    print(f"Heatmap Size: {info['heatmap_size']}")
    print(f"Heatmap Type: {info['heatmap_type']}")
    print(f"Splits: {info['splits']}")
    print(f"Parasitemia Range: {info['parasitemia_range']}")
    print(f"Class Distribution: {info['class_distribution']}")
    print(f"Mask Shapes: {max(info.get('mask_shapes', 'N/A'))}")
    print(f"Mask Paths: {info.get('mask_paths', 'N/A')[:1]}")
    print(f"Severity Classes: {info.get('severity_classes', 'N/A')}")
    print(f"Severity Labels: {info.get('severity_labels', 'N/A')}")
    
    
    print(f"Supported Tasks: {info['supported_tasks']}")
    
    # Pandas table presentation
    if 'heatmap_analysis' in info:
        print("\n" + "="*80)
        print("HEATMAP ANALYSIS - COMPREHENSIVE STATISTICS")
        print("="*80)
        
        overall = info['heatmap_analysis']['overall']
        
        # Overall Statistics Table
        print(f"\nOVERALL HEATMAP STATISTICS:")
        print("-" * 50)
        
        overall_data = {
            'Metric': ['Average Min Intensity', 'Average Max Intensity', 'Average Mean Intensity', 
                      'Average Std Intensity', 'Average Coverage (%)', 'Avg Non-Zero Pixels',
                      'Min Value Range', 'Max Value Range', 'Avg Intensity Range'],
            'Value': [f"{overall['avg_min']:.3f}", f"{overall['avg_max']:.3f}", f"{overall['avg_mean']:.3f}",
                     f"{overall['avg_std']:.3f}", f"{overall['avg_coverage']:.1f}%", f"{overall['avg_non_zero_pixels']:.0f}",
                     f"{overall['min_range'][0]:.3f} - {overall['min_range'][1]:.3f}",
                     f"{overall['max_range'][0]:.3f} - {overall['max_range'][1]:.3f}",
                     f"{overall['avg_intensity_range']:.3f}"],
            'Description': ['Baseline infection intensity', 'Peak infection intensity', 'Overall infection density',
                           'Intensity variation', 'Infected area coverage', 'Infection pixel count',
                           'Global minimum bounds', 'Global maximum bounds', 'Average intensity spread']
        }
        
        overall_df = pd.DataFrame(overall_data)
        print(overall_df.to_string(index=False, max_colwidth=25))
        
        # Infection Level Statistics Table
        print(f"\nHEATMAP STATISTICS BY INFECTION LEVEL:")
        print("-" * 80)
        
        by_level = info['heatmap_analysis']['by_infection_level']
        
        # Create comprehensive comparison table
        level_data = []
        for level, stats in by_level.items():
            if stats['count'] > 0:
                level_data.append({
                    'Infection Level': level.upper(),
                    'Samples': stats['count'],
                    'Avg Min Intensity': f"{stats['avg_min_intensity']:.3f}",
                    'Min Range': f"{stats['min_intensity_range'][0]:.3f}-{stats['min_intensity_range'][1]:.3f}",
                    'Avg Max Intensity': f"{stats['avg_max_intensity']:.3f}",
                    'Max Range': f"{stats['max_intensity_range'][0]:.3f}-{stats['max_intensity_range'][1]:.3f}",
                    'Avg Mean Intensity': f"{stats['avg_mean_intensity']:.3f}",
                    'Mean Range': f"{stats['mean_intensity_range'][0]:.3f}-{stats['mean_intensity_range'][1]:.3f}",
                    'Avg Coverage (%)': f"{stats['avg_coverage']:.1f}%",
                    'Coverage Range': f"{stats['coverage_range'][0]:.1f}%-{stats['coverage_range'][1]:.1f}%",
                    'Avg Non-Zero Pixels': f"{stats['avg_non_zero_pixels']:.0f}",
                    'Pixels Range': f"{stats['non_zero_pixels_range'][0]:.0f}-{stats['non_zero_pixels_range'][1]:.0f}"
                })
            else:
                level_data.append({
                    'Infection Level': level.upper(),
                    'Samples': 0,
                    'Avg Min Intensity': 'N/A',
                    'Min Range': 'N/A',
                    'Avg Max Intensity': 'N/A', 
                    'Max Range': 'N/A',
                    'Avg Mean Intensity': 'N/A',
                    'Mean Range': 'N/A',
                    'Avg Coverage (%)': 'N/A',
                    'Coverage Range': 'N/A',
                    'Avg Non-Zero Pixels': 'N/A',
                    'Pixels Range': 'N/A'
                })
        
        level_df = pd.DataFrame(level_data)
        
        # Display in sections for better readability
        print("\nINTENSITY METRICS:")
        intensity_cols = ['Infection Level', 'Samples', 'Avg Min Intensity', 'Min Range', 
                         'Avg Max Intensity', 'Max Range', 'Avg Mean Intensity', 'Mean Range']
        print(level_df[intensity_cols].to_string(index=False, max_colwidth=15))
        
        print("\nCOVERAGE METRICS:")
        coverage_cols = ['Infection Level', 'Samples', 'Avg Coverage (%)', 'Coverage Range', 
                        'Avg Non-Zero Pixels', 'Pixels Range']
        print(level_df[coverage_cols].to_string(index=False, max_colwidth=15))
        
        # Summary Analysis
        print(f"\nANALYSIS SUMMARY:")
        print("-" * 40)
        
        # Find most and least infected levels
        active_levels = [(level, stats) for level, stats in by_level.items() if stats['count'] > 0]
        if active_levels:
            max_coverage_level = max(active_levels, key=lambda x: x[1]['avg_coverage'])
            max_intensity_level = max(active_levels, key=lambda x: x[1]['avg_max_intensity'])
            
            print(f"Highest Coverage:     {max_coverage_level[0].upper()} ({max_coverage_level[1]['avg_coverage']:.1f}%)")
            print(f"Highest Intensity:    {max_intensity_level[0].upper()} ({max_intensity_level[1]['avg_max_intensity']:.3f})")
            print(f"Total Active Levels:  {len(active_levels)}/4")
            print(f"Total Samples:        {sum(stats['count'] for _, stats in active_levels)}")
            
            # Quality assessment
            overall_quality = "EXCELLENT" if overall['avg_max'] > 0.8 else "GOOD" if overall['avg_max'] > 0.6 else "MODERATE"
            coverage_quality = "OPTIMAL" if 5 <= overall['avg_coverage'] <= 15 else "ACCEPTABLE" if overall['avg_coverage'] < 25 else "HIGH"
            
            print(f"Intensity Quality:    {overall_quality} (avg max: {overall['avg_max']:.3f})")
            print(f"Coverage Quality:     {coverage_quality} ({overall['avg_coverage']:.1f}%)")


# Get the actual info with comprehensive heatmap analysis
if 'results' in locals():
    actual_dataset_info = get_actual_dataset_info(results)
    print_actual_info(actual_dataset_info)
else:
    print("Run the pipeline first to get results!")