In [None]:
# 🔧 Environment Setup v1
import os
import sys
from pathlib import Path

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("🔍 Running in Google Colab")
    
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Project paths
    DRIVE_PROJECT_PATH = "/content/drive/MyDrive/abr_project"
    LOCAL_PROJECT_PATH = "/content/abr_project"
    
except ImportError:
    IN_COLAB = False
    print("🔍 Running locally")
    LOCAL_PROJECT_PATH = os.getcwd()
    DRIVE_PROJECT_PATH = None

print(f"✅ Environment detected: {'Colab' if IN_COLAB else 'Local'}")


In [None]:
# 📁 Project Setup - Find project in Colab environment
if IN_COLAB:
    # Check multiple possible locations for the cloned project
    possible_paths = [
        "/content/abr-project-v1",  # Common clone name
        "/content/abr_project",     # Alternative name
        "/content/abr-cvae-project", # Another alternative
        DRIVE_PROJECT_PATH,         # Google Drive location
    ]
    
    # Also check for any directory containing train.py
    content_dirs = [d for d in os.listdir('/content') if os.path.isdir(f'/content/{d}')]
    for dir_name in content_dirs:
        dir_path = f'/content/{dir_name}'
        if os.path.exists(f'{dir_path}/train.py'):
            possible_paths.insert(0, dir_path)  # Add to front of list
    
    project_found = False
    for path in possible_paths:
        if os.path.exists(f"{path}/train.py"):
            LOCAL_PROJECT_PATH = path
            project_found = True
            print(f"✅ Project found at: {LOCAL_PROJECT_PATH}")
            break
    
    if not project_found:
        print("❌ Project not found in any expected location!")
        print("🔍 Searched locations:")
        for path in possible_paths:
            print(f"  - {path}")
        print("\n💡 Solutions:")
        print("1. Clone your repository: !git clone <your-repo-url>")
        print("2. Upload to Google Drive at: /MyDrive/abr_project")
        print("📋 Required files: train.py, evaluate.py, src/, configs/, data/")
        
        # List current content directory to help debug
        print(f"\n📁 Current /content directory contents:")
        for item in os.listdir('/content'):
            item_path = f'/content/{item}'
            if os.path.isdir(item_path):
                print(f"  📂 {item}/")
                # Check if it might be the project
                if os.path.exists(f'{item_path}/train.py'):
                    print(f"    ✅ Contains train.py - This might be your project!")
            else:
                print(f"  📄 {item}")
        
        raise FileNotFoundError("Project files not found")
    
    # Set working directory
    os.chdir(LOCAL_PROJECT_PATH)
    sys.path.insert(0, LOCAL_PROJECT_PATH)
    
else:
    sys.path.insert(0, LOCAL_PROJECT_PATH)

print(f"📂 Working directory: {os.getcwd()}")
print("✅ Project setup complete!")


In [None]:
# 📦 Install Dependencies
if IN_COLAB:
    print("📦 Installing dependencies for Colab...")
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    !pip install PyYAML scipy scikit-learn matplotlib seaborn tqdm tensorboard openpyxl
else:
    print("📦 Installing from requirements.txt...")
    if os.path.exists("requirements.txt"):
        !pip install -r requirements.txt

print("✅ Dependencies installed!")


In [None]:
# 🚀 Import Libraries and Setup
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
import json
import subprocess
import time
import re
from datetime import datetime
from IPython.display import clear_output
import warnings
warnings.filterwarnings('ignore')

# GPU setup and optimization
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")

