In [18]:
# Traditional Rules vs. AI Learning 🔄

In [21]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import warnings
warnings.filterwarnings('ignore')

class BalancedAILearning:
    def __init__(self):
        # Training data
        np.random.seed(42)
        n = 80
        
        # Features
        self.feature1 = np.random.uniform(0, 10, n)
        self.feature2 = np.random.uniform(1, 5, n)
        self.feature3 = np.random.uniform(0, 8, n)
        
        # Hidden relationship
        self.true_output = (
            2 * self.feature1 + 
            5 * self.feature2**1.5 - 
            0.8 * self.feature3**2 + 
            0.3 * self.feature1 * self.feature2 + 
            np.random.normal(0, 2, n)
        )
        
        # Test cases
        self.test_f1 = [3, 7, 5]
        self.test_f2 = [2, 4, 3] 
        self.test_f3 = [1, 6, 4]
        self.test_true = [
            2*f1 + 5*f2**1.5 - 0.8*f3**2 + 0.3*f1*f2 
            for f1, f2, f3 in zip(self.test_f1, self.test_f2, self.test_f3)
        ]
        
        self.learning_history = []
        self.best_error = float('inf')
        
        self.create_interface()
    
    def create_interface(self):
        # Compact header
        display(HTML("""
        <style>
            .ai-header {
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                padding: 15px;
                border-radius: 8px;
                margin-bottom: 15px;
                text-align: center;
            }
            .ai-header h3 {
                margin: 0 0 5px 0;
            }
            .ai-header p {
                margin: 0;
                font-size: 14px;
                opacity: 0.95;
            }
            .control-section {
                background: #f8f9fa;
                padding: 12px;
                border-radius: 8px;
                margin-bottom: 10px;
            }
            .control-title {
                font-weight: bold;
                color: #495057;
                font-size: 13px;
                margin-bottom: 8px;
                padding-bottom: 5px;
                border-bottom: 1px solid #dee2e6;
            }
            .error-card {
                background: white;
                padding: 10px;
                border-radius: 8px;
                box-shadow: 0 2px 4px rgba(0,0,0,0.1);
                text-align: center;
                margin: 10px 0;
            }
            .error-value {
                font-size: 28px;
                font-weight: bold;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                -webkit-background-clip: text;
                -webkit-text-fill-color: transparent;
            }
            .accuracy-bar {
                background: #e9ecef;
                height: 20px;
                border-radius: 10px;
                overflow: hidden;
                margin: 5px 0;
            }
            .accuracy-fill {
                height: 100%;
                background: linear-gradient(90deg, #ff6b6b 0%, #ffd93d 50%, #6bcf7f 100%);
                transition: width 0.5s ease;
            }
        </style>
        <div class="ai-header">
            <h3>🧠 AI Learning Interface</h3>
            <p>Adjust parameters to minimize error and discover the hidden pattern</p>
        </div>
        """))
        
        # Compact controls with better spacing
        self.w1 = widgets.FloatSlider(
            value=1.0, min=-5, max=5, step=0.1,
            description='W₁:',
            style={'description_width': '30px'},
            layout=widgets.Layout(width='220px')
        )
        self.w2 = widgets.FloatSlider(
            value=1.0, min=-5, max=5, step=0.1,
            description='W₂:',
            style={'description_width': '30px'},
            layout=widgets.Layout(width='220px')
        )
        self.w3 = widgets.FloatSlider(
            value=1.0, min=-5, max=5, step=0.1,
            description='W₃:',
            style={'description_width': '30px'},
            layout=widgets.Layout(width='220px')
        )
        
        self.power1 = widgets.FloatSlider(
            value=1.0, min=0.5, max=2.5, step=0.1,
            description='P₁:',
            style={'description_width': '30px'},
            layout=widgets.Layout(width='220px')
        )
        self.power2 = widgets.FloatSlider(
            value=1.0, min=0.5, max=2.5, step=0.1,
            description='P₂:',
            style={'description_width': '30px'},
            layout=widgets.Layout(width='220px')
        )
        self.power3 = widgets.FloatSlider(
            value=1.0, min=0.5, max=2.5, step=0.1,
            description='P₃:',
            style={'description_width': '30px'},
            layout=widgets.Layout(width='220px')
        )
        
        self.interaction = widgets.FloatSlider(
            value=0.0, min=-1, max=1, step=0.05,
            description='Mix:',
            style={'description_width': '30px'},
            layout=widgets.Layout(width='220px')
        )
        
        # Compact buttons
        self.random_btn = widgets.Button(
            description='🎲 Random',
            button_style='primary',
            layout=widgets.Layout(width='105px', height='30px')
        )
        self.reset_btn = widgets.Button(
            description='🔄 Reset',
            button_style='warning',
            layout=widgets.Layout(width='105px', height='30px')
        )
        
        self.random_btn.on_click(self.randomize)
        self.reset_btn.on_click(self.reset)
        
        # Compact error display
        self.error_display = widgets.HTML()
        self.accuracy_display = widgets.HTML()
        
        # Organize controls compactly
        controls = widgets.VBox([
            widgets.HTML('<div class="control-title">Linear Weights</div>'),
            self.w1, self.w2, self.w3,
            widgets.HTML('<div class="control-title">Non-linearity Powers</div>'),
            self.power1, self.power2, self.power3,
            widgets.HTML('<div class="control-title">Feature Interaction</div>'),
            self.interaction,
            widgets.HBox([self.random_btn, self.reset_btn], 
                        layout=widgets.Layout(margin='10px 0')),
            self.error_display,
            self.accuracy_display
        ], layout=widgets.Layout(
            padding='15px',
            width='250px',
            background='white',
            border='1px solid #dee2e6',
            border_radius='8px'
        ))
        
        # Output area for plots
        self.plot_output = widgets.Output()
        
        # Main layout - side by side with better proportions
        main_layout = widgets.HBox([
            controls,
            self.plot_output
        ], layout=widgets.Layout(gap='15px'))
        
        # Set up interactive updates using observe instead of interact
        def update_wrapper(*args):
            self.update_model(
                self.w1.value, self.w2.value, self.w3.value,
                self.power1.value, self.power2.value, self.power3.value,
                self.interaction.value
            )
        
        # Connect each slider to the update function
        self.w1.observe(update_wrapper, 'value')
        self.w2.observe(update_wrapper, 'value')
        self.w3.observe(update_wrapper, 'value')
        self.power1.observe(update_wrapper, 'value')
        self.power2.observe(update_wrapper, 'value')
        self.power3.observe(update_wrapper, 'value')
        self.interaction.observe(update_wrapper, 'value')
        
        # Display only the main layout
        display(main_layout)
        
        # Trigger initial update
        update_wrapper()
    
    def predict(self, f1, f2, f3, w1, w2, w3, p1, p2, p3, mix):
        result = w1 * (f1 ** p1) + w2 * (f2 ** p2) + w3 * (f3 ** p3) + mix * f1 * f2
        return result
    
    def update_model(self, w1, w2, w3, p1, p2, p3, mix):
        with self.plot_output:
            clear_output(wait=True)
            
            # Make predictions
            predictions = []
            for f1, f2, f3 in zip(self.test_f1, self.test_f2, self.test_f3):
                pred = self.predict(f1, f2, f3, w1, w2, w3, p1, p2, p3, mix)
                predictions.append(pred)
            
            # Calculate error
            error = np.mean([(p - t)**2 for p, t in zip(predictions, self.test_true)])**0.5
            accuracy = max(0, min(100, 100 - error * 5))
            
            # Update displays
            self.error_display.value = f'''
            <div class="error-card">
                <div class="error-value">{error:.1f}</div>
                <div style="color: #6c757d; font-size: 12px;">
                    Model Error {'✨' if error < 3 else ''}
                </div>
            </div>
            '''
            
            self.accuracy_display.value = f'''
            <div style="text-align: center; margin-top: 5px;">
                <div style="font-size: 12px; color: #6c757d; margin-bottom: 3px;">
                    Accuracy: {accuracy:.0f}%
                </div>
                <div class="accuracy-bar">
                    <div class="accuracy-fill" style="width: {accuracy}%"></div>
                </div>
            </div>
            '''
            
            # Update history
            self.learning_history.append(error)
            if len(self.learning_history) > 30:
                self.learning_history.pop(0)
            
            # Create figure with better spacing
            fig = make_subplots(
                rows=2, cols=2,
                subplot_titles=(
                    '<b>3D Training Data Cloud</b>', 
                    '<b>Predictions vs Reality</b>',
                    '<b>Feature Processing</b>', 
                    '<b>Learning Progress</b>'
                ),
                specs=[
                    [{"type": "scatter3d"}, {"type": "scatter"}],
                    [{"type": "bar"}, {"type": "scatter"}]
                ],
                vertical_spacing=0.18,  # More vertical space
                horizontal_spacing=0.15, # More horizontal space
                row_heights=[0.5, 0.5],
                column_widths=[0.5, 0.5]
            )
            
            # 1. 3D Training Data
            fig.add_trace(
                go.Scatter3d(
                    x=self.feature1,
                    y=self.feature2,
                    z=self.feature3,
                    mode='markers',
                    marker=dict(
                        size=4,
                        color=self.true_output,
                        colorscale='Viridis',
                        showscale=False,
                        opacity=0.7
                    ),
                    showlegend=False,
                    hovertemplate='F1: %{x:.1f}<br>F2: %{y:.1f}<br>F3: %{z:.1f}<extra></extra>'
                ),
                row=1, col=1
            )
            
            # 2. Predictions vs Reality
            fig.add_trace(
                go.Scatter(
                    x=self.test_true,
                    y=predictions,
                    mode='markers+text',
                    text=['Test 1', 'Test 2', 'Test 3'],
                    textposition='top center',
                    textfont=dict(size=10),
                    marker=dict(
                        size=18,
                        color=['#667eea', '#764ba2', '#8b5cf6'],
                        line=dict(color='white', width=2)
                    ),
                    showlegend=False,
                    hovertemplate='True: %{x:.1f}<br>Predicted: %{y:.1f}<extra></extra>'
                ),
                row=1, col=2
            )
            
            # Perfect prediction line
            min_val = min(min(self.test_true), min(predictions)) - 5
            max_val = max(max(self.test_true), max(predictions)) + 5
            fig.add_trace(
                go.Scatter(
                    x=[min_val, max_val],
                    y=[min_val, max_val],
                    mode='lines',
                    line=dict(dash='dash', color='#4caf50', width=2),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=1, col=2
            )
            
            # 3. Feature Processing
            raw_means = [np.mean(self.test_f1), np.mean(self.test_f2), np.mean(self.test_f3)]
            processed_means = [
                np.mean([w1 * (f ** p1) for f in self.test_f1]),
                np.mean([w2 * (f ** p2) for f in self.test_f2]),
                np.mean([w3 * (f ** p3) for f in self.test_f3])
            ]
            
            # Combined bar chart
            fig.add_trace(
                go.Bar(
                    x=['F1', 'F2', 'F3'],
                    y=raw_means,
                    name='Input',
                    marker_color='rgba(102, 126, 234, 0.5)',
                    showlegend=False,
                    text=[f'{v:.1f}' for v in raw_means],
                    textposition='outside',
                    hovertemplate='Input %{x}: %{y:.2f}<extra></extra>'
                ),
                row=2, col=1
            )
            
            fig.add_trace(
                go.Bar(
                    x=['F1', 'F2', 'F3'],
                    y=processed_means,
                    name='Processed',
                    marker_color='#667eea',
                    showlegend=False,
                    text=[f'{v:.1f}' for v in processed_means],
                    textposition='outside',
                    hovertemplate='Processed %{x}: %{y:.2f}<extra></extra>'
                ),
                row=2, col=1
            )
            
            # 4. Learning Progress
            if self.learning_history:
                fig.add_trace(
                    go.Scatter(
                        x=list(range(len(self.learning_history))),
                        y=self.learning_history,
                        mode='lines',
                        line=dict(color='#764ba2', width=3, shape='spline'),
                        fill='tozeroy',
                        fillcolor='rgba(118, 75, 162, 0.1)',
                        showlegend=False,
                        hovertemplate='Step %{x}<br>Error: %{y:.2f}<extra></extra>'
                    ),
                    row=2, col=2
                )
            
            # Update layout with better proportions
            fig.update_layout(
                height=650,  # Taller to match control panel
                width=1000,  # Wider for better subplot spacing
                showlegend=False,
                margin=dict(l=40, r=40, t=60, b=40),
                paper_bgcolor='#fafafa',
                plot_bgcolor='white',
                title_text="",
                title_font_size=16,
                font=dict(size=11)
            )
            
            # Update axes with better styling
            fig.update_xaxes(title='True Value', gridcolor='#e0e0e0', row=1, col=2)
            fig.update_yaxes(title='Predicted', gridcolor='#e0e0e0', row=1, col=2)
            fig.update_xaxes(title='Features', gridcolor='#e0e0e0', row=2, col=1)
            fig.update_yaxes(title='Average Value', gridcolor='#e0e0e0', row=2, col=1)
            fig.update_xaxes(title='Training Steps', gridcolor='#e0e0e0', row=2, col=2)
            fig.update_yaxes(title='Error', gridcolor='#e0e0e0', row=2, col=2)
            
            # 3D scene with better camera angle
            fig.update_scenes(
                xaxis=dict(title='Feature 1', gridcolor='#e0e0e0', showbackground=False),
                yaxis=dict(title='Feature 2', gridcolor='#e0e0e0', showbackground=False),
                zaxis=dict(title='Feature 3', gridcolor='#e0e0e0', showbackground=False),
                camera=dict(
                    eye=dict(x=1.8, y=1.8, z=1.2),
                    center=dict(x=0, y=0, z=-0.1)
                ),
                aspectmode='cube'
            )
            
            # Make bar chart grouped
            fig.update_xaxes(categoryorder='array', categoryarray=['F1', 'F2', 'F3'], row=2, col=1)
            fig.update_layout(barmode='group')
            
            fig.show()
            
            # Concise feedback
            if error < self.best_error:
                self.best_error = error
                if error < 3:
                    print("🎉 Excellent! You've discovered the pattern!")
                elif error < 8:
                    print("📈 Good progress! Keep fine-tuning...")
                else:
                    print("🎯 Getting warmer...")
    
    def randomize(self, b):
        """Randomize parameters"""
        self.w1.value = np.random.uniform(-5, 5)
        self.w2.value = np.random.uniform(-5, 5)
        self.w3.value = np.random.uniform(-5, 5)
        self.power1.value = np.random.uniform(0.5, 2.5)
        self.power2.value = np.random.uniform(0.5, 2.5)
        self.power3.value = np.random.uniform(0.5, 2.5)
        self.interaction.value = np.random.uniform(-1, 1)
    
    def reset(self, b):
        """Reset parameters"""
        self.w1.value = 1.0
        self.w2.value = 1.0
        self.w3.value = 1.0
        self.power1.value = 1.0
        self.power2.value = 1.0
        self.power3.value = 1.0
        self.interaction.value = 0.0
        self.learning_history = []

# Create the balanced AI learning demo
ai_demo = BalancedAILearning()

HBox(children=(VBox(children=(HTML(value='<div class="control-title">Linear Weights</div>'), FloatSlider(value…