In [2]:
import sqlite3
import json
import yaml
from datetime import datetime
import shutil
import pandas as pd
from pathlib import Path
from typing import Dict, List, Union, Any, Optional, Tuple
import uuid
import uproot
import numpy as np
from tqdm import tqdm
import requests
import tensorflow as tf
from abc import ABC, abstractmethod
from tensorflow import keras
from qkeras import QDense, QActivation, quantized_bits, quantized_relu
import platform
import os
import gc
import matplotlib.pyplot as plt
import psutil
import time

2025-02-03 15:25:33.318917: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-03 15:25:33.399456: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-03 15:25:33.400517: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
ATLAS_RUN_NUMBERS = ["00296939", "00296942", "00297447", "00297170", "00297041", "00297730", "00298591", "00298595", "00298609", "00298633", "00298687", "00298690", "00298771", "00298773", "00298862", "00298967", "00299055", "00299144", "00299147", "00299184", "00299241", "00299243", "00299278", "00299288", "00299315", "00299340", "00299343", "00299390", "00299584", "00300279", "00300287", "00300345", "00300415", "00300418", "00300487", "00300540", "00300571", "00300600", "00300655", "00300687", "00300784", "00300800", "00300863", "00300908", "00301912", "00301915", "00301918", "00301932", "00301973", "00302053", "00302137", "00302265", "00302269", "00302300", "00302347", "00302380", "00302391", "00302393", "00302737", "00302829", "00302831", "00302872", "00302919", "00302925", "00302956", "00303007", "00303059", "00303079", "00303201", "00303208", "00303264", "00303266", "00303291", "00303304", "00303338", "00303421", "00303499", "00303560", "00303638", "00303726", "00303811", "00303817", "00303819", "00303832", "00303846", "00303892", "00303943", "00304006", "00304008", "00304128", "00304178", "00304198", "00304211", "00304243", "00304308", "00304337", "00304409", "00304431", "00304494", "00305291", "00305293", "00305359", "00305380", "00305543", "00305571", "00305618", "00305671", "00305674", "00305723", "00305727", "00305735", "00305777", "00305811", "00305920", "00306247", "00306269", "00306278", "00306310", "00306384", "00306419", "00306442", "00306448", "00306451", "00306556", "00306655", "00306657", "00306714", "00307124", "00307126", "00307195", "00307259", "00307306", "00307354", "00307358", "00307394", "00307454", "00307514", "00307539", "00307569", "00307601", "00307619", "00307656", "00307710", "00307716", "00307732", "00307861", "00307935", "00308047", "00308084", "00309311", "00309314", "00309346", "00309375", "00309390", "00309440", "00309516", "00309640", "00309674", "00309759", "00310015", "00310210", "00310247", "00310249", "00310341", "00310370", "00310405", "00310468", "00310473", "00310574", "00310634", "00310691", "00310738", "00310781", "00310809", "00310863", "00310872", "00310969", "00311071", "00311170", "00311244", "00311287", "00311321", "00311365", "00311402", "00311473", "00311481"]

In [4]:
def get_system_usage():
        """Get current system memory and CPU usage with proper cleanup"""
        # Force garbage collection before checking memory
        gc.collect()
        
        # Get memory info
        memory = psutil.virtual_memory()
        cpu_percent = psutil.cpu_percent(interval=0.1)  # Shorter interval
        
        # Calculate memory in GB
        total_gb = memory.total / (1024**3)
        available_gb = memory.available / (1024**3)
        used_gb = memory.used / (1024**3)
        
        return {
            'memory': {
                'total_gb': total_gb,
                'available_gb': available_gb,
                'used_gb': used_gb,
                'percent': memory.percent
            },
            'cpu': {
                'percent': cpu_percent
            }
        }

def print_system_usage(prefix=""):
    """Print current system usage with optional prefix"""
    usage = get_system_usage()
    print(f"\n{prefix}System Usage:")
    print(f"Memory: {usage['memory']['used_gb']:.1f}GB / {usage['memory']['total_gb']:.1f}GB ({usage['memory']['percent']}%)")
    print(f"Available Memory: {usage['memory']['available_gb']:.1f}GB")
    print(f"CPU Usage: {usage['cpu']['percent']}%")

