# ADNI Dataset Model Cross-Validation Pipeline
## Features: GPU acceleration, timeout support, comprehensive error handling, automatic result archiving, and advanced classifier support

In [None]:
# Model batches organized by computational requirements
heavy_models = [
    'vit_base_patch16_224.augreg2_in21k_ft_in1k',
    'vit_base_patch16_224',
    "vit_tiny_patch16_224.augreg_in21k_ft_in1k",
    "vit_tiny_patch16_224",
    'swin_base_patch4_window7_224.ms_in22k_ft_in1k',
    'swin_base_patch4_window7_224',
    'maxvit_tiny_224',
    'tf_efficientnet_b4.ns_jft_in1k',
    'convnext_small.fb_in22k_ft_in1k'
]

medium_models = [
    'tf_efficientnetv2_s.in21k_ft_in1k',
    'convnext_tiny.fb_in22k_ft_in1k',
    'coatnet_0_rw_224.sw_in1k',
    'resnet50.a1_in1k',
    'resnext50_32x4d.a1h_in1k',
    'densenet121.ra_in1k',
    'inception_v3',
    'xception',
    'vgg16_bn'
]

light_models = [
    'mobilevit_s.cvnets_in1k',
    'efficientformer_l1.snap_dist_in1k',
    'poolformer_s12.sail_in1k',
    'resnet18',
    'efficientnet_b0',
    'mobilenetv3_large_100.ra_in1k',
    'ghostnet_100.in1k'
]

def get_model_batches():
    """Organize models into processing batches based on computational requirements."""
    batches = []
    
    def chunk_list(lst, n):
        for i in range(0, len(lst), n):
            yield lst[i:i + n]

    # Heavy models: 3 per batch
    for chunk in chunk_list(heavy_models, 3):
        batches.append(chunk)
    
    # Medium models: 5 per batch
    for chunk in chunk_list(medium_models, 5):
        batches.append(chunk)
    
    # Light models: 8 per batch
    for chunk in chunk_list(light_models, 8):
        batches.append(chunk)

    return batches

In [None]:
# NEW: Classifier Configuration
# Map specific models to specific classifiers
# If a model is not in this map, it will use 'baseline' by default

# Available classifiers:
# - 'baseline': Standard CrossEntropy
# - 'progressive': 3-phase discriminative fine-tuning (RECOMMENDED)
# - 'evidential': Uncertainty quantification
# - 'metric_learning': Prototypes + Triplet + Center Loss
# - 'regularized': Manifold Mixup + Label Smoothing
# - 'attention_enhanced': SE Blocks + Cosine Classifier
# - 'progressive_evidential': Progressive + Evidential
# - 'clinical_grade': Clinical deployment (5 techniques + SAM)
# - 'hybrid_transformer': CNN + Transformer hybrid
# - 'ultimate': All 10 techniques (maximum recall)
# - 'all': Test all classifiers on this model

MODEL_CLASSIFIER_MAP = {
    # Example: Use progressive fine-tuning for ResNet models
    'resnet18': 'progressive',
    'resnet50.a1_in1k': 'progressive',
    
    # Example: Use clinical-grade for EfficientNet (high accuracy needed)
    'tf_efficientnet_b4.ns_jft_in1k': 'clinical_grade',
    
    # Example: Use ultimate for your best model
    # 'efficientnet_b0': 'ultimate',
    
    # Example: Test ALL classifiers on a specific model
    # 'vit_tiny_patch16_224': 'all',
}

# Default classifier for models not in the map
DEFAULT_CLASSIFIER = 'baseline'

print("Classifier Configuration:")
print(f"  Default: {DEFAULT_CLASSIFIER}")
print(f"  Custom mappings: {len(MODEL_CLASSIFIER_MAP)}")
for model, clf in MODEL_CLASSIFIER_MAP.items():
    print(f"    {model}: {clf}")

In [None]:
import subprocess
import sys
import os
from datetime import datetime
import traceback
import warnings
import shutil

warnings.filterwarnings("ignore")

OUTPUT_DIR = "output"

sys.path.insert(0, os.path.dirname(os.path.abspath('.')))
try:
    from module.config import SUBPROCESS_TIMEOUT
    print(f"Loaded timeout: {SUBPROCESS_TIMEOUT}s ({SUBPROCESS_TIMEOUT/3600:.1f}h)")
