In [None]:
# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("🔍 Running in Google Colab")
except:
    IN_COLAB = False
    print("🔍 Running locally")

import os
import sys
from pathlib import Path

# Set up paths
if IN_COLAB:
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set project path (adjust as needed)
    PROJECT_PATH = "/content/drive/MyDrive/abr-cvae-project"  # Change this to your project path
    
    # Alternative: Clone from GitHub
    # !git clone https://github.com/your-username/abr-cvae-project.git
    # PROJECT_PATH = "/content/abr-cvae-project"
else:
    PROJECT_PATH = os.getcwd()

print(f"📁 Project path: {PROJECT_PATH}")
os.chdir(PROJECT_PATH)
sys.path.append(PROJECT_PATH)


In [None]:
# Install required packages
if IN_COLAB:
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    !pip install watchdog PyYAML scipy scikit-learn matplotlib seaborn tqdm tensorboard openpyxl
else:
    !pip install -r requirements.txt
    !pip install watchdog  # Add watchdog if not in requirements

print("✅ Dependencies installed successfully!")


In [None]:
# Import core libraries
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
import json
import time
import subprocess
import threading
from datetime import datetime
from pathlib import Path
from IPython.display import display, HTML, clear_output
import warnings
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
    print(f"📊 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

print("✅ Core libraries imported successfully!")


In [None]:
# File synchronization with local watchdog
import importlib

def check_sync_status():
    """Check if files are being synced from local machine"""
    sync_indicator = Path(PROJECT_PATH) / ".sync_status"
    
    if sync_indicator.exists():
        try:
            with open(sync_indicator, 'r') as f:
                status = f.read().strip()
            print(f"✅ Sync status: {status}")
            return True
        except:
            pass
    
    print("⚠️  No sync status found")
    print("💡 Make sure to run the local watchdog script:")
    print(f"   python local_watchdog_sync.py --local-path /path/to/your/project --project-name {os.path.basename(PROJECT_PATH)}")
    return False

def reload_modules():
    """Reload Python modules to get latest changes"""
    modules_to_reload = []
    
    # Find all imported modules from our project
    project_path_obj = Path(PROJECT_PATH)
    for module_name, module in sys.modules.items():
        if hasattr(module, '__file__') and module.__file__:
            try:
                module_path = Path(module.__file__)
                if project_path_obj in module_path.parents or str(module_path).startswith(str(project_path_obj)):
                    modules_to_reload.append(module_name)
            except:
                continue
    
    # Reload modules
    reloaded_count = 0
    for module_name in modules_to_reload:
        try:
            importlib.reload(sys.modules[module_name])
            reloaded_count += 1
        except Exception as e:
            print(f"⚠️  Could not reload {module_name}: {e}")
    
    if reloaded_count > 0:
        print(f"🔄 Reloaded {reloaded_count} modules")
    else:
        print("📋 No modules to reload")
    
    return reloaded_count

def setup_colab_sync():
    """Setup file synchronization for Colab"""
    print("🔧 Setting up file synchronization...")
    print("=" * 50)
    
    if IN_COLAB:
        print("📋 Running in Google Colab")
        print("💡 Local Watchdog Setup Instructions:")
        print("   1. On your local machine, install watchdog:")
        print("      pip install watchdog")
        print("   2. Run the sync script:")
        print(f"      python local_watchdog_sync.py --local-path /path/to/your/project --project-name {os.path.basename(PROJECT_PATH)}")
        print("   3. Keep the script running while developing")
        print("   4. Your changes will automatically sync to Google Drive")
        print("   5. Use reload_modules() in this notebook to get the latest changes")
        
        # Check current sync status
        print("\n🔍 Checking sync status...")
        check_sync_status()
        
    else:
        print("📋 Running locally - no sync needed")
    
    print("\n✅ Sync setup complete!")

# Setup sync system
setup_colab_sync()


In [None]:
# Reload modules to get latest changes from local development
print("🔄 Reloading modules...")
reloaded_count = reload_modules()

if reloaded_count > 0:
    print(f"✅ Successfully reloaded {reloaded_count} modules")
    print("💡 Your latest local changes are now available!")
else:
    print("📋 No modules needed reloading")

# Check sync status
print("\n🔍 Checking sync status...")
check_sync_status()


In [None]:
# Configuration management
def list_available_configs():
    """List all available configuration files"""
    config_dir = Path(PROJECT_PATH) / 'configs'
    if not config_dir.exists():
        print("❌ No configs directory found")
        return []
    
    configs = list(config_dir.glob('*.yaml'))
    print("📁 Available configurations:")
    for i, config in enumerate(configs):
        print(f"  {i+1}. {config.name}")
    return configs

def load_config(config_path):
    """Load configuration from YAML file"""
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        print(f"✅ Configuration loaded from {config_path}")
        return config
    except Exception as e:
        print(f"❌ Error loading config: {e}")
        return None

def display_config(config):
    """Display configuration in a nice format"""
    print("📋 Current Configuration:")
    print("=" * 50)
    
    for section, values in config.items():
        print(f"\n🔧 {section.upper()}:")
        if isinstance(values, dict):
            for key, value in values.items():
                if isinstance(value, dict):
                    print(f"  {key}:")
                    for subkey, subvalue in value.items():
                        print(f"    {subkey}: {subvalue}")
                else:
                    print(f"  {key}: {value}")
        else:
            print(f"  {values}")

def create_colab_config():
    """Create a Colab-optimized configuration"""
    config = {
        'data': {
            'sequence_length': 200,
            'channels': 1,
            'train_split': 0.7,
            'val_split': 0.15,
            'test_split': 0.15,
            'random_state': 42,
            'augment_train': True,
            'augment_prob': 0.5,
            'noise_std': 0.01,
            'time_shift_max': 50
        },
        'model': {
            'type': 'original',  # or 'advanced'
            'static_dim': 4,
            'input_dim': 1,
            'latent_dim': 128,
            'condition_dim': 128,
            'hidden_dim': 256,
            'num_encoder_layers': 4,
            'num_decoder_layers': 4,
            'num_heads': 8,
            'dropout': 0.1,
            'sequence_length': 200,
            'beta': 1.0
        },
        'training': {
            'epochs': 50,  # Reduced for Colab
            'batch_size': 16 if torch.cuda.is_available() else 8,  # Colab GPU friendly
            'initial_beta': 0.0,
            'final_beta': 1.0,
            'beta_annealing_epochs': 25,
            'early_stopping_patience': 10,
            'checkpoint_interval': 5,
            'sample_interval': 10,
            'log_interval': 50,
            'output_dir': 'outputs_colab',
            'optimizer': {
                'type': 'adamw',
                'lr': 0.001,
                'weight_decay': 0.0001,
                'betas': [0.9, 0.999],
                'eps': 1e-8
            },
            'scheduler': {
                'enabled': True,
                'type': 'cosine',
                'min_lr': 1e-6
            },
            'gradient_clipping': {
                'enabled': True,
                'max_norm': 1.0
            }
        },
        'evaluation': {
            'num_reconstruction_samples': 500,
            'num_generation_samples': 500,
            'samples_per_condition': 3,
            'num_interpolations': 5,
            'num_latent_samples': 500
        }
    }
    
    return config

# List available configurations
available_configs = list_available_configs()

# Load default configuration or create one for Colab
if available_configs:
    # Use the first config as default, or let user choose
    default_config_path = available_configs[0]
    print(f"\n🎯 Using default config: {default_config_path.name}")
    config = load_config(default_config_path)
else:
    print("\n🎯 Creating Colab-optimized configuration")
    config = create_colab_config()

if config:
    display_config(config)
else:
    print("❌ Failed to load configuration")


In [None]:
# Training setup and execution
import subprocess
import threading
import queue
import re
from IPython.display import clear_output

class TrainingMonitor:
    """Monitor training progress in real-time"""
    
    def __init__(self):
        self.metrics = {
            'epochs': [],
            'train_loss': [],
            'val_loss': [],
            'kl_loss': [],
            'recon_loss': [],
            'beta': []
        }
        self.current_epoch = 0
        self.best_val_loss = float('inf')
        
    def parse_log_line(self, line):
        """Parse training log line and extract metrics"""
        try:
            # Extract epoch info
            epoch_patterns = [
                r'Epoch (\d+)/(\d+)',
                r'Epoch (\d+):',
                r'Starting epoch (\d+)'
            ]
            for pattern in epoch_patterns:
                epoch_match = re.search(pattern, line)
                if epoch_match:
                    self.current_epoch = int(epoch_match.group(1))
                    break
                
            # Extract losses with multiple patterns
            train_loss_patterns = [
                r'Train Loss: ([\d.]+)',
                r'Training Loss: ([\d.]+)',
                r'train_loss: ([\d.]+)'
            ]
            for pattern in train_loss_patterns:
                train_loss_match = re.search(pattern, line)
                if train_loss_match:
                    self.metrics['train_loss'].append(float(train_loss_match.group(1)))
                    break
                
            val_loss_patterns = [
                r'Val Loss: ([\d.]+)',
                r'Validation Loss: ([\d.]+)',
                r'val_loss: ([\d.]+)'
            ]
            for pattern in val_loss_patterns:
                val_loss_match = re.search(pattern, line)
                if val_loss_match:
                    val_loss = float(val_loss_match.group(1))
                    self.metrics['val_loss'].append(val_loss)
                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                    break
                    
            # Extract KL loss
            kl_patterns = [
                r'KL Loss: ([\d.]+)',
                r'KL: ([\d.]+)',
                r'kl_loss: ([\d.]+)'
            ]
            for pattern in kl_patterns:
                kl_loss_match = re.search(pattern, line)
                if kl_loss_match:
                    self.metrics['kl_loss'].append(float(kl_loss_match.group(1)))
                    break
                    
            # Extract reconstruction loss
            recon_patterns = [
                r'Recon Loss: ([\d.]+)',
                r'Reconstruction Loss: ([\d.]+)',
                r'recon_loss: ([\d.]+)',
                r'Recon: ([\d.]+)'
            ]
            for pattern in recon_patterns:
                recon_loss_match = re.search(pattern, line)
                if recon_loss_match:
                    self.metrics['recon_loss'].append(float(recon_loss_match.group(1)))
                    break
                    
            # Extract beta
            beta_patterns = [
                r'Beta: ([\d.]+)',
                r'beta: ([\d.]+)'
            ]
            for pattern in beta_patterns:
                beta_match = re.search(pattern, line)
                if beta_match:
                    self.metrics['beta'].append(float(beta_match.group(1)))
                    break
                    
        except Exception as e:
            # Silently continue if parsing fails
            pass
    
    def plot_progress(self):
        """Plot training progress"""
        if not self.metrics['train_loss']:
            return
            
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle(f'Training Progress - Epoch {self.current_epoch}', fontsize=16)
        
        # Loss curves
        if self.metrics['train_loss'] and self.metrics['val_loss']:
            epochs = range(1, len(self.metrics['train_loss']) + 1)
            axes[0, 0].plot(epochs, self.metrics['train_loss'], label='Train Loss', color='blue')
            axes[0, 0].plot(epochs, self.metrics['val_loss'], label='Val Loss', color='red')
            axes[0, 0].set_title('Training & Validation Loss')
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].legend()
            axes[0, 0].grid(True)
        
        # KL and Reconstruction Loss
        if self.metrics['kl_loss'] and self.metrics['recon_loss']:
            epochs = range(1, len(self.metrics['kl_loss']) + 1)
            axes[0, 1].plot(epochs, self.metrics['kl_loss'], label='KL Loss', color='green')
            axes[0, 1].plot(epochs, self.metrics['recon_loss'], label='Recon Loss', color='orange')
            axes[0, 1].set_title('KL & Reconstruction Loss')
            axes[0, 1].set_xlabel('Epoch')
            axes[0, 1].set_ylabel('Loss')
            axes[0, 1].legend()
            axes[0, 1].grid(True)
        
        # Beta annealing
        if self.metrics['beta']:
            epochs = range(1, len(self.metrics['beta']) + 1)
            axes[1, 0].plot(epochs, self.metrics['beta'], label='Beta', color='purple')
            axes[1, 0].set_title('Beta Annealing')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Beta Value')
            axes[1, 0].legend()
            axes[1, 0].grid(True)
        
        # Summary stats
        axes[1, 1].axis('off')
        
        # Format stats with proper conditional handling
        latest_train = f"{self.metrics['train_loss'][-1]:.4f}" if self.metrics['train_loss'] else 'N/A'
        latest_val = f"{self.metrics['val_loss'][-1]:.4f}" if self.metrics['val_loss'] else 'N/A'
        current_beta = f"{self.metrics['beta'][-1]:.4f}" if self.metrics['beta'] else 'N/A'
        
        stats_text = f"""
        Current Epoch: {self.current_epoch}
        Best Val Loss: {self.best_val_loss:.4f}
        Latest Train Loss: {latest_train}
        Latest Val Loss: {latest_val}
        Current Beta: {current_beta}
        """
        axes[1, 1].text(0.1, 0.5, stats_text, fontsize=12, verticalalignment='center')
        
        plt.tight_layout()
        plt.show()
    
    def print_summary(self):
        """Print a simple text summary of training progress"""
        print(f"📊 Training Summary:")
        print(f"   Current Epoch: {self.current_epoch}")
        print(f"   Best Val Loss: {self.best_val_loss:.4f}")
        if self.metrics['train_loss']:
            print(f"   Latest Train Loss: {self.metrics['train_loss'][-1]:.4f}")
        if self.metrics['val_loss']:
            print(f"   Latest Val Loss: {self.metrics['val_loss'][-1]:.4f}")
        if self.metrics['beta']:
            print(f"   Current Beta: {self.metrics['beta'][-1]:.4f}")
        print(f"   Total Epochs Logged: {len(self.metrics['train_loss'])}")