In [5]:
class ModelRegistry:
    """
    Enhanced central registry for managing ML experiments, models, and metrics
    Tracks detailed dataset configurations and training metrics
    """
    def __init__(self, base_path: str):
        self.base_path = Path(base_path)
        self.db_path = self.base_path / "registry.db"
        self.model_store = self.base_path / "model_store"
        
        # Create directories if they don't exist
        self.model_store.mkdir(parents=True, exist_ok=True)
        
        # Initialize database
        self._initialize_db()
        
    def _initialize_db(self):
        """Create database tables if they don't exist"""
        with sqlite3.connect(self.db_path) as conn:
            # Main experiments table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS experiments (
                    experiment_id TEXT PRIMARY KEY,
                    timestamp DATETIME,
                    name TEXT,
                    description TEXT,
                    status TEXT,
                    environment_info JSON
                )
            """)
            
            # Dataset configuration table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS dataset_configs (
                    experiment_id TEXT PRIMARY KEY,
                    run_numbers JSON,
                    track_selections JSON,
                    event_selections JSON,
                    max_tracks_per_event INTEGER,
                    min_tracks_per_event INTEGER,
                    normalization_params JSON,
                    train_fraction FLOAT,
                    validation_fraction FLOAT,
                    test_fraction FLOAT,
                    batch_size INTEGER,
                    shuffle_buffer INTEGER,
                    data_quality_metrics JSON,
                    FOREIGN KEY(experiment_id) REFERENCES experiments(experiment_id)
                )
            """)
            
            # Model configuration table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS model_configs (
                    experiment_id TEXT PRIMARY KEY,
                    model_type TEXT,
                    architecture JSON,
                    hyperparameters JSON,
                    FOREIGN KEY(experiment_id) REFERENCES experiments(experiment_id)
                )
            """)
            
            # Training configuration and results
            conn.execute("""
                CREATE TABLE IF NOT EXISTS training_info (
                    experiment_id TEXT PRIMARY KEY,
                    config JSON,
                    start_time DATETIME,
                    end_time DATETIME,
                    epochs_completed INTEGER,
                    training_history JSON,
                    final_metrics JSON,
                    hardware_metrics JSON,
                    FOREIGN KEY(experiment_id) REFERENCES experiments(experiment_id)
                )
            """)
            
            # Checkpoints table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS checkpoints (
                    checkpoint_id TEXT PRIMARY KEY,
                    experiment_id TEXT,
                    name TEXT,
                    timestamp DATETIME,
                    metadata JSON,
                    FOREIGN KEY(experiment_id) REFERENCES experiments(experiment_id)
                )
            """)
            
    def register_experiment(self,
                          name: str,
                          dataset_config: dict,
                          model_config: dict,
                          training_config: dict,
                          description: str = "") -> str:
        """
        Register new experiment with enhanced configuration tracking
        
        Args:
            name: Human readable experiment name
            dataset_config: Dataset parameters including:
                - run_numbers: List of ATLAS run numbers
                - track_selections: Dictionary of track-level selection criteria
                - event_selections: Dictionary of event-level selection criteria
                - max_tracks_per_event: Maximum number of tracks to keep per event
                - min_tracks_per_event: Minimum number of tracks required per event
                - normalization_params: Dictionary of normalization parameters
                - train_fraction: Fraction of data for training
                - validation_fraction: Fraction for validation
                - test_fraction: Fraction for testing
                - batch_size: Batch size used
                - shuffle_buffer: Shuffle buffer size
                - data_quality_metrics: Results of data validation
            model_config: Model configuration including:
                - model_type: Type of model (e.g., "autoencoder")
                - architecture: Network architecture details
                - hyperparameters: Model hyperparameters
            training_config: Training parameters
            description: Optional experiment description
        """
        experiment_id = str(uuid.uuid4())
        timestamp = datetime.now()
        
        # Get environment info
        environment_info = {
            "python_version": platform.python_version(),
            "tensorflow_version": tf.__version__,
            "platform": platform.platform(),
            "cpu_count": os.cpu_count()
        }
        try:
            environment_info["gpu_devices"] = tf.config.list_physical_devices('GPU')
        except:
            environment_info["gpu_devices"] = []
            
        with sqlite3.connect(self.db_path) as conn:
            # Insert main experiment info
            conn.execute(
                """
                INSERT INTO experiments 
                (experiment_id, timestamp, name, description, status, environment_info)
                VALUES (?, ?, ?, ?, ?, ?)
                """,
                (
                    experiment_id,
                    timestamp,
                    name,
                    description,
                    "registered",
                    json.dumps(environment_info)
                )
            )
            
            # Insert dataset configuration
            conn.execute(
                """
                INSERT INTO dataset_configs
                (experiment_id, run_numbers, track_selections, event_selections,
                max_tracks_per_event, min_tracks_per_event, normalization_params,
                train_fraction, validation_fraction, test_fraction,
                batch_size, shuffle_buffer, data_quality_metrics)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    experiment_id,
                    json.dumps(dataset_config['run_numbers']),
                    json.dumps(dataset_config['track_selections']),
                    json.dumps(dataset_config['event_selections']),
                    dataset_config['max_tracks_per_event'],
                    dataset_config['min_tracks_per_event'],
                    json.dumps(dataset_config.get('normalization_params', {})),
                    dataset_config['train_fraction'],
                    dataset_config['validation_fraction'],
                    dataset_config['test_fraction'],
                    dataset_config['batch_size'],
                    dataset_config['shuffle_buffer'],
                    json.dumps(dataset_config.get('data_quality_metrics', {}))
                )
            )
            
            # Insert model configuration
            conn.execute(
                """
                INSERT INTO model_configs
                (experiment_id, model_type, architecture, hyperparameters)
                VALUES (?, ?, ?, ?)
                """,
                (
                    experiment_id,
                    model_config['model_type'],
                    json.dumps(model_config['architecture']),
                    json.dumps(model_config.get('hyperparameters', {}))
                )
            )
            
            # Insert initial training info
            conn.execute(
                """
                INSERT INTO training_info
                (experiment_id, config, start_time, epochs_completed, training_history, final_metrics)
                VALUES (?, ?, ?, ?, ?, ?)
                """,
                (
                    experiment_id,
                    json.dumps(training_config),
                    None,
                    0,
                    "{}",
                    "{}"
                )
            )
        
        # Create experiment directory structure
        exp_dir = self.model_store / experiment_id
        exp_dir.mkdir(parents=True, exist_ok=True)
        
        # Save configurations as YAML for easy reading
        configs_dir = exp_dir / "configs"
        configs_dir.mkdir(exist_ok=True)
        
        with open(configs_dir / "dataset_config.yaml", 'w') as f:
            yaml.dump(dataset_config, f)
        with open(configs_dir / "model_config.yaml", 'w') as f:
            yaml.dump(model_config, f)
        with open(configs_dir / "training_config.yaml", 'w') as f:
            yaml.dump(training_config, f)
            
        return experiment_id
        
    def update_training_progress(self,
                               experiment_id: str,
                               epoch: int,
                               metrics: Dict[str, float],
                               hardware_metrics: Optional[Dict] = None):
        """Update training progress and metrics"""
        with sqlite3.connect(self.db_path) as conn:
            current = conn.execute(
                "SELECT training_history FROM training_info WHERE experiment_id = ?",
                (experiment_id,)
            ).fetchone()
            
            if current is None:
                raise ValueError(f"No experiment found with id {experiment_id}")
                
            history = json.loads(current[0])
            
            # Update history
            if str(epoch) not in history:
                history[str(epoch)] = {}
            history[str(epoch)].update(metrics)
            
            # Update training info
            updates = {
                "epochs_completed": epoch,
                "training_history": json.dumps(history)
            }
            
            if hardware_metrics:
                updates["hardware_metrics"] = json.dumps(hardware_metrics)
                
            update_sql = "UPDATE training_info SET " + \
                        ", ".join(f"{k} = ?" for k in updates.keys()) + \
                        " WHERE experiment_id = ?"
            
            conn.execute(update_sql, list(updates.values()) + [experiment_id])
            
    def complete_training(self,
                         experiment_id: str,
                         final_metrics: Dict[str, float]):
        """Mark training as complete and save final metrics"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute(
                """
                UPDATE training_info 
                SET end_time = ?, final_metrics = ?
                WHERE experiment_id = ?
                """,
                (datetime.now(), json.dumps(final_metrics), experiment_id)
            )
            
            conn.execute(
                "UPDATE experiments SET status = ? WHERE experiment_id = ?",
                ("completed", experiment_id)
            )

    def save_checkpoint(self, 
                       experiment_id: str,
                       models: Dict[str, Any],
                       checkpoint_name: str = "latest",
                       metadata: Optional[Dict] = None):
        """
        Save model checkpoints for an experiment
        
        Args:
            experiment_id: Experiment identifier
            models: Dictionary of named models to save
            checkpoint_name: Name for this checkpoint
            metadata: Optional metadata about the checkpoint
        """
        exp_dir = self.model_store / experiment_id
        checkpoint_dir = exp_dir / "checkpoints" / checkpoint_name
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        checkpoint_id = str(uuid.uuid4())
        
        # Save each model
        for name, model in models.items():
            model_path = checkpoint_dir / name
            model.save(model_path)
            
        # Save checkpoint metadata
        if metadata is None:
            metadata = {}
        metadata.update({
            "saved_models": list(models.keys()),
            "checkpoint_path": str(checkpoint_dir)
        })
            
        # Record checkpoint in database
        with sqlite3.connect(self.db_path) as conn:
            conn.execute(
                """
                INSERT INTO checkpoints
                (checkpoint_id, experiment_id, name, timestamp, metadata)
                VALUES (?, ?, ?, ?, ?)
                """,
                (
                    checkpoint_id,
                    experiment_id,
                    checkpoint_name,
                    datetime.now(),
                    json.dumps(metadata)
                )
            )
            
            # Update experiment status
            conn.execute(
                "UPDATE experiments SET status = ? WHERE experiment_id = ?",
                ("checkpoint_saved", experiment_id)
            )
            
    def load_checkpoint(self, 
                       experiment_id: str,
                       checkpoint_name: str = "latest") -> Dict[str, str]:
        """
        Get paths to saved model checkpoints
        
        Returns:
            Dictionary of model names to their saved paths
        """
        with sqlite3.connect(self.db_path) as conn:
            result = conn.execute(
                """
                SELECT metadata FROM checkpoints 
                WHERE experiment_id = ? AND name = ?
                ORDER BY timestamp DESC LIMIT 1
                """,
                (experiment_id, checkpoint_name)
            ).fetchone()
            
        if result is None:
            raise ValueError(
                f"No checkpoint '{checkpoint_name}' found for experiment {experiment_id}"
            )
            
        metadata = json.loads(result[0])
        checkpoint_path = Path(metadata["checkpoint_path"])
        
        if not checkpoint_path.exists():
            raise ValueError(f"Checkpoint directory not found: {checkpoint_path}")
            
        return {
            model_name: str(checkpoint_path / model_name)
            for model_name in metadata["saved_models"]
            if (checkpoint_path / model_name).exists()
        }

    def query_experiments(self, 
                         filters: Optional[Dict] = None,
                         metrics: Optional[List[str]] = None,
                         sort_by: Optional[str] = None,
                         ascending: bool = True) -> pd.DataFrame:
        """
        Query experiments with enhanced filtering and sorting
        
        Args:
            filters: Dictionary of column:value pairs to filter on
            metrics: List of specific metrics to include
            sort_by: Column to sort by
            ascending: Sort order
        
        Returns:
            DataFrame of matching experiments
        """
        # Build base query joining all tables
        query = """
        SELECT e.*, dc.*, mc.*, ti.* 
        FROM experiments e
        LEFT JOIN dataset_configs dc ON e.experiment_id = dc.experiment_id
        LEFT JOIN model_configs mc ON e.experiment_id = mc.experiment_id
        LEFT JOIN training_info ti ON e.experiment_id = ti.experiment_id
        """
        
        params = []
        if filters:
            conditions = []
            for column, value in filters.items():
                table = {
                    'name': 'e',
                    'status': 'e',
                    'model_type': 'mc',
                    'batch_size': 'dc'
                }.get(column.split('.')[0], 'e')
                
                conditions.append(f"{table}.{column} = ?")
                params.append(value)
                
            if conditions:
                query += " WHERE " + " AND ".join(conditions)
                
        with sqlite3.connect(self.db_path) as conn:
            df = pd.read_sql_query(query, conn, params=params)
            
        # Parse JSON columns
        json_columns = {
            'environment_info': 'experiments',
            'run_numbers': 'dataset_configs',
            'selections': 'dataset_configs',
            'normalization_params': 'dataset_configs',
            'data_quality_metrics': 'dataset_configs',
            'architecture': 'model_configs',
            'hyperparameters': 'model_configs',
            'config': 'training_info',
            'training_history': 'training_info',
            'final_metrics': 'training_info',
            'hardware_metrics': 'training_info'
        }
        
        for col in json_columns:
            if col in df.columns:
                df[col] = df[col].apply(lambda x: json.loads(x) if pd.notna(x) else {})
        
        # Extract specific metrics if requested
        if metrics:
            for metric in metrics:
                df[f"metric_{metric}"] = df['final_metrics'].apply(
                    lambda x: x.get(metric, None) if isinstance(x, dict) else None
                )
                
        # Sort if requested
        if sort_by:
            df = df.sort_values(sort_by, ascending=ascending)
            
        return df
        
    def delete_experiment(self, experiment_id: str):
        """Delete an experiment and all associated files"""
        with sqlite3.connect(self.db_path) as conn:
            # Delete from all tables
            for table in ['training_info', 'model_configs', 'dataset_configs', 
                         'checkpoints', 'experiments']:
                conn.execute(f"DELETE FROM {table} WHERE experiment_id = ?", 
                           (experiment_id,))
            
        # Delete files
        exp_dir = self.model_store / experiment_id
        if exp_dir.exists():
            shutil.rmtree(exp_dir)
            
    def get_experiment_details(self, experiment_id: str) -> dict:
        """Get complete experiment information including all configs and results"""
        with sqlite3.connect(self.db_path) as conn:
            # Get all information from all tables
            experiment = conn.execute(
                "SELECT * FROM experiments WHERE experiment_id = ?",
                (experiment_id,)
            ).fetchone()
            
            if experiment is None:
                raise ValueError(f"No experiment found with id {experiment_id}")
                
            dataset_config = conn.execute(
                "SELECT * FROM dataset_configs WHERE experiment_id = ?",
                (experiment_id,)
            ).fetchone()
            
            model_config = conn.execute(
                "SELECT * FROM model_configs WHERE experiment_id = ?",
                (experiment_id,)
            ).fetchone()
            
            training_info = conn.execute(
                "SELECT * FROM training_info WHERE experiment_id = ?",
                (experiment_id,)
            ).fetchone()
            
            checkpoints = conn.execute(
                "SELECT * FROM checkpoints WHERE experiment_id = ?",
                (experiment_id,)
            ).fetchall()
            
        # Combine all information
        details = {
            "experiment_info": dict(zip(
                ['experiment_id', 'timestamp', 'name', 'description', 'status', 'environment_info'],
                experiment
            )),
            "dataset_config": dict(zip(
                ['experiment_id', 'run_numbers', 'selections', 'normalization_params',
                 'train_fraction', 'validation_fraction', 'test_fraction',
                 'batch_size', 'shuffle_buffer', 'data_quality_metrics'],
                dataset_config
            )),
            "model_config": dict(zip(
                ['experiment_id', 'model_type', 'architecture', 'hyperparameters'],
                model_config
            )),
            "training_info": dict(zip(
                ['experiment_id', 'config', 'start_time', 'end_time', 
                 'epochs_completed', 'training_history', 'final_metrics', 'hardware_metrics'],
                training_info
            )),
            "checkpoints": [
                dict(zip(
                    ['checkpoint_id', 'experiment_id', 'name', 'timestamp', 'metadata'],
                    checkpoint
                ))
                for checkpoint in checkpoints
            ]
        }
        
        # Parse JSON fields
        details['experiment_info']['environment_info'] = json.loads(details['experiment_info']['environment_info'])
        details['dataset_config']['run_numbers'] = json.loads(details['dataset_config']['run_numbers'])
        details['dataset_config']['selections'] = json.loads(details['dataset_config']['selections'])
        details['dataset_config']['normalization_params'] = json.loads(details['dataset_config']['normalization_params'])
        details['dataset_config']['data_quality_metrics'] = json.loads(details['dataset_config']['data_quality_metrics'])
        details['model_config']['architecture'] = json.loads(details['model_config']['architecture'])
        details['model_config']['hyperparameters'] = json.loads(details['model_config']['hyperparameters'])
        details['training_info']['config'] = json.loads(details['training_info']['config'])
        details['training_info']['training_history'] = json.loads(details['training_info']['training_history'])
        details['training_info']['final_metrics'] = json.loads(details['training_info']['final_metrics'])
        if details['training_info']['hardware_metrics']:
            details['training_info']['hardware_metrics'] = json.loads(details['training_info']['hardware_metrics'])
        
        for checkpoint in details['checkpoints']:
            checkpoint['metadata'] = json.loads(checkpoint['metadata'])
            
        return details
    
    def get_performance_summary(self, experiment_id: str) -> Dict[str, Any]:
        """Get summary of model performance metrics
        
        Args:
            experiment_id: Experiment identifier
            
        Returns:
            Dictionary containing performance metrics and training time
        """
        experiment = self.get_experiment_details(experiment_id)
        training_info = experiment['training_info']
        
        # Calculate training time
        if training_info['start_time'] and training_info['end_time']:
            start = datetime.fromisoformat(training_info['start_time'])
            end = datetime.fromisoformat(training_info['end_time'])
            training_duration = (end - start).total_seconds()
        else:
            training_duration = None
            
        # Get final metrics
        final_metrics = training_info['final_metrics']
        
        # Get training history progression
        history = training_info['training_history']
        metric_progression = {
            metric: [epoch_data.get(metric) for epoch_data in history.values()]
            for metric in set().union(*[epoch_data.keys() for epoch_data in history.values()])
        }
        
        return {
            'training_duration': training_duration,
            'epochs_completed': training_info['epochs_completed'],
            'final_metrics': final_metrics,
            'metric_progression': metric_progression,
            'hardware_metrics': training_info.get('hardware_metrics', {})
        }

    def compare_experiments(self, 
                          experiment_ids: List[str],
                          metrics: List[str] = None) -> pd.DataFrame:
        """Compare multiple experiments
        
        Args:
            experiment_ids: List of experiment IDs to compare
            metrics: Optional list of specific metrics to compare
            
        Returns:
            DataFrame with experiment comparison
        """
        experiments = []
        for exp_id in experiment_ids:
            details = self.get_experiment_details(exp_id)
            exp_summary = {
                'experiment_id': exp_id,
                'name': details['experiment_info']['name'],
                'status': details['experiment_info']['status'],
                'model_type': details['model_config']['model_type'],
                'run_numbers': details['dataset_config']['run_numbers'],
                'batch_size': details['dataset_config']['batch_size'],
                'epochs_completed': details['training_info']['epochs_completed']
            }
            
            # Add final metrics
            if metrics:
                for metric in metrics:
                    exp_summary[f'final_{metric}'] = details['training_info']['final_metrics'].get(metric)
                    
            # Add training duration if available
            if details['training_info']['start_time'] and details['training_info']['end_time']:
                start = datetime.fromisoformat(details['training_info']['start_time'])
                end = datetime.fromisoformat(details['training_info']['end_time'])
                exp_summary['training_duration'] = (end - start).total_seconds()
                
            experiments.append(exp_summary)
            
        return pd.DataFrame(experiments)