except ImportError:
    SUBPROCESS_TIMEOUT = 12 * 3600
    print(f"Default timeout: {SUBPROCESS_TIMEOUT}s ({SUBPROCESS_TIMEOUT/3600:.1f}h)")

In [None]:
# UPDATED: Template now includes classifier mapping
SUBPROCESS_TEMPLATE = r"""
import sys
from module.cross_validation import Cross_Validator
from module.utils import Logger

def run_batch():
    models = __models_list__
    classifier_map = __classifier_map__

    logger = Logger("batch_" + str(hash(str(models)))[:8])
    logger.info(f"Starting validation for {models}")
    logger.info(f"Classifier mapping: {classifier_map}")
    
    try:
        validator = Cross_Validator(
            models,
            logger,
            model_classifier_map=classifier_map
        )
        validator.run()
        logger.info("Validation complete")
    except Exception as e:
        logger.error(f"Batch failed: {e}")
        raise

if __name__ == "__main__":
    run_batch()
"""

In [None]:
def run_subprocess(models_list, classifier_map=None):
    """Run a batch of models in a subprocess with classifier mapping.
    
    Args:
        models_list: List of model names to train
        classifier_map: Dict mapping model_name -> classifier_type
                       If None, uses DEFAULT_CLASSIFIER for all models
    """
    script_filename = "temp_runner.py"
    script_path = os.path.join("module", script_filename)
    
    # Create classifier map for this batch
    if classifier_map is None:
        batch_classifier_map = {
            model: DEFAULT_CLASSIFIER for model in models_list
        }
    else :
        # Filter to only include models in this batch
        batch_classifier_map = {
            model: classifier_map.get(model, DEFAULT_CLASSIFIER)
            for model in models_list
        }
    
    script_content = SUBPROCESS_TEMPLATE.replace("__models_list__", str(models_list)) \
                                        .replace("__classifier_map__", str(batch_classifier_map))
    
    with open(script_path, "w") as f:
        f.write(script_content)
    
    print(f"üöÄ Launching: {models_list}")
    print(f"   Classifiers: {batch_classifier_map}")
    print(f"   Timeout: {SUBPROCESS_TIMEOUT/3600:.1f}h")

    try:
        module_path = f"module.{script_filename[:-3]}"
        subprocess.run(
            [sys.executable, "-m", module_path], 
            check=True,
            timeout=SUBPROCESS_TIMEOUT
        )
    finally:
        if os.path.exists(script_path):
            os.remove(script_path)

In [None]:
def run_queue(classifier_map=None):
    """Run all model batches with classifier mapping.
    
    Args:
        classifier_map: Dict mapping model_name -> classifier_type
                       If None, uses DEFAULT_CLASSIFIER for all models
    """
    run_id = f"RUN_{datetime.now().strftime('%Y%m%d_%H%M')}"
    total_batches = completed = timeout = failed = 0
    errors = []
    
    print(f"\n{'='*80}")
    print(f"STARTING BATCH PROCESSING: {run_id}")
    print(f"{'='*80}")
    
    if classifier_map:
        print(f"\nClassifier Mapping ({len(classifier_map)} custom):")
        for model, clf in classifier_map.items():
            print(f"  {model}: {clf}")
    print(f"Default Classifier: {DEFAULT_CLASSIFIER}")
    print(f"{'='*80}\n")

    try:
        batches = get_model_batches()
        total_batches = len(batches)

        for i, batch in enumerate(batches):
            print(f"\n{'>'*80}")
            print(f">>> Batch {i+1}/{total_batches}")
            print(f">>> Models: {batch}")
            print(f"{'>'*80}")
            
            try:
                run_subprocess(batch, classifier_map)
                completed += 1
                print(f"\n‚úÖ Batch {i+1}/{total_batches} completed successfully\n")
            except subprocess.TimeoutExpired:
                timeout += 1
                errors.append(f"Batch {i+1} TIMEOUT after {SUBPROCESS_TIMEOUT/3600:.1f}h")
                print(f"\n‚è∞ Batch {i+1} TIMEOUT\n")
            except subprocess.CalledProcessError as e:
                failed += 1
                errors.append(f"Batch {i+1} ERROR: {e}")
                print(f"\n‚ùå Batch {i+1} failed with error\n")
            except Exception as e:
                failed += 1
                errors.append(f"Batch {i+1}: {traceback.format_exc()}")
                print(f"\n‚ùå Batch {i+1} failed with exception\n")
                
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è User interrupt detected\n")
    
    # Summary Report
    summary = f"""
{'='*80}
EXECUTION SUMMARY
{'='*80}
Run ID: {run_id}
Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

Results:
  ‚úì Completed: {completed}/{total_batches}
  ‚è∞ Timeout:   {timeout}
  ‚ùå Failed:    {failed}
  
Success Rate: {100*completed/total_batches if total_batches > 0 else 0:.1f}%
{'='*80}
"""
    print(summary)
    
    # Save detailed report
    report_file = f"REPORT_{run_id}.txt"
    with open(report_file, "w") as f:
        f.write(summary)
        f.write("\n\nDETAILED ERRORS:\n")
        f.write("\n".join(errors) if errors else "No errors")
    
    print(f"üìÑ Detailed report saved to: {report_file}\n")
    
    return {
        'run_id': run_id,
        'completed': completed,
        'timeout': timeout,
        'failed': failed,
        'total': total_batches
    }