def run_training_with_monitoring(config, output_dir="outputs_colab"):
    """Run training with real-time monitoring"""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Save config for training
    config_path = os.path.join(output_dir, "colab_config.yaml")
    with open(config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    
    # Initialize monitor
    monitor = TrainingMonitor()
    
    # Prepare training command
    cmd = [
        sys.executable, "train.py",
        "--config", config_path,
        "--output-dir", output_dir,
        "--device", str(device)
    ]
    
    print(f"🚀 Starting training with command: {' '.join(cmd)}")
    print("📊 Training progress will be displayed below...")
    
    # Run training process
    try:
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            universal_newlines=True,
            bufsize=1
        )
        
        # Monitor output in real-time
        line_count = 0
        last_plot_update = 0
        
        for line in iter(process.stdout.readline, ''):
            line = line.strip()
            if line:
                line_count += 1
                
                # Only print important lines to avoid spam
                if any(keyword in line for keyword in [
                    'Epoch', 'Loss:', 'Starting', 'Model', 'Training', 'Validation', 
                    'Best', 'Saved', 'Early stopping', 'completed', 'ERROR', 'WARNING'
                ]):
                    print(line)
                
                # Parse metrics from all lines
                monitor.parse_log_line(line)
                
                # Update plots less frequently to avoid spam
                if ('Validation Loss:' in line or 'Val Loss:' in line) and (line_count - last_plot_update > 50):
                    last_plot_update = line_count
                    clear_output(wait=True)
                    print("🔄 Updating training progress...")
                    monitor.print_summary()
                    monitor.plot_progress()
                
                # Show progress every 200 lines for very verbose output  
                elif line_count % 200 == 0:
                    print(f"📊 Training in progress... ({line_count} lines processed)")
                    if monitor.current_epoch > 0:
                        monitor.print_summary()
        
        # Final plot update
        if monitor.metrics['train_loss']:
            clear_output(wait=True)
            print("🎯 Training completed! Final results:")
            monitor.plot_progress()
        
        # Wait for process to complete
        return_code = process.wait()
        
        if return_code == 0:
            print("✅ Training completed successfully!")
            return True
        else:
            print(f"❌ Training failed with return code: {return_code}")
            return False
            
    except Exception as e:
        print(f"❌ Error during training: {e}")
        return False