In [6]:
class ATLASDataManager:
    """Manages ATLAS PHYSLITE data access"""
    
    def __init__(self, base_dir: str = "atlas_data"):
        self.base_dir = Path(base_dir)
        self.base_url = "https://opendata.cern.ch/record/80001/files"
        self._setup_directories()
        self.catalog_counts = {}  # Cache for number of catalogs per run
    
    def get_catalog_count(self, run_number: str) -> int:
        """
        Discover how many catalog files exist for a run by probing the server
        
        Args:
            run_number: ATLAS run number
            
        Returns:
            Number of available catalog files
        """
        if run_number in self.catalog_counts:
            return self.catalog_counts[run_number]
            
        padded_run = run_number.zfill(8)
        index = 0
        
        while True:
            url = f"/record/80001/files/data16_13TeV_Run_{padded_run}_file_index.json_{index}"
            response = requests.head(f"https://opendata.cern.ch{url}")
            
            if response.status_code != 200:
                break
                
            index += 1
        
        self.catalog_counts[run_number] = index
        return index
    
    def download_run_catalog(self, run_number: str, index: int = 0) -> Optional[Path]:
        """
        Download a specific run catalog file.
        
        Args:
            run_number: ATLAS run number
            index: Catalog index
            
        Returns:
            Path to the downloaded catalog file or None if file doesn't exist
        """
        padded_run = run_number.zfill(8)
        url = f"/record/80001/files/data16_13TeV_Run_{padded_run}_file_index.json_{index}"
        output_path = self.base_dir / "catalogs" / f"Run_{run_number}_catalog_{index}.root"
        
        try:
            if self._download_file(url, output_path, f"Downloading catalog {index} for Run {run_number}"):
                return output_path
        except Exception as e:
            print(f"Failed to download catalog {index} for run {run_number}: {str(e)}")
            if output_path.exists():
                output_path.unlink()  # Clean up partial download
            return None
    
    # def get_run_info(self, run_number: str) -> Dict:
    #     """
    #     Get information about available catalogs for a run
        
    #     Returns:
    #         Dictionary containing run information including number of catalogs
    #     """
    #     n_catalogs = self.get_catalog_count(run_number)
        
    #     # Test download of first catalog to get sample info
    #     sample_path = self.download_run_catalog(run_number, 0)
    #     sample_info = {}
        
    #     if sample_path and sample_path.exists():
    #         try:
    #             with uproot.open(sample_path) as file:
    #                 tree = file["CollectionTree;1"]
    #                 sample_info = {
    #                     'events_per_catalog': len(tree),
    #                     'estimated_total_events': len(tree) * n_catalogs
    #                 }
    #         finally:
    #             sample_path.unlink()  # Clean up sample file
        
    #     return {
    #         'run_number': run_number,
    #         'n_catalogs': n_catalogs,
    #         **sample_info
    #     }
    
    # def iter_catalogs(self, run_number: str, start_index: int = 0, end_index: Optional[int] = None):
    #     """
    #     Iterator that yields catalogs for a run one at a time, cleaning up after each
        
    #     Args:
    #         run_number: Run number to process
    #         start_index: Starting catalog index
    #         end_index: Optional ending catalog index (exclusive)
            
    #     Yields:
    #         Tuple of (index, catalog_path)
    #     """
    #     if end_index is None:
    #         end_index = self.get_catalog_count(run_number)
        
    #     for index in range(start_index, end_index):
    #         catalog_path = self.download_run_catalog(run_number, index)
    #         if catalog_path and catalog_path.exists():
    #             try:
    #                 yield index, catalog_path
    #             finally:
    #                 catalog_path.unlink()  # Clean up after processing
    
    def _setup_directories(self):
        """Create necessary directory structure"""
        self.base_dir.mkdir(exist_ok=True)
        (self.base_dir / "catalogs").mkdir(exist_ok=True)

    
    def _download_file(self, url: str, output_path: Path, desc: str = None) -> bool:
        """Download a single file if it doesn't exist"""
        if output_path.exists():
            return False
        
        print(f"Downloading file: {url}")    
        response = requests.get(f"https://opendata.cern.ch{url}", stream=True)
        if response.status_code == 200:
            total_size = int(response.headers.get('content-length', 0))
            
            with open(output_path, 'wb') as f, tqdm(
                desc=desc,
                total=total_size,
                unit='iB',
                unit_scale=True,
                unit_divisor=1024,
            ) as pbar:
                for data in response.iter_content(chunk_size=1024):
                    size = f.write(data)
                    pbar.update(size)
            return True
        else:
            raise Exception(f"Download failed with status code: {response.status_code}")
    
    def get_run_catalog_path(self, run_number: str, index: int = 0) -> Path:
        """Get path to a run catalog file"""
        return self.base_dir / "catalogs" / f"Run_{run_number}_catalog_{index}.root"
    
    # def verify_catalog_file(self, file_path: Path) -> bool:
    #     """Verify that a catalog file contains valid track data"""
    #     try:
    #         with uproot.open(file_path) as file:
    #             if "CollectionTree;1" not in file:
    #                 print(f"No CollectionTree found in {file_path}")
    #                 return False
                
    #             tree = file["CollectionTree;1"]
    #             required_branches = [
    #                 "InDetTrackParticlesAuxDyn.d0",
    #                 "InDetTrackParticlesAuxDyn.z0",
    #                 "InDetTrackParticlesAuxDyn.phi",
    #                 "InDetTrackParticlesAuxDyn.theta",
    #                 "InDetTrackParticlesAuxDyn.qOverP"
    #             ]
                
    #             for branch in required_branches:
    #                 if branch not in tree:
    #                     print(f"Missing required branch: {branch}")
    #                     return False
                
    #             # Try reading some data
    #             try:
    #                 data = tree["InDetTrackParticlesAuxDyn.d0"].array(library="np")
    #                 if len(data) == 0:
    #                     print(f"No events found in {file_path}")
    #                     return False
    #             except Exception as e:
    #                 print(f"Error reading data: {str(e)}")
    #                 return False
                
    #             return True
                
    #     except Exception as e:
    #         print(f"Error verifying file {file_path}: {str(e)}")
    #         return False
    
    # def get_stats(self) -> Dict:
    #     """
    #     Get statistics about downloaded catalog files.
        
    #     Returns:
    #         Dictionary containing catalog statistics
    #     """
    #     catalog_dir = self.base_dir / "catalogs"
    #     catalog_files = list(catalog_dir.glob("*.root"))
        
    #     total_events = 0
    #     total_tracks = 0
        
    #     # Calculate events and tracks from catalogs
    #     for catalog in catalog_files:
    #         try:
    #             with uproot.open(catalog) as file:
    #                 if "CollectionTree;1" in file:
    #                     tree = file["CollectionTree;1"]
    #                     total_events += len(tree)
                        
    #                     # Get track count from d0 branch
    #                     try:
    #                         tracks = tree["InDetTrackParticlesAuxDyn.d0"].array(library="np")
    #                         total_tracks += sum(len(t) for t in tracks)
    #                     except Exception:
    #                         pass
    #         except Exception as e:
    #             print(f"Warning: Could not read stats from {catalog}: {e}")
        
    #     stats = {
    #         "catalogs": len(catalog_files),
    #         "total_events": total_events,
    #         "total_tracks": total_tracks,
    #         "total_size": sum(f.stat().st_size for f in catalog_files) / (1024 * 1024 * 1024)  # in GB
    #     }
    #     return stats
    
    # def print_status(self):
    #     """Print current status of catalogs"""
    #     stats = self.get_stats()
    #     print(f"=== ATLAS Data Status ===")
    #     print(f"Number of catalogs: {stats['catalogs']}")
    #     print(f"Total events: {stats['total_events']:,}")
    #     print(f"Total tracks: {stats['total_tracks']:,}")
    #     print(f"Total data size: {stats['total_size']:.2f} GB")