In [None]:
def zip_output_directory(summary):
    """Archive output directory after successful completion."""
    import zipfile
    
    if not os.path.exists(OUTPUT_DIR):
        print(f"‚ö†Ô∏è Output directory '{OUTPUT_DIR}' not found. Nothing to zip.")
        return
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M')
    zip_name = f"Results_{timestamp}.zip"
    
    print(f"\n{'='*80}")
    print("ARCHIVING RESULTS")
    print(f"{'='*80}\n")
    
    try:
        with zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
            files = 0
            for root, dirs, filelist in os.walk(OUTPUT_DIR):
                for file in filelist:
                    filepath = os.path.join(root, file)
                    zipf.write(filepath, os.path.relpath(filepath, '.'))
                    files += 1
        
        size_mb = os.path.getsize(zip_name) / (1024*1024)
        print(f"‚úÖ Archive created: {zip_name}")
        print(f"   Size: {size_mb:.1f} MB")
        print(f"   Files: {files}")

        print(f"\n{'='*80}")
        print("EXECUTION SUMMARY")
        print(f"{'='*80}")
        
        status = "‚úÖ SUCCESS" if summary['completed'] == summary['total'] else "‚ö†Ô∏è PARTIAL"
        print(f"\nStatus: {status}")
        print(f"Run ID: {summary['run_id']}")
        print(f"Completed: {summary['completed']}/{summary['total']} batches")
        print(f"Timeout: {summary['timeout']}")
        print(f"Failed: {summary['failed']}")
        
        print(f"\n{'='*80}")
        print("üéâ ALL PROCESSING COMPLETE!")
        print(f"{'='*80}\n")
        
    except Exception as e:
        print(f"‚ùå Archive creation failed: {e}")
        traceback.print_exc()

In [None]:
# MAIN EXECUTION
# Run all model batches with classifier configuration

print("\n" + "="*80)
print("ADNI CROSS-VALIDATION PIPELINE")
print("="*80)
print(f"Total Models: {len(heavy_models + medium_models + light_models)}")
print(f"Total Batches: {len(get_model_batches())}")
print(f"Timeout per Batch: {SUBPROCESS_TIMEOUT/3600:.1f}h")
print("="*80 + "\n")

# Run with classifier mapping
summary = run_queue(classifier_map=MODEL_CLASSIFIER_MAP)

# Archive results
zip_output_directory(summary)

In [None]:
# OPTIONAL: Test single batch (for debugging)
# Uncomment to run a single batch instead of all batches

# test_batch = ['resnet18']
# test_classifier_map = {
#     'resnet18': 'progressive'  # or 'all' to test all classifiers
# }

# print("\nüß™ TESTING SINGLE BATCH\n")
# run_subprocess(test_batch, test_classifier_map)

In [None]:
# OPTIONAL: Quick test with all classifiers on one model
# Uncomment to compare all classifiers on a single model

# quick_test_batch = ['resnet18']
# quick_test_map = {'resnet18': 'all'}  # Test ALL classifiers

# print("\nüî¨ QUICK TEST: All Classifiers on ResNet18\n")
# run_subprocess(quick_test_batch, quick_test_map)