if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
    print(f"📊 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # GPU optimizations
    torch.backends.cudnn.benchmark = True  # Optimize for consistent input sizes
    torch.backends.cudnn.deterministic = False  # Allow non-deterministic for speed
    
    # Clear GPU cache
    torch.cuda.empty_cache()
    
    # Check GPU memory usage
    allocated = torch.cuda.memory_allocated(0) / 1e9
    cached = torch.cuda.memory_reserved(0) / 1e9
    print(f"📈 GPU Memory - Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
    
    # Set memory fraction to avoid OOM
    torch.cuda.set_per_process_memory_fraction(0.9)
    
    print("✅ GPU optimizations applied!")
else:
    print("⚠️ CUDA not available! Training will be very slow on CPU.")
    print("💡 Make sure to enable GPU in Colab: Runtime > Change runtime type > GPU")

print("✅ Libraries imported successfully!")


In [None]:
# 📊 Enhanced Training Monitor with Real-time Epoch Progress
from tqdm.notebook import tqdm

class TrainingMonitor:
    def __init__(self, total_epochs=None):
        self.metrics = {
            'train_loss': [], 'val_loss': [], 'kl_loss': [], 'recon_loss': [], 
            'beta': [], 'epochs': []
        }
        self.current_epoch = 0
        self.total_epochs = total_epochs
        self.best_val_loss = float('inf')
        self.start_time = time.time()
        
        # Current epoch metrics
        self.current_train_loss = None
        self.current_val_loss = None
        self.current_kl_loss = None
        self.current_recon_loss = None
        self.current_beta = None
        
        # Track what we've already printed to avoid duplicates
        self.last_printed_epoch = 0
        self.epoch_results_printed = False
        
        # Progress bars
        self.main_pbar = None  # Overall training progress
        self.epoch_pbar = None  # Current epoch progress
        
        # Initialize main progress bar
        if total_epochs:
            self.main_pbar = tqdm(total=total_epochs, desc="🚀 Training", 
                                position=0, leave=True,
                                bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}] Best: {postfix}')
        else:
            self.main_pbar = tqdm(desc="🚀 Training", position=0, leave=True,
                                bar_format='{desc}: {n} epochs [{elapsed}] Best: {postfix}')
        
    def parse_log_line(self, line):
        try:
            # Parse epoch information
            if 'Epoch' in line and ('/' in line or ':' in line):
                epoch_match = re.search(r'Epoch (\d+)', line)
                if epoch_match:
                    new_epoch = int(epoch_match.group(1))
                    if new_epoch > self.current_epoch:
                        # Close previous epoch progress bar
                        if self.epoch_pbar:
                            self.epoch_pbar.close()
                        
                        self.current_epoch = new_epoch
                        self.main_pbar.n = self.current_epoch
                        self.epoch_results_printed = False  # Reset for new epoch
                        
                        # Create new epoch progress bar
                        self.epoch_pbar = tqdm(total=100, desc=f"📊 Epoch {self.current_epoch}", 
                                             position=1, leave=False,
                                             bar_format='{desc}: {percentage:3.0f}%|{bar}| {postfix}')
                        
                        # Reset current metrics for new epoch
                        self.current_train_loss = None
                        self.current_val_loss = None
                        self.current_kl_loss = None
                        self.current_recon_loss = None
                        self.current_beta = None
                    
            # Parse training loss
            train_patterns = [r'Train Loss: ([\d.]+)', r'Training Loss: ([\d.]+)', r'train_loss: ([\d.]+)']
            for pattern in train_patterns:
                match = re.search(pattern, line)
                if match:
                    self.current_train_loss = float(match.group(1))
                    self.metrics['train_loss'].append(self.current_train_loss)
                    # Update epoch progress bar to ~50% when training loss is available
                    if self.epoch_pbar:
                        self.epoch_pbar.n = 50
                        self.update_epoch_progress()
                    break
                    
            # Parse validation loss
            val_patterns = [r'Val Loss: ([\d.]+)', r'Validation Loss: ([\d.]+)', r'val_loss: ([\d.]+)']
            for pattern in val_patterns:
                match = re.search(pattern, line)
                if match:
                    self.current_val_loss = float(match.group(1))
                    self.metrics['val_loss'].append(self.current_val_loss)
                    if self.current_val_loss < self.best_val_loss:
                        self.best_val_loss = self.current_val_loss
                    
                    # Complete epoch progress bar when validation loss is available
                    if self.epoch_pbar:
                        self.epoch_pbar.n = 100
                        self.update_epoch_progress()
                        # Close epoch bar after a brief moment
                        self.epoch_pbar.close()
                        self.epoch_pbar = None
                    
                    # Update main progress bar
                    self.update_main_progress_bar()
                    
                    # Print summary if new best (only once per epoch)
                    if not self.epoch_results_printed and self.current_epoch > self.last_printed_epoch:
                        if self.current_val_loss == self.best_val_loss:
                            tqdm.write(f"   ⭐ NEW BEST! Val Loss: {self.best_val_loss:.4f}")
                        self.epoch_results_printed = True
                        self.last_printed_epoch = self.current_epoch
                    break
                    
            # Parse KL loss
            kl_patterns = [r'KL Loss: ([\d.]+)', r'KL: ([\d.]+)', r'kl_loss: ([\d.]+)']
            for pattern in kl_patterns:
                match = re.search(pattern, line)
                if match:
                    self.current_kl_loss = float(match.group(1))
                    self.metrics['kl_loss'].append(self.current_kl_loss)
                    # Update epoch progress when KL loss is available
                    if self.epoch_pbar:
                        self.update_epoch_progress()
                    break
                    
            # Parse reconstruction loss
            recon_patterns = [r'Recon Loss: ([\d.]+)', r'Reconstruction Loss: ([\d.]+)', r'recon_loss: ([\d.]+)', r'Recon: ([\d.]+)']
            for pattern in recon_patterns:
                match = re.search(pattern, line)
                if match:
                    self.current_recon_loss = float(match.group(1))
                    self.metrics['recon_loss'].append(self.current_recon_loss)
                    # Update epoch progress when recon loss is available
                    if self.epoch_pbar:
                        self.update_epoch_progress()
                    break
                    
            # Parse beta
            beta_patterns = [r'Beta: ([\d.]+)', r'beta: ([\d.]+)']
            for pattern in beta_patterns:
                match = re.search(pattern, line)
                if match:
                    self.current_beta = float(match.group(1))
                    self.metrics['beta'].append(self.current_beta)
                    # Update epoch progress when beta is available
                    if self.epoch_pbar:
                        self.update_epoch_progress()
                    break
                    
        except:
            pass
    
    def update_epoch_progress(self):
        """Update current epoch progress bar with real-time metrics"""
        if not self.epoch_pbar:
            return
            
        postfix_parts = []
        if self.current_train_loss is not None:
            postfix_parts.append(f"Train: {self.current_train_loss:.4f}")
        if self.current_val_loss is not None:
            postfix_parts.append(f"Val: {self.current_val_loss:.4f}")
        if self.current_kl_loss is not None:
            postfix_parts.append(f"KL: {self.current_kl_loss:.4f}")
        if self.current_recon_loss is not None:
            postfix_parts.append(f"Recon: {self.current_recon_loss:.4f}")
        if self.current_beta is not None:
            postfix_parts.append(f"Beta: {self.current_beta:.4f}")
        
        postfix_str = " | ".join(postfix_parts) if postfix_parts else "Processing..."
        self.epoch_pbar.set_postfix_str(postfix_str)
        self.epoch_pbar.refresh()
    
    def update_main_progress_bar(self):
        """Update main training progress bar"""
        if not self.main_pbar:
            return
            
        if self.best_val_loss != float('inf'):
            self.main_pbar.set_postfix_str(f"{self.best_val_loss:.4f}")
        self.main_pbar.refresh()
    
    def close(self):
        """Close all progress bars"""
        if self.epoch_pbar:
            self.epoch_pbar.close()
            self.epoch_pbar = None
        if self.main_pbar:
            self.main_pbar.close()
            self.main_pbar = None
    
    def plot_progress(self):
        """Plot training progress"""
        if not self.metrics['train_loss']:
            return
            
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle(f'Training Progress - Epoch {self.current_epoch}', fontsize=16)
        
        # Loss curves
        if self.metrics['train_loss'] and self.metrics['val_loss']:
            epochs = range(1, len(self.metrics['train_loss']) + 1)
            axes[0, 0].plot(epochs, self.metrics['train_loss'], 'b-', label='Train Loss', linewidth=2)
            axes[0, 0].plot(epochs, self.metrics['val_loss'], 'r-', label='Val Loss', linewidth=2)
            axes[0, 0].set_title('Training & Validation Loss')
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)
        
        # KL and Reconstruction Loss
        if self.metrics['kl_loss'] and self.metrics['recon_loss']:
            epochs = range(1, len(self.metrics['kl_loss']) + 1)
            axes[0, 1].plot(epochs, self.metrics['kl_loss'], 'g-', label='KL Loss', linewidth=2)
            axes[0, 1].plot(epochs, self.metrics['recon_loss'], 'orange', label='Recon Loss', linewidth=2)
            axes[0, 1].set_title('KL & Reconstruction Loss')
            axes[0, 1].set_xlabel('Epoch')
            axes[0, 1].set_ylabel('Loss')
            axes[0, 1].legend()
            axes[0, 1].grid(True, alpha=0.3)
        
        # Beta annealing
        if self.metrics['beta']:
            epochs = range(1, len(self.metrics['beta']) + 1)
            axes[1, 0].plot(epochs, self.metrics['beta'], 'purple', label='Beta', linewidth=2)
            axes[1, 0].set_title('Beta Annealing')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Beta Value')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
        
        # Summary stats
        axes[1, 1].axis('off')
        elapsed = time.time() - self.start_time
        hours, remainder = divmod(elapsed, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        stats_text = f"""Training Statistics:

Current Epoch: {self.current_epoch}
Total Epochs: {self.total_epochs or 'Unknown'}
Best Val Loss: {self.best_val_loss:.4f}
Training Time: {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}

Latest Metrics:
Train Loss: {self.current_train_loss:.4f if self.current_train_loss else 'N/A'}
Val Loss: {self.current_val_loss:.4f if self.current_val_loss else 'N/A'}
KL Loss: {self.current_kl_loss:.4f if self.current_kl_loss else 'N/A'}
Recon Loss: {self.current_recon_loss:.4f if self.current_recon_loss else 'N/A'}
Beta: {self.current_beta:.4f if self.current_beta else 'N/A'}
        """
        axes[1, 1].text(0.1, 0.5, stats_text, fontsize=11, verticalalignment='center',
                bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
        
        plt.tight_layout()
        plt.show()

print("✅ Enhanced training monitor with real-time updates ready!")


In [None]:
# 🎮 GPU Monitoring and Optimization Functions
def check_gpu_status():
    """Check current GPU status and usage"""
    if not torch.cuda.is_available():
        print("❌ CUDA not available!")
        print("💡 Enable GPU: Runtime > Change runtime type > Hardware accelerator > GPU")
        return False
    
    print("🎮 GPU Status:")
    print(f"  Device: {torch.cuda.get_device_name(0)}")
    print(f"  CUDA Version: {torch.version.cuda}")
    print(f"  PyTorch CUDA: {torch.backends.cudnn.version()}")
    
    # Memory info
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    allocated = torch.cuda.memory_allocated(0) / 1e9
    cached = torch.cuda.memory_reserved(0) / 1e9
    free = total_memory - allocated
    
    print(f"  Total Memory: {total_memory:.1f} GB")
    print(f"  Allocated: {allocated:.2f} GB ({allocated/total_memory*100:.1f}%)")
    print(f"  Cached: {cached:.2f} GB")
    print(f"  Free: {free:.2f} GB")
    
    # Performance check
    print("\n🔥 Performance Test:")
    start_time = time.time()
    x = torch.randn(1000, 1000, device=device)
    y = torch.mm(x, x)
    torch.cuda.synchronize()
    gpu_time = time.time() - start_time
    print(f"  GPU Matrix Multiply (1000x1000): {gpu_time:.3f}s")
    
    if gpu_time > 0.1:
        print("⚠️ GPU seems slow. Check if GPU is properly enabled.")
    else:
        print("✅ GPU performance looks good!")
    
    return True

def optimize_for_colab():
    """Apply Colab-specific optimizations"""
    if not torch.cuda.is_available():
        return
    
    print("🚀 Applying Colab GPU optimizations...")
    
    # Clear cache
    torch.cuda.empty_cache()
    
    # Set optimal settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    
    # Memory management
    torch.cuda.set_per_process_memory_fraction(0.9)
    
    # Check if we have enough memory for large batches
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    if total_memory >= 15:
        recommended_batch = 32
    elif total_memory >= 8:
        recommended_batch = 16
    else:
        recommended_batch = 8
    
    print(f"✅ Optimizations applied!")
    print(f"💡 Recommended batch size: {recommended_batch}")
    
    return recommended_batch

def monitor_gpu_during_training():
    """Monitor GPU usage during training"""
    if not torch.cuda.is_available():
        return
    
    allocated = torch.cuda.memory_allocated(0) / 1e9
    cached = torch.cuda.memory_reserved(0) / 1e9
    total = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    return {
        'allocated_gb': allocated,
        'cached_gb': cached,
        'total_gb': total,
        'utilization_pct': (allocated / total) * 100
    }

print("✅ GPU monitoring functions ready!")
print("📋 Available functions:")
print("  - check_gpu_status(): Check GPU availability and performance")
print("  - optimize_for_colab(): Apply Colab-specific optimizations")
print("  - monitor_gpu_during_training(): Monitor GPU usage")


In [None]:
# 🔄 Complete Data Preprocessing Function
def run_preprocessing():
    """Complete data preprocessing from Excel file to processed pickle files"""
    
    print("🔄 Starting complete data preprocessing pipeline...")
    
    # Check if preprocessed data already exists
    processed_data_paths = [
        "data/processed/abr_processed_data.pkl",
        "data/processed/abr_data_preprocessed.pkl",
        "data/abr_processed_data.pkl"
    ]
    
    for path in processed_data_paths:
        if os.path.exists(path):
            print(f"✅ Preprocessed data found at: {path}")
            return True
    
    print("📊 No preprocessed data found. Starting from Excel file...")
    
    # Look for Excel file in multiple locations
    excel_paths = [
        # Google Drive locations
        f"{DRIVE_PROJECT_PATH}/data/abr_data_preprocessed.xlsx" if DRIVE_PROJECT_PATH else None,
        f"{DRIVE_PROJECT_PATH}/abr_data_preprocessed.xlsx" if DRIVE_PROJECT_PATH else None,
        "/content/drive/MyDrive/abr_data_preprocessed.xlsx",
        "/content/drive/MyDrive/data/abr_data_preprocessed.xlsx",
        # Local locations
        "data/abr_data_preprocessed.xlsx",
        "data/raw/abr_data.xlsx",
        "abr_data_preprocessed.xlsx"
    ]
    
    excel_file = None
    for path in excel_paths:
        if path and os.path.exists(path):
            excel_file = path
            print(f"📊 Found Excel file at: {excel_file}")
            break
    
    if not excel_file:
        print("❌ Excel file not found!")
        print("🔍 Searched locations:")
        for path in excel_paths:
            if path:
                print(f"  - {path}")
        print("\n💡 Please upload your Excel file to one of these locations:")
        print("  - /content/drive/MyDrive/abr_data_preprocessed.xlsx")
        print("  - /content/drive/MyDrive/data/abr_data_preprocessed.xlsx")
        return False
    
    try:
        # Create necessary directories
        os.makedirs("data/processed", exist_ok=True)
        os.makedirs("data/raw", exist_ok=True)
        
        print("📖 Loading Excel file...")
        # Load the Excel file
        df = pd.read_excel(excel_file)
        print(f"✅ Loaded Excel file with shape: {df.shape}")
        print(f"📋 Columns: {list(df.columns)}")
        
        # Copy Excel file to local data directory if not already there
        local_excel_path = "data/abr_data_preprocessed.xlsx"
        if not os.path.exists(local_excel_path):
            import shutil
            shutil.copy2(excel_file, local_excel_path)
            print(f"📁 Copied Excel file to: {local_excel_path}")
        
        # Try to use existing preprocessing script first
        preprocess_scripts = ["preprocess.py", "process_data.py", "src/preprocess.py"]
        script_found = None
        
        for script in preprocess_scripts:
            if os.path.exists(script):
                script_found = script
                break
        
        if script_found:
            print(f"🔄 Using existing preprocessing script: {script_found}")
            cmd = [sys.executable, script_found]
            
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
            
            if result.returncode == 0:
                print("✅ Preprocessing script completed successfully!")
                
                # Check if output was created
                for path in processed_data_paths:
                    if os.path.exists(path):
                        print(f"✅ Preprocessed data created at: {path}")
                        return True
                        
                print("⚠️ Script completed but no output found. Falling back to built-in preprocessing...")
            else:
                print("⚠️ Preprocessing script failed. Using built-in preprocessing...")
                if result.stderr:
                    print(f"Script error: {result.stderr}")
        
        # Built-in preprocessing pipeline
        print("🔄 Running built-in preprocessing pipeline...")
        
        # Basic preprocessing steps
        print("🧹 Cleaning data...")
        
        # Remove any completely empty rows/columns
        df = df.dropna(how='all').dropna(axis=1, how='all')
        
        # Basic data validation
        print(f"📊 Data shape after cleaning: {df.shape}")
        
        # Look for waveform data columns (typically numeric columns)
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        print(f"📈 Found {len(numeric_cols)} numeric columns (potential waveform data)")
        
        # Look for clinical data columns (typically contain keywords)
        clinical_keywords = ['latency', 'amplitude', 'threshold', 'wave', 'peak']
        clinical_cols = []
        for col in df.columns:
            if any(keyword.lower() in col.lower() for keyword in clinical_keywords):
                clinical_cols.append(col)
        
        print(f"🏥 Found {len(clinical_cols)} potential clinical columns: {clinical_cols}")
        
        # Create a basic processed dataset structure
        processed_data = {
            'raw_data': df,
            'waveform_columns': numeric_cols,
            'clinical_columns': clinical_cols,
            'metadata': {
                'original_shape': df.shape,
                'processing_date': datetime.now().isoformat(),
                'source_file': excel_file
            }
        }
        
        # Save processed data
        output_path = "data/processed/abr_processed_data.pkl"
        
        print(f"💾 Saving processed data to: {output_path}")
        import pickle
        with open(output_path, 'wb') as f:
            pickle.dump(processed_data, f)
        
        print("✅ Built-in preprocessing completed successfully!")
        print(f"📊 Processed data saved with {len(df)} samples")
        print(f"📈 Waveform columns: {len(numeric_cols)}")
        print(f"🏥 Clinical columns: {len(clinical_cols)}")
        
        return True
        
    except Exception as e:
        print(f"❌ Preprocessing failed with error: {e}")
        import traceback
        print("🔍 Full error traceback:")
        traceback.print_exc()
        return False

def check_data_status():
    """Check the current status of data files"""
    print("📊 Data Status Check:")
    print("=" * 50)
    
    # Check for Excel files
    excel_paths = [
        f"{DRIVE_PROJECT_PATH}/data/abr_data_preprocessed.xlsx" if DRIVE_PROJECT_PATH else None,
        f"{DRIVE_PROJECT_PATH}/abr_data_preprocessed.xlsx" if DRIVE_PROJECT_PATH else None,
        "/content/drive/MyDrive/abr_data_preprocessed.xlsx",
        "data/abr_data_preprocessed.xlsx",
        "abr_data_preprocessed.xlsx"
    ]
    
    print("📁 Excel Files:")
    excel_found = False
    for path in excel_paths:
        if path and os.path.exists(path):
            size = os.path.getsize(path) / (1024*1024)  # MB
            print(f"  ✅ {path} ({size:.1f} MB)")
            excel_found = True
        elif path:
            print(f"  ❌ {path}")
    
    if not excel_found:
        print("  ⚠️ No Excel files found!")
    
    # Check for processed files
    processed_paths = [
        "data/processed/abr_processed_data.pkl",
        "data/processed/abr_data_preprocessed.pkl",
        "data/abr_processed_data.pkl"
    ]
    
    print("\n🔄 Processed Files:")
    processed_found = False
    for path in processed_paths:
        if os.path.exists(path):
            size = os.path.getsize(path) / (1024*1024)  # MB
            print(f"  ✅ {path} ({size:.1f} MB)")
            processed_found = True
        else:
            print(f"  ❌ {path}")
    
    if not processed_found:
        print("  ⚠️ No processed files found!")
    
    # Check preprocessing scripts
    scripts = ["preprocess.py", "process_data.py", "src/preprocess.py"]
    print("\n🔧 Preprocessing Scripts:")
    for script in scripts:
        if os.path.exists(script):
            print(f"  ✅ {script}")
        else:
            print(f"  ❌ {script}")
    
    return excel_found, processed_found

print("✅ Complete preprocessing pipeline ready!")
print("📋 Available functions:")
print("  - run_preprocessing(): Complete preprocessing from Excel to pickle")
print("  - check_data_status(): Check current data file status")


In [None]:
# 🎯 Training Function using existing train.py
def run_training(model_type='original', epochs=50, batch_size=16, experiment_name=None):
    """Run training using the existing train.py script"""
    import sys
    import subprocess
    import os
    from datetime import datetime
    
    # Check if preprocessing is needed
    processed_data_paths = [
        "data/processed/abr_processed_data.pkl",
        "data/processed/abr_data_preprocessed.pkl", 
        "data/abr_processed_data.pkl"
    ]
    
    data_exists = any(os.path.exists(path) for path in processed_data_paths)
    
    if not data_exists:
        print("🔄 Preprocessed data not found. Running preprocessing first...")
        if not run_preprocessing():
            print("❌ Preprocessing failed. Cannot proceed with training.")
            return None
        print("✅ Preprocessing completed. Starting training...")
    
    if experiment_name is None:
        experiment_name = f"{model_type}_cvae_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    output_dir = f"outputs_{experiment_name}"
    
    # Create optimized config for GPU training
    # Adjust batch size based on GPU availability and memory
    if torch.cuda.is_available():
        gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        if gpu_memory_gb >= 15:  # High-end GPU
            optimal_batch_size = max(batch_size, 32)
        elif gpu_memory_gb >= 8:  # Mid-range GPU
            optimal_batch_size = max(batch_size, 16)
        else:  # Lower-end GPU
            optimal_batch_size = min(batch_size, 8)
    else:
        optimal_batch_size = min(batch_size, 4)  # CPU fallback
    
    # Optimized configuration for best advanced CVAE results
    config = {
        'data': {
            'sequence_length': 256,  # Increased for better temporal modeling
            'train_split': 0.75,     # More training data
            'val_split': 0.15, 
            'test_split': 0.10,
            'num_workers': 6 if torch.cuda.is_available() else 2,
            'pin_memory': torch.cuda.is_available(),
            'persistent_workers': torch.cuda.is_available(),
            'prefetch_factor': 4,    # Faster data loading
            'drop_last': True,       # Consistent batch sizes
            # Data augmentation for better generalization
            'augmentation': {
                'noise_std': 0.01,   # Small noise injection
                'time_shift': 0.05,  # Temporal shifts
                'amplitude_scale': [0.95, 1.05]  # Amplitude variations
            },
            # Advanced preprocessing
            'normalization': 'robust',  # Better for outliers
            'feature_scaling': 'standard'
        },
        'model': {
            'type': model_type,
            'static_dim': 4,
            
            # Optimized architecture for advanced model
            'latent_dim': 256 if model_type == 'advanced' else 128,  # Larger latent space
            'hidden_dim': 512 if model_type == 'advanced' else 256,  # Deeper networks
            'num_layers': 4 if model_type == 'advanced' else 3,      # More layers
            
            # Advanced model specific (hierarchical latents)
            'hierarchical_latents': {
                'global_dim': 128,    # Global features
                'local_dim': 128,     # Local temporal features
                'num_levels': 3       # Hierarchy levels
            } if model_type == 'advanced' else None,
            
            # Regularization
            'dropout': 0.15,          # Prevent overfitting
            'layer_norm': True,       # Stable training
            'spectral_norm': False,   # Can add if needed
            
            # Activation functions
            'activation': 'gelu',     # Better than ReLU for transformers
            'final_activation': 'tanh',
            
            # VAE specific
            'beta_schedule': 'cyclical',  # Better than fixed beta
            'beta_min': 0.1,
            'beta_max': 4.0,
            'beta_cycles': 4,         # Number of cycles during training
            'kl_tolerance': 0.5,      # Free bits for KL
            
            # Advanced techniques
            'mixed_precision': torch.cuda.is_available(),
            'compile_model': torch.cuda.is_available(),
            'gradient_checkpointing': True,  # Memory efficient
        },
        'training': {
            'epochs': epochs,
            'batch_size': optimal_batch_size,
            'output_dir': output_dir,
            'device': str(device),
            
            # Advanced training strategies
            'gradient_accumulation_steps': 2,  # Effective larger batch size
            'gradient_clip_norm': 1.0,         # Prevent exploding gradients
            'mixed_precision': torch.cuda.is_available(),
            'compile_model': torch.cuda.is_available(),
            
            # Optimized optimizer
            'optimizer': {
                'type': 'adamw',
                'lr': 3e-4,              # Optimal for transformers
                'weight_decay': 0.05,    # Strong regularization
                'eps': 1e-8,
                'betas': [0.9, 0.95],    # Optimized for VAEs
                'amsgrad': True          # Better convergence
            },
            
            # Advanced learning rate scheduling
            'scheduler': {
                'type': 'cosine_with_restarts',
                'warmup_steps': max(200, epochs // 10),  # Proper warmup
                'T_0': epochs // 4,      # First restart
                'T_mult': 2,             # Restart multiplier
                'eta_min': 1e-6,         # Minimum LR
                'last_epoch': -1
            },
            
            # Early stopping and model selection
            'early_stopping': {
                'patience': 15,          # More patience for complex model
                'min_delta': 1e-4,       # Minimum improvement
                'monitor': 'val_loss',
                'mode': 'min',
                'restore_best_weights': True
            },
            
            # Model checkpointing
            'checkpointing': {
                'save_best': True,
                'save_last': True,
                'save_every_n_epochs': 10,
                'monitor': 'val_loss'
            },
            
            # Loss function weights (for advanced model)
            'loss_weights': {
                'reconstruction': 1.0,
                'kl_divergence': 1.0,
                'hierarchical_reg': 0.1 if model_type == 'advanced' else 0.0,
                'consistency_reg': 0.05 if model_type == 'advanced' else 0.0
            },
            
            # Validation and logging
            'validation_freq': 1,       # Validate every epoch
            'log_freq': 50,            # Log every 50 steps
            'plot_freq': 5,            # Plot every 5 epochs
            
            # Advanced techniques
            'label_smoothing': 0.1,     # Better generalization
            'ema_decay': 0.999,         # Exponential moving average
            'stochastic_weight_avg': True,  # Better final model
        },
        
        # Evaluation configuration
        'evaluation': {
            'metrics': [
                'reconstruction_error', 'kl_divergence', 'elbo',
                'fid_score', 'inception_score', 'lpips',
                'ssim', 'psnr', 'mse', 'mae',
                'latent_traversal', 'interpolation_quality',
                'disentanglement_score', 'mutual_info_gap'
            ],
            'num_samples': 1000,        # For evaluation
            'batch_size': 64,
            'save_reconstructions': True,
            'save_generations': True,
            'save_latent_space': True
        }
    }
    
    print(f"🎯 Optimized batch size: {optimal_batch_size} (requested: {batch_size})")
    print(f"🔧 GPU optimizations: {'Enabled' if torch.cuda.is_available() else 'Disabled'}")
    
    # Save config
    os.makedirs(output_dir, exist_ok=True)
    config_path = f"{output_dir}/config.yaml"
    with open(config_path, 'w') as f:
        yaml.dump(config, f)
    
    # Prepare training command (start with basic args)
    cmd = [
        sys.executable, "train.py",
        "--config", config_path,
        "--output-dir", output_dir,
        "--device", str(device)
    ]
    
    # Add optional arguments that the script might support
    # We'll check what arguments train.py actually accepts
    try:
        # Check train.py help to see available arguments
        help_result = subprocess.run([sys.executable, "train.py", "--help"], 
                                   capture_output=True, text=True, timeout=10)
        available_args = help_result.stdout if help_result.returncode == 0 else ""
        
        # Only add arguments that are actually supported
        if "--model" in available_args:
            cmd.extend(["--model", model_type])
        if "--epochs" in available_args:
            cmd.extend(["--epochs", str(epochs)])
        if "--batch-size" in available_args:
            cmd.extend(["--batch-size", str(optimal_batch_size)])
        elif "--batch_size" in available_args:
            cmd.extend(["--batch_size", str(optimal_batch_size)])
            
        # Add GPU optimizations only if supported
        if torch.cuda.is_available():
            if "--mixed-precision" in available_args:
                cmd.append("--mixed-precision")
            if "--compile-model" in available_args:
                cmd.append("--compile-model")
            if "--pin-memory" in available_args:
                cmd.append("--pin-memory")
            if "--num-workers" in available_args:
                cmd.extend(["--num-workers", "4"])
                
    except Exception as e:
        print(f"⚠️ Could not check train.py arguments: {e}")
        # Fallback to basic arguments
        cmd.extend(["--epochs", str(epochs)])
    
    print(f"🚀 Starting {model_type} CVAE training...")
    print(f"📋 Command: {' '.join(cmd)}")
    
    # Initialize monitor with total epochs
    monitor = TrainingMonitor(total_epochs=epochs)
    
    try:
        # First, let's test if train.py can be run at all
        print("🔍 Testing train.py script...")
        test_result = subprocess.run([sys.executable, "train.py", "--help"], 
                                   capture_output=True, text=True, timeout=30)
        
        if test_result.returncode != 0:
            print("❌ train.py script has issues:")
            print("STDOUT:", test_result.stdout)
            print("STDERR:", test_result.stderr)
            return None
        
        print("✅ train.py script is accessible")
        print(f"📋 Final command: {' '.join(cmd)}")
        
        # Run training with better error capture
        process = subprocess.Popen(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, bufsize=1
        )
        
        line_count = 0
        last_plot_epoch = 0
        error_lines = []
        
        # Read both stdout and stderr with timeout
        import select
        import sys
        import time
        
        start_time = time.time()
        last_output_time = start_time
        timeout_seconds = 300  # 5 minutes without output = timeout
        
        while True:
            # Check if process is still running
            if process.poll() is not None:
                break
                
            # Check for timeout (no output for too long)
            current_time = time.time()
            if current_time - last_output_time > timeout_seconds:
                print(f"\n⏰ No output for {timeout_seconds} seconds. Training might be stuck.")
                print("🔍 Checking if process is still alive...")
                if process.poll() is None:
                    print("⚠️ Process is still running but not producing output.")
                    print("💡 This might indicate:")
                    print("   - Very slow data loading")
                    print("   - GPU memory issues")
                    print("   - Model compilation taking too long")
                    print("   - Deadlock in training loop")
                    
                    # Try to get some output
                    try:
                        remaining_stdout, remaining_stderr = process.communicate(timeout=10)
                        if remaining_stdout:
                            print("📋 Late stdout:", remaining_stdout[-500:])
                        if remaining_stderr:
                            print("📋 Late stderr:", remaining_stderr[-500:])
                    except:
                        print("❌ Could not get additional output")
                    
                    # Kill the process
                    print("🛑 Terminating stuck process...")
                    process.terminate()
                    try:
                        process.wait(timeout=5)
                    except:
                        process.kill()
                    break
                else:
                    break
                
            # Read available output
            try:
                stdout_line = process.stdout.readline()
                if stdout_line:
                    line = stdout_line.strip()
                    if line:
                        line_count += 1
                        last_output_time = current_time  # Reset timeout
                        
                        # Parse metrics (this updates the tqdm bar)
                        monitor.parse_log_line(line)
                        
                        # Print more lines initially for debugging
                        if line_count < 100 or any(keyword in line for keyword in [
                            'Starting', 'Loading', 'Epoch', 'Best model saved', 'Early stopping', 'completed', 'ERROR', 'WARNING', 'Traceback', 'Exception'
                        ]):
                            tqdm.write(f"[{line_count:3d}] {line}")
                        
                        # Collect error lines
                        if any(keyword in line.lower() for keyword in ['error', 'exception', 'traceback', 'failed']):
                            error_lines.append(line)
                        
                        # Show plot every 10 epochs
                        if monitor.current_epoch > 0 and monitor.current_epoch != last_plot_epoch and monitor.current_epoch % 10 == 0:
                            last_plot_epoch = monitor.current_epoch
                            if monitor.metrics['train_loss']:
                                print(f"\n📊 Training Progress Update (Epoch {monitor.current_epoch}):")
                                monitor.plot_progress()
                else:
                    # No output available, sleep briefly
                    time.sleep(0.1)
            except Exception as e:
                print(f"⚠️ Error reading output: {e}")
                break
        
        # Get any remaining output
        remaining_stdout, remaining_stderr = process.communicate()
        
        if remaining_stdout:
            for line in remaining_stdout.split('\n'):
                if line.strip():
                    tqdm.write(line.strip())
                    if any(keyword in line.lower() for keyword in ['error', 'exception', 'traceback']):
                        error_lines.append(line.strip())
        
        if remaining_stderr:
            print("\n🔍 STDERR Output:")
            for line in remaining_stderr.split('\n'):
                if line.strip():
                    tqdm.write(f"STDERR: {line.strip()}")
                    error_lines.append(f"STDERR: {line.strip()}")
        
        # Final results
        return_code = process.returncode
        
        # Close progress bar
        monitor.close()
        
        if return_code == 0:
            print("\n✅ Training completed successfully!")
            
            # Show final plot
            if monitor.metrics['train_loss']:
                print("📊 Final Training Results:")
                monitor.plot_progress()
            
            # Copy results to Drive if in Colab
            if IN_COLAB and DRIVE_PROJECT_PATH:
                drive_results_dir = f"{DRIVE_PROJECT_PATH}/results/{experiment_name}"
                os.makedirs(f"{DRIVE_PROJECT_PATH}/results", exist_ok=True)
                !cp -r "{output_dir}" "{drive_results_dir}"
                print(f"💾 Results backed up to: {drive_results_dir}")
            
            return output_dir
        else:
            print(f"\n❌ Training failed with return code: {return_code}")
            
            if error_lines:
                print("\n🔍 Error Summary:")
                for error in error_lines[-10:]:  # Show last 10 errors
                    print(f"  {error}")
            
            # Check common issues
            print("\n🔧 Troubleshooting:")
            print("1. Check if all required files exist:")
            print(f"   - train.py: {'✅' if os.path.exists('train.py') else '❌'}")
            print(f"   - Config: {'✅' if os.path.exists(config_path) else '❌'}")
            print(f"   - Data: {'✅' if any(os.path.exists(p) for p in processed_data_paths) else '❌'}")
            
            print("2. Try running with simpler arguments:")
            simple_cmd = [sys.executable, "train.py", "--config", config_path]
            print(f"   {' '.join(simple_cmd)}")
            
            return None
            
    except Exception as e:
        print(f"❌ Training error: {e}")
        import traceback
        traceback.print_exc()
        # Make sure to close progress bar even on error
        if 'monitor' in locals():
            monitor.close()
        return None
    finally:
        # Ensure progress bar is always closed
        if 'monitor' in locals():
            monitor.close()

print("✅ Training function ready!")


In [None]:
# 🔬 Evaluation Function using existing evaluate.py
def run_evaluation(model_path=None, output_dir=None):
    """Run evaluation using the existing evaluate.py script"""
    
    # Find model if not specified
    if model_path is None:
        # Look for the most recent output directory
        output_dirs = [d for d in os.listdir('.') if d.startswith('outputs_')]
        if not output_dirs:
            print("❌ No training outputs found. Please run training first.")
            return None
        
        latest_output_dir = max(output_dirs, key=lambda d: os.path.getmtime(d))
        
        # Look for best checkpoint
        best_checkpoint = os.path.join(latest_output_dir, "best_checkpoint.pth")
        if os.path.exists(best_checkpoint):
            model_path = best_checkpoint
        else:
            # Look for any checkpoint
            checkpoints = [f for f in os.listdir(latest_output_dir) if f.endswith('.pth')]
            if checkpoints:
                model_path = os.path.join(latest_output_dir, checkpoints[-1])
            else:
                print(f"❌ No model checkpoints found in {latest_output_dir}")
                return None
    
    if output_dir is None:
        output_dir = f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    # Prepare evaluation command
    cmd = [
        sys.executable, "evaluate.py",
        "--model", model_path,
        "--output-dir", output_dir,
        "--comprehensive"
    ]
    
    print(f"🔬 Starting evaluation...")
    print(f"📋 Model: {model_path}")
    print(f"📋 Command: {' '.join(cmd)}")
    
    try:
        # Run evaluation
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)  # 30 min timeout
        
        if result.stdout:
            print("📋 Evaluation Output:")
            print(result.stdout)
        
        if result.stderr:
            print("⚠️ Evaluation Warnings:")
            print(result.stderr)
        
        if result.returncode == 0:
            print("✅ Evaluation completed successfully!")
            
            # Copy results to Drive if in Colab
            if IN_COLAB and DRIVE_PROJECT_PATH:
                drive_eval_dir = f"{DRIVE_PROJECT_PATH}/evaluations/{output_dir}"
                os.makedirs(f"{DRIVE_PROJECT_PATH}/evaluations", exist_ok=True)
                !cp -r "{output_dir}" "{drive_eval_dir}"
                print(f"💾 Evaluation results backed up to: {drive_eval_dir}")
            
            return output_dir
        else:
            print(f"❌ Evaluation failed with return code: {result.returncode}")
            return None
            
    except subprocess.TimeoutExpired:
        print("❌ Evaluation timed out after 30 minutes")
        return None
    except Exception as e:
        print(f"❌ Evaluation error: {e}")
        return None

print("✅ Evaluation function ready!")


In [None]:
# 🚀 Quick Start Functions
def quick_train_original(epochs=50):
    """Quick training with original CVAE model"""
    return run_training(model_type='original', epochs=epochs, experiment_name='original_quick')

def quick_train_advanced(epochs=100):
    """Quick training with advanced CVAE model"""
    return run_training(model_type='advanced', epochs=epochs, experiment_name='advanced_quick')

def quick_evaluate():
    """Quick evaluation of the most recent model"""
    return run_evaluation()

def compare_models():
    """Train and compare both models"""
    print("🔬 Starting model comparison...")
    
    # Train original model
    print("\n🚀 Training Original CVAE (30 epochs)...")
    original_output = run_training('original', epochs=30, experiment_name='comparison_original')
    
    if original_output:
        print("\n🔬 Evaluating Original CVAE...")
        run_evaluation(f"{original_output}/best_checkpoint.pth", "eval_original")
    
    # Train advanced model  
    print("\n🚀 Training Advanced CVAE (50 epochs)...")
    advanced_output = run_training('advanced', epochs=50, experiment_name='comparison_advanced')
    
    if advanced_output:
        print("\n🔬 Evaluating Advanced CVAE...")
        run_evaluation(f"{advanced_output}/best_checkpoint.pth", "eval_advanced")
    
    print("\n📊 Model comparison complete!")
    if IN_COLAB:
        print(f"📁 Check your Google Drive at: {DRIVE_PROJECT_PATH}/results/")
    
    return original_output, advanced_output

def list_results():
    """List all training and evaluation results"""
    print("📁 Local Results:")
    
    # Training outputs
    outputs = [d for d in os.listdir('.') if d.startswith('outputs_')]
    if outputs:
        print("🚀 Training Results:")
        for output in sorted(outputs):
            print(f"  - {output}")
    
    # Evaluation results
    evals = [d for d in os.listdir('.') if d.startswith('evaluation_')]
    if evals:
        print("🔬 Evaluation Results:")
        for eval_dir in sorted(evals):
            print(f"  - {eval_dir}")
    
    # Drive results (if in Colab)
    if IN_COLAB and DRIVE_PROJECT_PATH:
        print(f"\n📁 Google Drive Results: {DRIVE_PROJECT_PATH}/")
        if os.path.exists(f"{DRIVE_PROJECT_PATH}/results"):
            drive_results = os.listdir(f"{DRIVE_PROJECT_PATH}/results")
            if drive_results:
                print("🚀 Drive Training Results:")
                for result in sorted(drive_results):
                    print(f"  - {result}")
        
        if os.path.exists(f"{DRIVE_PROJECT_PATH}/evaluations"):
            drive_evals = os.listdir(f"{DRIVE_PROJECT_PATH}/evaluations")
            if drive_evals:
                print("🔬 Drive Evaluation Results:")
                for eval_dir in sorted(drive_evals):
                    print(f"  - {eval_dir}")

print("✅ Quick start functions ready!")
print("\n🚀 Available functions:")
print("  - quick_train_original(epochs=50)")
print("  - quick_train_advanced(epochs=100)")
print("  - quick_evaluate()")
print("  - compare_models()")
print("  - list_results()")
print("\n💡 Example usage:")
print("  quick_train_original()  # Train original model")
print("  quick_evaluate()        # Evaluate latest model")


In [None]:
# 🚀 START TRAINING - Run this cell to begin!

# Choose one of these options:

# Option 1: Train original CVAE model (recommended for first run)
quick_train_original(epochs=30)

# Option 2: Train advanced CVAE model (more complex, takes longer)
# quick_train_advanced(epochs=50)

# Option 3: Compare both models (takes longest but most comprehensive)
# compare_models()


In [None]:
# 🔄 DATA STATUS & PREPROCESSING

# First, check what data files are available
print("🔍 Checking current data status...")
check_data_status()

print("\n" + "="*60)
print("🔄 Starting preprocessing...")

# Run complete preprocessing from Excel file
run_preprocessing()


In [None]:
# 🎮 GPU STATUS CHECK - Run this to verify GPU is working

print("🔍 Checking GPU status and performance...")
gpu_available = check_gpu_status()

if gpu_available:
    print("\n🚀 Applying optimizations...")
    recommended_batch = optimize_for_colab()
    
    print(f"\n💡 Training Tips:")
    print(f"  - Use batch size: {recommended_batch} for optimal performance")
    print(f"  - Mixed precision training: Enabled")
    print(f"  - Model compilation: Enabled")
    print(f"  - Expected speedup: 2-4x faster than CPU")
else:
    print("\n❌ GPU not available!")
    print("🔧 To enable GPU:")
    print("  1. Go to Runtime > Change runtime type")
    print("  2. Set Hardware accelerator to 'GPU'")
    print("  3. Click Save and restart the runtime")
    print("  4. Re-run all cells")


In [None]:
# 🔍 DEBUG TRAINING SCRIPT - Run this if training fails

def debug_training_setup():
    """Debug the training setup to identify issues"""
    import sys
    import subprocess
    import os
    
    print("🔍 Debugging training setup...")
    
    # Check if train.py exists and is runnable
    print("\n1. 📄 Checking train.py:")
    if os.path.exists("train.py"):
        print("   ✅ train.py exists")
        
        # Try to get help
        try:
            result = subprocess.run([sys.executable, "train.py", "--help"], 
                                  capture_output=True, text=True, timeout=10)
            if result.returncode == 0:
                print("   ✅ train.py is runnable")
                print("   📋 Available arguments:")
                for line in result.stdout.split('\n'):
                    if '--' in line:
                        print(f"     {line.strip()}")
            else:
                print("   ❌ train.py has issues:")
                print(f"     STDOUT: {result.stdout}")
                print(f"     STDERR: {result.stderr}")
        except Exception as e:
            print(f"   ❌ Error running train.py: {e}")
    else:
        print("   ❌ train.py not found")
    
    # Check data files
    print("\n2. 📊 Checking data files:")
    data_paths = [
        "data/processed/abr_processed_data.pkl",
        "data/processed/abr_data_preprocessed.pkl",
        "data/abr_processed_data.pkl"
    ]
    
    data_found = False
    for path in data_paths:
        if os.path.exists(path):
            size = os.path.getsize(path) / (1024*1024)
            print(f"   ✅ {path} ({size:.1f} MB)")
            data_found = True
        else:
            print(f"   ❌ {path}")
    
    if not data_found:
        print("   ⚠️ No processed data found! Run preprocessing first.")
    
    # Check GPU
    print("\n3. 🎮 Checking GPU:")
    if torch.cuda.is_available():
        print(f"   ✅ GPU available: {torch.cuda.get_device_name(0)}")
        print(f"   📊 Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("   ❌ GPU not available")
    
    # Check dependencies
    print("\n4. 📦 Checking key dependencies:")
    deps = ['torch', 'numpy', 'pandas', 'yaml', 'sklearn']
    for dep in deps:
        try:
            __import__(dep)
            print(f"   ✅ {dep}")
        except ImportError:
            print(f"   ❌ {dep} not found")
    
    # Try a minimal training command
    print("\n5. 🧪 Testing minimal training command:")
    if os.path.exists("train.py") and data_found:
        # Create a minimal config
        minimal_config = {
            'data': {'sequence_length': 200},
            'model': {'type': 'original'},
            'training': {'epochs': 1, 'batch_size': 2}
        }
        
        test_config_path = "test_config.yaml"
        with open(test_config_path, 'w') as f:
            yaml.dump(minimal_config, f)
        
        test_cmd = [sys.executable, "train.py", "--config", test_config_path]
        print(f"   🔧 Command: {' '.join(test_cmd)}")
        
        try:
            result = subprocess.run(test_cmd, capture_output=True, text=True, timeout=60)
            if result.returncode == 0:
                print("   ✅ Minimal training works!")
            else:
                print("   ❌ Minimal training failed:")
                print(f"     Return code: {result.returncode}")
                if result.stdout:
                    print("     STDOUT:", result.stdout[-500:])  # Last 500 chars
                if result.stderr:
                    print("     STDERR:", result.stderr[-500:])  # Last 500 chars
        except subprocess.TimeoutExpired:
            print("   ⏰ Test timed out (might be working but slow)")
        except Exception as e:
            print(f"   ❌ Test error: {e}")
        finally:
            # Clean up
            if os.path.exists(test_config_path):
                os.remove(test_config_path)
    
    print("\n✅ Debug complete!")

# Run the debug
debug_training_setup()

def test_training_script():
    """Test if train.py works with minimal arguments"""
    import sys
    import subprocess
    import os
    
    print("🧪 Testing train.py with minimal arguments...")
    
    # Create a very simple config
    simple_config = {
        'data': {'sequence_length': 200},
        'model': {'type': 'original', 'latent_dim': 32, 'hidden_dim': 64},
        'training': {'epochs': 1, 'batch_size': 2, 'output_dir': 'test_output'}
    }
    
    test_config_path = "minimal_test_config.yaml"
    with open(test_config_path, 'w') as f:
        yaml.dump(simple_config, f)
    
    # Test command
    cmd = [sys.executable, "train.py", "--config", test_config_path]
    print(f"🔧 Test command: {' '.join(cmd)}")
    
    try:
        # Run with timeout
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
        
        print(f"📋 Return code: {result.returncode}")
        
        if result.stdout:
            print("📋 STDOUT (last 1000 chars):")
            print(result.stdout[-1000:])
        
        if result.stderr:
            print("📋 STDERR (last 1000 chars):")
            print(result.stderr[-1000:])
            
        if result.returncode == 0:
            print("✅ Basic training script works!")
        else:
            print("❌ Training script has issues")
            
    except subprocess.TimeoutExpired:
        print("⏰ Test timed out after 2 minutes")
    except Exception as e:
        print(f"❌ Test error: {e}")
    finally:
        # Clean up
        if os.path.exists(test_config_path):
            os.remove(test_config_path)
        if os.path.exists("test_output"):
            import shutil
            shutil.rmtree("test_output", ignore_errors=True)

print("\n🧪 Run test_training_script() to test the basic training functionality")


In [None]:
# 🚀 OPTIMIZED QUICK START FUNCTIONS - Best configurations for advanced CVAE

def quick_train_original_optimized(epochs=50):
    """Train original CVAE with optimized settings"""
    return run_training('original', epochs=epochs, batch_size=16, experiment_name='original_optimized')

def quick_train_advanced_best(epochs=100):
    """Train advanced CVAE with best possible settings for optimal results"""
    print("🎯 Starting BEST advanced CVAE training with optimal hyperparameters:")
    print("   ✨ Hierarchical latent spaces (256D)")
    print("   🔄 Cyclical beta annealing")
    print("   📈 Cosine learning rate with restarts")
    print("   🎮 GPU optimizations enabled")
    print("   🛡️ Advanced regularization")
    print("   📊 Comprehensive evaluation metrics")
    
    # Use optimal batch size based on GPU memory
    if torch.cuda.is_available():
        gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        if gpu_memory_gb >= 15:
            optimal_batch = 32
        elif gpu_memory_gb >= 10:
            optimal_batch = 24
        else:
            optimal_batch = 16
    else:
        optimal_batch = 8
    
    print(f"🎮 Using optimal batch size: {optimal_batch}")
    return run_training('advanced', epochs=epochs, batch_size=optimal_batch, experiment_name='advanced_best')

def quick_train_advanced_fast(epochs=50):
    """Train advanced CVAE with good settings (faster for testing)"""
    return run_training('advanced', epochs=epochs, batch_size=16, experiment_name='advanced_fast')

def train_advanced_production(epochs=200):
    """Production-quality training with all optimizations for publication-ready results"""
    print("🏆 PRODUCTION TRAINING - Publication-ready advanced CVAE")
    print("🎯 This configuration is optimized for:")
    print("   📊 Maximum reconstruction quality")
    print("   🧠 Best latent space disentanglement") 
    print("   🔬 Comprehensive evaluation metrics")
    print("   💾 Automatic checkpointing and backup")
    print("   ⚡ GPU acceleration with mixed precision")
    print("   🛡️ Advanced regularization techniques")
    print("   📈 Adaptive learning rate scheduling")
    print("   🔄 Cyclical beta annealing for better KL balance")
    
    # Use maximum optimal batch size
    if torch.cuda.is_available():
        gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        if gpu_memory_gb >= 15:
            optimal_batch = 32
        elif gpu_memory_gb >= 10:
            optimal_batch = 28
        elif gpu_memory_gb >= 8:
            optimal_batch = 20
        else:
            optimal_batch = 16
    else:
        optimal_batch = 8
    
    print(f"🎮 Production batch size: {optimal_batch}")
    print(f"⏱️ Estimated training time: {epochs * 2} minutes")
    
    return run_training('advanced', epochs=epochs, batch_size=optimal_batch, experiment_name='advanced_production')

def compare_models_optimized(epochs=75):
    """Compare both models with optimized settings"""
    print("🔄 Training Original CVAE with optimized settings...")
    original_results = quick_train_original_optimized(epochs)
    
    print("\n🔄 Training Advanced CVAE with optimized settings...")
    advanced_results = quick_train_advanced_best(epochs)
    
    return original_results, advanced_results

print("✅ OPTIMIZED training functions ready!")
print("\n🚀 Available optimized functions:")
print("  🟢 quick_train_original_optimized(epochs=50) - Original CVAE with best settings")
print("  🔥 quick_train_advanced_best(epochs=100) - Advanced CVAE optimized for best results")
print("  ⚡ quick_train_advanced_fast(epochs=50) - Advanced CVAE for quick testing")
print("  🏆 train_advanced_production(epochs=200) - Production-quality training")
print("  📊 compare_models_optimized(epochs=75) - Compare both with optimal settings")

print("\n🎯 RECOMMENDED FOR BEST RESULTS:")
print("  train_advanced_production(epochs=200)")
print("\n⚡ RECOMMENDED FOR QUICK TESTING:")
print("  quick_train_advanced_fast(epochs=50)")


In [None]:
# 🚀 OPTIMIZED TRAINING EXECUTION - Choose your training strategy

print("🎯 OPTIMIZED TRAINING OPTIONS:")
print("="*60)

# Option 1: Production-quality training (BEST RESULTS)
print("🏆 Option 1: Production Training (RECOMMENDED)")
print("   - 200 epochs with all optimizations")
print("   - Publication-ready results")
print("   - Estimated time: ~6-7 hours")
print("   - Best reconstruction quality")
print("   - Optimal latent space disentanglement")
# train_advanced_production(epochs=200)

print("\n🔥 Option 2: Best Training (GOOD BALANCE)")
print("   - 100 epochs with optimal settings")
print("   - High-quality results")
print("   - Estimated time: ~3-4 hours")
print("   - Excellent performance")
# quick_train_advanced_best(epochs=100)

print("\n⚡ Option 3: Fast Training (QUICK TESTING)")
print("   - 50 epochs for rapid iteration")
print("   - Good results for testing")
print("   - Estimated time: ~1-2 hours")
print("   - Perfect for experimentation")
quick_train_advanced_fast(epochs=50)

print("\n📊 Option 4: Model Comparison")
print("   - Compare original vs advanced")
print("   - 75 epochs each")
print("   - Estimated time: ~4-5 hours")
# compare_models_optimized(epochs=75)

print("\n" + "="*60)
print("💡 TIPS:")
print("  - Uncomment the option you want to run")
print("  - Only run one option at a time")
print("  - Results are automatically saved to Google Drive")
print("  - GPU acceleration is automatically enabled")


In [None]:
# 🔄 RUN PREPROCESSING - Run this if you get data preprocessing errors

# This will automatically run when needed, but you can also run it manually
run_preprocessing()


In [None]:
# 🔬 EVALUATE MODEL - Run this after training completes

# Evaluate the most recent trained model
quick_evaluate()


In [None]:
# 🔬 EVALUATE MODEL - Run this after training completes

# Evaluate the most recent trained model
quick_evaluate()