In [7]:
class SelectionConfig:
    """Configuration for track and event selections"""
    def __init__(self,
                 # Required parameters
                 max_tracks_per_event: int,
                 min_tracks_per_event: int,
                 # Optional selection criteria
                 track_selections: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None,
                 event_selections: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None):
        # Required parameters
        self.max_tracks_per_event = max_tracks_per_event
        self.min_tracks_per_event = min_tracks_per_event
        
        # Optional selections
        self.track_selections = track_selections or {}
        self.event_selections = event_selections or {}
        
    def apply_track_selections(self, track_features: Dict[str, np.ndarray]) -> np.ndarray:
        """Apply selections to individual tracks"""
        # Start with all tracks
        mask = np.ones(len(next(iter(track_features.values()))), dtype=bool)
        
        # Map feature names to indices
        feature_map = {
            'pt': track_features['pt'],
            'eta': track_features['eta'],
            'phi': track_features['phi'],
            'd0': track_features['d0'],
            'z0': track_features['z0'],
            'chi2_per_ndof': track_features['chi2_per_ndof']
        }
        
        # Apply each selection
        for feature, criteria in self.track_selections.items():
            if feature not in feature_map:
                continue
                
            feature_values = feature_map[feature]
            if isinstance(criteria, tuple):
                min_val, max_val = criteria
                if min_val is not None:
                    mask &= (feature_values >= min_val)
                if max_val is not None:
                    mask &= (feature_values <= max_val)
            else:
                mask &= (feature_values >= criteria)
                
        return mask
    
    def apply_event_selections(self, event_features: Dict[str, np.ndarray]) -> bool:
        """Apply selections to entire events"""
        # First check minimum tracks requirement
        n_tracks = len(event_features.get('pt', []))
        if n_tracks < self.min_tracks_per_event:
            return False
            
        # Apply event-level selections
        for feature, criteria in self.event_selections.items():
            if feature not in event_features:
                continue
                
            value = event_features[feature]
            if isinstance(criteria, tuple):
                min_val, max_val = criteria
                if min_val is not None and value < min_val:
                    return False
                if max_val is not None and value > max_val:
                    return False
            else:
                if value < criteria:
                    return False
        
        return True

