# ADNI Dataset Model Cross-Validation Pipeline
## Features: GPU acceleration, timeout support, comprehensive error handling, and automatic result archiving

In [None]:
heavy_models = [
    'vit_base_patch16_224.augreg2_in21k_ft_in1k',
    'vit_base_patch16_224',
    "vit_tiny_patch16_224.augreg_in21k_ft_in1k",
    "vit_tiny_patch16_224",
    'swin_base_patch4_window7_224.ms_in22k_ft_in1k',
    'swin_base_patch4_window7_224',
    'maxvit_tiny_224',
    'tf_efficientnet_b4.ns_jft_in1k',
    'convnext_small.fb_in22k_ft_in1k'
]

medium_models = [
    'tf_efficientnetv2_s.in21k_ft_in1k',
    'convnext_tiny.fb_in22k_ft_in1k',
    'coatnet_0_rw_224.sw_in1k',
    'resnet50.a1_in1k',
    'resnext50_32x4d.a1h_in1k',
    'densenet121.ra_in1k',
    'inception_v3',
    'xception',
    'vgg16_bn'
]

light_models = [
    'mobilevit_s.cvnets_in1k',
    'efficientformer_l1.snap_dist_in1k',
    'poolformer_s12.sail_in1k',
    'resnet18',
    'efficientnet_b0',
    'mobilenetv3_large_100.ra_in1k',
    'ghostnet_100.in1k'
]

def get_model_batches():
    batches = []
    
    def chunk_list(lst, n):
        for i in range(0, len(lst), n):
            yield lst[i:i + n]

    for chunk in chunk_list(heavy_models, 3):
        batches.append(chunk)
    for chunk in chunk_list(medium_models, 5):
        batches.append(chunk)
    for chunk in chunk_list(light_models, 8):
        batches.append(chunk)

    return batches

In [None]:
import subprocess
import sys
import os
from datetime import datetime
import traceback
import warnings

warnings.filterwarnings("ignore")

OUTPUT_DIR = "output"

sys.path.insert(0, os.path.dirname(os.path.abspath('.')))
try:
    from module.config import SUBPROCESS_TIMEOUT
    print(f"Loaded timeout: {SUBPROCESS_TIMEOUT}s ({SUBPROCESS_TIMEOUT/3600:.1f}h)")
except ImportError:
    SUBPROCESS_TIMEOUT = 14400
    print(f"Default timeout: {SUBPROCESS_TIMEOUT}s ({SUBPROCESS_TIMEOUT/3600:.1f}h)")

In [None]:
SUBPROCESS_TEMPLATE = r"""
import sys
from module.cross_validation import Cross_Validator
from module.utils import Logger

def run_batch():
    models = __models_list__
    use_aug = __augmentation__

    logger = Logger("batch_" + str(hash(str(models)))[:8])
    logger.info(f"Starting validation for {models}")
    
    try:
        validator = Cross_Validator(models, logger, use_aug=use_aug)
        validator.run()
        logger.info("Validation complete")
    except Exception as e:
        logger.error(f"Batch failed: {e}")
        raise

if __name__ == "__main__":
    run_batch()
"""

In [None]:
def run_subprocess(models_list, use_aug):
    script_filename = "temp_runner.py"
    script_path = os.path.join("module", script_filename)
    
    script_content = SUBPROCESS_TEMPLATE.replace("__models_list__", str(models_list)) \
                                        .replace("__augmentation__", str(use_aug))
    
    with open(script_path, "w") as f:
        f.write(script_content)
    
    print(f"üöÄ Launching: {models_list}")
    print(f"   Timeout: {SUBPROCESS_TIMEOUT/3600:.1f}h")

    try:
        module_path = f"module.{script_filename[:-3]}"
        subprocess.run(
            [sys.executable, "-m", module_path], 
            check=True,
            timeout=SUBPROCESS_TIMEOUT
        )
    finally:
        if os.path.exists(script_path):
            os.remove(script_path)

In [None]:
def run_queue(use_aug):
    run_id = f"AUG_{use_aug}_{datetime.now().strftime('%Y%m%d_%H%M')}"
    total_batches = completed = timeout = failed = 0
    errors = []
    
    print(f"--- Queue {run_id} ---")

    try:
        batches = get_model_batches()
        total_batches = len(batches)

        for i, batch in enumerate(batches):
            print(f"\n>>> Batch {i+1}/{total_batches}: {batch}")
            try:
                run_subprocess(batch, use_aug)
                completed += 1
                print(f"‚úÖ Batch {i+1} done")
            except subprocess.TimeoutExpired:
                timeout += 1
                errors.append(f"Batch {i+1} TIMEOUT")
                print(f"‚è∞ Timeout")
            except subprocess.CalledProcessError as e:
                failed += 1
                errors.append(f"Batch {i+1} ERROR: {e}")
                print(f"‚ùå Failed")
            except Exception as e:
                failed += 1
                errors.append(f"Batch {i+1}: {traceback.format_exc()}")
                print(f"‚ùå Error")
    except KeyboardInterrupt:
        print("‚ö†Ô∏è User interrupt")
    
    summary = f"""
=== RUN SUMMARY ===
ID: {run_id}
Completed: {completed}/{total_batches}
Timeout: {timeout}
Failed: {failed}
Errors: {len(errors)}
"""
    print(summary)
    
    with open(f"REPORT_{run_id}.txt", "w") as f:
        f.write(summary + "\n\n" + "\n".join(errors))
    
    return {'run_id': run_id, 'completed': completed, 'timeout': timeout, 'failed': failed, 'total': total_batches}

In [None]:
def zip_output_directory(summaries):
    import zipfile
    
    if not os.path.exists(OUTPUT_DIR):
        return
    
    zip_name = f"Results_{datetime.now().strftime('%Y%m%d_%H%M')}.zip"
    
    print("\nüì¶ ZIPPING OUTPUT")
    
    try:
        with zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
            files = 0
            for root, dirs, filelist in os.walk(OUTPUT_DIR):
                for file in filelist:
                    filepath = os.path.join(root, file)
                    zipf.write(filepath, os.path.relpath(filepath, '.'))
                    files += 1
        
        size_mb = os.path.getsize(zip_name) / (1024*1024)
        print(f"‚úÖ {zip_name} ({size_mb:.1f}MB, {files} files)")

        print("\nüìä SUMMARY:")
        
        for s in summaries:
            status = "‚úì" if s['completed'] == s['total'] else "‚ö†"
            print(f"\n{status} {s['run_id']}: {s['completed']}/{s['total']}")
        
        print("\nüéâ COMPLETE!")
    except Exception as e:
        print(f"‚ùå Zip failed: {e}")

In [None]:
summaries = []

print("üöÄ NON-AUGMENTED RUN")
summaries.append(run_queue(False))

print("\nüöÄ AUGMENTED RUN")
summaries.append(run_queue(True))

zip_output_directory(summaries)