In [None]:
import numpy as np
from scipy.integrate import solve_ivp
import plotly.graph_objects as go
import plotly.express as px
import ipywidgets as widgets
from IPython.display import display

# Define the Lorenz system (same as your original code)
def lorenz_system(t, state, sigma=10, rho=28, beta=8/3):
    x, y, z = state
    dxdt = sigma * (y - x)
    dydt = x * (rho - z) - y
    dzdt = x * y - beta * z
    return [dxdt, dydt, dzdt]

# Function to solve the ODE and return the points (similar to your original)
def ode_solution_points(function, state0, time, dt=0.01):
    solution = solve_ivp(
        function, 
        t_span=(0, time), 
        y0=state0, 
        t_eval=np.arange(0, time, dt)
    )
    return solution.y.T

class InteractiveLorenzAttractor:
    def __init__(self):
        # Default parameters
        self.sigma = 10
        self.rho = 28
        self.beta = 8/3
        self.time = 15
        self.dt = 0.02
        self.num_curves = 5
        self.initial_conditions = [[1 + i, 1 + i, 1 + i] for i in range(self.num_curves)]
        
        # Create the widgets
        self.sigma_slider = widgets.FloatSlider(
            value=self.sigma, min=1, max=20, step=0.5, 
            description='Sigma:', continuous_update=False
        )
        self.rho_slider = widgets.FloatSlider(
            value=self.rho, min=0, max=50, step=1, 
            description='Rho:', continuous_update=False
        )
        self.beta_slider = widgets.FloatSlider(
            value=self.beta, min=0, max=5, step=0.1, 
            description='Beta:', continuous_update=False
        )
        self.time_slider = widgets.FloatSlider(
            value=self.time, min=5, max=30, step=1, 
            description='Time:', continuous_update=False
        )
        self.curves_slider = widgets.IntSlider(
            value=self.num_curves, min=1, max=15, step=1, 
            description='Curves:', continuous_update=False
        )
        self.update_button = widgets.Button(description='Update Plot')
        self.clear_button = widgets.Button(description='Clear Plot')
        self.play_button = widgets.Button(description='Play Animation', icon='play')
        self.stop_button = widgets.Button(description='Stop Animation', icon='stop')
        
        # Set up the output
        self.output = widgets.Output()
        
        # Connect widgets to callbacks
        self.update_button.on_click(self.update_plot)
        self.clear_button.on_click(self.clear_plot)
        self.play_button.on_click(self.play_animation)
        self.stop_button.on_click(self.stop_animation)
        
        # Create initial figure
        self.fig = go.FigureWidget()
        self.setup_figure()
        self.update_plot(None)  # Initial plot
        
        # Animation properties
        self.animation_running = False
        self.current_frame = 0
        self.max_frames = 100
        self.particle_indices = []  # Store indices of particles for animation
        
    def setup_figure(self):
        """Set up the 3D figure with proper axes and layout"""
        self.fig.update_layout(
            scene=dict(
                xaxis=dict(range=[-30, 30], title='X'),
                yaxis=dict(range=[-30, 30], title='Y'),
                zaxis=dict(range=[0, 50], title='Z'),
                aspectmode='cube'
            ),
            margin=dict(l=0, r=0, b=0, t=0),
            height=600,
            width=800,
            title="Interactive Lorenz Attractor"
        )
    
    def generate_lorenz_data(self):
        """Generate the Lorenz attractor data based on current parameters"""
        curves_data = []
        
        # Update number of curves
        num_curves = self.curves_slider.value
        self.num_curves = num_curves
        
        # Always ensure we have at least one initial condition
        if len(self.initial_conditions) != num_curves:
            self.initial_conditions = [[1 + i, 1 + i, 1 + i] for i in range(num_curves)]
            
        # Generate color palette
        if num_curves == 1:
            # Special case for a single curve
            colors = [px.colors.sequential.Viridis[5]]  # Just pick a nice blue
        else:
            colors = px.colors.sample_colorscale(
                px.colors.sequential.Viridis, num_curves
            )
        
        # Generate curves
        for i in range(num_curves):
            state0 = self.initial_conditions[i]
            # Solve the ODE system
            points = ode_solution_points(
                lambda t, state: lorenz_system(t, state, self.sigma_slider.value, 
                                              self.rho_slider.value, self.beta_slider.value),
                state0, 
                self.time_slider.value, 
                dt=self.dt
            )
            
            # Store data as dictionary
            curve_data = {
                'x': points[:, 0],
                'y': points[:, 1],
                'z': points[:, 2],
                'color': colors[i],
                'name': f'Curve {i+1}'
            }
            curves_data.append(curve_data)
            
        return curves_data
    
    def update_plot(self, b):
        """Update the plot with new parameters"""
        # Clear existing traces
        self.fig.data = []
        self.particle_indices = []  # Reset particle indices
        
        # Generate new data
        curves_data = self.generate_lorenz_data()
        
        # Add each curve to the figure
        for i, curve in enumerate(curves_data):
            # Add the main curve
            self.fig.add_trace(
                go.Scatter3d(
                    x=curve['x'], 
                    y=curve['y'], 
                    z=curve['z'],
                    mode='lines',
                    name=curve['name'],
                    line=dict(color=curve['color'], width=3),
                    opacity=0.7
                )
            )
            
            # Add starting point as marker
            self.fig.add_trace(
                go.Scatter3d(
                    x=[curve['x'][0]],
                    y=[curve['y'][0]],
                    z=[curve['z'][0]],
                    mode='markers',
                    marker=dict(size=5, color=curve['color']),
                    name=f"{curve['name']} Start",
                    showlegend=False
                )
            )
    
    def clear_plot(self, b):
        """Clear all traces from the plot"""
        self.fig.data = []
        self.particle_indices = []
    
    def play_animation(self, b):
        """Start the animation of particles moving along the attractor"""
        if not self.animation_running:
            self.animation_running = True
            curves_data = self.generate_lorenz_data()
            
            # Store the data for animation
            self.animation_data = curves_data
            
            if not curves_data:
                self.animation_running = False
                return
                
            self.animation_length = len(curves_data[0]['x']) if curves_data else 0
            self.current_frame = 0
            
            if self.animation_length == 0:
                self.animation_running = False
                return
                
            # Add moving particles and keep track of their indices
            self.particle_indices = []
            for i, curve in enumerate(curves_data):
                particle_trace = go.Scatter3d(
                    x=[curve['x'][0]], 
                    y=[curve['y'][0]], 
                    z=[curve['z'][0]],
                    mode='markers',
                    marker=dict(
                        size=8, 
                        color=curve['color'],
                        symbol='circle',
                        line=dict(width=2, color='white')
                    ),
                    name=f'Particle {i+1}',
                    showlegend=False
                )
                self.fig.add_trace(particle_trace)
                self.particle_indices.append(len(self.fig.data) - 1)  # Store the index of this particle
            
            # Start the animation
            self.animate_step()
    
    def animate_step(self):
        """Update the animation for one step"""
        if not self.animation_running or not self.animation_data or not self.particle_indices:
            return
            
        self.current_frame = (self.current_frame + 1) % self.animation_length
        
        # Update each particle position using stored indices
        for i, curve in enumerate(self.animation_data):
            if i < len(self.particle_indices):
                particle_idx = self.particle_indices[i]
                if particle_idx < len(self.fig.data):
                    new_data = dict(
                        x=[curve['x'][self.current_frame]],
                        y=[curve['y'][self.current_frame]],
                        z=[curve['z'][self.current_frame]]
                    )
                    self.fig.data[particle_idx].update(new_data)
        
        # Continue animation
        from IPython.display import display
        import time
        time.sleep(0.02)  # Control animation speed
        
        if self.animation_running:
            import IPython
            IPython.display.clear_output(wait=True)
            display(self.fig)
            IPython.get_ipython().kernel.do_one_iteration()
            self.animate_step()
    
    def stop_animation(self, b):
        """Stop the animation"""
        self.animation_running = False
    
    def show(self):
        """Display the interactive plot and controls"""
        # Create layout
        parameter_widgets = widgets.VBox([
            self.sigma_slider, 
            self.rho_slider, 
            self.beta_slider,
            self.time_slider,
            self.curves_slider
        ])
        
        button_row = widgets.HBox([
            self.update_button, 
            self.clear_button,
            self.play_button,
            self.stop_button
        ])
        
        controls = widgets.VBox([parameter_widgets, button_row])
        
        # Display everything
        display(widgets.VBox([controls, self.fig]))
        
# Create and display the interactive plot
lorenz_interactive = InteractiveLorenzAttractor()
lorenz_interactive.show()