# Imports

In [5]:
import sys
import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown, Image, HTML
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

# Fix environment variable
os.environ['MAX_FILE_SIZE'] = '104857600'

# ⚠️ IMPORTANT: Set API keys BEFORE importing backend modules
# This ensures the settings object reads them correctly
os.environ['GOOGLE_API_KEY'] = 'AIzaSyBGtJZtRFMIdoPgkEUG8UQPA6pBxUvmwSg' 
os.environ['OPENAI_API_KEY'] = ''
os.environ['ANTHROPIC_API_KEY'] = ''

# Or set a valid Gemini API key if you have one:
# os.environ['GOOGLE_API_KEY'] = 'your-valid-api-key-here'

# Add backend to path
backend_path = Path("./backend").resolve()
if str(backend_path) not in sys.path:
    sys.path.insert(0, str(backend_path))

print("📦 Importing backend modules...")

# Import backend modules
from app.workflows.state_management import state_manager
from app.agents.data_cleaning.enhanced_data_cleaning_agent import EnhancedDataCleaningAgent
from app.agents.data_analysis.eda_agent import EDAAgent
from app.agents.ml_pipeline.feature_engineering_agent import FeatureEngineeringAgent
from app.agents.ml_pipeline.ml_builder_agent import MLBuilderAgent
from app.agents.ml_pipeline.model_evaluation_agent import ModelEvaluationAgent

# Reinitialize LLM service to pick up API keys set after import
# (This ensures the LLM service checks environment variables directly)
from app.services.llm_service import get_llm_service, LLMProvider
llm_service = get_llm_service(force_reinit=True)
print(f"✅ LLM Service reinitialized - Gemini client available: {LLMProvider.GEMINI in llm_service.clients}")

print("✅ Setup complete!")

📦 Importing backend modules...
✅ LLM Service reinitialized - Gemini client available: True
✅ Setup complete!


# Load Dataset and configuration

In [6]:
# ============================================================================
# CONFIGURATION
# ============================================================================

display(Markdown("# 🎯 Multi-Agent ML Pipeline Visualization"))
display(Markdown("*Visualizing outputs from each agent in the classification workflow*\n"))

# Load your dataset
DATASET_PATH = "/Users/mgmanjusha/Classify.ai/test_data/spotify_churn_dataset.csv"  
TARGET_COLUMN = "is_churned"  # Change this
WORKFLOW_ID = "spotify_churn_demo"  # Change this

# Load dataset
df = pd.read_csv(DATASET_PATH)

print(f"✅ Dataset loaded: {df.shape[0]:,} rows × {df.shape[1]} columns")
print(f"🎯 Target column: {TARGET_COLUMN}")
print(f"📊 Columns: {list(df.columns)}")

# 🎯 Multi-Agent ML Pipeline Visualization

*Visualizing outputs from each agent in the classification workflow*


✅ Dataset loaded: 8,000 rows × 12 columns
🎯 Target column: is_churned
📊 Columns: ['user_id', 'gender', 'age', 'country', 'subscription_type', 'listening_time', 'songs_played_per_day', 'skip_rate', 'device_type', 'ads_listened_per_week', 'offline_listening', 'is_churned']


# Initialize State

In [7]:
# ============================================================================
# INITIALIZE WORKFLOW STATE
# ============================================================================

display(Markdown("## 🔄 Initializing Workflow State"))

state = state_manager.initialize_state(
    session_id=WORKFLOW_ID,
    dataset_id=f"dataset_{WORKFLOW_ID}",
    target_column=TARGET_COLUMN,
    user_description=f"Classification task for {TARGET_COLUMN}",
    api_key="demo",
    original_dataset=df
)

print("✅ State initialized")
print(f"📦 State contains {len(state)} fields for tracking agent outputs")

## 🔄 Initializing Workflow State

✅ State initialized
📦 State contains 82 fields for tracking agent outputs


# Data Cleaning Agent

In [14]:
# ============================================================================
# DATA CLEANING AGENT
# ============================================================================

display(Markdown("# 🧹 Data Cleaning Agent"))

cleaning_agent = EnhancedDataCleaningAgent()
state = await cleaning_agent.execute(state)