In [8]:
class IntegratedDatasetManager:
    """Manages dataset creation from ATLAS PHYSLITE data"""
    def __init__(self, 
                 run_numbers: List[str],
                 catalog_limit: Optional[int] = None,
                 base_dir: str = "atlas_data",
                 batch_size: int = 1000,
                 cache_size: Optional[int] = None,
                 shuffle_buffer: int = 10000,
                 validation_fraction: float = 0.2):
        self.run_numbers = run_numbers
        self.catalog_limit = catalog_limit
        self.batch_size = batch_size
        self.cache_size = cache_size
        self.shuffle_buffer = shuffle_buffer
        self.validation_fraction = validation_fraction
        self.atlas_manager = ATLASDataManager(base_dir)
        self.normalization_params = None
        
    # def _read_catalog_data(self, catalog_path: Path) -> Dict[str, np.ndarray]:
    #     """Read track data from a catalog file, maintaining event structure"""
    #     try:
    #         print(f"\nReading catalog: {catalog_path}")
    #         with uproot.open(catalog_path) as file:
    #             tree = file["CollectionTree;1"]
    #             events = []
                
    #             # Process in chunks to manage memory
    #             for arrays in tree.iterate([
    #                 "InDetTrackParticlesAuxDyn.d0",
    #                 "InDetTrackParticlesAuxDyn.z0",
    #                 "InDetTrackParticlesAuxDyn.phi",
    #                 "InDetTrackParticlesAuxDyn.theta",
    #                 "InDetTrackParticlesAuxDyn.qOverP",
    #                 "InDetTrackParticlesAuxDyn.chiSquared",
    #                 "InDetTrackParticlesAuxDyn.numberDoF"
    #             ], library="np"):
    #                 for evt_idx in range(len(arrays["InDetTrackParticlesAuxDyn.d0"])):
    #                     event_tracks = {
    #                         'd0': arrays["InDetTrackParticlesAuxDyn.d0"][evt_idx],
    #                         'z0': arrays["InDetTrackParticlesAuxDyn.z0"][evt_idx],
    #                         'phi': arrays["InDetTrackParticlesAuxDyn.phi"][evt_idx],
    #                         'theta': arrays["InDetTrackParticlesAuxDyn.theta"][evt_idx],
    #                         'qOverP': arrays["InDetTrackParticlesAuxDyn.qOverP"][evt_idx],
    #                         'chiSquared': arrays["InDetTrackParticlesAuxDyn.chiSquared"][evt_idx],
    #                         'numberDoF': arrays["InDetTrackParticlesAuxDyn.numberDoF"][evt_idx]
    #                     }
                        
    #                     # Calculate derived quantities for this event's tracks
    #                     pt, eta, chi2_per_ndof = self._calculate_track_quantities(event_tracks)
                        
    #                     # Select top N tracks by pT
    #                     track_features = np.column_stack([pt, eta, event_tracks['phi'], 
    #                                                     event_tracks['d0'], event_tracks['z0'], 
    #                                                     chi2_per_ndof])
                        
    #                     events.append(track_features)
                
    #             return events
            
    #     except Exception as e:
    #         print(f"Error reading catalog {catalog_path}: {str(e)}")
    #         return None
    
    def event_generator(self, selection_config: SelectionConfig):
        """
        Generator that yields processed events one at a time
        
        Args:
            selection_config: SelectionConfig object containing selection criteria
            
        Yields:
            Numpy array of shape (max_tracks_per_event, n_features)
        """
        for run_number in self.run_numbers:
            print(f"\nProcessing run {run_number}")
            catalog_idx = 0
            
            while True:
                try:
                    # Get or download catalog
                    catalog_path = self.atlas_manager.get_run_catalog_path(run_number, catalog_idx)
                    if not catalog_path.exists():
                        catalog_path = self.atlas_manager.download_run_catalog(run_number, catalog_idx)
                        if catalog_path is None:
                            break  # No more catalogs for this run
                    
                    print(f"Processing catalog {catalog_idx}")
                    
                    # Open file and process events
                    with uproot.open(catalog_path) as file:
                        tree = file["CollectionTree;1"]
                        
                        # Define all branches we need
                        branches = [
                            "InDetTrackParticlesAuxDyn.d0",
                            "InDetTrackParticlesAuxDyn.z0",
                            "InDetTrackParticlesAuxDyn.phi",
                            "InDetTrackParticlesAuxDyn.theta",
                            "InDetTrackParticlesAuxDyn.qOverP",
                            "InDetTrackParticlesAuxDyn.chiSquared",
                            "InDetTrackParticlesAuxDyn.numberDoF"
                        ]
                        
                        # Process in chunks to manage memory
                        for arrays in tree.iterate(branches, library="np", step_size=1000):
                            n_events = len(arrays["InDetTrackParticlesAuxDyn.d0"])
                            
                            for evt_idx in range(n_events):
                                # Extract raw event data - directly use the numpy array for this event
                                raw_event = {
                                    'd0': arrays["InDetTrackParticlesAuxDyn.d0"][evt_idx],  # Already a numpy array
                                    'z0': arrays["InDetTrackParticlesAuxDyn.z0"][evt_idx],
                                    'phi': arrays["InDetTrackParticlesAuxDyn.phi"][evt_idx],
                                    'theta': arrays["InDetTrackParticlesAuxDyn.theta"][evt_idx],
                                    'qOverP': arrays["InDetTrackParticlesAuxDyn.qOverP"][evt_idx],
                                    'chiSquared': arrays["InDetTrackParticlesAuxDyn.chiSquared"][evt_idx],
                                    'numberDoF': arrays["InDetTrackParticlesAuxDyn.numberDoF"][evt_idx]
                                }
                                
                                # Skip empty events (no tracks)
                                if len(raw_event['d0']) == 0:
                                    continue
                                    
                                processed_event = self._process_event(raw_event, selection_config)
                                if processed_event is not None:
                                    yield processed_event
                    
                    catalog_idx += 1

                    if self.catalog_limit and catalog_idx > self.catalog_limit:
                        break
                    
                except Exception as e:
                    print(f"Error processing catalog {catalog_idx} of run {run_number}: {str(e)}")
                    break
                    
                finally:
                    # Clean up catalog file
                    print("Not cleaning up catalog files for faster testing")
                    # if catalog_path and catalog_path.exists():
                    #     catalog_path.unlink()

    def _process_event(self, 
                    event_tracks: Dict[str, np.ndarray], 
                    selection_config: SelectionConfig) -> Optional[np.ndarray]:
        """Process a single event's tracks with selections"""
        # Calculate derived quantities
        track_features = {
            'pt': np.abs(1.0 / (event_tracks['qOverP'] * 1000)) * np.sin(event_tracks['theta']),
            'eta': -np.log(np.tan(event_tracks['theta'] / 2)),
            'phi': event_tracks['phi'],
            'd0': event_tracks['d0'],
            'z0': event_tracks['z0'],
            'chi2_per_ndof': event_tracks['chiSquared'] / event_tracks['numberDoF']
        }
        
        # Add event-level features
        event_features = {
            'n_total_tracks': len(track_features['pt']),
            'mean_pt': np.mean(track_features['pt']),
            'max_pt': np.max(track_features['pt'])
        }
        
        # Apply event-level selections first (cheaper)
        if not selection_config.apply_event_selections(event_features):
            return None
        
        # Apply track-level selections
        good_tracks_mask = selection_config.apply_track_selections(track_features)
        good_tracks = np.where(good_tracks_mask)[0]
        
        # Check if we still have enough tracks
        if len(good_tracks) < selection_config.min_tracks_per_event:
            return None
        
        # Sort by pT and take top N tracks
        track_pts = track_features['pt'][good_tracks]
        sorted_indices = np.argsort(track_pts)[::-1]
        top_tracks = good_tracks[sorted_indices[:selection_config.max_tracks_per_event]]
        
        # Create feature array
        features = np.column_stack([
            track_features['pt'][top_tracks],
            track_features['eta'][top_tracks],
            track_features['phi'][top_tracks],
            track_features['d0'][top_tracks],
            track_features['z0'][top_tracks],
            track_features['chi2_per_ndof'][top_tracks]
        ])
        
        # Pad if necessary
        if len(features) < selection_config.max_tracks_per_event:
            padding = np.zeros((selection_config.max_tracks_per_event - len(features), 6))
            features = np.vstack([features, padding])
        
        return features

    def create_streaming_dataset(self, 
                               selection_config: SelectionConfig,
                               compute_normalizing_stats: bool = True) -> tf.data.Dataset:
        """
        Create a streaming dataset from the event generator
        
        Args:
            selection_config: Optional dictionary of selection criteria
            compute_normalizing_stats: Whether to compute normalization parameters
                                     from a sample of the data
        """
        print("\nCreating streaming dataset")
        # Create dataset from generator
        dataset = tf.data.Dataset.from_generator(
            lambda: self.event_generator(selection_config),
            output_signature=tf.TensorSpec(
                shape=(selection_config.max_tracks_per_event, 6),
                dtype=tf.float32
            )
        )

        print("Normalizing datatset")
        
        # Compute normalization parameters if needed
        if compute_normalizing_stats and self.normalization_params is None:
            print("\nComputing normalization parameters from sample...")
            sample_size = 10000
            sample_events = []
            for event in self.event_generator(selection_config):
                sample_events.append(event)
                if len(sample_events) >= sample_size:
                    break
            
            if sample_events:
                sample_data = np.vstack(sample_events)
                self.normalization_params = {
                    'means': np.mean(sample_data, axis=0),
                    'stds': np.std(sample_data, axis=0)
                }
        
        # Add normalization if parameters are available
        if self.normalization_params is not None:
            dataset = dataset.map(
                lambda x: (x - self.normalization_params['means']) / self.normalization_params['stds']
            )
        
        # Add shuffling and batching
        if self.cache_size:
            dataset = dataset.cache()
        dataset = dataset.shuffle(self.shuffle_buffer)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset
    
    # def create_train_val_test_datasets(self, 
    #                                  selection_config: SelectionConfig,
    #                                  val_fraction: float = 0.1,
    #                                  test_fraction: float = 0.1) -> Tuple[tf.data.Dataset, 
    #                                                                     tf.data.Dataset, 
    #                                                                     tf.data.Dataset]:
    #     """Create training, validation, and test datasets"""
    #     # Create base dataset
    #     full_dataset = self.create_streaming_dataset(
    #         selection_config=selection_config,
    #         compute_normalizing_stats=True
    #     )

    #     print("Splitting dataset")
        
    #     # Calculate split sizes
    #     val_size = int(val_fraction * 100)
    #     test_size = int(test_fraction * 100)
        
    #     # Split the dataset
    #     test_dataset = full_dataset.take(test_size)
    #     remaining = full_dataset.skip(test_size)
    #     val_dataset = remaining.take(val_size)
    #     train_dataset = remaining.skip(val_size)
        
    #     return train_dataset, val_dataset, test_dataset
        
    # def create_datasets(self, 
    #                    max_tracks_per_event: int = 30,
    #                    min_tracks_per_event: int = 5,
    #                    track_selections: Optional[Dict] = None,
    #                    event_selections: Optional[Dict] = None):
    #     """Main entry point"""
    #     # Create SelectionConfig object
    #     selection_config = SelectionConfig(
    #         max_tracks_per_event=max_tracks_per_event,
    #         min_tracks_per_event=min_tracks_per_event,
    #         track_selections=track_selections,
    #         event_selections=event_selections
    #     )
        
    #     # Create streaming dataset
    #     return self.create_train_val_test_datasets(selection_config)
    
    # def _calculate_track_features(self, raw_event: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    #     """
    #     Calculate all track-level features from raw event data, safely handling edge cases
        
    #     Args:
    #         raw_event: Dictionary containing raw track parameters
            
    #     Returns:
    #         Dictionary of calculated track features
    #     """
    #     try:
    #         # Check for empty or invalid input
    #         if any(len(arr) == 0 for arr in raw_event.values()):
    #             return {
    #                 'pt': np.array([]),
    #                 'eta': np.array([]),
    #                 'phi': np.array([]),
    #                 'd0': np.array([]),
    #                 'z0': np.array([]),
    #                 'chi2_per_ndof': np.array([])
    #             }
                
    #         # Handle potential division by zero in qOverP
    #         qOverP = raw_event['qOverP']
    #         mask = qOverP != 0
    #         pt = np.zeros_like(qOverP)
    #         pt[mask] = np.abs(1.0 / (qOverP[mask] * 1000)) * np.sin(raw_event['theta'][mask])
            
    #         # Calculate eta safely
    #         theta = raw_event['theta']
    #         eta = np.zeros_like(theta)
    #         valid_theta = (theta > 0) & (theta < np.pi)  # Avoid theta = 0 or pi
    #         eta[valid_theta] = -np.log(np.tan(theta[valid_theta] / 2))
            
    #         # Calculate chi2/ndof safely
    #         ndof = raw_event['numberDoF']
    #         chi2 = raw_event['chiSquared']
    #         chi2_per_ndof = np.zeros_like(ndof)
    #         valid_ndof = ndof > 0
    #         chi2_per_ndof[valid_ndof] = chi2[valid_ndof] / ndof[valid_ndof]
            
    #         return {
    #             'pt': pt,
    #             'eta': eta,
    #             'phi': raw_event['phi'],
    #             'd0': raw_event['d0'],
    #             'z0': raw_event['z0'],
    #             'chi2_per_ndof': chi2_per_ndof
    #         }
            
    #     except Exception as e:
    #         print(f"Warning: Error calculating track features: {str(e)}")
    #         print(f"Raw event shapes: {[(k, v.shape) for k, v in raw_event.items()]}")
    #         # Return empty arrays
    #         return {
    #             'pt': np.array([]),
    #             'eta': np.array([]),
    #             'phi': np.array([]),
    #             'd0': np.array([]),
    #             'z0': np.array([]),
    #             'chi2_per_ndof': np.array([])
    #         }

    # def _calculate_event_features(self, track_features: Dict[str, np.ndarray]) -> Dict[str, Union[float, int]]:
    #     """
    #     Calculate event-level features from track features, safely handling empty events
        
    #     Args:
    #         track_features: Dictionary of track-level features
            
    #     Returns:
    #         Dictionary of event-level features
    #     """
    #     pt = track_features['pt']
    #     eta = track_features['eta']
        
    #     # Handle empty events safely
    #     if len(pt) == 0:
    #         return {
    #             'n_total_tracks': 0,
    #             'mean_pt': 0.0,
    #             'max_pt': 0.0,
    #             'min_pt': 0.0,
    #             'eta_spread': 0.0,
    #             'total_pt': 0.0,
    #             'n_high_pt_tracks': 0,
    #             'n_central_tracks': 0
    #         }
        
    #     # Calculate statistics for non-empty events
    #     try:
    #         return {
    #             'n_total_tracks': len(pt),
    #             'mean_pt': float(np.mean(pt)),
    #             'max_pt': float(np.max(pt)),
    #             'min_pt': float(np.min(pt)),
    #             'eta_spread': float(np.std(eta)) if len(eta) > 1 else 0.0,
    #             'total_pt': float(np.sum(pt)),
    #             'n_high_pt_tracks': int(np.sum(pt > 10.0)),
    #             'n_central_tracks': int(np.sum(np.abs(eta) < 1.5))
    #         }
    #     except Exception as e:
    #         print(f"Warning: Error calculating event features: {str(e)}")
    #         print(f"Track features shapes: {[(k, v.shape) for k, v in track_features.items()]}")
    #         # Return safe default values
    #         return {
    #             'n_total_tracks': len(pt),
    #             'mean_pt': 0.0,
    #             'max_pt': 0.0,
    #             'min_pt': 0.0,
    #             'eta_spread': 0.0,
    #             'total_pt': 0.0,
    #             'n_high_pt_tracks': 0,
    #             'n_central_tracks': 0
    #         }
    
    # def compute_normalization(self, features: np.ndarray) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
    #     """Compute normalization parameters from training data"""
    #     print("\nComputing normalization parameters...")
    #     means = np.mean(features, axis=0)
    #     stds = np.std(features, axis=0)
    #     self.normalization_params = {'means': means, 'stds': stds}
    #     return self.normalization_params
    
    # def apply_normalization(self, features: np.ndarray) -> np.ndarray:
    #     """Apply normalization using computed parameters"""
    #     if self.normalization_params is None:
    #         raise ValueError("Normalization parameters not computed. Run compute_normalization first.")
        
    #     normalized = (features - self.normalization_params['means']) / self.normalization_params['stds']
    #     return normalized
    
    # def split_dataset(self, 
    #                     features: np.ndarray,
    #                     train_fraction: float = 0.8,
    #                     validation_fraction: float = 0.1
    #                     ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    #     """Split data into training, validation, and test sets
        
    #     Args:
    #         features: Input feature array
    #         train_fraction: Fraction of data for training (default: 0.8)
    #         validation_fraction: Fraction of data for validation (default: 0.1)
            
    #     Returns:
    #         Tuple of (train_features, val_features, test_features)
    #     """
    #     print("\nSplitting into train/validation/test sets...")
    #     total_samples = len(features)
        
    #     # Calculate split sizes
    #     train_size = int(total_samples * train_fraction)
    #     val_size = int(total_samples * validation_fraction)
        
    #     # Shuffle indices
    #     indices = np.random.permutation(total_samples)
        
    #     # Split indices
    #     train_indices = indices[:train_size]
    #     val_indices = indices[train_size:train_size + val_size]
    #     test_indices = indices[train_size + val_size:]
        
    #     # Create splits
    #     train_features = features[train_indices]
    #     val_features = features[val_indices]
    #     test_features = features[test_indices]
        
    #     print(f"Training samples: {len(train_features)}")
    #     print(f"Validation samples: {len(val_features)}")
    #     print(f"Test samples: {len(test_features)}")
        
    #     return train_features, val_features, test_features

    
    def create_datasets(self, 
                    max_tracks_per_event: int = 30,
                    min_tracks_per_event: int = 5,
                    track_selections: Optional[Dict] = None,
                    event_selections: Optional[Dict] = None,
                    validation_fraction: float = 0.1,
                    test_fraction: float = 0.1) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
        """
        Create train/validation/test datasets with streaming and selections
        
        Args:
            max_tracks_per_event: Maximum number of tracks to keep per event
            min_tracks_per_event: Minimum number of tracks required per event
            track_selections: Dictionary of track-level selection criteria
            event_selections: Dictionary of event-level selection criteria
            validation_fraction: Fraction of data for validation
            test_fraction: Fraction of data for test set
        
        Returns:
            Tuple of (train_dataset, val_dataset, test_dataset)
        """

        print("\nInitial system state:")
        print_system_usage("Pre-processing: ")

        # Create selection configuration
        selection_config = SelectionConfig(
            max_tracks_per_event=max_tracks_per_event,
            min_tracks_per_event=min_tracks_per_event,
            track_selections=track_selections,
            event_selections=event_selections
        )
        
        print("Creating base dataset...")
        base_dataset = self.create_streaming_dataset(
            selection_config=selection_config,
            compute_normalizing_stats=True
        )

        print("\nSystem usage after dataset creation:")
        print_system_usage("Post-processing: ")

        # Debug print - check if base dataset has data
        print("Checking base dataset...")
        try:
            for i, batch in enumerate(base_dataset):
                print(f"Base dataset batch {i} shape: {batch.shape}")
                if i == 0:  # Just check first batch
                    break
        except Exception as e:
            print(f"Error checking base dataset: {e}")
        
        # Calculate split sizes (in number of batches)
        val_size = int(validation_fraction * 100)  # Take 100 batches as reference
        test_size = int(test_fraction * 100)
        
        # Split dataset
        print("\nSplitting datasets...")
        test_dataset = base_dataset.take(test_size)
        print("Created test dataset")
        
        remaining = base_dataset.skip(test_size)
        print("Created remaining dataset")
        
        val_dataset = remaining.take(val_size)
        print("Created validation dataset")
        
        train_dataset = remaining.skip(val_size)
        print("Created training dataset")
        
        # Cache each dataset (with unique identifiers)
        print("\nCaching datasets...")
        cache_dir = "/tmp/tf_cache"
        os.makedirs(cache_dir, exist_ok=True)
        test_dataset = test_dataset.cache(f"{cache_dir}/test_cache")
        val_dataset = val_dataset.cache(f"{cache_dir}/val_cache")
        train_dataset = train_dataset.cache(f"{cache_dir}/train_cache")

        # Force cache loading
        print("\nForcing cache load...")
        for dataset, name in [(train_dataset, "train"), 
                            (val_dataset, "validation"), 
                            (test_dataset, "test")]:
            print(f"Loading {name} cache...")
            try:
                for i, batch in enumerate(dataset):
                    if i == 0:
                        print(f"First {name} batch shape: {batch.shape}")
            except Exception as e:
                print(f"Error loading {name} cache: {e}")
        
        # Add shuffling and prefetching for training
        train_dataset = train_dataset.shuffle(
            buffer_size=self.shuffle_buffer,
            reshuffle_each_iteration=True
        )
        
        # Add prefetching to all datasets
        train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
        val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
        test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

        print("\nFinal system state after caching:")
        print_system_usage("Post-caching: ")
        
        # Print dataset sizes
        print("\nDataset splits:")
        print(f"Training batches: {100 - val_size - test_size}")
        print(f"Validation batches: {val_size}")
        print(f"Test batches: {test_size}")
        
        return train_dataset, val_dataset, test_dataset
    
    # def validate_data_quality(self, data) -> Dict:
        # """Run basic data quality checks on either tf.data.Dataset or numpy array
        
        # Args:
        #     data: Either a tf.data.Dataset or numpy array to validate
            
        # Returns:
        #     Dictionary containing validation results and statistics
        # """
        # feature_info = {
        #     'pt': {'range': (0, 5000)},
        #     'eta': {'range': (-5, 5)},
        #     'phi': {'range': (-np.pi, np.pi)}, 
        #     'd0': {'range': (-10, 10)},
        #     'z0': {'range': (-200, 200)},
        #     'chi2_per_ndof': {'range': (0, 100)}
        # }
        
        # validation_results = {
        #     'total_tracks': 0,
        #     'out_of_range_tracks': {name: 0 for name in feature_info},
        #     'null_values': {name: 0 for name in feature_info},
        #     'status': 'pass'
        # }

        # print(f"\nStarting data validation...")
        # print(f"Input type: {type(data)}")
        
        # # Handle both tf.data.Dataset and numpy array inputs
        # if isinstance(data, tf.data.Dataset):
        #     print("Validating TensorFlow dataset...")
        #     for batch in data:
        #         if isinstance(batch, tuple):  # (input, target) pairs
        #             batch = batch[0]  # Take input part
        #         batch_np = batch.numpy()
        #         validation_results['total_tracks'] += len(batch_np)
        #         self._validate_batch(batch_np, validation_results, feature_info)
        # else:  # Assuming numpy array
        #     print(f"Validating numpy array with shape: {data.shape}")
        #     validation_results['total_tracks'] = len(data)
        #     self._validate_batch(data, validation_results, feature_info)
            
        # # Set status based on validation results
        # for name in feature_info:
        #     if validation_results['null_values'][name] > 0:
        #         print(f"Found null values in {name}")
        #         validation_results['status'] = 'fail'
        #         break
                
        #     # Allow a small fraction (0.1%) of tracks to be out of range
        #     out_of_range_fraction = (validation_results['out_of_range_tracks'][name] / 
        #                         validation_results['total_tracks'])
        #     if out_of_range_fraction > 0.001:
        #         print(f"Too many out of range values in {name}: {out_of_range_fraction:.3%}")
        #         validation_results['status'] = 'warning'
        
        # # Add summary statistics
        # validation_results['summary'] = {
        #     'null_fraction': {
        #         name: count / validation_results['total_tracks']
        #         for name, count in validation_results['null_values'].items()
        #     },
        #     'out_of_range_fraction': {
        #         name: count / validation_results['total_tracks']
        #         for name, count in validation_results['out_of_range_tracks'].items()
        #     }
        # }
        
        # print(f"Validation complete. Status: {validation_results['status']}")
        # return validation_results

    # def _validate_batch(self, batch_data: np.ndarray, validation_results: Dict, feature_info: Dict):
        # """Helper method to validate a batch of data
        
        # Args:
        #     batch_data: Numpy array of shape (batch_size, n_features)
        #     validation_results: Dictionary to store validation results
        #     feature_info: Dictionary containing valid ranges for each feature
        # """
        # for i, (name, info) in enumerate(feature_info.items()):
        #     feature_data = batch_data[:, i]
            
        #     # Check for null/nan values
        #     null_mask = np.isnan(feature_data) | np.isinf(feature_data)
        #     n_nulls = np.sum(null_mask)
        #     validation_results['null_values'][name] += n_nulls
        #     if n_nulls > 0:
        #         print(f"Found {n_nulls} null/inf values in {name}")
            
        #     # Check value ranges
        #     if 'range' in info:
        #         min_val, max_val = info['range']
        #         range_mask = (feature_data < min_val) | (feature_data > max_val)
        #         n_out_of_range = np.sum(range_mask)
        #         validation_results['out_of_range_tracks'][name] += n_out_of_range
        #         if n_out_of_range > 0:
        #             print(f"Found {n_out_of_range} out-of-range values in {name}")
        
    # def get_data_status(self) -> Dict:
        # """
        # Get status of catalog files and processing configuration
        
        # Returns:
        #     Dictionary containing data and processing status
        # """
        # # Get basic stats from atlas manager
        # stats = self.atlas_manager.get_stats()
        
        # # Add processing configuration
        # stats.update({
        #     "runs": len(self.run_numbers),
        #     "cache_size": self.cache_size,
        #     "batch_size": self.batch_size,
        #     "shuffle_buffer": self.shuffle_buffer,
        #     "run_numbers": self.run_numbers
        # })
        
        # # Add average tracks per event
        # if stats['total_events'] > 0:
        #     stats['avg_tracks_per_event'] = stats['total_tracks'] / stats['total_events']
        
        # # Add data quality metrics if a dataset has been created
        # try:
        #     dataset = self.create_inner_track_dataset()
        #     validation_results = self.validate_data_quality(dataset)
        #     stats['data_quality'] = {
        #         'status': validation_results['status'],
        #         'null_fractions': validation_results['summary']['null_fraction'],
        #         'out_of_range_fractions': validation_results['summary']['out_of_range_fraction']
        #     }
        # except Exception as e:
        #     stats['data_quality'] = {'status': 'unknown', 'error': str(e)}
        
        # return stats
        
    # def get_feature_names(self) -> List[str]:
    #     """Get list of feature names in order"""
    #     return [
    #         'pT',
    #         'eta',
    #         'phi',
    #         'd0',
    #         'z0',
    #         'chi2_per_ndof'
    #     ]
    
    # def get_feature_info(self) -> Dict:
    #     """Get information about features"""
    #     return {
    #         'pT': {'units': 'GeV', 'range': (0, 5000), 'description': 'Transverse momentum'},
    #         'eta': {'units': None, 'range': (-5, 5), 'description': 'Pseudorapidity'},
    #         'phi': {'units': 'rad', 'range': (-np.pi, np.pi), 'description': 'Azimuthal angle'},
    #         'd0': {'units': 'mm', 'range': (-10, 10), 'description': 'Transverse impact parameter'},
    #         'z0': {'units': 'mm', 'range': (-200, 200), 'description': 'Longitudinal impact parameter'},
    #         'chi2_per_ndof': {'units': None, 'range': (0, 100), 'description': 'Track fit quality'}
    #     }

