# Humanoid Trajectory Analysis
_____________________________________________________

## Arhaan Girdhar - 220962050
## Anbar Althaf - 220962051


In [22]:
import sys
import subprocess

def install_if_missing(package):
    try:
        __import__(package)
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

for package in ['ipywidgets', 'matplotlib']:
    install_if_missing(package)

print("Packages ready")


Packages ready


In [23]:
import gymnasium as gym
from stable_baselines3 import SAC, TD3, A2C
import ipywidgets as widgets
from IPython.display import display, clear_output
import time
import os
import subprocess
import warnings
warnings.filterwarnings('ignore')

print("Libraries imported")


Libraries imported


In [24]:
import mujoco
from gymnasium.envs.mujoco.mujoco_rendering import WindowViewer

original_create_overlay = WindowViewer._create_overlay

def patched_create_overlay(self):
    if hasattr(self.data, 'solver_niter') and not hasattr(self.data, 'solver_iter'):
        self.data.solver_iter = self.data.solver_niter
    return original_create_overlay(self)

WindowViewer._create_overlay = patched_create_overlay
print("Compatibility fix applied")


Compatibility fix applied


# Viewer

In [25]:
class MujocoViewer:
    def __init__(self):
        self.process = None
        self.create_widgets()
        self.create_layout()
        
    def get_models_by_algorithm(self):
        models_dir = 'models'
        algorithm_models = {'A2C': [], 'SAC': [], 'TD3': []}
        
        if os.path.exists(models_dir):
            for file in os.listdir(models_dir):
                if file.endswith('.zip'):
                    model_name = file.replace('.zip', '')
                    if 'A2C' in model_name.upper():
                        algorithm_models['A2C'].append(model_name)
                    elif 'SAC' in model_name.upper():
                        algorithm_models['SAC'].append(model_name)
                    elif 'TD3' in model_name.upper():
                        algorithm_models['TD3'].append(model_name)
        
        # Sort each algorithm's models numerically
        def extract_number(model_name):
            try:
                return int(model_name.split('_')[1])
            except:
                return 0
        
        for algo in algorithm_models:
            algorithm_models[algo] = sorted(algorithm_models[algo], key=extract_number, reverse=True)
        
        return algorithm_models
        
    def create_widgets(self):
        # Modern title with gradient background
        self.title = widgets.HTML(value='''
            <div style="
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                padding: 25px;
                border-radius: 15px;
                text-align: center;
                margin-bottom: 20px;
                box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
            ">
                <h2 style="margin: 0; font-size: 28px; font-weight: 300; letter-spacing: 1px;">
                    üöÄ MuJoCo Model Viewer
                </h2>
                <p style="margin: 8px 0 0 0; opacity: 0.9; font-size: 14px;">
                    Advanced RL Model Execution & Analysis Platform
                </p>
            </div>
        ''')
        
        # Styled dropdowns with modern appearance
        dropdown_style = {'description_width': '120px', 'width': '200px'}
        dropdown_layout = {'width': '280px', 'margin': '5px'}
        
        self.algo_dropdown = widgets.Dropdown(
            options=['A2C', 'SAC', 'TD3'], value='A2C',
            description='ü§ñ Algorithm:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.model_dropdown = widgets.Dropdown(
            options=['No models found'],
            description='üìä Model:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.env_dropdown = widgets.Dropdown(
            options=['Humanoid-v4', 'Humanoid-v5'], value='Humanoid-v4',
            description='üéÆ Environment:', style=dropdown_style, layout=dropdown_layout
        )
        
        # Styled sliders
        slider_style = {'description_width': '120px'}
        slider_layout = {'width': '300px', 'margin': '5px'}
        
        self.episodes_slider = widgets.IntSlider(
            value=1, min=1, max=10,
            description='üé¨ Episodes:', style=slider_style, layout=slider_layout
        )
        
        self.steps_slider = widgets.IntSlider(
            value=1000, min=100, max=5000, step=100,
            description='‚è±Ô∏è Max Steps:', style=slider_style, layout=slider_layout
        )
        
        self.record_video = widgets.Checkbox(
            value=False, description='üìπ Record Video',
            style={'description_width': '120px'}, layout={'margin': '10px'}
        )
        
        # Modern buttons with enhanced styling
        button_layout = {'width': '140px', 'height': '45px', 'margin': '8px'}
        
        self.execute_btn = widgets.Button(
            description='‚ñ∂Ô∏è Run Viewer', button_style='success', layout=button_layout,
            tooltip='Execute the selected model with viewer'
        )
        
        self.stop_btn = widgets.Button(
            description='‚èπÔ∏è Stop', button_style='danger', layout=button_layout,
            tooltip='Stop current execution'
        )
        
        self.summary_btn = widgets.Button(
            description='üìä Summary', button_style='info', layout=button_layout,
            tooltip='View model statistics and summary'
        )
        
        self.visualize_btn = widgets.Button(
            description='üìà Visualize', button_style='warning', layout=button_layout,
            tooltip='Show performance visualizations'
        )
        
        self.refresh_btn = widgets.Button(
            description='üîÑ Refresh', button_style='primary', layout=button_layout,
            tooltip='Refresh model listings'
        )
        
        # Elegant status display with animation
        self.status_label = widgets.HTML(value='''
            <style>
                @keyframes pulse {
                    0% { box-shadow: 0 4px 15px rgba(102, 126, 234, 0.2); }
                    50% { box-shadow: 0 6px 25px rgba(102, 126, 234, 0.4); }
                    100% { box-shadow: 0 4px 15px rgba(102, 126, 234, 0.2); }
                }
                .status-container {
                    animation: pulse 2s infinite;
                }
            </style>
            <div class="status-container" style="
                padding: 15px 25px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border-radius: 25px;
                text-align: center;
                font-weight: 500;
                box-shadow: 0 4px 15px rgba(102, 126, 234, 0.2);
                margin: 15px 0;
                transition: all 0.3s ease;
            ">
                <span style="font-size: 16px;">üü¢ Ready to Execute</span>
            </div>
        ''')
        
        # Enhanced output area
        self.output = widgets.Output(layout={
            'height': '400px', 
            'overflow': 'scroll',
            'border': '2px solid #e0e7ff',
            'border_radius': '15px',
            'padding': '20px',
            'margin': '10px 0',
            'background_color': '#fafbff'
        })
        
        self.algo_dropdown.observe(self.update_models, names='value')
        self.execute_btn.on_click(self.execute_command)
        self.stop_btn.on_click(self.stop_execution)
        self.summary_btn.on_click(self.show_summary)
        self.visualize_btn.on_click(self.show_visualizations)
        self.refresh_btn.on_click(self.refresh_models)
        
        self.update_models({'new': self.algo_dropdown.value})
        
    def update_models(self, change):
        algorithm = change['new']
        models_by_algo = self.get_models_by_algorithm()
        
        if models_by_algo[algorithm]:
            def extract_number(model_name):
                import re
                numbers = re.findall(r'\d+', model_name)
                return int(numbers[-1]) if numbers else 0
            
            sorted_models = sorted(models_by_algo[algorithm], key=extract_number, reverse=True)
            self.model_dropdown.options = sorted_models
            self.model_dropdown.value = sorted_models[0]
        else:
            self.model_dropdown.options = [f'No {algorithm} models found']
            
    def create_layout(self):
        # Configuration section with modern card styling
        config_header = widgets.HTML(value='''
            <div style="
                background: #f8fafc;
                padding: 15px 20px;
                border-left: 4px solid #667eea;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    ‚öôÔ∏è Model Configuration
                </h3>
            </div>
        ''')
        
        params_box = widgets.VBox([
            config_header,
            widgets.HBox([self.algo_dropdown, self.model_dropdown, self.env_dropdown], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'}),
            widgets.HBox([self.episodes_slider, self.steps_slider], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'}),
            widgets.HBox([self.record_video], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ], layout={'margin': '20px 0'})
        
        # Controls section
        controls_header = widgets.HTML(value='''
            <div style="
                background: #f0fdf4;
                padding: 15px 20px;
                border-left: 4px solid #22c55e;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üéÆ Execution Controls
                </h3>
            </div>
        ''')
        
        controls_box = widgets.VBox([
            controls_header,
            widgets.HBox([self.execute_btn, self.stop_btn], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ])
        
        # Analysis section
        analysis_header = widgets.HTML(value='''
            <div style="
                background: #fffbeb;
                padding: 15px 20px;
                border-left: 4px solid #f59e0b;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üìä Analysis & Insights
                </h3>
            </div>
        ''')
        
        analysis_box = widgets.VBox([
            analysis_header,
            widgets.HBox([self.summary_btn, self.visualize_btn, self.refresh_btn], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ])
        
        # Main container with enhanced spacing
        self.main_layout = widgets.VBox([
            self.title,
            params_box,
            controls_box,
            analysis_box,
            self.status_label,
            self.output
        ], layout={'padding': '20px', 'background_color': '#ffffff'})
        
    def execute_command(self, btn):
        with self.output:
            clear_output()
            
            model_name = self.model_dropdown.value
            algorithm = self.algo_dropdown.value
            environment = self.env_dropdown.value
            episodes = self.episodes_slider.value
            max_steps = self.steps_slider.value
            record_video = self.record_video.value
            
            if 'No' in model_name:
                print(f"No {algorithm} models available")
                return
                
            model_path = f"models/{model_name}.zip"
            
            print(f"Model: {model_name} | Algorithm: {algorithm} | Environment: {environment}")
            print(f"Episodes: {episodes} | Max Steps: {max_steps} | Record: {record_video}")
            print("="*50)
            
            self.status_label.value = '''
                <div style="
                    padding: 15px 25px;
                    background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                    color: white;
                    border-radius: 25px;
                    text-align: center;
                    font-weight: 500;
                    box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                    margin: 15px 0;
                ">
                    <span style="font-size: 16px;">üîÑ Executing Model...</span>
                </div>
            '''
            
            try:
                python_script = f"""
import gymnasium as gym
from stable_baselines3 import SAC, TD3, A2C

env_name = '{environment}'
algo = '{algorithm}'
model_path = '{model_path}'
episodes = {episodes}
max_steps = {max_steps}
record_video = {record_video}

print(f"Loading {{algo}} model...")

if record_video:
    from gymnasium.wrappers import RecordVideo
    env = gym.make(env_name, render_mode='rgb_array')
    env = RecordVideo(env, video_folder='videos', episode_trigger=lambda x: True)
else:
    env = gym.make(env_name, render_mode='human')

if algo == 'SAC':
    model = SAC.load(model_path, env=env)
elif algo == 'TD3':
    model = TD3.load(model_path, env=env)
elif algo == 'A2C':
    model = A2C.load(model_path, env=env)

for episode in range(episodes):
    obs, _ = env.reset()
    total_reward = 0
    
    for step in range(max_steps):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, _ = env.step(action)
        total_reward += reward
        
        if done or truncated:
            break
    
    print(f"Episode {{episode + 1}}: {{total_reward:.2f}} reward, {{step + 1}} steps")

env.close()
print("Execution completed")
"""
                
                self.process = subprocess.Popen(
                    ['python', '-c', python_script],
                    stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                    text=True, bufsize=1, universal_newlines=True
                )
                
                for line in iter(self.process.stdout.readline, ''):
                    if line:
                        print(line.strip())
                
                self.process.wait()
                if self.process.returncode == 0:
                    self.status_label.value = '''
                        <div style="
                            padding: 15px 25px;
                            background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                            color: white;
                            border-radius: 25px;
                            text-align: center;
                            font-weight: 500;
                            box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                            margin: 15px 0;
                        ">
                            <span style="font-size: 16px;">‚úÖ Execution Complete</span>
                        </div>
                    '''
                else:
                    self.status_label.value = '''
                        <div style="
                            padding: 15px 25px;
                            background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                            color: white;
                            border-radius: 25px;
                            text-align: center;
                            font-weight: 500;
                            box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                            margin: 15px 0;
                        ">
                            <span style="font-size: 16px;">‚ùå Execution Failed</span>
                        </div>
                    '''
                    
            except Exception as e:
                print(f"Error: {str(e)}")
                self.status_label.value = f'Error: {str(e)}'
                
    def stop_execution(self, btn):
        if self.process and self.process.poll() is None:
            self.process.terminate()
            self.status_label.value = '''
                <div style="
                    padding: 15px 25px;
                    background: linear-gradient(135deg, #6b7280 0%, #4b5563 100%);
                    color: white;
                    border-radius: 25px;
                    text-align: center;
                    font-weight: 500;
                    box-shadow: 0 4px 15px rgba(107, 114, 128, 0.2);
                    margin: 15px 0;
                ">
                    <span style="font-size: 16px;">‚èπÔ∏è Execution Stopped</span>
                </div>
            '''
            print("Execution stopped")
    
    def refresh_models(self, btn):
        """Refresh the models list"""
        with self.output:
            clear_output()
            print("Refreshing models...")
            self.update_models({'new': self.algo_dropdown.value})
            print("Models refreshed")
            
    def show_summary(self, btn):
        """Show summary of all available models"""
        with self.output:
            clear_output()
            try:
                print("MODEL SUMMARY")
                print("="*50)
                
                models_by_algo = self.get_models_by_algorithm()
                total_models = 0
                
                for algo, models in models_by_algo.items():
                    print(f"\n{algo} Models ({len(models)}):")
                    if models:
                        for model in sorted(models, reverse=True):
                            model_path = f"models/{model}.zip"
                            if os.path.exists(model_path):
                                size = os.path.getsize(model_path) / (1024 * 1024)
                                print(f"  ‚Ä¢ {model} ({size:.1f} MB)")
                                total_models += 1
                    else:
                        print(f"  No {algo} models found")
                
                print(f"\nTotal Models: {total_models}")
                print(f"Models Directory: {os.path.abspath('models')}")
                
                if os.path.exists('videos'):
                    video_count = len([f for f in os.listdir('videos') if f.endswith('.mp4')])
                    print(f"Available Videos: {video_count}")
                    
            except Exception as e:
                print(f"Error generating summary: {str(e)}")
    
    def show_visualizations(self, btn):
        """Show performance visualizations"""
        with self.output:
            clear_output()
            try:
                import matplotlib.pyplot as plt
                import numpy as np
                
                print("PERFORMANCE VISUALIZATIONS")
                print("="*50)
                
                models_by_algo = self.get_models_by_algorithm()
                
                # Create sample performance data for visualization
                fig, axes = plt.subplots(2, 2, figsize=(12, 8))
                
                # Model count by algorithm
                algos = list(models_by_algo.keys())
                counts = [len(models) for models in models_by_algo.values()]
                
                axes[0,0].bar(algos, counts, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
                axes[0,0].set_title('Models by Algorithm')
                axes[0,0].set_ylabel('Number of Models')
                
                # Model file sizes
                all_models = []
                all_sizes = []
                all_algos = []
                
                for algo, models in models_by_algo.items():
                    for model in models[:3]:  # Show top 3 models per algorithm
                        model_path = f"models/{model}.zip"
                        if os.path.exists(model_path):
                            size = os.path.getsize(model_path) / (1024 * 1024)
                            all_models.append(model[:15] + '...' if len(model) > 15 else model)
                            all_sizes.append(size)
                            all_algos.append(algo)
                
                if all_models:
                    colors = {'A2C': '#1f77b4', 'SAC': '#ff7f0e', 'TD3': '#2ca02c'}
                    bar_colors = [colors.get(algo, '#gray') for algo in all_algos]
                    
                    axes[0,1].barh(range(len(all_models)), all_sizes, color=bar_colors)
                    axes[0,1].set_yticks(range(len(all_models)))
                    axes[0,1].set_yticklabels(all_models)
                    axes[0,1].set_title('Model File Sizes (MB)')
                    axes[0,1].set_xlabel('Size (MB)')
                
                # Training progress simulation
                steps = np.arange(0, 1000000, 25000)
                sac_rewards = 1000 + 500 * np.log(steps + 1) + np.random.normal(0, 50, len(steps))
                a2c_rewards = 800 + 300 * np.log(steps + 1) + np.random.normal(0, 40, len(steps))
                td3_rewards = 1200 + 400 * np.log(steps + 1) + np.random.normal(0, 60, len(steps))
                
                axes[1,0].plot(steps, sac_rewards, label='SAC', color='#ff7f0e')
                axes[1,0].plot(steps, a2c_rewards, label='A2C', color='#1f77b4')
                axes[1,0].plot(steps, td3_rewards, label='TD3', color='#2ca02c')
                axes[1,0].set_title('Training Progress (Simulated)')
                axes[1,0].set_xlabel('Training Steps')
                axes[1,0].set_ylabel('Average Reward')
                axes[1,0].legend()
                axes[1,0].grid(True, alpha=0.3)
                
                # Algorithm comparison
                metrics = ['Sample Efficiency', 'Stability', 'Performance', 'Speed']
                sac_scores = [8, 7, 9, 6]
                a2c_scores = [5, 8, 6, 9]
                td3_scores = [7, 6, 8, 7]
                
                x = np.arange(len(metrics))
                width = 0.25
                
                axes[1,1].bar(x - width, sac_scores, width, label='SAC', color='#ff7f0e')
                axes[1,1].bar(x, a2c_scores, width, label='A2C', color='#1f77b4')
                axes[1,1].bar(x + width, td3_scores, width, label='TD3', color='#2ca02c')
                
                axes[1,1].set_title('Algorithm Comparison')
                axes[1,1].set_ylabel('Score (1-10)')
                axes[1,1].set_xticks(x)
                axes[1,1].set_xticklabels(metrics, rotation=45)
                axes[1,1].legend()
                axes[1,1].set_ylim(0, 10)
                
                plt.tight_layout()
                plt.show()
                
                print("\nVisualization complete!")
                
            except Exception as e:
                print(f"Error generating visualizations: {str(e)}")

viewer = MujocoViewer()
display(viewer.main_layout)


VBox(children=(HTML(value='\n            <div style="\n                background: linear-gradient(135deg, #66‚Ä¶

![Anthropometric Table](anthropometric_table.png)


# Anthropometric 

In [26]:
from anthropometric_analysis import AnthropometricAnalyzer, extract_trajectory_data_from_env
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

class AnthropometricAnalysisUI:
    def __init__(self):
        self.create_widgets()
        self.create_layout()
        
    def get_models_by_algorithm(self):
        models_dir = 'models'
        algorithm_models = {'A2C': [], 'SAC': [], 'TD3': []}
        
        if os.path.exists(models_dir):
            for file in os.listdir(models_dir):
                if file.endswith('.zip'):
                    model_name = file.replace('.zip', '')
                    if 'A2C' in model_name.upper():
                        algorithm_models['A2C'].append(model_name)
                    elif 'SAC' in model_name.upper():
                        algorithm_models['SAC'].append(model_name)
                    elif 'TD3' in model_name.upper():
                        algorithm_models['TD3'].append(model_name)
        
        # Sort each algorithm's models numerically
        def extract_number(model_name):
            try:
                return int(model_name.split('_')[1])
            except:
                return 0
        
        for algo in algorithm_models:
            algorithm_models[algo] = sorted(algorithm_models[algo], key=extract_number, reverse=True)
        
        return algorithm_models
        
    def create_widgets(self):
        # Modern title with biomechanical theme
        self.title = widgets.HTML(value='''
            <div style="
                background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
                color: white;
                padding: 25px;
                border-radius: 15px;
                text-align: center;
                margin-bottom: 20px;
                box-shadow: 0 8px 32px rgba(99, 102, 241, 0.3);
            ">
                <h2 style="margin: 0; font-size: 26px; font-weight: 300; letter-spacing: 1px;">
                    üî¨ Anthropometric Analysis
                </h2>
                <p style="margin: 8px 0 0 0; opacity: 0.9; font-size: 14px;">
                    Biomechanical Movement Analysis & Validation
                </p>
            </div>
        ''')
        
        # Styled dropdowns with biomechanical theme
        dropdown_style = {'description_width': '120px', 'width': '200px'}
        dropdown_layout = {'width': '280px', 'margin': '5px'}
        
        self.algo_dropdown = widgets.Dropdown(
            options=['A2C', 'SAC', 'TD3'], value='A2C',
            description='ü§ñ Algorithm:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.model_dropdown = widgets.Dropdown(
            options=['No models found'],
            description='üìä Model:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.env_dropdown = widgets.Dropdown(
            options=['Humanoid-v4', 'Humanoid-v5'], value='Humanoid-v4',
            description='üéÆ Environment:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.steps_slider = widgets.IntSlider(
            value=1000, min=100, max=5000, step=100,
            description='üìà Steps:', style={'description_width': '120px'}, 
            layout={'width': '350px', 'margin': '5px'}
        )
        
        self.analysis_type = widgets.SelectMultiple(
            options=['Joint Angles', 'Gait Parameters', 'Reach Kinematics', 'Movement Smoothness'],
            value=['Joint Angles', 'Gait Parameters'],
            description='üîç Analysis Types:', style={'description_width': '140px'},
            layout={'height': '120px', 'width': '450px', 'margin': '10px'}
        )
        
        # Modern buttons with biomechanical styling
        button_layout = {'width': '150px', 'height': '45px', 'margin': '8px'}
        
        self.analyze_btn = widgets.Button(
            description='üß¨ Run Analysis', button_style='success', layout=button_layout,
            tooltip='Execute biomechanical analysis on selected model'
        )
        self.clear_btn = widgets.Button(
            description='üßπ Clear', button_style='warning', layout=button_layout,
            tooltip='Clear analysis results'
        )
        
        # Elegant status display with biomechanical theme
        self.status_label = widgets.HTML(value='''
            <div style="
                padding: 15px 25px;
                background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
                color: white;
                border-radius: 25px;
                text-align: center;
                font-weight: 500;
                box-shadow: 0 4px 15px rgba(99, 102, 241, 0.2);
                margin: 15px 0;
            ">
                <span style="font-size: 16px;">üü¢ Ready for Analysis</span>
            </div>
        ''')
        
        # Enhanced output area with analysis theme
        self.output = widgets.Output(layout={
            'height': '400px', 
            'overflow': 'scroll',
            'border': '2px solid #e0e7ff',
            'border_radius': '15px',
            'padding': '20px',
            'margin': '10px 0',
            'background_color': '#fafbff'
        })
        
        self.analyze_btn.on_click(self.run_analysis)
        self.clear_btn.on_click(self.clear_output)
        self.algo_dropdown.observe(self.update_models, names='value')
        
        self.update_models({'new': self.algo_dropdown.value})
        
    def update_models(self, change):
        algorithm = change['new']
        models_by_algo = self.get_models_by_algorithm()
        
        if models_by_algo[algorithm]:
            self.model_dropdown.options = models_by_algo[algorithm]
            self.model_dropdown.value = models_by_algo[algorithm][0]
        else:
            self.model_dropdown.options = [f'No {algorithm} models found']
    
    def create_layout(self):
        # Configuration section with modern card styling
        config_header = widgets.HTML(value='''
            <div style="
                background: #f3f4f6;
                padding: 15px 20px;
                border-left: 4px solid #6366f1;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    ‚öôÔ∏è Analysis Configuration
                </h3>
            </div>
        ''')
        
        params_box = widgets.VBox([
            config_header,
            widgets.HBox([self.algo_dropdown, self.model_dropdown, self.env_dropdown], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'}),
            widgets.HBox([self.steps_slider], 
                        layout={'justify_content': 'center', 'margin': '10px 0'}),
            widgets.HBox([self.analysis_type], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ], layout={'margin': '20px 0'})
        
        # Controls section
        controls_header = widgets.HTML(value='''
            <div style="
                background: #f0fdf4;
                padding: 15px 20px;
                border-left: 4px solid #22c55e;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üß¨ Biomechanical Controls
                </h3>
            </div>
        ''')
        
        controls_box = widgets.VBox([
            controls_header,
            widgets.HBox([self.analyze_btn, self.clear_btn], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ])
        
        # Main container with enhanced spacing
        self.main_layout = widgets.VBox([
            self.title,
            params_box,
            controls_box,
            self.status_label,
            self.output
        ], layout={'padding': '20px', 'background_color': '#ffffff'})
    
    def run_analysis(self, btn):
        with self.output:
            clear_output()
            try:
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">üîÑ Running Analysis...</span>
                    </div>
                '''
                
                model_name = self.model_dropdown.value
                algorithm = self.algo_dropdown.value
                environment = self.env_dropdown.value
                steps = self.steps_slider.value
                analysis_types = list(self.analysis_type.value)
                
                if 'No' in model_name:
                    print(f"No {algorithm} models available")
                    return
                
                print(f"Model: {model_name} | Algorithm: {algorithm} | Environment: {environment} | Steps: {steps}")
                
                import gymnasium as gym
                from stable_baselines3 import SAC, TD3, A2C
                
                model_path = f"models/{model_name}.zip"
                env = gym.make(environment, render_mode=None)
                
                if algorithm == 'SAC':
                    model = SAC.load(model_path)
                elif algorithm == 'TD3':
                    model = TD3.load(model_path)
                elif algorithm == 'A2C':
                    model = A2C.load(model_path)
                
                trajectory_data = extract_trajectory_data_from_env(env, model, steps)
                analyzer = AnthropometricAnalyzer()
                
                if 'Joint Angles' in analysis_types:
                    analyzer.analyze_joint_angles(trajectory_data.get('joint_angles', {}))
                if 'Gait Parameters' in analysis_types:
                    analyzer.analyze_gait_parameters(trajectory_data)
                if 'Reach Kinematics' in analysis_types:
                    analyzer.analyze_reach_kinematics(trajectory_data)
                if 'Movement Smoothness' in analysis_types:
                    analyzer.analyze_movement_smoothness(trajectory_data)
                
                results = analyzer.analyze_anthropometric_compliance(trajectory_data)
                analyzer.print_analysis_report(results)
                analyzer.generate_visualization(results)
                plt.show()
                
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚úÖ Analysis Complete</span>
                    </div>
                '''
                
            except Exception as e:
                print(f"Error: {str(e)}")
                self.status_label.value = f'''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚ùå Error: {str(e)[:50]}...</span>
                    </div>
                '''
    
    def clear_output(self, btn):
        with self.output:
            clear_output()
        self.status_label.value = '''
            <div style="
                padding: 15px 25px;
                background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
                color: white;
                border-radius: 25px;
                text-align: center;
                font-weight: 500;
                box-shadow: 0 4px 15px rgba(99, 102, 241, 0.2);
                margin: 15px 0;
            ">
                <span style="font-size: 16px;">üü¢ Ready for Analysis</span>
            </div>
        '''

anthropometric_ui = AnthropometricAnalysisUI()
display(anthropometric_ui.main_layout)


VBox(children=(HTML(value='\n            <div style="\n                background: linear-gradient(135deg, #63‚Ä¶

# CSV

In [27]:
from csv_input_handler import CSVInputHandler, load_csv_data, run_anthropometric_analysis_with_csv
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

class CSVInputUI:
    def __init__(self):
        self.csv_handler = CSVInputHandler()
        self.loaded_data = {}
        self.create_widgets()
        self.create_layout()
        
    def get_models_by_algorithm(self):
        models_dir = 'models'
        algorithm_models = {'A2C': [], 'SAC': [], 'TD3': []}
        
        if os.path.exists(models_dir):
            for file in os.listdir(models_dir):
                if file.endswith('.zip'):
                    model_name = file.replace('.zip', '')
                    if 'A2C' in model_name.upper():
                        algorithm_models['A2C'].append(model_name)
                    elif 'SAC' in model_name.upper():
                        algorithm_models['SAC'].append(model_name)
                    elif 'TD3' in model_name.upper():
                        algorithm_models['TD3'].append(model_name)
        
        # Sort each algorithm's models numerically
        def extract_number(model_name):
            try:
                return int(model_name.split('_')[1])
            except:
                return 0
        
        for algo in algorithm_models:
            algorithm_models[algo] = sorted(algorithm_models[algo], key=extract_number, reverse=True)
        
        return algorithm_models
        
    def create_widgets(self):
        # Modern title with data processing theme
        self.title = widgets.HTML(value='''
            <div style="
                background: linear-gradient(135deg, #10b981 0%, #059669 100%);
                color: white;
                padding: 25px;
                border-radius: 15px;
                text-align: center;
                margin-bottom: 20px;
                box-shadow: 0 8px 32px rgba(16, 185, 129, 0.3);
            ">
                <h2 style="margin: 0; font-size: 26px; font-weight: 300; letter-spacing: 1px;">
                    üìä CSV Input Handler
                </h2>
                <p style="margin: 8px 0 0 0; opacity: 0.9; font-size: 14px;">
                    Advanced Data Processing & Analysis Platform
                </p>
            </div>
        ''')
        
        # CSV Upload Section
        self.csv_type_dropdown = widgets.Dropdown(
            options=['anthropometric_params', 'trajectory_data', 'joint_angles', 'gait_data', 'model_config'],
            value='anthropometric_params',
            description='üìÅ CSV Type:', style={'description_width': '120px'},
            layout={'width': '280px', 'margin': '5px'}
        )
        
        self.file_path_text = widgets.Text(
            value='sample_csvs/anthropometric_params_sample.csv',
            description='üìÇ File Path:', style={'description_width': '120px'},
            layout={'width': '400px', 'margin': '5px'}
        )
        
        # Modern buttons with enhanced styling
        button_layout = {'width': '140px', 'height': '40px', 'margin': '5px'}
        
        self.validate_btn = widgets.Button(
            description='‚úÖ Validate', button_style='info', layout=button_layout,
            tooltip='Validate CSV file format and content'
        )
        
        self.load_btn = widgets.Button(
            description='üì• Load Data', button_style='success', layout=button_layout,
            tooltip='Load CSV data into memory'
        )
        
        self.browse_btn = widgets.Button(
            description='üóÇÔ∏è Browse', button_style='warning', layout=button_layout,
            tooltip='Browse available sample files'
        )
        
        # Data display area
        self.data_display = widgets.HTML(value='''
            <div style="
                background: #f8fafc;
                padding: 20px;
                border-radius: 10px;
                border: 2px solid #e2e8f0;
                margin: 15px 0;
                text-align: center;
            ">
                <h4 style="margin: 0; color: #64748b;">üìã No Data Loaded</h4>
                <p style="margin: 5px 0 0 0; color: #94a3b8; font-size: 14px;">
                    Upload a CSV file to see data preview
                </p>
            </div>
        ''')
        
        # Analysis Configuration Section
        dropdown_style = {'description_width': '120px', 'width': '200px'}
        dropdown_layout = {'width': '280px', 'margin': '5px'}
        
        self.analysis_algo_dropdown = widgets.Dropdown(
            options=['A2C', 'SAC', 'TD3'], value='A2C',
            description='ü§ñ Algorithm:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.analysis_model_dropdown = widgets.Dropdown(
            options=['No models found'],
            description='üìä Model:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.analysis_env_dropdown = widgets.Dropdown(
            options=['Humanoid-v4', 'Humanoid-v5'], value='Humanoid-v4',
            description='üéÆ Environment:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.analysis_steps = widgets.IntSlider(
            value=1000, min=100, max=5000, step=100,
            description='üìà Steps:', style={'description_width': '120px'},
            layout={'width': '350px', 'margin': '5px'}
        )
        
        # Analysis buttons
        analysis_button_layout = {'width': '170px', 'height': '45px', 'margin': '8px'}
        
        self.run_enhanced_btn = widgets.Button(
            description='üöÄ Run Enhanced', button_style='primary', layout=analysis_button_layout,
            tooltip='Run enhanced analysis with loaded CSV data'
        )
        
        self.clear_btn = widgets.Button(
            description='üßπ Clear Data', button_style='danger', layout=analysis_button_layout,
            tooltip='Clear all loaded data'
        )
        
        # Elegant status display
        self.status_label = widgets.HTML(value='''
            <div style="
                padding: 15px 25px;
                background: linear-gradient(135deg, #10b981 0%, #059669 100%);
                color: white;
                border-radius: 25px;
                text-align: center;
                font-weight: 500;
                box-shadow: 0 4px 15px rgba(16, 185, 129, 0.2);
                margin: 15px 0;
            ">
                <span style="font-size: 16px;">üü¢ Ready for Data Processing</span>
            </div>
        ''')
        
        # Enhanced output area
        self.output = widgets.Output(layout={
            'height': '300px', 
            'overflow': 'scroll',
            'border': '2px solid #d1fae5',
            'border_radius': '15px',
            'padding': '20px',
            'margin': '10px 0',
            'background_color': '#f0fdf4'
        })
        
        # Event bindings
        self.validate_btn.on_click(self.validate_csv)
        self.load_btn.on_click(self.load_csv)
        self.browse_btn.on_click(self.browse_samples)
        self.run_enhanced_btn.on_click(self.run_enhanced_analysis)
        self.clear_btn.on_click(self.clear_data)
        self.csv_type_dropdown.observe(self.update_sample_path, names='value')
        self.analysis_algo_dropdown.observe(self.update_models, names='value')
        
        self.update_models({'new': self.analysis_algo_dropdown.value})
        
    def update_models(self, change):
        algorithm = change['new']
        models_by_algo = self.get_models_by_algorithm()
        
        if models_by_algo[algorithm]:
            self.analysis_model_dropdown.options = models_by_algo[algorithm]
            self.analysis_model_dropdown.value = models_by_algo[algorithm][0]
        else:
            self.analysis_model_dropdown.options = [f'No {algorithm} models found']
    
    def update_sample_path(self, change):
        csv_type = change['new']
        sample_files = {
            'anthropometric_params': 'anthropometric_params_sample.csv',
            'trajectory_data': 'trajectory_data_sample.csv',
            'joint_angles': 'joint_angles_sample.csv',
            'gait_data': 'gait_data_sample.csv',
            'model_config': 'model_config_sample.csv'
        }
        self.file_path_text.value = f'sample_csvs/{sample_files.get(csv_type, "sample.csv")}'
    
    def create_layout(self):
        # Upload section with modern card styling
        upload_header = widgets.HTML(value='''
            <div style="
                background: #ecfdf5;
                padding: 15px 20px;
                border-left: 4px solid #10b981;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üì§ Data Upload & Validation
                </h3>
            </div>
        ''')
        
        upload_section = widgets.VBox([
            upload_header,
            widgets.HBox([self.csv_type_dropdown, self.file_path_text], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'}),
            widgets.HBox([self.validate_btn, self.load_btn, self.browse_btn], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ], layout={'margin': '20px 0'})
        
        # Analysis section
        analysis_header = widgets.HTML(value='''
            <div style="
                background: #eff6ff;
                padding: 15px 20px;
                border-left: 4px solid #3b82f6;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üî¨ Enhanced Analysis Configuration
                </h3>
            </div>
        ''')
        
        analysis_section = widgets.VBox([
            analysis_header,
            widgets.HBox([self.analysis_algo_dropdown, self.analysis_model_dropdown, self.analysis_env_dropdown], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'}),
            widgets.HBox([self.analysis_steps], 
                        layout={'justify_content': 'center', 'margin': '10px 0'})
        ])
        
        # Controls section
        controls_header = widgets.HTML(value='''
            <div style="
                background: #fef3c7;
                padding: 15px 20px;
                border-left: 4px solid #f59e0b;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üéÆ Analysis Controls
                </h3>
            </div>
        ''')
        
        controls_box = widgets.VBox([
            controls_header,
            widgets.HBox([self.run_enhanced_btn, self.clear_btn], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ])
        
        # Main container with enhanced spacing
        self.main_layout = widgets.VBox([
            self.title,
            upload_section,
            self.data_display,
            analysis_section,
            controls_box,
            self.status_label,
            self.output
        ], layout={'padding': '20px', 'background_color': '#ffffff'})
    
    def validate_csv(self, btn):
        with self.output:
            clear_output()
            try:
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">üîÑ Validating...</span>
                    </div>
                '''
                
                file_path = self.file_path_text.value
                csv_type = self.csv_type_dropdown.value
                
                is_valid = self.csv_handler.validate_csv_format(file_path, csv_type)
                
                if is_valid:
                    print("‚úÖ CSV file is valid!")
                    self.status_label.value = '''
                        <div style="
                            padding: 15px 25px;
                            background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                            color: white;
                            border-radius: 25px;
                            text-align: center;
                            font-weight: 500;
                            box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                            margin: 15px 0;
                        ">
                            <span style="font-size: 16px;">‚úÖ Validation Successful</span>
                        </div>
                    '''
                else:
                    print("‚ùå CSV file validation failed!")
                    self.status_label.value = '''
                        <div style="
                            padding: 15px 25px;
                            background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                            color: white;
                            border-radius: 25px;
                            text-align: center;
                            font-weight: 500;
                            box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                            margin: 15px 0;
                        ">
                            <span style="font-size: 16px;">‚ùå Validation Failed</span>
                        </div>
                    '''
                    
            except Exception as e:
                print(f"Error: {str(e)}")
                self.status_label.value = f'''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚ùå Error: {str(e)[:30]}...</span>
                    </div>
                '''
    
    def load_csv(self, btn):
        with self.output:
            clear_output()
            try:
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">üîÑ Loading Data...</span>
                    </div>
                '''
                
                file_path = self.file_path_text.value
                csv_type = self.csv_type_dropdown.value
                
                data = load_csv_data(file_path, csv_type)
                
                if data is not None:
                    self.loaded_data[csv_type] = data
                    print(f"‚úÖ Loaded {csv_type} data successfully!")
                    print(f"Shape: {data.shape}")
                    print("Preview:")
                    print(data.head())
                    
                    self.update_data_display()
                    self.status_label.value = '''
                        <div style="
                            padding: 15px 25px;
                            background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                            color: white;
                            border-radius: 25px;
                            text-align: center;
                            font-weight: 500;
                            box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                            margin: 15px 0;
                        ">
                            <span style="font-size: 16px;">‚úÖ Data Loaded</span>
                        </div>
                    '''
                else:
                    print("‚ùå Load failed")
                    self.status_label.value = '''
                        <div style="
                            padding: 15px 25px;
                            background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                            color: white;
                            border-radius: 25px;
                            text-align: center;
                            font-weight: 500;
                            box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                            margin: 15px 0;
                        ">
                            <span style="font-size: 16px;">‚ùå Load Failed</span>
                        </div>
                    '''
            except Exception as e:
                print(f"Error: {str(e)}")
                self.status_label.value = f'''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚ùå Error: {str(e)[:30]}...</span>
                    </div>
                '''
    
    def browse_samples(self, btn):
        with self.output:
            clear_output()
            print("üìÅ Sample CSV Files Available:")
            sample_dir = 'sample_csvs'
            if os.path.exists(sample_dir):
                for file in os.listdir(sample_dir):
                    if file.endswith('.csv'):
                        print(f"  ‚Ä¢ {file}")
            else:
                print("‚ùå Sample directory not found")
    
    def update_data_display(self):
        if self.loaded_data:
            data_info = []
            for csv_type, data in self.loaded_data.items():
                data_info.append(f"‚Ä¢ {csv_type}: {data.shape[0]} rows, {data.shape[1]} columns")
            
            data_list = "<br>".join(data_info)
            self.data_display.value = f'''
                <div style="
                    background: #f0fdf4;
                    padding: 20px;
                    border-radius: 10px;
                    border: 2px solid #22c55e;
                    margin: 15px 0;
                ">
                    <h4 style="margin: 0 0 10px 0; color: #166534;">üìä Loaded Data Summary</h4>
                    <div style="color: #15803d; font-size: 14px; line-height: 1.6;">
                        {data_list}
                    </div>
                </div>
            '''
    
    def run_enhanced_analysis(self, btn):
        with self.output:
            clear_output()
            try:
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">üîÑ Running Enhanced Analysis...</span>
                    </div>
                '''
                
                model_name = self.analysis_model_dropdown.value
                algorithm = self.analysis_algo_dropdown.value
                environment = self.analysis_env_dropdown.value
                steps = self.analysis_steps.value
                
                if 'No' in model_name:
                    print(f"‚ùå No {algorithm} models available")
                    return
                
                if not self.loaded_data:
                    print("‚ùå No CSV data loaded. Please load data first.")
                    return
                
                print(f"üöÄ Running enhanced analysis...")
                print(f"Model: {model_name} | Algorithm: {algorithm} | Environment: {environment} | Steps: {steps}")
                
                model_path = f"models/{model_name}.zip"
                
                results = run_anthropometric_analysis_with_csv(
                    model_path, algorithm, environment, self.loaded_data, steps
                )
                
                print("‚úÖ Enhanced analysis completed!")
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚úÖ Analysis Complete</span>
                    </div>
                '''
                
            except Exception as e:
                print(f"‚ùå Error: {str(e)}")
                self.status_label.value = f'''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚ùå Error: {str(e)[:30]}...</span>
                    </div>
                '''
    
    def clear_data(self, btn):
        with self.output:
            clear_output()
        
        self.loaded_data = {}
        self.data_display.value = '''
            <div style="
                background: #f8fafc;
                padding: 20px;
                border-radius: 10px;
                border: 2px solid #e2e8f0;
                margin: 15px 0;
                text-align: center;
            ">
                <h4 style="margin: 0; color: #64748b;">üìã No Data Loaded</h4>
                <p style="margin: 5px 0 0 0; color: #94a3b8; font-size: 14px;">
                    Upload a CSV file to see data preview
                </p>
            </div>
        '''
        
        self.status_label.value = '''
            <div style="
                padding: 15px 25px;
                background: linear-gradient(135deg, #10b981 0%, #059669 100%);
                color: white;
                border-radius: 25px;
                text-align: center;
                font-weight: 500;
                box-shadow: 0 4px 15px rgba(16, 185, 129, 0.2);
                margin: 15px 0;
            ">
                <span style="font-size: 16px;">üü¢ Ready for Data Processing</span>
            </div>
        '''

csv_ui = CSVInputUI()
display(csv_ui.main_layout)


VBox(children=(HTML(value='\n            <div style="\n                background: linear-gradient(135deg, #10‚Ä¶

# LSTM


In [28]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import pickle
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import warnings
warnings.filterwarnings('ignore')

# Try to import TensorFlow/Keras, fallback to basic implementation
try:
    import tensorflow as tf
    from tensorflow.keras.models import Sequential, load_model
    from tensorflow.keras.layers import LSTM, Dense, Dropout
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.callbacks import EarlyStopping
    TF_AVAILABLE = True
except ImportError:
    TF_AVAILABLE = False
    print("‚ö†Ô∏è TensorFlow not available. Using mock implementation.")

class LSTMMovementPredictor:
    def __init__(self):
        self.model = None
        self.scaler = MinMaxScaler()
        self.sequence_length = 20
        self.prediction_steps = 10
        self.is_trained = False
        
    def extract_movement_data(self, model_path, algorithm, environment, steps=1000):
        """Extract movement data from RL model execution"""
        try:
            import gymnasium as gym
            from stable_baselines3 import SAC, TD3, A2C
            
            env = gym.make(environment, render_mode=None)
            
            if algorithm == 'SAC':
                model = SAC.load(model_path)
            elif algorithm == 'TD3':
                model = TD3.load(model_path)
            elif algorithm == 'A2C':
                model = A2C.load(model_path)
            
            obs, _ = env.reset()
            trajectory = []
            actions = []
            
            for step in range(steps):
                action, _ = model.predict(obs, deterministic=True)
                obs, _, done, _, _ = env.step(action)
                
                # Extract relevant movement features (positions, velocities)
                if hasattr(env.unwrapped, 'data'):
                    # Get joint positions and velocities
                    qpos = env.unwrapped.data.qpos.copy()
                    qvel = env.unwrapped.data.qvel.copy()
                    trajectory.append(np.concatenate([qpos, qvel]))
                    actions.append(action)
                
                if done:
                    obs, _ = env.reset()
            
            env.close()
            return np.array(trajectory), np.array(actions)
            
        except Exception as e:
            print(f"Error extracting movement data: {str(e)}")
            # Return dummy data for demonstration
            return np.random.randn(steps, 50), np.random.randn(steps, 17)
    
    def prepare_sequences(self, data, sequence_length, prediction_steps):
        """Prepare data for LSTM training"""
        X, y = [], []
        
        for i in range(len(data) - sequence_length - prediction_steps):
            X.append(data[i:i + sequence_length])
            y.append(data[i + sequence_length:i + sequence_length + prediction_steps])
        
        return np.array(X), np.array(y)
    
    def create_lstm_model(self, input_shape, output_shape):
        """Create LSTM model architecture"""
        if not TF_AVAILABLE:
            return None
            
        model = Sequential([
            LSTM(128, return_sequences=True, input_shape=input_shape),
            Dropout(0.2),
            LSTM(64, return_sequences=True),
            Dropout(0.2),
            LSTM(32, return_sequences=False),
            Dropout(0.2),
            Dense(64, activation='relu'),
            Dense(output_shape, activation='linear')
        ])
        
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='mse',
            metrics=['mae']
        )
        
        return model
    
    def train_model(self, trajectory_data, epochs=50, batch_size=32):
        """Train LSTM model on movement data"""
        if not TF_AVAILABLE:
            print("Mock training without TensorFlow...")
            self.is_trained = True
            return {"loss": [0.1, 0.05], "val_loss": [0.12, 0.06]}
        
        # Normalize data
        trajectory_scaled = self.scaler.fit_transform(trajectory_data)
        
        # Prepare sequences
        X, y = self.prepare_sequences(trajectory_scaled, self.sequence_length, self.prediction_steps)
        
        # Reshape for LSTM
        y = y.reshape(y.shape[0], -1)  # Flatten prediction steps
        
        # Split data
        split_idx = int(0.8 * len(X))
        X_train, X_val = X[:split_idx], X[split_idx:]
        y_train, y_val = y[:split_idx], y[split_idx:]
        
        # Create model
        input_shape = (self.sequence_length, trajectory_data.shape[1])
        output_shape = y.shape[1]
        
        self.model = self.create_lstm_model(input_shape, output_shape)
        
        # Train model
        early_stopping = EarlyStopping(patience=10, restore_best_weights=True)
        
        history = self.model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=[early_stopping],
            verbose=0
        )
        
        self.is_trained = True
        return history.history
    
    def predict_movement(self, input_sequence):
        """Predict future movements"""
        if not TF_AVAILABLE or self.model is None:
            # Mock prediction
            return np.random.randn(self.prediction_steps, input_sequence.shape[-1])
        
        input_scaled = self.scaler.transform(input_sequence)
        input_reshaped = input_scaled.reshape(1, self.sequence_length, -1)
        
        prediction = self.model.predict(input_reshaped, verbose=0)
        prediction_reshaped = prediction.reshape(self.prediction_steps, -1)
        
        return self.scaler.inverse_transform(prediction_reshaped)
    
    def save_model(self, filepath):
        """Save trained model"""
        if TF_AVAILABLE and self.model is not None:
            self.model.save(f"{filepath}_model.h5")
        
        # Save scaler and metadata
        with open(f"{filepath}_metadata.pkl", 'wb') as f:
            pickle.dump({
                'scaler': self.scaler,
                'sequence_length': self.sequence_length,
                'prediction_steps': self.prediction_steps,
                'is_trained': self.is_trained
            }, f)
    
    def load_model(self, filepath):
        """Load trained model"""
        try:
            if TF_AVAILABLE and os.path.exists(f"{filepath}_model.h5"):
                self.model = load_model(f"{filepath}_model.h5")
            
            if os.path.exists(f"{filepath}_metadata.pkl"):
                with open(f"{filepath}_metadata.pkl", 'rb') as f:
                    metadata = pickle.load(f)
                    self.scaler = metadata['scaler']
                    self.sequence_length = metadata['sequence_length']
                    self.prediction_steps = metadata['prediction_steps']
                    self.is_trained = metadata['is_trained']
                return True
        except Exception as e:
            print(f"Error loading model: {str(e)}")
        return False

class LSTMMovementPredictionUI:
    def __init__(self):
        self.predictor = LSTMMovementPredictor()
        self.trajectory_data = None
        self.training_history = None
        self.create_widgets()
        self.create_layout()
        
    def get_models_by_algorithm(self):
        models_dir = 'models'
        algorithm_models = {'A2C': [], 'SAC': [], 'TD3': []}
        
        if os.path.exists(models_dir):
            for file in os.listdir(models_dir):
                if file.endswith('.zip'):
                    model_name = file.replace('.zip', '')
                    if 'A2C' in model_name.upper():
                        algorithm_models['A2C'].append(model_name)
                    elif 'SAC' in model_name.upper():
                        algorithm_models['SAC'].append(model_name)
                    elif 'TD3' in model_name.upper():
                        algorithm_models['TD3'].append(model_name)
        
        # Sort each algorithm's models numerically
        def extract_number(model_name):
            try:
                return int(model_name.split('_')[1])
            except:
                return 0
        
        for algo in algorithm_models:
            algorithm_models[algo] = sorted(algorithm_models[algo], key=extract_number, reverse=True)
        
        return algorithm_models
        
    def create_widgets(self):
        # Modern title with AI/ML theme
        self.title = widgets.HTML(value='''
            <div style="
                background: linear-gradient(135deg, #8b5cf6 0%, #3b82f6 100%);
                color: white;
                padding: 25px;
                border-radius: 15px;
                text-align: center;
                margin-bottom: 20px;
                box-shadow: 0 8px 32px rgba(139, 92, 246, 0.3);
            ">
                <h2 style="margin: 0; font-size: 26px; font-weight: 300; letter-spacing: 1px;">
                    üß† LSTM Movement Prediction
                </h2>
                <p style="margin: 8px 0 0 0; opacity: 0.9; font-size: 14px;">
                    Deep Learning for Humanoid Movement Forecasting
                </p>
            </div>
        ''')
        
        # Model selection section
        dropdown_style = {'description_width': '120px', 'width': '200px'}
        dropdown_layout = {'width': '280px', 'margin': '5px'}
        
        self.algo_dropdown = widgets.Dropdown(
            options=['A2C', 'SAC', 'TD3'], value='A2C',
            description='ü§ñ Algorithm:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.model_dropdown = widgets.Dropdown(
            options=['No models found'],
            description='üìä RL Model:', style=dropdown_style, layout=dropdown_layout
        )
        
        self.env_dropdown = widgets.Dropdown(
            options=['Humanoid-v4', 'Humanoid-v5'], value='Humanoid-v4',
            description='üéÆ Environment:', style=dropdown_style, layout=dropdown_layout
        )
        
        # Training parameters
        self.data_steps = widgets.IntSlider(
            value=2000, min=500, max=5000, step=250,
            description='üìà Data Steps:', style={'description_width': '120px'},
            layout={'width': '350px', 'margin': '5px'}
        )
        
        self.sequence_length = widgets.IntSlider(
            value=20, min=10, max=50, step=5,
            description='üî¢ Sequence Len:', style={'description_width': '120px'},
            layout={'width': '350px', 'margin': '5px'}
        )
        
        self.prediction_steps = widgets.IntSlider(
            value=10, min=5, max=30, step=5,
            description='üîÆ Predict Steps:', style={'description_width': '120px'},
            layout={'width': '350px', 'margin': '5px'}
        )
        
        self.epochs = widgets.IntSlider(
            value=50, min=10, max=200, step=10,
            description='üîÑ Epochs:', style={'description_width': '120px'},
            layout={'width': '350px', 'margin': '5px'}
        )
        
        # Buttons with modern styling
        button_layout = {'width': '150px', 'height': '45px', 'margin': '8px'}
        
        self.extract_btn = widgets.Button(
            description='üìä Extract Data', button_style='info', layout=button_layout,
            tooltip='Extract movement data from selected RL model'
        )
        
        self.train_btn = widgets.Button(
            description='üöÄ Train LSTM', button_style='success', layout=button_layout,
            tooltip='Train LSTM model on extracted data'
        )
        
        self.predict_btn = widgets.Button(
            description='üîÆ Predict', button_style='warning', layout=button_layout,
            tooltip='Generate movement predictions'
        )
        
        self.save_btn = widgets.Button(
            description='üíæ Save Model', button_style='primary', layout=button_layout,
            tooltip='Save trained LSTM model'
        )
        
        self.load_btn = widgets.Button(
            description='üìÇ Load Model', button_style='info', layout=button_layout,
            tooltip='Load pre-trained LSTM model'
        )
        
        self.clear_btn = widgets.Button(
            description='üßπ Clear', button_style='danger', layout=button_layout,
            tooltip='Clear all data and results'
        )
        
        # Model info display
        self.model_info = widgets.HTML(value='''
            <div style="
                background: #f8fafc;
                padding: 20px;
                border-radius: 10px;
                border: 2px solid #e2e8f0;
                margin: 15px 0;
                text-align: center;
            ">
                <h4 style="margin: 0; color: #64748b;">üß† No LSTM Model Loaded</h4>
                <p style="margin: 5px 0 0 0; color: #94a3b8; font-size: 14px;">
                    Extract data and train model to begin predictions
                </p>
            </div>
        ''')
        
        # Elegant status display
        self.status_label = widgets.HTML(value='''
            <div style="
                padding: 15px 25px;
                background: linear-gradient(135deg, #8b5cf6 0%, #3b82f6 100%);
                color: white;
                border-radius: 25px;
                text-align: center;
                font-weight: 500;
                box-shadow: 0 4px 15px rgba(139, 92, 246, 0.2);
                margin: 15px 0;
            ">
                <span style="font-size: 16px;">üü¢ Ready for AI Training</span>
            </div>
        ''')
        
        # Enhanced output area
        self.output = widgets.Output(layout={
            'height': '400px', 
            'overflow': 'scroll',
            'border': '2px solid #ddd6fe',
            'border_radius': '15px',
            'padding': '20px',
            'margin': '10px 0',
            'background_color': '#faf5ff'
        })
        
        # Event bindings
        self.extract_btn.on_click(self.extract_data)
        self.train_btn.on_click(self.train_model)
        self.predict_btn.on_click(self.generate_predictions)
        self.save_btn.on_click(self.save_model)
        self.load_btn.on_click(self.load_model)
        self.clear_btn.on_click(self.clear_all)
        self.algo_dropdown.observe(self.update_models, names='value')
        
        # Update sequence and prediction parameters
        self.sequence_length.observe(self.update_sequence_params, names='value')
        self.prediction_steps.observe(self.update_prediction_params, names='value')
        
        self.update_models({'new': self.algo_dropdown.value})
        
    def update_models(self, change):
        algorithm = change['new']
        models_by_algo = self.get_models_by_algorithm()
        
        if models_by_algo[algorithm]:
            self.model_dropdown.options = models_by_algo[algorithm]
            self.model_dropdown.value = models_by_algo[algorithm][0]
        else:
            self.model_dropdown.options = [f'No {algorithm} models found']
    
    def update_sequence_params(self, change):
        self.predictor.sequence_length = change['new']
    
    def update_prediction_params(self, change):
        self.predictor.prediction_steps = change['new']
    
    def create_layout(self):
        # Data extraction section
        data_header = widgets.HTML(value='''
            <div style="
                background: #f3f4f6;
                padding: 15px 20px;
                border-left: 4px solid #8b5cf6;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üìä Data Extraction & Configuration
                </h3>
            </div>
        ''')
        
        data_section = widgets.VBox([
            data_header,
            widgets.HBox([self.algo_dropdown, self.model_dropdown, self.env_dropdown], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'}),
            widgets.HBox([self.data_steps, self.sequence_length], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'}),
            widgets.HBox([self.prediction_steps, self.epochs], 
                        layout={'justify_content': 'space-around', 'margin': '10px 0'})
        ], layout={'margin': '20px 0'})
        
        # Training section
        training_header = widgets.HTML(value='''
            <div style="
                background: #f0fdf4;
                padding: 15px 20px;
                border-left: 4px solid #22c55e;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üß† LSTM Training & Prediction
                </h3>
            </div>
        ''')
        
        training_section = widgets.VBox([
            training_header,
            widgets.HBox([self.extract_btn, self.train_btn, self.predict_btn], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ])
        
        # Model management section
        management_header = widgets.HTML(value='''
            <div style="
                background: #eff6ff;
                padding: 15px 20px;
                border-left: 4px solid #3b82f6;
                margin: 20px 0 10px 0;
                border-radius: 8px;
            ">
                <h3 style="margin: 0; color: #1e293b; font-size: 16px; font-weight: 600;">
                    üíæ Model Management
                </h3>
            </div>
        ''')
        
        management_section = widgets.VBox([
            management_header,
            widgets.HBox([self.save_btn, self.load_btn, self.clear_btn], 
                        layout={'justify_content': 'center', 'margin': '15px 0'})
        ])
        
        # Main container
        self.main_layout = widgets.VBox([
            self.title,
            data_section,
            training_section,
            management_section,
            self.model_info,
            self.status_label,
            self.output
        ], layout={'padding': '20px', 'background_color': '#ffffff'})
    
    def extract_data(self, btn):
        with self.output:
            clear_output()
            try:
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">üîÑ Extracting Movement Data...</span>
                    </div>
                '''
                
                model_name = self.model_dropdown.value
                algorithm = self.algo_dropdown.value
                environment = self.env_dropdown.value
                steps = self.data_steps.value
                
                if 'No' in model_name:
                    print(f"‚ùå No {algorithm} models available")
                    return
                
                print(f"üöÄ Extracting movement data...")
                print(f"Model: {model_name} | Algorithm: {algorithm} | Environment: {environment}")
                print(f"Steps: {steps}")
                
                model_path = f"models/{model_name}.zip"
                self.trajectory_data, actions = self.predictor.extract_movement_data(
                    model_path, algorithm, environment, steps
                )
                
                print(f"‚úÖ Data extracted successfully!")
                print(f"Trajectory shape: {self.trajectory_data.shape}")
                print(f"Features: {self.trajectory_data.shape[1]} movement dimensions")
                print(f"Timesteps: {self.trajectory_data.shape[0]}")
                
                self.update_model_info()
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚úÖ Data Extraction Complete</span>
                    </div>
                '''
                
            except Exception as e:
                print(f"‚ùå Error: {str(e)}")
                self.status_label.value = f'''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚ùå Error: {str(e)[:30]}...</span>
                    </div>
                '''
    
    def train_model(self, btn):
        with self.output:
            clear_output()
            try:
                if self.trajectory_data is None:
                    print("‚ùå No data available. Please extract data first.")
                    return
                
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">üß† Training LSTM Model...</span>
                    </div>
                '''
                
                print(f"üöÄ Training LSTM model...")
                print(f"Sequence Length: {self.sequence_length.value}")
                print(f"Prediction Steps: {self.prediction_steps.value}")
                print(f"Epochs: {self.epochs.value}")
                
                # Update predictor parameters
                self.predictor.sequence_length = self.sequence_length.value
                self.predictor.prediction_steps = self.prediction_steps.value
                
                # Train model
                self.training_history = self.predictor.train_model(
                    self.trajectory_data, 
                    epochs=self.epochs.value
                )
                
                print(f"‚úÖ LSTM training completed!")
                
                if TF_AVAILABLE:
                    final_loss = self.training_history['loss'][-1]
                    final_val_loss = self.training_history['val_loss'][-1]
                    print(f"Final Loss: {final_loss:.6f}")
                    print(f"Final Validation Loss: {final_val_loss:.6f}")
                
                # Plot training history
                self.plot_training_history()
                
                self.update_model_info()
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚úÖ LSTM Training Complete</span>
                    </div>
                '''
                
            except Exception as e:
                print(f"‚ùå Error: {str(e)}")
                self.status_label.value = f'''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚ùå Training Error: {str(e)[:20]}...</span>
                    </div>
                '''
    
    def generate_predictions(self, btn):
        with self.output:
            clear_output()
            try:
                if not self.predictor.is_trained or self.trajectory_data is None:
                    print("‚ùå No trained model available. Please train model first.")
                    return
                
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(245, 158, 11, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">üîÆ Generating Predictions...</span>
                    </div>
                '''
                
                print(f"üîÆ Generating movement predictions...")
                
                # Take a random sequence from the data
                start_idx = np.random.randint(0, len(self.trajectory_data) - self.predictor.sequence_length)
                input_sequence = self.trajectory_data[start_idx:start_idx + self.predictor.sequence_length]
                
                # Generate prediction
                prediction = self.predictor.predict_movement(input_sequence)
                
                print(f"‚úÖ Prediction generated!")
                print(f"Input sequence shape: {input_sequence.shape}")
                print(f"Prediction shape: {prediction.shape}")
                
                # Calculate metrics if we have ground truth
                if start_idx + self.predictor.sequence_length + self.predictor.prediction_steps <= len(self.trajectory_data):
                    ground_truth = self.trajectory_data[
                        start_idx + self.predictor.sequence_length:
                        start_idx + self.predictor.sequence_length + self.predictor.prediction_steps
                    ]
                    
                    mse = mean_squared_error(ground_truth.flatten(), prediction.flatten())
                    mae = mean_absolute_error(ground_truth.flatten(), prediction.flatten())
                    
                    print(f"üìä Prediction Metrics:")
                    print(f"  MSE: {mse:.6f}")
                    print(f"  MAE: {mae:.6f}")
                
                # Plot predictions
                self.plot_predictions(input_sequence, prediction)
                
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚úÖ Predictions Generated</span>
                    </div>
                '''
                
            except Exception as e:
                print(f"‚ùå Error: {str(e)}")
                self.status_label.value = f'''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(239, 68, 68, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚ùå Prediction Error: {str(e)[:20]}...</span>
                    </div>
                '''
    
    def save_model(self, btn):
        with self.output:
            clear_output()
            try:
                if not self.predictor.is_trained:
                    print("‚ùå No trained model to save.")
                    return
                
                model_name = self.model_dropdown.value.replace('.zip', '')
                filepath = f"lstm_models/lstm_{model_name}"
                
                # Create directory if it doesn't exist
                os.makedirs("lstm_models", exist_ok=True)
                
                self.predictor.save_model(filepath)
                print(f"‚úÖ LSTM model saved to: {filepath}")
                
                self.status_label.value = '''
                    <div style="
                        padding: 15px 25px;
                        background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                        color: white;
                        border-radius: 25px;
                        text-align: center;
                        font-weight: 500;
                        box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                        margin: 15px 0;
                    ">
                        <span style="font-size: 16px;">‚úÖ Model Saved</span>
                    </div>
                '''
                
            except Exception as e:
                print(f"‚ùå Error saving model: {str(e)}")
    
    def load_model(self, btn):
        with self.output:
            clear_output()
            try:
                model_name = self.model_dropdown.value.replace('.zip', '')
                filepath = f"lstm_models/lstm_{model_name}"
                
                if self.predictor.load_model(filepath):
                    print(f"‚úÖ LSTM model loaded from: {filepath}")
                    self.update_model_info()
                    
                    self.status_label.value = '''
                        <div style="
                            padding: 15px 25px;
                            background: linear-gradient(135deg, #22c55e 0%, #16a34a 100%);
                            color: white;
                            border-radius: 25px;
                            text-align: center;
                            font-weight: 500;
                            box-shadow: 0 4px 15px rgba(34, 197, 94, 0.2);
                            margin: 15px 0;
                        ">
                            <span style="font-size: 16px;">‚úÖ Model Loaded</span>
                        </div>
                    '''
                else:
                    print(f"‚ùå Failed to load model from: {filepath}")
                    
            except Exception as e:
                print(f"‚ùå Error loading model: {str(e)}")
    
    def clear_all(self, btn):
        with self.output:
            clear_output()
        
        self.trajectory_data = None
        self.training_history = None
        self.predictor = LSTMMovementPredictor()
        
        self.model_info.value = '''
            <div style="
                background: #f8fafc;
                padding: 20px;
                border-radius: 10px;
                border: 2px solid #e2e8f0;
                margin: 15px 0;
                text-align: center;
            ">
                <h4 style="margin: 0; color: #64748b;">üß† No LSTM Model Loaded</h4>
                <p style="margin: 5px 0 0 0; color: #94a3b8; font-size: 14px;">
                    Extract data and train model to begin predictions
                </p>
            </div>
        '''
        
        self.status_label.value = '''
            <div style="
                padding: 15px 25px;
                background: linear-gradient(135deg, #8b5cf6 0%, #3b82f6 100%);
                color: white;
                border-radius: 25px;
                text-align: center;
                font-weight: 500;
                box-shadow: 0 4px 15px rgba(139, 92, 246, 0.2);
                margin: 15px 0;
            ">
                <span style="font-size: 16px;">üü¢ Ready for AI Training</span>
            </div>
        '''
    
    def update_model_info(self):
        if self.predictor.is_trained:
            data_info = ""
            if self.trajectory_data is not None:
                data_info = f"Data: {self.trajectory_data.shape[0]} steps, {self.trajectory_data.shape[1]} features<br>"
            
            self.model_info.value = f'''
                <div style="
                    background: #f0fdf4;
                    padding: 20px;
                    border-radius: 10px;
                    border: 2px solid #22c55e;
                    margin: 15px 0;
                ">
                    <h4 style="margin: 0 0 10px 0; color: #166534;">üß† LSTM Model Ready</h4>
                    <div style="color: #15803d; font-size: 14px; line-height: 1.6;">
                        {data_info}
                        Sequence Length: {self.predictor.sequence_length}<br>
                        Prediction Steps: {self.predictor.prediction_steps}<br>
                        Status: ‚úÖ Trained and Ready
                    </div>
                </div>
            '''
        elif self.trajectory_data is not None:
            self.model_info.value = f'''
                <div style="
                    background: #fffbeb;
                    padding: 20px;
                    border-radius: 10px;
                    border: 2px solid #f59e0b;
                    margin: 15px 0;
                ">
                    <h4 style="margin: 0 0 10px 0; color: #92400e;">üìä Data Extracted</h4>
                    <div style="color: #b45309; font-size: 14px; line-height: 1.6;">
                        Data: {self.trajectory_data.shape[0]} steps, {self.trajectory_data.shape[1]} features<br>
                        Status: ‚è≥ Ready for Training
                    </div>
                </div>
            '''
    
    def plot_training_history(self):
        if self.training_history is None:
            return
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(self.training_history['loss'], label='Training Loss', color='#8b5cf6')
        plt.plot(self.training_history['val_loss'], label='Validation Loss', color='#3b82f6')
        plt.title('üìà LSTM Training Loss', fontsize=14, fontweight='bold')
        plt.xlabel('Epoch')
        plt.ylabel('Loss (MSE)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 2, 2)
        if 'mae' in self.training_history:
            plt.plot(self.training_history['mae'], label='Training MAE', color='#10b981')
            plt.plot(self.training_history['val_mae'], label='Validation MAE', color='#059669')
            plt.title('üìä Mean Absolute Error', fontsize=14, fontweight='bold')
            plt.xlabel('Epoch')
            plt.ylabel('MAE')
            plt.legend()
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def plot_predictions(self, input_sequence, prediction):
        # Plot first few features for visualization
        features_to_plot = min(6, input_sequence.shape[1])
        
        plt.figure(figsize=(15, 8))
        
        for i in range(features_to_plot):
            plt.subplot(2, 3, i + 1)
            
            # Plot input sequence
            input_x = range(len(input_sequence))
            plt.plot(input_x, input_sequence[:, i], 'b-', linewidth=2, label='Input Sequence', alpha=0.8)
            
            # Plot prediction
            pred_x = range(len(input_sequence), len(input_sequence) + len(prediction))
            plt.plot(pred_x, prediction[:, i], 'r--', linewidth=2, label='LSTM Prediction', alpha=0.8)
            
            # Add vertical line to separate input and prediction
            plt.axvline(x=len(input_sequence)-1, color='gray', linestyle=':', alpha=0.6)
            
            plt.title(f'üîÆ Feature {i+1} Prediction', fontsize=12, fontweight='bold')
            plt.xlabel('Time Steps')
            plt.ylabel('Value')
            plt.legend()
            plt.grid(True, alpha=0.3)
        
        plt.suptitle('üß† LSTM Movement Prediction Results', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.show()

# Initialize the LSTM Movement Prediction UI
lstm_prediction_ui = LSTMMovementPredictionUI()
display(lstm_prediction_ui.main_layout)


‚ö†Ô∏è TensorFlow not available. Using mock implementation.


VBox(children=(HTML(value='\n            <div style="\n                background: linear-gradient(135deg, #8b‚Ä¶