# Training execution
def start_training():
    """Start the training process"""
    if not config:
        print("❌ No configuration loaded. Please run the configuration cell first.")
        return False
    
    print("🎯 Starting CVAE training...")
    print(f"📋 Model type: {config['model'].get('type', 'original')}")
    print(f"📋 Epochs: {config['training']['epochs']}")
    print(f"📋 Batch size: {config['training']['batch_size']}")
    print(f"📋 Device: {device}")
    
    # Ensure data directory exists
    data_dir = os.path.join(PROJECT_PATH, "data", "processed")
    if not os.path.exists(data_dir):
        print(f"❌ Data directory not found: {data_dir}")
        print("📋 Please ensure your preprocessed data is available.")
        return False
    
    return run_training_with_monitoring(config)

print("✅ Training pipeline ready!")
print("📋 Run start_training() to begin training")


In [None]:
# Evaluation setup and execution
def find_best_checkpoint(output_dir="outputs_colab"):
    """Find the best checkpoint from training"""
    output_path = Path(output_dir)
    
    if not output_path.exists():
        print(f"❌ Output directory not found: {output_dir}")
        return None
    
    # Look for best checkpoint
    best_checkpoint = output_path / "best_checkpoint.pth"
    if best_checkpoint.exists():
        print(f"✅ Found best checkpoint: {best_checkpoint}")
        return str(best_checkpoint)
    
    # Look for any checkpoint files
    checkpoints = list(output_path.glob("checkpoint_epoch_*.pth"))
    if checkpoints:
        # Return the latest checkpoint
        latest_checkpoint = max(checkpoints, key=lambda p: p.stat().st_mtime)
        print(f"✅ Found latest checkpoint: {latest_checkpoint}")
        return str(latest_checkpoint)
    
    print("❌ No checkpoint files found")
    return None