In [9]:
class BaseModel(ABC):
    """Base class for all models"""
    def __init__(self):
        self.model = None
        
    @abstractmethod
    def build(self, input_shape: tuple) -> None:
        """Build the model architecture"""
        pass
        
    @abstractmethod
    def get_config(self) -> dict:
        """Get model configuration"""
        pass
        
    def save(self, path: str) -> None:
        """Save model weights and config"""
        if self.model is None:
            raise ValueError("Model not built yet")
        self.model.save(path)
        
    def load(self, path: str) -> None:
        """Load model weights"""
        if self.model is None:
            raise ValueError("Model not built yet")
        self.model.load_weights(path)

# class AutoEncoder(BaseModel):
#     """Basic autoencoder with configurable architecture"""
#     def __init__(
#         self,
#         input_dim: int,
#         latent_dim: int,
#         encoder_layers: List[int],
#         decoder_layers: List[int],
#         quant_bits: Optional[int] = None,
#         activation: str = 'relu',
#         name: str = 'autoencoder'
#     ):
#         """
#         Initialize autoencoder model
        
#         Args:
#             input_dim: Dimension of input data
#             latent_dim: Dimension of latent space
#             encoder_layers: List of layer sizes for encoder
#             decoder_layers: List of layer sizes for decoder
#             quant_bits: Number of bits for quantization (optional)
#             activation: Activation function to use
#             name: Model name
#         """
#         super().__init__()
#         self.input_dim = input_dim
#         self.latent_dim = latent_dim
#         self.encoder_layers = encoder_layers
#         self.decoder_layers = decoder_layers
#         self.quant_bits = quant_bits
#         self.activation = activation
#         self.name = name