print(f"✅ Status: {state.get('agent_statuses', {}).get('data_cleaning')}")

# Quality Score
quality_score = state.get('data_quality_score')
if quality_score:
    print(f"\n📊 Data Quality Score: {quality_score:.2%}")
    
    # Visualize quality score
    fig, ax = plt.subplots(1, 1, figsize=(8, 3))
    color = 'green' if quality_score > 0.8 else 'orange' if quality_score > 0.6 else 'red'
    ax.barh(['Quality Score'], [quality_score], color=color, alpha=0.7)
    ax.set_xlim([0, 1])
    ax.set_xlabel('Score')
    ax.set_title('Data Quality Score')
    ax.axvline(x=0.8, color='green', linestyle='--', alpha=0.3, label='Good (>0.8)')
    ax.legend()
    plt.tight_layout()
    plt.show()

# Issues Found
issues = state.get('cleaning_issues_found', [])
print(f"\n🔍 Issues Found ({len(issues)}):")
if issues:
    for i, issue in enumerate(issues, 1):
        print(f"  {i}. {issue}")
else:
    print("  ✅ No issues detected!")

# Actions Taken
actions = state.get('cleaning_actions_taken', [])
print(f"\n✅ Actions Taken ({len(actions)}):")
for i, action in enumerate(actions, 1):
    print(f"  {i}. {action}")

# Cleaning Summary
cleaning_summary = state.get('cleaning_summary')
if cleaning_summary:
    print(f"\n📋 Cleaning Summary:")
    print(cleaning_summary[:500])  # First 500 chars

# 🧹 Data Cleaning Agent

2025-10-29 16:50:07,624 - enhanced_data_cleaning - INFO - CodeValidator initialized successfully
2025-10-29 16:50:07,624 - enhanced_data_cleaning - INFO - CodeValidator initialized successfully
2025-10-29 16:50:07,624 - enhanced_data_cleaning - INFO - CodeValidator initialized successfully
2025-10-29 16:50:07,624 - enhanced_data_cleaning - INFO - CodeValidator initialized successfully
2025-10-29 16:50:07,624 - enhanced_data_cleaning - INFO - CodeValidator initialized successfully
2025-10-29 16:50:07,624 - enhanced_data_cleaning - INFO - CodeValidator initialized successfully
2025-10-29 16:50:07,637 - enhanced_data_cleaning - INFO - SandboxExecutor initialized (timeout=60s, memory=2g)
2025-10-29 16:50:07,637 - enhanced_data_cleaning - INFO - SandboxExecutor initialized (timeout=60s, memory=2g)
2025-10-29 16:50:07,637 - enhanced_data_cleaning - INFO - SandboxExecutor initialized (timeout=60s, memory=2g)
2025-10-29 16:50:07,637 - enhanced_data_cleaning - INFO - SandboxExecutor initialized

✅ Status: pending

🔍 Issues Found (0):
  ✅ No issues detected!

✅ Actions Taken (0):


# EDA Agent

In [12]:
# ============================================================================
# EDA AGENT
# ============================================================================

display(Markdown("# 📊 Exploratory Data Analysis Agent"))

# Store dataset for EDA
state_manager.store_dataset(state, df, dataset_type="cleaned")

eda_agent = EDAAgent()
state = await eda_agent.execute(state)

print(f"✅ Status: {state.get('agent_statuses', {}).get('eda_analysis')}")

# Statistical Summary
display(Markdown("\n## 📈 Dataset Statistics"))

stats = state.get('statistical_summary', {})
if stats:
    print(f"Dataset Overview:")
    dataset_shape = stats.get('dataset_shape', {})
    print(f"  - Shape: {dataset_shape.get('rows', 0):,} rows × {dataset_shape.get('columns', 0)} columns")
    
    data_types = stats.get('data_types', {})
    print(f"  - Numeric Features: {data_types.get('numeric', 0)}")
    print(f"  - Categorical Features: {data_types.get('categorical', 0)}")
    
    missing = stats.get('missing_values', {})
    print(f"  - Missing Values: {missing.get('total', 0)} ({missing.get('percentage', 0):.2f}%)")
    
    duplicates = stats.get('duplicates', {})
    print(f"  - Duplicates: {duplicates.get('exact_duplicates', 0)} ({duplicates.get('duplicate_percentage', 0):.2f}%)")