def run_evaluation(checkpoint_path, output_dir="results_colab"):
    """Run comprehensive evaluation using evaluate.py"""
    
    if not checkpoint_path or not os.path.exists(checkpoint_path):
        print(f"❌ Checkpoint not found: {checkpoint_path}")
        return False
    
    # Create results directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Prepare evaluation command
    cmd = [
        sys.executable, "evaluate.py",
        "--model", checkpoint_path,
        "--output-dir", output_dir,
        "--comprehensive"
    ]
    
    print(f"🔬 Starting evaluation with command: {' '.join(cmd)}")
    print("📊 Evaluation progress will be displayed below...")
    
    try:
        # Run evaluation process
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            timeout=3600  # 1 hour timeout
        )
        
        # Display output
        if result.stdout:
            print("📋 Evaluation Output:")
            print(result.stdout)
        
        if result.stderr:
            print("⚠️ Evaluation Warnings/Errors:")
            print(result.stderr)
        
        if result.returncode == 0:
            print("✅ Evaluation completed successfully!")
            return True
        else:
            print(f"❌ Evaluation failed with return code: {result.returncode}")
            return False
            
    except subprocess.TimeoutExpired:
        print("❌ Evaluation timed out after 1 hour")
        return False
    except Exception as e:
        print(f"❌ Error during evaluation: {e}")
        return False