class AutoEncoder(BaseModel):
    def __init__(
        self,
        input_dim: int,
        latent_dim: int,
        encoder_layers: List[int],
        decoder_layers: List[int],
        quant_bits: Optional[int] = None,
        activation: str = 'relu',
        name: str = 'autoencoder'
    ):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.quant_bits = quant_bits
        self.activation = activation
        self.name = name
        
    def build(self, input_shape: tuple = None) -> None:
        """Build encoder and decoder networks"""
        if input_shape is None:
            input_shape = (self.input_dim,)
        
        # Input layer
        inputs = keras.Input(shape=input_shape, name='input_layer')
        
        # Create encoder layers
        encoder_layers = []
        for i, units in enumerate(self.encoder_layers):
            if self.quant_bits:
                # Dense layer
                dense = QDense(
                    units,
                    kernel_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                    bias_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                    name=f'encoder_dense_{i}'
                )
                encoder_layers.append(dense)
                
                # Activation layer
                activation = QActivation(
                    quantized_relu(self.quant_bits),
                    name=f'encoder_activation_{i}'
                )
                encoder_layers.append(activation)
                
                # Batch normalization
                batch_norm = keras.layers.BatchNormalization(
                    name=f'encoder_bn_{i}'
                )
                encoder_layers.append(batch_norm)
            else:
                dense = keras.layers.Dense(
                    units,
                    name=f'encoder_dense_{i}'
                )
                encoder_layers.append(dense)
                
                activation = keras.layers.Activation(
                    self.activation,
                    name=f'encoder_activation_{i}'
                )
                encoder_layers.append(activation)
                
                batch_norm = keras.layers.BatchNormalization(
                    name=f'encoder_bn_{i}'
                )
                encoder_layers.append(batch_norm)
        
        # Create latent layer
        if self.quant_bits:
            latent_layer = QDense(
                self.latent_dim,
                kernel_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                bias_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                name='latent_layer'
            )
        else:
            latent_layer = keras.layers.Dense(
                self.latent_dim,
                name='latent_layer'
            )
        
        # Create decoder layers
        decoder_layers = []
        for i, units in enumerate(self.decoder_layers):
            if self.quant_bits:
                # Dense layer
                dense = QDense(
                    units,
                    kernel_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                    bias_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                    name=f'decoder_dense_{i}'
                )
                decoder_layers.append(dense)
                
                # Activation layer
                activation = QActivation(
                    quantized_relu(self.quant_bits),
                    name=f'decoder_activation_{i}'
                )
                decoder_layers.append(activation)
                
                # Batch normalization
                batch_norm = keras.layers.BatchNormalization(
                    name=f'decoder_bn_{i}'
                )
                decoder_layers.append(batch_norm)
            else:
                dense = keras.layers.Dense(
                    units,
                    name=f'decoder_dense_{i}'
                )
                decoder_layers.append(dense)
                
                activation = keras.layers.Activation(
                    self.activation,
                    name=f'decoder_activation_{i}'
                )
                decoder_layers.append(activation)
                
                batch_norm = keras.layers.BatchNormalization(
                    name=f'decoder_bn_{i}'
                )
                decoder_layers.append(batch_norm)
        
        # Create output layer
        if self.quant_bits:
            output_layer = QDense(
                self.input_dim,
                kernel_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                bias_quantizer=quantized_bits(self.quant_bits, 1, alpha=1.0),
                name='output_layer'
            )
        else:
            output_layer = keras.layers.Dense(
                self.input_dim,
                name='output_layer'
            )
        
        # Build the model by applying layers sequentially
        # Encoder
        x = inputs
        for layer in encoder_layers:
            x = layer(x)
        
        # Latent space
        latent = latent_layer(x)
        
        # Decoder
        x = latent
        for layer in decoder_layers:
            x = layer(x)
        
        # Output
        outputs = output_layer(x)
        
        # Create model
        self.model = keras.Model(inputs=inputs, outputs=outputs, name=self.name)
        
        print("\nModel layer structure:")
        for layer in self.model.layers:
            print(f"Layer: {layer.name}, Type: {type(layer)}")
        
    def get_config(self) -> dict:
        return {
            "model_type": "autoencoder",
            "input_dim": self.input_dim,
            "latent_dim": self.latent_dim,
            "encoder_layers": self.encoder_layers,
            "decoder_layers": self.decoder_layers,
            "quant_bits": self.quant_bits,
            "activation": self.activation,
            "name": self.name
        }

class ModelFactory:
    """Factory for creating different model types"""
    @staticmethod
    def create_model(model_type: str, config: dict) -> BaseModel:
        if model_type == "autoencoder":
            config_copy = config.copy()
            config_copy.pop('model_type', None)  # Remove model_type key
            return AutoEncoder(
                input_dim=config['input_dim'],
                latent_dim=config['latent_dim'],
                encoder_layers=config['encoder_layers'],
                decoder_layers=config['decoder_layers'],
                quant_bits=config.get('quant_bits', None),
                activation=config.get('activation', 'relu'),
                name=config.get('name', 'autoencoder')
            )
        else:
            raise ValueError(f"Unknown model type: {model_type}")
            
    @staticmethod
    def from_config(config: dict) -> BaseModel:
        """Create model from config dictionary"""
        model_type = config.pop("model_type")
        return ModelFactory.create_model(model_type, config)


In [10]:
class ModelTrainer:
    """Handles model training and evaluation"""
    def __init__(
        self,
        model: BaseModel,
        training_config: dict,
        optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
        loss: Optional[tf.keras.losses.Loss] = None
    ):
        self.model = model
        self.config = training_config
        
        # Set up training parameters
        self.batch_size = training_config.get("batch_size", 32)
        self.epochs = training_config.get("epochs", 10)
        self.validation_split = training_config.get("validation_split", 0.2)
        
        # Set up optimizer and loss
        self.optimizer = optimizer or tf.keras.optimizers.Adam(
            learning_rate=training_config.get("learning_rate", 0.001)
        )
        self.loss = loss or tf.keras.losses.MeanSquaredError()
        
        # Training history
        self.history = None
        
    def compile_model(self):
        """Compile the model with optimizer and loss"""
        if self.model.model is None:
            raise ValueError("Model not built yet")
            
        self.model.model.compile(
            optimizer=self.optimizer,
            loss=self.loss,
            metrics=['mse'],
            run_eagerly=True 
        )
        
    def train(
        self,
        dataset: tf.data.Dataset,
        validation_data: Optional[tf.data.Dataset] = None,
        callbacks: List[tf.keras.callbacks.Callback] = None
    ) -> Dict[str, Any]:
        """Train the model"""
        if self.model.model is None:
            raise ValueError("Model not built yet")
        
        print("\nChecking datasets before training:")
        # Check if datasets have any data
        try:
            print("Checking training dataset...")
            for i, batch in enumerate(dataset):
                print(f"Training batch {i} shape: {batch.shape}")
                if i == 0:  # Just check first batch
                    break
        except Exception as e:
            print(f"Error checking training dataset: {e}")
        
        if validation_data is not None:
            try:
                print("\nChecking validation dataset...")
                for i, batch in enumerate(validation_data):
                    print(f"Validation batch {i} shape: {batch.shape}")
                    if i == 0:  # Just check first batch
                        break
            except Exception as e:
                print(f"Error checking validation dataset: {e}")
        
        # Compile model
        self.compile_model()
        
        # Setup callbacks
        if callbacks is None:
            callbacks = []
        
        # Train the model
        print("\nStarting model.fit...")
        self.history = self.model.model.fit(
            dataset,
            epochs=self.epochs,
            validation_data=validation_data,
            callbacks=callbacks,
            shuffle=True
        )
        
        return self.get_training_summary()
        
    def evaluate(
        self,
        dataset: tf.data.Dataset
    ) -> Dict[str, float]:
        """
        Evaluate the model
        
        Returns:
            Dictionary of evaluation metrics
        """
        if self.model.model is None:
            raise ValueError("Model not built yet")
            
        results = self.model.model.evaluate(dataset, return_dict=True)
        return results
        
    def get_training_summary(self) -> Dict[str, Any]:
        """Get summary of training results"""
        if self.history is None:
            raise ValueError("Model not trained yet")
            
        return {
            "training_config": self.config,
            "final_loss": float(self.history.history['loss'][-1]),
            "final_val_loss": float(self.history.history['val_loss'][-1]),
            "history": {
                metric: [float(val) for val in values]
                for metric, values in self.history.history.items()
            }
        }