# Target Analysis
display(Markdown("\n## 🎯 Target Variable Analysis"))

target_analysis = state.get('target_analysis', {})
if target_analysis:
    distribution = target_analysis.get('distribution', {})
    print(f"Target Variable: {TARGET_COLUMN}")
    print(f"  - Unique Values: {distribution.get('unique_values', 'N/A')}")
    print(f"  - Missing: {distribution.get('missing_count', 0)} ({distribution.get('missing_percentage', 0):.2f}%)")
    
    # Class balance
    if 'class_balance' in target_analysis:
        balance = target_analysis['class_balance']
        balanced = "✅ Balanced" if balance.get('is_balanced') else "⚠️ Imbalanced"
        print(f"  - Class Balance: {balanced}")
        print(f"  - Balance Ratio: {balance.get('balance_ratio', 0):.2f}")
        print(f"  - Majority Class: {balance.get('majority_class')} ({balance.get('majority_count', 0):,} samples)")
        print(f"  - Minority Class: {balance.get('minority_class')} ({balance.get('minority_count', 0):,} samples)")
    
    # Value counts visualization
    value_counts = distribution.get('value_counts', {})
    if value_counts:
        fig, ax = plt.subplots(figsize=(8, 5))
        classes = list(value_counts.keys())
        counts = list(value_counts.values())
        ax.bar([str(c) for c in classes], counts, color=['skyblue', 'coral'])
        ax.set_xlabel('Class')
        ax.set_ylabel('Count')
        ax.set_title(f'Target Distribution: {TARGET_COLUMN}')
        for i, (c, count) in enumerate(zip(classes, counts)):
            ax.text(i, count, f'{count:,}', ha='center', va='bottom', fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    # Target insights
    insights = target_analysis.get('insights', [])
    if insights:
        print(f"\n💡 Key Insights:")
        for insight in insights:
            print(f"  • {insight}")

# Distribution Analysis
display(Markdown("\n## 📊 Feature Distributions"))

dist_analysis = state.get('distribution_analysis', {})
if dist_analysis:
    dist_stats = dist_analysis.get('distribution_statistics', {})
    if dist_stats:
        print(f"Distribution Statistics for {len(dist_stats)} numeric features:")
        
        # Create summary table
        dist_df = pd.DataFrame.from_dict(dist_stats, orient='index')
        dist_df = dist_df.round(2)
        display(dist_df.head(10))  # Show first 10
    
    insights = dist_analysis.get('insights', [])
    if insights:
        print(f"\n💡 Distribution Insights:")
        for insight in insights[:5]:  # Top 5
            print(f"  • {insight}")

# Correlation Analysis
display(Markdown("\n## 🔗 Correlation Analysis"))

corr_analysis = state.get('correlation_analysis', {})
if corr_analysis:
    target_corr = corr_analysis.get('target_correlations', {})
    
    if target_corr:
        strong_pos = target_corr.get('strong_positive', {})
        strong_neg = target_corr.get('strong_negative', {})
        moderate_pos = target_corr.get('moderate_positive', {})
        moderate_neg = target_corr.get('moderate_negative', {})
        weak = target_corr.get('weak', {})
        
        print(f"Correlations with {TARGET_COLUMN}:")
        
        if strong_pos:
            print(f"\n  🔴 Strong Positive (>0.7): {len(strong_pos)} features")
            for feat, corr in list(strong_pos.items())[:3]:
                print(f"    - {feat}: {corr:.3f}")
        
        if strong_neg:
            print(f"\n  🔵 Strong Negative (<-0.7): {len(strong_neg)} features")
            for feat, corr in list(strong_neg.items())[:3]:
                print(f"    - {feat}: {corr:.3f}")
        
        if not strong_pos and not strong_neg:
            print(f"\n  ⚠️ No stro

SyntaxError: EOL while scanning string literal (4121664803.py, line 122)