def display_evaluation_results(results_dir="results_colab"):
    """Display evaluation results"""
    results_path = Path(results_dir)
    
    if not results_path.exists():
        print(f"❌ Results directory not found: {results_dir}")
        return
    
    print("📊 Evaluation Results:")
    print("=" * 50)
    
    # Look for evaluation subdirectories
    eval_dirs = [d for d in results_path.iterdir() if d.is_dir() and d.name.startswith('eval_')]
    
    if not eval_dirs:
        print("❌ No evaluation results found")
        return
    
    # Use the most recent evaluation
    latest_eval_dir = max(eval_dirs, key=lambda d: d.stat().st_mtime)
    print(f"📁 Latest evaluation: {latest_eval_dir.name}")
    
    # Display metadata
    metadata_file = latest_eval_dir / "evaluation_metadata.json"
    if metadata_file.exists():
        with open(metadata_file, 'r') as f:
            metadata = json.load(f)
        print("\\n📋 Evaluation Metadata:")
        for key, value in metadata.items():
            if isinstance(value, dict):
                print(f"  {key}:")
                for subkey, subvalue in value.items():
                    print(f"    {subkey}: {subvalue}")
            else:
                print(f"  {key}: {value}")
    
    # Look for generated plots
    plot_files = list(latest_eval_dir.glob("*.png"))
    if plot_files:
        print(f"\\n🖼️ Generated plots ({len(plot_files)} files):")
        for plot_file in plot_files[:5]:  # Show first 5
            print(f"  📊 {plot_file.name}")
    
    # Look for CSV results
    csv_files = list(latest_eval_dir.glob("*.csv"))
    if csv_files:
        print(f"\\n📈 CSV results ({len(csv_files)} files):")
        for csv_file in csv_files:
            print(f"  📄 {csv_file.name}")
            
            # Display first few rows if it's a small file
            try:
                df = pd.read_csv(csv_file)
                if len(df) <= 10:
                    print(f"    Preview:")
                    print(df.to_string(index=False, max_rows=5))
                else:
                    print(f"    Shape: {df.shape}")
                    print(f"    Columns: {list(df.columns)}")
            except Exception as e:
                print(f"    Error reading CSV: {e}")