In [11]:
def test_model_pipeline():
    """Test the complete model pipeline including factory, trainer, and registry"""

    print("\n" + "="*50)
    print("Starting Model Pipeline Test")
    print("="*50)
    print(f"TensorFlow: {tf.__version__} (Eager: {tf.executing_eagerly()})")
    
    try:
        # Helper function for JSON serialization
        def ensure_serializable(obj):
            """Recursively convert numpy types to Python native types"""
            if isinstance(obj, dict):
                return {key: ensure_serializable(value) for key, value in obj.items()}
            elif isinstance(obj, list):
                return [ensure_serializable(item) for item in obj]
            elif isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            return obj

        # 1. Initialize Registry
        print("Initializing model registry...")
        registry = ModelRegistry("experiments")
        
        # 2. Setup Data Pipeline
        print("Setting up data pipeline...")
        data_manager = IntegratedDatasetManager(
            run_numbers=["00296939", "00296942"],
            catalog_limit=1,
            batch_size=1000,
            cache_size=10000,
            shuffle_buffer=10000
        )

        # Define selections
        track_selections = {
            'eta': (-2.5, 2.5),
            'chi2_per_ndof': (0.0, 3.0),
            # 'pt': (1.0, None)  # Minimum pT of 1 GeV
        }

        event_selections = {
            # 'mean_pt': (2.0, None)  # Mean pT > 2 GeV
        }

        # Create datasets with new selection system
        train_dataset, val_dataset, test_dataset = data_manager.create_datasets(
            max_tracks_per_event=30,
            min_tracks_per_event=3,
            track_selections=track_selections,
            event_selections=event_selections,
            validation_fraction=0.15,
            test_fraction=0.15
        )

        # Update dataset config for registry
        dataset_config = ensure_serializable({
            "run_numbers": data_manager.run_numbers,
            "track_selections": track_selections,
            "event_selections": event_selections,
            "max_tracks_per_event": 30,
            "min_tracks_per_event": 3,
            "normalization_params": data_manager.normalization_params,
            "train_fraction": 0.7,
            "validation_fraction": 0.15,
            "test_fraction": 0.15,
            "batch_size": data_manager.batch_size,
            "shuffle_buffer": data_manager.shuffle_buffer,
            # "data_quality_metrics": data_manager.validate_data_quality(train_dataset)
        })

        # The model configuration also needs to be updated to reflect the fixed input size
        model_config_flat = {
            "model_type": "autoencoder",
            "input_dim": 30 * 6,
            "latent_dim": 3,
            "encoder_layers": [64, 32],
            "decoder_layers": [32, 64],
            "quant_bits": 8,
            "activation": "relu",
            "name": "track_autoencoder"
        }
    
        
        # Model config - nested version for registry
        model_config_registry = {
            "model_type": "autoencoder",
            "architecture": {
                "input_dim": model_config_flat["input_dim"],
                "latent_dim": model_config_flat["latent_dim"],
                "encoder_layers": model_config_flat["encoder_layers"],
                "decoder_layers": model_config_flat["decoder_layers"]
            },
            "hyperparameters": {
                "activation": model_config_flat["activation"],
                "quant_bits": model_config_flat["quant_bits"]
            }
        }
        
        # Training config
        training_config = {
            "batch_size": 1000,
            "epochs": 50,
            "learning_rate": 0.001,
            "early_stopping": {
                "patience": 3,
                "min_delta": 1e-4
            }
        }
        
        # 4. Register Experiment
        print("Registering experiment...")
        experiment_id = registry.register_experiment(
            name="autoencoder_test",
            dataset_config=dataset_config,
            model_config=model_config_registry,
            training_config=training_config,
            description="Testing autoencoder on track data with enhanced monitoring"
        )
        print(f"Created experiment: {experiment_id}")
        
        # 5. Create and Build Model
        print("Creating model...")
        try:
            model = ModelFactory.create_model(
                model_type="autoencoder",
                config=model_config_flat
            )
            model.build(input_shape=(6,))
        except Exception as e:
            print(f"Model creation failed: {str(e)}")
            print(f"Model config used: {json.dumps(model_config_flat, indent=2)}")
            raise
        
        # 6. Setup Training
        print("Setting up training...")
        trainer = ModelTrainer(
            model=model,
            training_config=training_config
        )
        
        # Setup callbacks
        class RegistryCallback(tf.keras.callbacks.Callback):
            def on_epoch_end(self, epoch, logs=None):
                logs = ensure_serializable(logs or {})
                registry.update_training_progress(
                    experiment_id=experiment_id,
                    epoch=epoch,
                    metrics=logs,
                    # hardware_metrics=get_hardware_metrics()
                )
                
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                patience=training_config["early_stopping"]["patience"],
                min_delta=training_config["early_stopping"]["min_delta"],
                restore_best_weights=True
            ),
            RegistryCallback()
        ]
        
        # 7. Train Model
        print("Starting training...")
        training_start_time = datetime.now()  # Record start time
        training_results = trainer.train(
            dataset=train_dataset,
            validation_data=val_dataset,
            callbacks=callbacks
        )
        training_end_time = datetime.now()  # Record end time
        training_duration = (training_end_time - training_start_time).total_seconds()
        
        # 8. Evaluate Model
        print("Evaluating model...")
        test_results = trainer.evaluate(test_dataset)
        
        # Record final results
        registry.complete_training(
            experiment_id=experiment_id,
            final_metrics=ensure_serializable({
                **training_results,
                **test_results,
                'test_loss': test_results['loss'],
                'training_duration': training_duration
            })
        )
        
        # 9. Save Model
        print("Saving model checkpoint...")
        registry.save_checkpoint(
            experiment_id=experiment_id,
            models={"autoencoder": model.model},
            checkpoint_name="final",
            metadata=ensure_serializable({
                "test_loss": test_results['loss'],
                "final_train_loss": training_results['final_loss']
            })
        )
        
        # 10. Display Results
        print("\n" + "="*50)
        print("Experiment Results")
        print("="*50)

        details = registry.get_experiment_details(experiment_id)
        performance = registry.get_performance_summary(experiment_id)

        print(f"\nExperiment ID: {experiment_id}")
        print(f"Status: {details['experiment_info']['status']}")

        # Handle potential None values for duration
        duration = performance.get('training_duration')
        if duration is not None:
            print(f"Training Duration: {duration:.2f}s")
        else:
            print("Training Duration: Not available")

        print(f"Epochs Completed: {performance['epochs_completed']}")

        print("\nMetrics:")
        def print_metrics(metrics, indent=2):
            """Helper function to print metrics with proper formatting"""
            for key, value in metrics.items():
                indent_str = " " * indent
                if isinstance(value, dict):
                    print(f"{indent_str}{key}:")
                    print_metrics(value, indent + 2)
                elif isinstance(value, (float, int)):
                    print(f"{indent_str}{key}: {value:.6f}")
                else:
                    print(f"{indent_str}{key}: {value}")

        # Print metrics using the helper function
        print_metrics(performance['final_metrics'])
        
        # 11. Visualize Results
        if True:  # Change to control visualization
            plt.figure(figsize=(12, 6))
            history = performance['metric_progression']
            plt.plot(history['loss'], label='Training Loss')
            plt.plot(history['val_loss'], label='Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training History')
            plt.legend()
            plt.grid(True)
            plt.show()
        
        print("Pipeline test completed successfully")
        return True
        
    except Exception as e:
        print(f"Pipeline test failed: {type(e).__name__}: {str(e)}")
        print(f"Error context:")
        raise

def get_gpu_memory_usage():
    """
    Get GPU memory usage as a percentage.
    Returns a dictionary with memory usage for each GPU.
    
    Returns:
        Dict: GPU index to memory usage percentage
        or None if no GPU is available
    """
    try:
        # First try using TensorFlow's device API
        gpus = tf.config.list_physical_devices('GPU')
        if not gpus:
            return None
            
        memory_usage = {}
        for gpu_id, gpu in enumerate(gpus):
            # Get memory info for this GPU
            memory_info = tf.config.experimental.get_memory_info(f'GPU:{gpu_id}')
            
            # Calculate percentage (note: this might not be available on all systems)
            if hasattr(memory_info, 'peak') and hasattr(memory_info, 'total'):
                memory_usage[f'gpu_{gpu_id}'] = (memory_info.peak / memory_info.total) * 100
                
        return memory_usage if memory_usage else None
        
    except:
        try:
            # Fallback to nvidia-smi command
            import subprocess
            import re
            
            output = subprocess.check_output(
                ['nvidia-smi', '--query-gpu=index,memory.used,memory.total', '--format=csv,nounits,noheader'],
                encoding='utf-8'
            )
            
            memory_usage = {}
            for line in output.strip().split('\n'):
                gpu_id, memory_used, memory_total = map(int, line.split(','))
                memory_percentage = (memory_used / memory_total) * 100
                memory_usage[f'gpu_{gpu_id}'] = memory_percentage
                
            return memory_usage
            
        except:
            # If both methods fail, return None
            return None


In [12]:
# Modify code to specify which catalogs to run for each run
# Put each catalog code in a try statement and if problem continue to next catalog
# Note: processing 65 total catalogs over two runs took 106 minutes, about 2 minutes per catalog


try:
    success = test_model_pipeline()
    if success:
        print("\nAll tests passed successfully!")
except Exception as e:
    print("\nTest failed with error:")
    print(str(e))


Starting Model Pipeline Test
TensorFlow: 2.13.1 (Eager: True)
Initializing model registry...
Setting up data pipeline...

Initial system state:

Pre-processing: System Usage:
Memory: 90.5GB / 376.2GB (26.7%)
Available Memory: 275.8GB
CPU Usage: 23.4%
Creating base dataset...

Creating streaming dataset
Normalizing datatset

Computing normalization parameters from sample...

Processing run 00296939
Processing catalog 0
Not cleaning up catalog files for faster testing
Processing catalog 1
Not cleaning up catalog files for faster testing

Processing run 00296942
Processing catalog 0
Not cleaning up catalog files for faster testing
Processing catalog 1
Not cleaning up catalog files for faster testing

System usage after dataset creation:

Post-processing: System Usage:
Memory: 90.6GB / 376.2GB (26.7%)
Available Memory: 275.8GB
CPU Usage: 23.8%
Checking base dataset...

Processing run 00296939
Processing catalog 0
Not cleaning up catalog files for faster testing
Processing catalog 1
Not cl