def start_evaluation():
    """Start the evaluation process"""
    
    # Find the best checkpoint
    checkpoint_path = find_best_checkpoint()
    
    if not checkpoint_path:
        print("❌ No checkpoint found. Please ensure training has completed.")
        return False
    
    print(f"🎯 Starting evaluation of checkpoint: {checkpoint_path}")
    
    # Run evaluation
    success = run_evaluation(checkpoint_path)
    
    if success:
        print("\\n📊 Displaying evaluation results...")
        display_evaluation_results()
    
    return success

print("✅ Evaluation pipeline ready!")
print("📋 Run start_evaluation() to evaluate the trained model")


In [None]:
# Quick training execution
print("🚀 Starting training with current configuration...")
training_success = start_training()


In [None]:
# Quick evaluation execution
print("🔬 Starting evaluation of trained model...")
evaluation_success = start_evaluation()


In [None]:
# Advanced configuration options
def create_custom_config():
    """Create custom configuration with advanced options"""
    
    print("🔧 Custom Configuration Builder")
    print("=" * 40)
    
    # Model type selection
    print("\\n1. Model Type:")
    print("   a) Original CVAE")
    print("   b) Advanced CVAE (hierarchical latents)")
    
    model_type = input("Choose model type (a/b): ").lower()
    if model_type == 'b':
        model_type = 'advanced'
    else:
        model_type = 'original'
    
    # Training parameters
    print("\\n2. Training Parameters:")
    epochs = int(input("Number of epochs (default 50): ") or "50")
    batch_size = int(input("Batch size (default 16): ") or "16")
    learning_rate = float(input("Learning rate (default 0.001): ") or "0.001")
    
    # Model architecture
    print("\\n3. Model Architecture:")
    latent_dim = int(input("Latent dimension (default 128): ") or "128")
    hidden_dim = int(input("Hidden dimension (default 256): ") or "256")
    
    # Create custom config
    custom_config = {
        'data': {
            'sequence_length': 200,
            'channels': 1,
            'train_split': 0.7,
            'val_split': 0.15,
            'test_split': 0.15,
            'random_state': 42,
            'augment_train': True,
            'augment_prob': 0.5,
            'noise_std': 0.01,
            'time_shift_max': 50
        },
        'model': {
            'type': model_type,
            'static_dim': 4,
            'input_dim': 1,
            'latent_dim': latent_dim,
            'condition_dim': 128,
            'hidden_dim': hidden_dim,
            'num_encoder_layers': 4,
            'num_decoder_layers': 4,
            'num_heads': 8,
            'dropout': 0.1,
            'sequence_length': 200,
            'beta': 1.0
        },
        'training': {
            'epochs': epochs,
            'batch_size': batch_size,
            'initial_beta': 0.0,
            'final_beta': 1.0,
            'beta_annealing_epochs': epochs // 2,
            'early_stopping_patience': max(10, epochs // 5),
            'checkpoint_interval': max(5, epochs // 10),
            'sample_interval': max(10, epochs // 5),
            'log_interval': 50,
            'output_dir': 'outputs_custom',
            'optimizer': {
                'type': 'adamw',
                'lr': learning_rate,
                'weight_decay': 0.0001,
                'betas': [0.9, 0.999],
                'eps': 1e-8
            },
            'scheduler': {
                'enabled': True,
                'type': 'cosine',
                'min_lr': 1e-6
            },
            'gradient_clipping': {
                'enabled': True,
                'max_norm': 1.0
            }
        },
        'evaluation': {
            'num_reconstruction_samples': 500,
            'num_generation_samples': 500,
            'samples_per_condition': 3,
            'num_interpolations': 5,
            'num_latent_samples': 500
        }
    }
    
    print("\\n✅ Custom configuration created!")
    return custom_config

def compare_models():
    """Compare different model configurations"""
    
    # Original model config
    original_config = create_colab_config()
    original_config['model']['type'] = 'original'
    original_config['training']['output_dir'] = 'outputs_original'
    
    # Advanced model config
    advanced_config = create_colab_config()
    advanced_config['model']['type'] = 'advanced'
    advanced_config['training']['output_dir'] = 'outputs_advanced'
    
    configs = {
        'Original CVAE': original_config,
        'Advanced CVAE': advanced_config
    }
    
    print("🔬 Model Comparison Mode")
    print("=" * 30)
    
    results = {}
    
    for model_name, model_config in configs.items():
        print(f"\\n🚀 Training {model_name}...")
        
        # Update global config
        global config
        config = model_config
        
        # Train model
        success = run_training_with_monitoring(model_config, model_config['training']['output_dir'])
        
        if success:
            print(f"✅ {model_name} training completed!")
            
            # Evaluate model
            checkpoint_path = find_best_checkpoint(model_config['training']['output_dir'])
            if checkpoint_path:
                eval_success = run_evaluation(checkpoint_path, f"results_{model_name.lower().replace(' ', '_')}")
                results[model_name] = {
                    'training_success': True,
                    'evaluation_success': eval_success,
                    'checkpoint_path': checkpoint_path
                }
            else:
                results[model_name] = {
                    'training_success': True,
                    'evaluation_success': False,
                    'checkpoint_path': None
                }
        else:
            print(f"❌ {model_name} training failed!")
            results[model_name] = {
                'training_success': False,
                'evaluation_success': False,
                'checkpoint_path': None
            }
    
    # Display comparison results
    print("\\n📊 Model Comparison Results:")
    print("=" * 40)
    
    for model_name, result in results.items():
        print(f"\\n🔬 {model_name}:")
        print(f"  Training: {'✅' if result['training_success'] else '❌'}")
        print(f"  Evaluation: {'✅' if result['evaluation_success'] else '❌'}")
        if result['checkpoint_path']:
            print(f"  Checkpoint: {result['checkpoint_path']}")
    
    return results

print("✅ Advanced configuration tools ready!")
print("📋 Available functions:")
print("  - create_custom_config(): Build custom configuration")
print("  - compare_models(): Compare original vs advanced models")


In [None]:
# Google Drive integration
import shutil
from datetime import datetime

def backup_to_drive():
    """Backup training results to Google Drive"""
    
    if not IN_COLAB:
        print("📋 Drive backup is only available in Google Colab")
        return
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_dir = f"/content/drive/MyDrive/abr_cvae_backups/backup_{timestamp}"
    
    print(f"💾 Creating backup in Google Drive: {backup_dir}")
    
    try:
        os.makedirs(backup_dir, exist_ok=True)
        
        # Backup outputs
        output_dirs = ["outputs_colab", "outputs_original", "outputs_advanced", "outputs_custom"]
        for output_dir in output_dirs:
            if os.path.exists(output_dir):
                print(f"📁 Backing up {output_dir}...")
                shutil.copytree(output_dir, os.path.join(backup_dir, output_dir))
        
        # Backup results
        result_dirs = ["results_colab", "results_original_cvae", "results_advanced_cvae"]
        for result_dir in result_dirs:
            if os.path.exists(result_dir):
                print(f"📊 Backing up {result_dir}...")
                shutil.copytree(result_dir, os.path.join(backup_dir, result_dir))
        
        # Create backup metadata
        metadata = {
            "backup_timestamp": timestamp,
            "project_path": PROJECT_PATH,
            "device_used": str(device),
            "torch_version": torch.__version__,
            "python_version": sys.version
        }
        
        metadata_path = os.path.join(backup_dir, "backup_metadata.json")
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"✅ Backup completed successfully!")
        print(f"📁 Backup location: {backup_dir}")
        
        return backup_dir
        
    except Exception as e:
        print(f"❌ Backup failed: {e}")
        return None

def restore_from_drive():
    """Restore training results from Google Drive"""
    
    if not IN_COLAB:
        print("📋 Drive restore is only available in Google Colab")
        return
    
    backup_base = "/content/drive/MyDrive/abr_cvae_backups"
    
    if not os.path.exists(backup_base):
        print("❌ No backups found in Google Drive")
        return
    
    # List available backups
    backups = [d for d in os.listdir(backup_base) if d.startswith("backup_")]
    backups.sort(reverse=True)  # Most recent first
    
    if not backups:
        print("❌ No backups found")
        return
    
    print("📁 Available backups:")
    for i, backup in enumerate(backups[:5]):  # Show last 5 backups
        print(f"  {i+1}. {backup}")
    
    try:
        choice = int(input("Choose backup to restore (1-5): ")) - 1
        if 0 <= choice < len(backups):
            backup_path = os.path.join(backup_base, backups[choice])
            
            print(f"🔄 Restoring from: {backup_path}")
            
            # Restore directories
            for item in os.listdir(backup_path):
                item_path = os.path.join(backup_path, item)
                if os.path.isdir(item_path) and not item.endswith('.json'):
                    print(f"📁 Restoring {item}...")
                    if os.path.exists(item):
                        shutil.rmtree(item)
                    shutil.copytree(item_path, item)
            
            print("✅ Restore completed successfully!")
            
        else:
            print("❌ Invalid choice")
            
    except ValueError:
        print("❌ Invalid input")
    except Exception as e:
        print(f"❌ Restore failed: {e}")

def sync_to_drive(source_dir, drive_subdir="abr_cvae_sync"):
    """Sync specific directory to Google Drive"""
    
    if not IN_COLAB:
        print("📋 Drive sync is only available in Google Colab")
        return
    
    if not os.path.exists(source_dir):
        print(f"❌ Source directory not found: {source_dir}")
        return
    
    drive_path = f"/content/drive/MyDrive/{drive_subdir}"
    sync_path = os.path.join(drive_path, os.path.basename(source_dir))
    
    try:
        os.makedirs(drive_path, exist_ok=True)
        
        print(f"🔄 Syncing {source_dir} to {sync_path}...")
        
        if os.path.exists(sync_path):
            shutil.rmtree(sync_path)
        
        shutil.copytree(source_dir, sync_path)
        
        print(f"✅ Sync completed: {sync_path}")
        return sync_path
        
    except Exception as e:
        print(f"❌ Sync failed: {e}")
        return None

# Auto-backup function
def setup_auto_backup():
    """Setup automatic backup every hour"""
    
    if not IN_COLAB:
        print("📋 Auto-backup is only available in Google Colab")
        return
    
    def backup_worker():
        while True:
            time.sleep(3600)  # Wait 1 hour
            print("⏰ Performing automatic backup...")
            backup_to_drive()
    
    backup_thread = threading.Thread(target=backup_worker, daemon=True)
    backup_thread.start()
    
    print("⏰ Auto-backup enabled (every hour)")

print("✅ Google Drive integration ready!")
print("📋 Available functions:")
print("  - backup_to_drive(): Backup all results to Google Drive")
print("  - restore_from_drive(): Restore from previous backup")
print("  - sync_to_drive(dir): Sync specific directory to Drive")
print("  - setup_auto_backup(): Enable automatic hourly backups")
